├── README.md ├── VOCdevkit └── VOC2007 │ ├── ImageSets │ └── Segmentation │ │ ├── train.txt │ │ └── val.txt │ ├── JPEGImages │ ├── 0.png │ └── 20000.png │ └── SegmentationClass │ ├── 0.png │ └── 20000.png ├── __pycache__ └── segformer.cpython-38.pyc ├── datasets ├── JPEGImages │ └── 1.jpg ├── SegmentationClass │ └── 1.png └── before │ ├── 1.jpg │ └── 1.json ├── get_miou.py ├── json_to_dataset.py ├── model_data └── README.md ├── nets ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── backbone.cpython-37.pyc │ ├── backbone.cpython-38.pyc │ ├── conv_.cpython-37.pyc │ ├── conv_.cpython-38.pyc │ ├── segformer.cpython-38.pyc │ └── segformer_training.cpython-38.pyc ├── backbone.py ├── conv_.py ├── mf_head.py ├── segformer.py └── segformer_training.py ├── predict.py ├── requirements.txt ├── segformer.py ├── summary.py ├── train.py ├── utils ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── callbacks.cpython-38.pyc │ ├── dataloader.cpython-38.pyc │ ├── utils.cpython-38.pyc │ ├── utils_fit.cpython-38.pyc │ └── utils_metrics.cpython-38.pyc ├── callbacks.py ├── dataloader.py ├── utils.py ├── utils_fit.py └── utils_metrics.py └── voc_annotation.py /README.md: -------------------------------------------------------------------------------- 1 | # Article link:https://www.mdpi.com/2072-4292/15/19/4697/htm 2 | Citation: Zhang, T.; Qin, C.; Li, W.; Mao, X.; Zhao, L.; Hou, B.; Jiao, L. Water Body Extraction of the Weihe River Basin Based on MF-SegFormer Applied to Landsat8 OLI Data. Remote Sens. 2023, 15, 4697. 3 | 4 | -------------------------------------------------------------------------------- /VOCdevkit/VOC2007/ImageSets/Segmentation/train.txt: -------------------------------------------------------------------------------- 1 | 0 2 | -------------------------------------------------------------------------------- /VOCdevkit/VOC2007/ImageSets/Segmentation/val.txt: -------------------------------------------------------------------------------- 1 | 20000 2 | -------------------------------------------------------------------------------- /VOCdevkit/VOC2007/JPEGImages/0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tiany-zhang/MF-SegFormer/009901d87c3911b51d3c3b6b9b7f976c6f8f3d72/VOCdevkit/VOC2007/JPEGImages/0.png -------------------------------------------------------------------------------- /VOCdevkit/VOC2007/JPEGImages/20000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tiany-zhang/MF-SegFormer/009901d87c3911b51d3c3b6b9b7f976c6f8f3d72/VOCdevkit/VOC2007/JPEGImages/20000.png -------------------------------------------------------------------------------- /VOCdevkit/VOC2007/SegmentationClass/0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tiany-zhang/MF-SegFormer/009901d87c3911b51d3c3b6b9b7f976c6f8f3d72/VOCdevkit/VOC2007/SegmentationClass/0.png -------------------------------------------------------------------------------- /VOCdevkit/VOC2007/SegmentationClass/20000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tiany-zhang/MF-SegFormer/009901d87c3911b51d3c3b6b9b7f976c6f8f3d72/VOCdevkit/VOC2007/SegmentationClass/20000.png -------------------------------------------------------------------------------- /__pycache__/segformer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tiany-zhang/MF-SegFormer/009901d87c3911b51d3c3b6b9b7f976c6f8f3d72/__pycache__/segformer.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/JPEGImages/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tiany-zhang/MF-SegFormer/009901d87c3911b51d3c3b6b9b7f976c6f8f3d72/datasets/JPEGImages/1.jpg -------------------------------------------------------------------------------- /datasets/SegmentationClass/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tiany-zhang/MF-SegFormer/009901d87c3911b51d3c3b6b9b7f976c6f8f3d72/datasets/SegmentationClass/1.png -------------------------------------------------------------------------------- /datasets/before/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tiany-zhang/MF-SegFormer/009901d87c3911b51d3c3b6b9b7f976c6f8f3d72/datasets/before/1.jpg -------------------------------------------------------------------------------- /datasets/before/1.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "3.16.7", 3 | "flags": {}, 4 | "shapes": [ 5 | { 6 | "label": "cat", 7 | "line_color": null, 8 | "fill_color": null, 9 | "points": [ 10 | [ 11 | 202.77358490566036, 12 | 626.0943396226414 13 | ], 14 | [ 15 | 178.24528301886792, 16 | 552.5094339622641 17 | ], 18 | [ 19 | 195.22641509433961, 20 | 444.9622641509434 21 | ], 22 | [ 23 | 177.30188679245282, 24 | 340.2452830188679 25 | ], 26 | [ 27 | 173.52830188679243, 28 | 201.56603773584905 29 | ], 30 | [ 31 | 211.2641509433962, 32 | 158.16981132075472 33 | ], 34 | [ 35 | 226.35849056603772, 36 | 87.41509433962264 37 | ], 38 | [ 39 | 208.43396226415092, 40 | 6.283018867924525 41 | ], 42 | [ 43 | 277.3018867924528, 44 | 57.226415094339615 45 | ], 46 | [ 47 | 416.92452830188677, 48 | 80.81132075471697 49 | ], 50 | [ 51 | 497.1132075471698, 52 | 64.77358490566037 53 | ], 54 | [ 55 | 578.2452830188679, 56 | 6.283018867924525 57 | ], 58 | [ 59 | 599.0, 60 | 35.52830188679245 61 | ], 62 | [ 63 | 589.566037735849, 64 | 96.84905660377359 65 | ], 66 | [ 67 | 592.3962264150944, 68 | 133.64150943396226 69 | ], 70 | [ 71 | 679.188679245283, 72 | 174.2075471698113 73 | ], 74 | [ 75 | 723.5283018867924, 76 | 165.71698113207546 77 | ], 78 | [ 79 | 726.3584905660377, 80 | 222.32075471698113 81 | ], 82 | [ 83 | 759.377358490566, 84 | 262.88679245283015 85 | ], 86 | [ 87 | 782.9622641509434, 88 | 350.62264150943395 89 | ], 90 | [ 91 | 766.9245283018868, 92 | 428.92452830188677 93 | ], 94 | [ 95 | 712.2075471698113, 96 | 465.71698113207543 97 | ], 98 | [ 99 | 695.2264150943396, 100 | 538.3584905660377 101 | ], 102 | [ 103 | 657.4905660377358, 104 | 601.566037735849 105 | ], 106 | [ 107 | 606, 108 | 633 109 | ], 110 | [ 111 | 213, 112 | 633 113 | ] 114 | ], 115 | "shape_type": "polygon", 116 | "flags": {} 117 | } 118 | ], 119 | "lineColor": [ 120 | 0, 121 | 255, 122 | 0, 123 | 128 124 | ], 125 | "fillColor": [ 126 | 255, 127 | 0, 128 | 0, 129 | 128 130 | ], 131 | "imagePath": "1.jpg", 132 | "imageData": "", 133 | "imageHeight": 634, 134 | "imageWidth": 950 135 | } -------------------------------------------------------------------------------- /get_miou.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from PIL import Image 4 | from tqdm import tqdm 5 | 6 | from segformer import SegFormer_Segmentation 7 | from utils.utils_metrics import compute_mIoU, show_results 8 | 9 | ''' 10 | 进行指标评估需要注意以下几点: 11 | 1、该文件生成的图为灰度图,因为值比较小,按照PNG形式的图看是没有显示效果的,所以看到近似全黑的图是正常的。 12 | 2、该文件计算的是验证集的miou,当前该库将测试集当作验证集使用,不单独划分测试集 13 | ''' 14 | if __name__ == "__main__": 15 | #---------------------------------------------------------------------------# 16 | # miou_mode用于指定该文件运行时计算的内容 17 | # miou_mode为0代表整个miou计算流程,包括获得预测结果、计算miou。 18 | # miou_mode为1代表仅仅获得预测结果。 19 | # miou_mode为2代表仅仅计算miou。 20 | #---------------------------------------------------------------------------# 21 | miou_mode = 0 22 | #------------------------------# 23 | # 分类个数+1、如2+1 24 | #------------------------------# 25 | num_classes = 2+1 26 | #--------------------------------------------# 27 | # 区分的种类,和json_to_dataset里面的一样 28 | #--------------------------------------------# 29 | name_classes = ["background","aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"] 30 | # name_classes = ["_background_","cat","dog"] 31 | #-------------------------------------------------------# 32 | # 指向VOC数据集所在的文件夹 33 | # 默认指向根目录下的VOC数据集 34 | #-------------------------------------------------------# 35 | VOCdevkit_path = 'VOCdevkit' 36 | 37 | image_ids = open(os.path.join(VOCdevkit_path, "VOC2007/ImageSets/Segmentation/val.txt"),'r').read().splitlines() 38 | gt_dir = os.path.join(VOCdevkit_path, "VOC2007/SegmentationClass/") 39 | miou_out_path = "miou_out" 40 | pred_dir = os.path.join(miou_out_path, 'detection-results') 41 | 42 | if miou_mode == 0 or miou_mode == 1: 43 | if not os.path.exists(pred_dir): 44 | os.makedirs(pred_dir) 45 | 46 | print("Load model.") 47 | segformer = SegFormer_Segmentation() 48 | print("Load model done.") 49 | 50 | print("Get predict result.") 51 | for image_id in tqdm(image_ids): 52 | image_path = os.path.join(VOCdevkit_path, "VOC2007/JPEGImages/"+image_id+".png") 53 | image = Image.open(image_path) 54 | image = segformer.get_miou_png(image) 55 | image.save(os.path.join(pred_dir, image_id + ".png")) 56 | print("Get predict result done.") 57 | 58 | if miou_mode == 0 or miou_mode == 2: 59 | print("Get miou.") 60 | hist, IoUs, PA_Recall, Precision = compute_mIoU(gt_dir, pred_dir, image_ids, num_classes, name_classes) # 执行计算mIoU的函数 61 | print("Get miou done.") 62 | show_results(miou_out_path, hist, IoUs, PA_Recall, Precision, name_classes) 63 | -------------------------------------------------------------------------------- /json_to_dataset.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import json 3 | import os 4 | import os.path as osp 5 | 6 | import numpy as np 7 | import PIL.Image 8 | from labelme import utils 9 | 10 | ''' 11 | 制作自己的语义分割数据集需要注意以下几点: 12 | 1、我使用的labelme版本是3.16.7,建议使用该版本的labelme,有些版本的labelme会发生错误, 13 | 具体错误为:Too many dimensions: 3 > 2 14 | 安装方式为命令行pip install labelme==3.16.7 15 | 2、此处生成的标签图是8位彩色图,与视频中看起来的数据集格式不太一样。 16 | 虽然看起来是彩图,但事实上只有8位,此时每个像素点的值就是这个像素点所属的种类。 17 | 所以其实和视频中VOC数据集的格式一样。因此这样制作出来的数据集是可以正常使用的。也是正常的。 18 | ''' 19 | if __name__ == '__main__': 20 | jpgs_path = "datasets/JPEGImages" 21 | pngs_path = "datasets/SegmentationClass" 22 | classes = ["_background_","aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"] 23 | # classes = ["_background_","cat","dog"] 24 | 25 | count = os.listdir("./datasets/before/") 26 | for i in range(0, len(count)): 27 | path = os.path.join("./datasets/before", count[i]) 28 | 29 | if os.path.isfile(path) and path.endswith('json'): 30 | data = json.load(open(path)) 31 | 32 | if data['imageData']: 33 | imageData = data['imageData'] 34 | else: 35 | imagePath = os.path.join(os.path.dirname(path), data['imagePath']) 36 | with open(imagePath, 'rb') as f: 37 | imageData = f.read() 38 | imageData = base64.b64encode(imageData).decode('utf-8') 39 | 40 | img = utils.img_b64_to_arr(imageData) 41 | label_name_to_value = {'_background_': 0} 42 | for shape in data['shapes']: 43 | label_name = shape['label'] 44 | if label_name in label_name_to_value: 45 | label_value = label_name_to_value[label_name] 46 | else: 47 | label_value = len(label_name_to_value) 48 | label_name_to_value[label_name] = label_value 49 | 50 | # label_values must be dense 51 | label_values, label_names = [], [] 52 | for ln, lv in sorted(label_name_to_value.items(), key=lambda x: x[1]): 53 | label_values.append(lv) 54 | label_names.append(ln) 55 | assert label_values == list(range(len(label_values))) 56 | 57 | lbl = utils.shapes_to_label(img.shape, data['shapes'], label_name_to_value) 58 | 59 | 60 | PIL.Image.fromarray(img).save(osp.join(jpgs_path, count[i].split(".")[0]+'.jpg')) 61 | 62 | new = np.zeros([np.shape(img)[0],np.shape(img)[1]]) 63 | for name in label_names: 64 | index_json = label_names.index(name) 65 | index_all = classes.index(name) 66 | new = new + index_all*(np.array(lbl) == index_json) 67 | 68 | utils.lblsave(osp.join(pngs_path, count[i].split(".")[0]+'.png'), new) 69 | print('Saved ' + count[i].split(".")[0] + '.jpg and ' + count[i].split(".")[0] + '.png') 70 | -------------------------------------------------------------------------------- /model_data/README.md: -------------------------------------------------------------------------------- 1 | 这里面存放的是已经训练好的权重,可通过百度网盘下载。 2 | -------------------------------------------------------------------------------- /nets/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------------------------------- /nets/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tiany-zhang/MF-SegFormer/009901d87c3911b51d3c3b6b9b7f976c6f8f3d72/nets/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /nets/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tiany-zhang/MF-SegFormer/009901d87c3911b51d3c3b6b9b7f976c6f8f3d72/nets/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /nets/__pycache__/backbone.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tiany-zhang/MF-SegFormer/009901d87c3911b51d3c3b6b9b7f976c6f8f3d72/nets/__pycache__/backbone.cpython-37.pyc -------------------------------------------------------------------------------- /nets/__pycache__/backbone.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tiany-zhang/MF-SegFormer/009901d87c3911b51d3c3b6b9b7f976c6f8f3d72/nets/__pycache__/backbone.cpython-38.pyc -------------------------------------------------------------------------------- /nets/__pycache__/conv_.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tiany-zhang/MF-SegFormer/009901d87c3911b51d3c3b6b9b7f976c6f8f3d72/nets/__pycache__/conv_.cpython-37.pyc -------------------------------------------------------------------------------- /nets/__pycache__/conv_.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tiany-zhang/MF-SegFormer/009901d87c3911b51d3c3b6b9b7f976c6f8f3d72/nets/__pycache__/conv_.cpython-38.pyc -------------------------------------------------------------------------------- /nets/__pycache__/segformer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tiany-zhang/MF-SegFormer/009901d87c3911b51d3c3b6b9b7f976c6f8f3d72/nets/__pycache__/segformer.cpython-38.pyc -------------------------------------------------------------------------------- /nets/__pycache__/segformer_training.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tiany-zhang/MF-SegFormer/009901d87c3911b51d3c3b6b9b7f976c6f8f3d72/nets/__pycache__/segformer_training.cpython-38.pyc -------------------------------------------------------------------------------- /nets/backbone.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2021, NVIDIA Corporation. All rights reserved. 3 | # 4 | # This work is licensed under the NVIDIA Source Code License 5 | # --------------------------------------------------------------- 6 | import math 7 | import warnings 8 | import numpy as np 9 | from functools import partial 10 | 11 | import torch 12 | import torch.nn as nn 13 | 14 | 15 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 16 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 17 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 18 | def norm_cdf(x): 19 | # Computes standard normal cumulative distribution function 20 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 21 | 22 | if (mean < a - 2 * std) or (mean > b + 2 * std): 23 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 24 | "The distribution of values may be incorrect.", 25 | stacklevel=2) 26 | 27 | with torch.no_grad(): 28 | # Values are generated by using a truncated uniform distribution and 29 | # then using the inverse CDF for the normal distribution. 30 | # Get upper and lower cdf values 31 | l = norm_cdf((a - mean) / std) 32 | u = norm_cdf((b - mean) / std) 33 | 34 | # Uniformly fill tensor with values from [l, u], then translate to 35 | # [2l-1, 2u-1]. 36 | tensor.uniform_(2 * l - 1, 2 * u - 1) 37 | 38 | # Use inverse cdf transform for normal distribution to get truncated 39 | # standard normal 40 | tensor.erfinv_() 41 | 42 | # Transform to proper mean, std 43 | tensor.mul_(std * math.sqrt(2.)) 44 | tensor.add_(mean) 45 | 46 | # Clamp to ensure it's in the proper range 47 | tensor.clamp_(min=a, max=b) 48 | return tensor 49 | 50 | 51 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 52 | r""" 53 | Fills the input Tensor with values drawn from a truncated 54 | normal distribution. The values are effectively drawn from the 55 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` 56 | with values outside :math:`[a, b]` redrawn until they are within 57 | the bounds. The method used for generating the random values works 58 | best when :math:`a \leq \text{mean} \leq b`. 59 | Args: 60 | tensor: an n-dimensional `torch.Tensor` 61 | mean: the mean of the normal distribution 62 | std: the standard deviation of the normal distribution 63 | a: the minimum cutoff value 64 | b: the maximum cutoff value 65 | Examples: 66 | >>> w = torch.empty(3, 5) 67 | >>> nn.init.trunc_normal_(w) 68 | """ 69 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 70 | 71 | #--------------------------------------# 72 | # Gelu激活函数的实现 73 | # 利用近似的数学公式 74 | #--------------------------------------# 75 | class GELU(nn.Module): 76 | def __init__(self): 77 | super(GELU, self).__init__() 78 | 79 | def forward(self, x): 80 | return 0.5 * x * (1 + torch.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * torch.pow(x,3)))) 81 | 82 | class OverlapPatchEmbed(nn.Module): 83 | def __init__(self, patch_size=7, stride=4, in_chans=3, embed_dim=768): 84 | super().__init__() 85 | patch_size = (patch_size, patch_size) 86 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, 87 | padding=(patch_size[0] // 2, patch_size[1] // 2)) 88 | self.norm = nn.LayerNorm(embed_dim) 89 | 90 | self.apply(self._init_weights) 91 | 92 | def _init_weights(self, m): 93 | if isinstance(m, nn.Linear): 94 | trunc_normal_(m.weight, std=.02) 95 | if isinstance(m, nn.Linear) and m.bias is not None: 96 | nn.init.constant_(m.bias, 0) 97 | elif isinstance(m, nn.LayerNorm): 98 | nn.init.constant_(m.bias, 0) 99 | nn.init.constant_(m.weight, 1.0) 100 | elif isinstance(m, nn.Conv2d): 101 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 102 | fan_out //= m.groups 103 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 104 | if m.bias is not None: 105 | m.bias.data.zero_() 106 | 107 | def forward(self, x): 108 | x = self.proj(x) 109 | _, _, H, W = x.shape 110 | x = x.flatten(2).transpose(1, 2) 111 | x = self.norm(x) 112 | 113 | return x, H, W 114 | 115 | #--------------------------------------------------------------------------------------------------------------------# 116 | # Attention机制 117 | # 将输入的特征qkv特征进行划分,首先生成query, key, value。query是查询向量、key是键向量、v是值向量。 118 | # 然后利用 查询向量query 叉乘 转置后的键向量key,这一步可以通俗的理解为,利用查询向量去查询序列的特征,获得序列每个部分的重要程度score。 119 | # 然后利用 score 叉乘 value,这一步可以通俗的理解为,将序列每个部分的重要程度重新施加到序列的值上去。 120 | # 121 | # 在segformer中,为了减少计算量,首先对特征图进行了浓缩,所有特征层都压缩到原图的1/32。 122 | # 当输入图片为512, 512时,Block1的特征图为128, 128,此时就先将特征层压缩为16, 16。 123 | # 在Block1的Attention模块中,相当于将8x8个特征点进行特征浓缩,浓缩为一个特征点。 124 | # 然后利用128x128个查询向量对16x16个键向量与值向量进行查询。尽管键向量与值向量的数量较少,但因为查询向量的不同,依然可以获得不同的输出。 125 | #--------------------------------------------------------------------------------------------------------------------# 126 | class Attention(nn.Module): 127 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): 128 | super().__init__() 129 | assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." 130 | 131 | self.dim = dim 132 | self.num_heads = num_heads 133 | head_dim = dim // num_heads 134 | self.scale = qk_scale or head_dim ** -0.5 135 | 136 | self.q = nn.Linear(dim, dim, bias=qkv_bias) 137 | 138 | self.sr_ratio = sr_ratio 139 | if sr_ratio > 1: 140 | self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) 141 | self.norm = nn.LayerNorm(dim) 142 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) 143 | 144 | self.attn_drop = nn.Dropout(attn_drop) 145 | 146 | self.proj = nn.Linear(dim, dim) 147 | self.proj_drop = nn.Dropout(proj_drop) 148 | 149 | self.apply(self._init_weights) 150 | 151 | def _init_weights(self, m): 152 | if isinstance(m, nn.Linear): 153 | trunc_normal_(m.weight, std=.02) 154 | if isinstance(m, nn.Linear) and m.bias is not None: 155 | nn.init.constant_(m.bias, 0) 156 | elif isinstance(m, nn.LayerNorm): 157 | nn.init.constant_(m.bias, 0) 158 | nn.init.constant_(m.weight, 1.0) 159 | elif isinstance(m, nn.Conv2d): 160 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 161 | fan_out //= m.groups 162 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 163 | if m.bias is not None: 164 | m.bias.data.zero_() 165 | 166 | def forward(self, x, H, W): 167 | B, N, C = x.shape 168 | # bs, 16384, 32 => bs, 16384, 32 => bs, 16384, 8, 4 => bs, 8, 16384, 4 169 | q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 170 | 171 | if self.sr_ratio > 1: 172 | # bs, 16384, 32 => bs, 32, 128, 128 173 | x_ = x.permute(0, 2, 1).reshape(B, C, H, W) 174 | # bs, 32, 128, 128 => bs, 32, 16, 16 => bs, 256, 32 175 | x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) 176 | x_ = self.norm(x_) 177 | # bs, 256, 32 => bs, 256, 64 => bs, 256, 2, 8, 4 => 2, bs, 8, 256, 4 178 | kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 179 | else: 180 | kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 181 | k, v = kv[0], kv[1] 182 | 183 | # bs, 8, 16384, 4 @ bs, 8, 4, 256 => bs, 8, 16384, 256 184 | attn = (q @ k.transpose(-2, -1)) * self.scale 185 | attn = attn.softmax(dim=-1) 186 | attn = self.attn_drop(attn) 187 | 188 | # bs, 8, 16384, 256 @ bs, 8, 256, 4 => bs, 8, 16384, 4 => bs, 16384, 32 189 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 190 | # bs, 16384, 32 => bs, 16384, 32 191 | x = self.proj(x) 192 | x = self.proj_drop(x) 193 | 194 | return x 195 | 196 | def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True): 197 | """ 198 | Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 199 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, 200 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... 201 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for 202 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 203 | 'survival rate' as the argument. 204 | """ 205 | if drop_prob == 0. or not training: 206 | return x 207 | keep_prob = 1 - drop_prob 208 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 209 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob) 210 | if keep_prob > 0.0 and scale_by_keep: 211 | random_tensor.div_(keep_prob) 212 | return x * random_tensor 213 | 214 | class DropPath(nn.Module): 215 | def __init__(self, drop_prob=None, scale_by_keep=True): 216 | super(DropPath, self).__init__() 217 | self.drop_prob = drop_prob 218 | self.scale_by_keep = scale_by_keep 219 | 220 | def forward(self, x): 221 | return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) 222 | 223 | class DWConv(nn.Module): 224 | def __init__(self, dim=768): 225 | super(DWConv, self).__init__() 226 | self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) 227 | 228 | def forward(self, x, H, W): 229 | B, N, C = x.shape 230 | x = x.transpose(1, 2).view(B, C, H, W) 231 | x = self.dwconv(x) 232 | x = x.flatten(2).transpose(1, 2) 233 | 234 | return x 235 | 236 | 237 | # class MFF(nn.Module): 238 | # ''' 239 | # 多尺度特征融合 MFF 240 | # ''' 241 | # 242 | # def __init__(self, channels=640, r=4): 243 | # super(MFF, self).__init__() 244 | # inter_channels = int(channels // r) 245 | # kernel_size1 = 1 246 | # self.conv1 = nn.Conv2d(2, 1, kernel_size1, stride=1, padding=0, bias=False) 247 | # self.local_att = nn.Sequential( 248 | # nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 249 | # nn.BatchNorm2d(inter_channels), 250 | # nn.ReLU(inplace=True), 251 | # nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 252 | # nn.BatchNorm2d(channels), 253 | # ) 254 | # 255 | # self.global_att = nn.Sequential( 256 | # nn.AdaptiveAvgPool2d(1), 257 | # nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 258 | # nn.BatchNorm2d(inter_channels), 259 | # nn.ReLU(inplace=True), 260 | # nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 261 | # nn.BatchNorm2d(channels), 262 | # ) 263 | # 264 | # self.sigmoid = nn.Sigmoid() 265 | # 266 | # def forward(self, x, residual): 267 | # xa = x + residual 268 | # xl = self.local_att(xa) 269 | # xg = self.global_att(xa) 270 | # xlg = xl + xg 271 | # avg_out = torch.mean(xlg, dim=1, keepdim=True) 272 | # max_out, _ = torch.max(xlg, dim=1, keepdim=True) 273 | # xlg = torch.cat([avg_out, max_out], dim=1) 274 | # xlg = self.conv1(xlg) 275 | # wei = self.sigmoid(xlg)*xlg 276 | # xo = x * wei + residual * (1 - wei) 277 | # return xo 278 | 279 | class Mlp(nn.Module): 280 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=GELU, drop=0.): 281 | super().__init__() 282 | out_features = out_features or in_features 283 | hidden_features = hidden_features or in_features 284 | 285 | self.fc1 = nn.Linear(in_features, hidden_features) 286 | self.dwconv = DWConv(hidden_features) 287 | self.act = act_layer() 288 | 289 | self.fc2 = nn.Linear(hidden_features, out_features) 290 | 291 | self.drop = nn.Dropout(drop) 292 | 293 | self.apply(self._init_weights) 294 | 295 | def _init_weights(self, m): 296 | if isinstance(m, nn.Linear): 297 | trunc_normal_(m.weight, std=.02) 298 | if isinstance(m, nn.Linear) and m.bias is not None: 299 | nn.init.constant_(m.bias, 0) 300 | elif isinstance(m, nn.LayerNorm): 301 | nn.init.constant_(m.bias, 0) 302 | nn.init.constant_(m.weight, 1.0) 303 | elif isinstance(m, nn.Conv2d): 304 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 305 | fan_out //= m.groups 306 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 307 | if m.bias is not None: 308 | m.bias.data.zero_() 309 | 310 | def forward(self, x, H, W): 311 | x = self.fc1(x) 312 | x = self.dwconv(x, H, W) 313 | x = self.act(x) 314 | x = self.drop(x) 315 | x = self.fc2(x) 316 | x = self.drop(x) 317 | return x 318 | 319 | class Block(nn.Module): 320 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 321 | drop_path=0., act_layer=GELU, norm_layer=nn.LayerNorm, sr_ratio=1): 322 | super().__init__() 323 | self.norm1 = norm_layer(dim) 324 | 325 | self.attn = Attention( 326 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, 327 | attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio 328 | ) 329 | self.norm2 = norm_layer(dim) 330 | self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) 331 | 332 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 333 | 334 | self.apply(self._init_weights) 335 | 336 | def _init_weights(self, m): 337 | if isinstance(m, nn.Linear): 338 | trunc_normal_(m.weight, std=.02) 339 | if isinstance(m, nn.Linear) and m.bias is not None: 340 | nn.init.constant_(m.bias, 0) 341 | elif isinstance(m, nn.LayerNorm): 342 | nn.init.constant_(m.bias, 0) 343 | nn.init.constant_(m.weight, 1.0) 344 | elif isinstance(m, nn.Conv2d): 345 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 346 | fan_out //= m.groups 347 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 348 | if m.bias is not None: 349 | m.bias.data.zero_() 350 | 351 | def forward(self, x, H, W): 352 | x = x + self.drop_path(self.attn(self.norm1(x), H, W)) 353 | x = x + self.drop_path(self.mlp(self.norm2(x), H, W)) 354 | return x 355 | 356 | class MixVisionTransformer(nn.Module): 357 | def __init__(self, in_chans=3, num_classes=1000, embed_dims=[32, 64, 160, 256], 358 | num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0., 359 | attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, 360 | depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1]): 361 | super().__init__() 362 | self.num_classes = num_classes 363 | self.depths = depths 364 | 365 | #----------------------------------# 366 | # Transformer模块,共有四个部分 367 | #----------------------------------# 368 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 369 | 370 | #----------------------------------# 371 | # block1 372 | #----------------------------------# 373 | #-----------------------------------------------# 374 | # 对输入图像进行分区,并下采样 375 | # 512, 512, 3 => 128, 128, 32 => 16384, 32 376 | #-----------------------------------------------# 377 | ## patch_embed,通过定义卷积操作的步长/stride,时相下采样## 378 | # # stage1, 大卷积核7*7 4倍下采样## 379 | self.patch_embed1 = OverlapPatchEmbed(patch_size=7, stride=4, in_chans=in_chans, embed_dim=embed_dims[0]) 380 | #-----------------------------------------------# 381 | # 利用transformer模块进行特征提取 382 | # 16384, 32 => 16384, 32 383 | #-----------------------------------------------# 384 | cur = 0 385 | self.block1 = nn.ModuleList( 386 | [ 387 | Block( 388 | dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale, 389 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, sr_ratio=sr_ratios[0] 390 | ) 391 | for i in range(depths[0]) 392 | ] 393 | ) 394 | self.norm1 = norm_layer(embed_dims[0]) 395 | 396 | #----------------------------------# 397 | # block2 398 | #----------------------------------# 399 | #-----------------------------------------------# 400 | # 对输入图像进行分区,并下采样 401 | # 128, 128, 32 => 64, 64, 64 => 4096, 64 402 | #-----------------------------------------------# 403 | self.patch_embed2 = OverlapPatchEmbed(patch_size=3, stride=2, in_chans=embed_dims[0], embed_dim=embed_dims[1]) 404 | #-----------------------------------------------# 405 | # 利用transformer模块进行特征提取 406 | # 4096, 64 => 4096, 64 407 | #-----------------------------------------------# 408 | cur += depths[0] 409 | self.block2 = nn.ModuleList( 410 | [ 411 | Block( 412 | dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale, 413 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, sr_ratio=sr_ratios[1] 414 | ) 415 | for i in range(depths[1]) 416 | ] 417 | ) 418 | self.norm2 = norm_layer(embed_dims[1]) 419 | 420 | #----------------------------------# 421 | # block3 422 | #----------------------------------# 423 | #-----------------------------------------------# 424 | # 对输入图像进行分区,并下采样 425 | # 64, 64, 64 => 32, 32, 160 => 1024, 160 426 | #-----------------------------------------------# 427 | self.patch_embed3 = OverlapPatchEmbed(patch_size=3, stride=2, in_chans=embed_dims[1], embed_dim=embed_dims[2]) 428 | #-----------------------------------------------# 429 | # 利用transformer模块进行特征提取 430 | # 1024, 160 => 1024, 160 431 | #-----------------------------------------------# 432 | cur += depths[1] 433 | self.block3 = nn.ModuleList( 434 | [ 435 | Block( 436 | dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale, 437 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, sr_ratio=sr_ratios[2] 438 | ) 439 | for i in range(depths[2]) 440 | ] 441 | ) 442 | self.norm3 = norm_layer(embed_dims[2]) 443 | 444 | #----------------------------------# 445 | # block4 446 | #----------------------------------# 447 | #-----------------------------------------------# 448 | # 对输入图像进行分区,并下采样 449 | # 32, 32, 160 => 16, 16, 256 => 256, 256 450 | #-----------------------------------------------# 451 | self.patch_embed4 = OverlapPatchEmbed(patch_size=3, stride=2, in_chans=embed_dims[2], embed_dim=embed_dims[3]) 452 | #-----------------------------------------------# 453 | # 利用transformer模块进行特征提取 454 | # 256, 256 => 256, 256 455 | #-----------------------------------------------# 456 | cur += depths[2] 457 | self.block4 = nn.ModuleList( 458 | [ 459 | Block( 460 | dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale, 461 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, sr_ratio=sr_ratios[3] 462 | ) 463 | for i in range(depths[3]) 464 | ] 465 | ) 466 | self.norm4 = norm_layer(embed_dims[3]) 467 | 468 | self.apply(self._init_weights) 469 | 470 | def _init_weights(self, m): 471 | if isinstance(m, nn.Linear): 472 | trunc_normal_(m.weight, std=.02) 473 | if isinstance(m, nn.Linear) and m.bias is not None: 474 | nn.init.constant_(m.bias, 0) 475 | elif isinstance(m, nn.LayerNorm): 476 | nn.init.constant_(m.bias, 0) 477 | nn.init.constant_(m.weight, 1.0) 478 | elif isinstance(m, nn.Conv2d): 479 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 480 | fan_out //= m.groups 481 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 482 | if m.bias is not None: 483 | m.bias.data.zero_() 484 | 485 | def forward(self, x): 486 | B = x.shape[0] 487 | outs = [] 488 | 489 | #----------------------------------# 490 | # block1 491 | #----------------------------------# 492 | x, H, W = self.patch_embed1.forward(x) 493 | for i, blk in enumerate(self.block1): 494 | x = blk.forward(x, H, W) 495 | x = self.norm1(x) 496 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 497 | outs.append(x) 498 | 499 | #----------------------------------# 500 | # block2 501 | #----------------------------------# 502 | x, H, W = self.patch_embed2.forward(x) 503 | for i, blk in enumerate(self.block2): 504 | x = blk.forward(x, H, W) 505 | x = self.norm2(x) 506 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 507 | outs.append(x) 508 | 509 | #----------------------------------# 510 | # block3 511 | #----------------------------------# 512 | x, H, W = self.patch_embed3.forward(x) 513 | for i, blk in enumerate(self.block3): 514 | x = blk.forward(x, H, W) 515 | x = self.norm3(x) 516 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 517 | outs.append(x) 518 | 519 | #----------------------------------# 520 | # block4 521 | #----------------------------------# 522 | x, H, W = self.patch_embed4.forward(x) 523 | for i, blk in enumerate(self.block4): 524 | x = blk.forward(x, H, W) 525 | x = self.norm4(x) 526 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 527 | outs.append(x) 528 | 529 | return outs 530 | 531 | class mit_b0(MixVisionTransformer): 532 | def __init__(self, pretrained = False): 533 | super(mit_b0, self).__init__( 534 | embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], 535 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], 536 | drop_rate=0.0, drop_path_rate=0.1) 537 | if pretrained: 538 | print("Load backbone weights") 539 | self.load_state_dict(torch.load("model_data/segformer_b0_backbone_weights.pth"), strict=False) 540 | 541 | class mit_b1(MixVisionTransformer): 542 | def __init__(self, pretrained = False): 543 | super(mit_b1, self).__init__( 544 | embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], 545 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], 546 | drop_rate=0.0, drop_path_rate=0.1) 547 | if pretrained: 548 | print("Load backbone weights") 549 | self.load_state_dict(torch.load("model_data/segformer_b1_backbone_weights.pth"), strict=False) 550 | 551 | class mit_b2(MixVisionTransformer): 552 | def __init__(self, pretrained = False): 553 | super(mit_b2, self).__init__( 554 | embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], 555 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], 556 | drop_rate=0.0, drop_path_rate=0.1) 557 | if pretrained: 558 | print("Load backbone weights") 559 | self.load_state_dict(torch.load("model_data/segformer_b2_backbone_weights.pth"), strict=False) 560 | 561 | class mit_b3(MixVisionTransformer): 562 | def __init__(self, pretrained = False): 563 | super(mit_b3, self).__init__( 564 | embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], 565 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1], 566 | drop_rate=0.0, drop_path_rate=0.1) 567 | if pretrained: 568 | print("Load backbone weights") 569 | self.load_state_dict(torch.load("model_data/segformer_b3_backbone_weights.pth"), strict=False) 570 | 571 | class mit_b4(MixVisionTransformer): 572 | def __init__(self, pretrained = False): 573 | super(mit_b4, self).__init__( 574 | embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], 575 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1], 576 | drop_rate=0.0, drop_path_rate=0.1) 577 | if pretrained: 578 | print("Load backbone weights") 579 | self.load_state_dict(torch.load("model_data/segformer_b4_backbone_weights.pth"), strict=False) 580 | 581 | class mit_b5(MixVisionTransformer): 582 | def __init__(self, pretrained = False): 583 | super(mit_b5, self).__init__( 584 | embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], 585 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1], 586 | drop_rate=0.0, drop_path_rate=0.1) 587 | if pretrained: 588 | print("Load backbone weights") 589 | self.load_state_dict(torch.load("model_data/segformer_b5_backbone_weights.pth"), strict=False) 590 | -------------------------------------------------------------------------------- /nets/conv_.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | class DilationConv(nn.Sequential): 4 | def __init__(self, in_channels, out_channels,k_size=3, dilation=1,padding=0): 5 | modules = [ 6 | nn.Conv1d(in_channels, out_channels, k_size, padding=padding, dilation=dilation, bias=False), 7 | ] 8 | super(DilationConv, self).__init__(*modules) 9 | 10 | class DSConv(nn.Module): 11 | def __init__(self,in_channels,out_channels,kernel_size,stride=1,dilation=1,padding=0,bias=True): 12 | super(DSConv,self).__init__() 13 | self.body = nn.Sequential( 14 | nn.Conv2d(in_channels = in_channels, out_channels = in_channels, 15 | kernel_size = (kernel_size, 1), 16 | stride = stride, 17 | padding = (padding,0), dilation = dilation, groups = in_channels, bias = bias), 18 | # 1x3 19 | nn.Conv2d(in_channels = in_channels, out_channels = in_channels, 20 | kernel_size = (1, kernel_size), 21 | stride = stride, 22 | padding = (0,padding) , dilation = dilation, groups = in_channels, bias = bias), 23 | # PointWise Conv 24 | nn.Conv2d(in_channels, out_channels, kernel_size = 1, stride = 1, padding = 0, bias = bias) 25 | ) 26 | 27 | def forward(self,x): 28 | return self.body(x) -------------------------------------------------------------------------------- /nets/mf_head.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from nets.conv_ import DSConv, DilationConv 4 | import torch.nn.functional as F 5 | 6 | 7 | class FeatureFusion(nn.Module): 8 | def __init__(self, low_channels, high_channels, out_channels): 9 | super(FeatureFusion, self).__init__() 10 | self.conv_low = nn.Sequential( 11 | DSConv(low_channels,out_channels,3,dilation = 2,padding = 2,bias = False), 12 | nn.BatchNorm2d(out_channels), 13 | nn.ReLU() 14 | ) 15 | self.conv_high = nn.Sequential( 16 | DSConv(high_channels,out_channels,1,bias = False), 17 | nn.BatchNorm2d(out_channels), 18 | nn.ReLU() 19 | ) 20 | 21 | self.cat_conv = nn.Sequential( 22 | DSConv(out_channels*2,out_channels,5,dilation = 2,padding = 4,bias = False), 23 | nn.BatchNorm2d(out_channels), 24 | nn.ReLU(), 25 | ) 26 | 27 | 28 | self.sig = nn.Sequential( 29 | DSConv(out_channels,out_channels,3,padding = 1,bias = False), 30 | nn.BatchNorm2d(out_channels), 31 | nn.Sigmoid() 32 | ) 33 | 34 | self.ff_out = DSConv(out_channels, out_channels, 1, bias = False) 35 | 36 | def forward(self, x_low, x_high): 37 | x_low = F.interpolate(x_low, size=x_high.size()[2:], mode='bilinear', align_corners=True) 38 | x_low = self.conv_low(x_low) 39 | x_high = self.conv_high(x_high) 40 | 41 | x_cat = torch.cat([x_low,x_high],dim = 1) 42 | ga = self.cat_conv(x_cat) 43 | # ga = self.sig(x) 44 | 45 | x_low = torch.mul(ga, x_low) 46 | x_high = torch.mul((1 - ga), x_high) 47 | 48 | x = self.ff_out(x_low + x_high) 49 | 50 | return x 51 | 52 | class ASPP(nn.Module): 53 | def __init__(self, in_channel=768, depth=768): 54 | super(ASPP, self).__init__() 55 | self.pool2d = nn.Sequential( 56 | nn.AdaptiveAvgPool2d((1, 1)), 57 | nn.Conv2d(in_channel, depth, 1, 1), 58 | nn.BatchNorm2d(depth), 59 | nn.ReLU() 60 | ) 61 | 62 | self.block1 = nn.Sequential( 63 | nn.Conv2d(in_channel, depth, 1, 1), 64 | nn.BatchNorm2d(depth), 65 | nn.ReLU() 66 | ) 67 | 68 | self.block2 = nn.Sequential( 69 | nn.Conv2d(in_channel, depth, 3, 1, padding=6, dilation=6), 70 | nn.BatchNorm2d(depth), 71 | nn.ReLU() 72 | ) 73 | 74 | self.block3 = nn.Sequential( 75 | nn.Conv2d(in_channel, depth, 3, 1, padding=12, dilation=12), 76 | nn.BatchNorm2d(depth), 77 | nn.ReLU() 78 | ) 79 | 80 | self.block4 = nn.Sequential( 81 | nn.Conv2d(in_channel, depth, 3, 1, padding=18, dilation=18), 82 | nn.BatchNorm2d(depth), 83 | nn.ReLU() 84 | ) 85 | 86 | self.block5 = nn.Sequential( 87 | nn.Conv2d(depth * 5, depth, 1, 1), 88 | nn.BatchNorm2d(depth), 89 | nn.ReLU() 90 | ) 91 | 92 | self.dropout = nn.Dropout2d(0.1) 93 | 94 | 95 | def forward(self, x): 96 | size = x.shape[2:] 97 | 98 | image_features = self.pool2d(x) 99 | image_features = F.upsample(image_features, size=size, mode='bilinear') 100 | 101 | atrous_block1 = self.block1(x) 102 | atrous_block6 = self.block2(x) 103 | atrous_block12 = self.block3(x) 104 | atrous_block18 = self.block4(x) 105 | concat = torch.cat([image_features, atrous_block1, atrous_block6, atrous_block12, atrous_block18], dim=1) 106 | block5 = self.block5(concat) 107 | net = self.dropout(block5) 108 | return net 109 | 110 | class MF_Head(nn.Module): 111 | ''' 112 | Multiscale feature fusion 113 | ''' 114 | def __init__(self,in_channels=[32, 64, 160, 256],num_classes=2): 115 | super(MF_Head,self).__init__() 116 | 117 | self.ff0 = FeatureFusion(in_channels[3],in_channels[2],256) 118 | self.ff1 = FeatureFusion(256,in_channels[1],256) 119 | self.ff2 = FeatureFusion(256,in_channels[0],256) 120 | self.aspp = ASPP(in_channel=768, depth=768) 121 | self.seg = nn.Conv2d(768,num_classes,kernel_size = 1) 122 | 123 | def forward(self,inputs): 124 | c1, c2, c3, c4 = inputs 125 | ff0_out = self.ff0(c4,c3) 126 | ff1_out = self.ff1(ff0_out,c2) 127 | ff2_out = self.ff2(ff1_out,c1) 128 | 129 | ff0_out = F.interpolate(ff0_out,ff2_out.shape[2:],mode='bilinear', align_corners=True) 130 | ff1_out = F.interpolate(ff1_out, ff2_out.shape[2:], mode = 'bilinear', align_corners = True) 131 | x = torch.cat([ff0_out,ff1_out,ff2_out],dim = 1) 132 | 133 | aspp_out = self.aspp(x) 134 | out = self.seg(aspp_out) 135 | 136 | return out 137 | 138 | -------------------------------------------------------------------------------- /nets/segformer.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2021, NVIDIA Corporation. All rights reserved. 3 | # 4 | # This work is licensed under the NVIDIA Source Code License 5 | # --------------------------------------------------------------- 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from nets.mf_head import MF_Head 11 | from nets.backbone import mit_b0, mit_b1, mit_b2, mit_b3, mit_b4, mit_b5 12 | 13 | 14 | class MLP(nn.Module): 15 | """ 16 | Linear Embedding 17 | """ 18 | def __init__(self, input_dim=2048, embed_dim=768): 19 | super().__init__() 20 | self.proj = nn.Linear(input_dim, embed_dim) 21 | 22 | def forward(self, x): 23 | x = x.flatten(2).transpose(1, 2) 24 | x = self.proj(x) 25 | return x 26 | 27 | class ConvModule(nn.Module): 28 | def __init__(self, c1, c2, k=1, s=1, p=0, g=1, act=True): 29 | super(ConvModule, self).__init__() 30 | self.conv = nn.Conv2d(c1, c2, k, s, p, groups=g, bias=False) 31 | self.bn = nn.BatchNorm2d(c2, eps=0.001, momentum=0.03) 32 | self.act = nn.ReLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity()) 33 | 34 | def forward(self, x): 35 | return self.act(self.bn(self.conv(x))) 36 | 37 | def fuseforward(self, x): 38 | return self.act(self.conv(x)) 39 | 40 | class SegFormerHead(nn.Module): 41 | """ 42 | SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers 43 | """ 44 | def __init__(self, num_classes=20, in_channels=[32, 64, 160, 256], embedding_dim=768, dropout_ratio=0.1): 45 | super(SegFormerHead, self).__init__() 46 | c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = in_channels 47 | 48 | self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=embedding_dim) 49 | self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=embedding_dim) 50 | self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=embedding_dim) 51 | self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=embedding_dim) 52 | 53 | self.linear_fuse = ConvModule( 54 | c1=embedding_dim*4, 55 | c2=embedding_dim, 56 | k=1, 57 | ) 58 | 59 | self.linear_pred = nn.Conv2d(embedding_dim, num_classes, kernel_size=1) 60 | self.dropout = nn.Dropout2d(dropout_ratio) 61 | 62 | def forward(self, inputs): 63 | c1, c2, c3, c4 = inputs 64 | 65 | ############## MLP decoder on C1-C4 ########### 66 | n, _, h, w = c4.shape 67 | 68 | _c4 = self.linear_c4(c4).permute(0,2,1).reshape(n, -1, c4.shape[2], c4.shape[3]) 69 | _c4 = F.interpolate(_c4, size=c1.size()[2:], mode='bilinear', align_corners=False) 70 | 71 | _c3 = self.linear_c3(c3).permute(0,2,1).reshape(n, -1, c3.shape[2], c3.shape[3]) 72 | _c3 = F.interpolate(_c3, size=c1.size()[2:], mode='bilinear', align_corners=False) 73 | 74 | _c2 = self.linear_c2(c2).permute(0,2,1).reshape(n, -1, c2.shape[2], c2.shape[3]) 75 | _c2 = F.interpolate(_c2, size=c1.size()[2:], mode='bilinear', align_corners=False) 76 | 77 | _c1 = self.linear_c1(c1).permute(0,2,1).reshape(n, -1, c1.shape[2], c1.shape[3]) 78 | 79 | _c = self.linear_fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1)) 80 | 81 | x = self.dropout(_c) 82 | x = self.linear_pred(x) 83 | 84 | return x 85 | 86 | class SegFormer(nn.Module): 87 | def __init__(self, num_classes = 21, phi = 'b0', pretrained = False): 88 | super(SegFormer, self).__init__() 89 | self.in_channels = { 90 | 'b0': [32, 64, 160, 256], 'b1': [64, 128, 320, 512], 'b2': [64, 128, 320, 512], 91 | 'b3': [64, 128, 320, 512], 'b4': [64, 128, 320, 512], 'b5': [64, 128, 320, 512], 92 | }[phi] 93 | self.backbone = { 94 | 'b0': mit_b0, 'b1': mit_b1, 'b2': mit_b2, 95 | 'b3': mit_b3, 'b4': mit_b4, 'b5': mit_b5, 96 | }[phi](pretrained) 97 | self.embedding_dim = { 98 | 'b0': 256, 'b1': 256, 'b2': 768, 99 | 'b3': 768, 'b4': 768, 'b5': 768, 100 | }[phi] 101 | 102 | #MF decoder 103 | self.decode_head = MF_Head(in_channels = self.in_channels, num_classes = num_classes) 104 | 105 | #Segformer decoder 106 | # self.decode_head = SegFormerHead(num_classes, self.in_channels, self.embedding_dim) 107 | 108 | def forward(self, inputs): 109 | H, W = inputs.size(2), inputs.size(3) 110 | 111 | x = self.backbone.forward(inputs) 112 | x = self.decode_head.forward(x) 113 | 114 | x = F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True) 115 | return x -------------------------------------------------------------------------------- /nets/segformer_training.py: -------------------------------------------------------------------------------- 1 | import math 2 | from functools import partial 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | def CE_Loss(inputs, target, cls_weights, num_classes=21): 10 | n, c, h, w = inputs.size() 11 | nt, ht, wt = target.size() 12 | if h != ht and w != wt: 13 | inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True) 14 | 15 | temp_inputs = inputs.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c) 16 | temp_target = target.view(-1) 17 | 18 | CE_loss = nn.CrossEntropyLoss(weight=cls_weights, ignore_index=num_classes)(temp_inputs, temp_target) 19 | return CE_loss 20 | 21 | def Focal_Loss(inputs, target, cls_weights, num_classes=21, alpha=0.5, gamma=2): 22 | n, c, h, w = inputs.size() 23 | nt, ht, wt = target.size() 24 | if h != ht and w != wt: 25 | inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True) 26 | 27 | temp_inputs = inputs.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c) 28 | temp_target = target.view(-1) 29 | 30 | logpt = -nn.CrossEntropyLoss(weight=cls_weights, ignore_index=num_classes, reduction='none')(temp_inputs, temp_target) 31 | pt = torch.exp(logpt) 32 | if alpha is not None: 33 | logpt *= alpha 34 | loss = -((1 - pt) ** gamma) * logpt 35 | loss = loss.mean() 36 | return loss 37 | 38 | def Dice_loss(inputs, target, beta=1, smooth = 1e-5): 39 | n, c, h, w = inputs.size() 40 | nt, ht, wt, ct = target.size() 41 | if h != ht and w != wt: 42 | inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True) 43 | 44 | temp_inputs = torch.softmax(inputs.transpose(1, 2).transpose(2, 3).contiguous().view(n, -1, c),-1) 45 | temp_target = target.view(n, -1, ct) 46 | 47 | #--------------------------------------------# 48 | # 计算dice loss 49 | #--------------------------------------------# 50 | tp = torch.sum(temp_target[...,:-1] * temp_inputs, axis=[0,1]) 51 | fp = torch.sum(temp_inputs , axis=[0,1]) - tp 52 | fn = torch.sum(temp_target[...,:-1] , axis=[0,1]) - tp 53 | 54 | score = ((1 + beta ** 2) * tp + smooth) / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + smooth) 55 | dice_loss = 1 - torch.mean(score) 56 | return dice_loss 57 | 58 | def weights_init(net, init_type='normal', init_gain=0.02): 59 | def init_func(m): 60 | classname = m.__class__.__name__ 61 | if hasattr(m, 'weight') and classname.find('Conv') != -1: 62 | if init_type == 'normal': 63 | torch.nn.init.normal_(m.weight.data, 0.0, init_gain) 64 | elif init_type == 'xavier': 65 | torch.nn.init.xavier_normal_(m.weight.data, gain=init_gain) 66 | elif init_type == 'kaiming': 67 | torch.nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 68 | elif init_type == 'orthogonal': 69 | torch.nn.init.orthogonal_(m.weight.data, gain=init_gain) 70 | else: 71 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 72 | elif classname.find('BatchNorm2d') != -1: 73 | torch.nn.init.normal_(m.weight.data, 1.0, 0.02) 74 | torch.nn.init.constant_(m.bias.data, 0.0) 75 | print('initialize network with %s type' % init_type) 76 | net.apply(init_func) 77 | 78 | def get_lr_scheduler(lr_decay_type, lr, min_lr, total_iters, warmup_iters_ratio = 0.1, warmup_lr_ratio = 0.1, no_aug_iter_ratio = 0.3, step_num = 10): 79 | def yolox_warm_cos_lr(lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter, iters): 80 | if iters <= warmup_total_iters: 81 | # lr = (lr - warmup_lr_start) * iters / float(warmup_total_iters) + warmup_lr_start 82 | lr = (lr - warmup_lr_start) * pow(iters / float(warmup_total_iters), 2) + warmup_lr_start 83 | elif iters >= total_iters - no_aug_iter: 84 | lr = min_lr 85 | else: 86 | lr = min_lr + 0.5 * (lr - min_lr) * ( 87 | 1.0 + math.cos(math.pi* (iters - warmup_total_iters) / (total_iters - warmup_total_iters - no_aug_iter)) 88 | ) 89 | return lr 90 | 91 | def step_lr(lr, decay_rate, step_size, iters): 92 | if step_size < 1: 93 | raise ValueError("step_size must above 1.") 94 | n = iters // step_size 95 | out_lr = lr * decay_rate ** n 96 | return out_lr 97 | 98 | if lr_decay_type == "cos": 99 | warmup_total_iters = min(max(warmup_iters_ratio * total_iters, 1), 3) 100 | warmup_lr_start = max(warmup_lr_ratio * lr, 1e-6) 101 | no_aug_iter = min(max(no_aug_iter_ratio * total_iters, 1), 15) 102 | func = partial(yolox_warm_cos_lr ,lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter) 103 | else: 104 | decay_rate = (min_lr / lr) ** (1 / (step_num - 1)) 105 | step_size = total_iters / step_num 106 | func = partial(step_lr, lr, decay_rate, step_size) 107 | 108 | return func 109 | 110 | def set_optimizer_lr(optimizer, lr_scheduler_func, epoch): 111 | lr = lr_scheduler_func(epoch) 112 | for param_group in optimizer.param_groups: 113 | param_group['lr'] = lr 114 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | #----------------------------------------------------# 2 | # 将单张图片预测、摄像头检测和FPS测试功能 3 | # 整合到了一个py文件中,通过指定mode进行模式的修改。 4 | #----------------------------------------------------# 5 | import time 6 | 7 | import cv2 8 | import numpy as np 9 | from PIL import Image 10 | 11 | from segformer import SegFormer_Segmentation 12 | 13 | if __name__ == "__main__": 14 | #-------------------------------------------------------------------------# 15 | # 如果想要修改对应种类的颜色,到generate函数里修改self.colors即可 16 | #-------------------------------------------------------------------------# 17 | segformer = SegFormer_Segmentation() 18 | #----------------------------------------------------------------------------------------------------------# 19 | # mode用于指定测试的模式: 20 | # 'predict' 表示单张图片预测,如果想对预测过程进行修改,如保存图片,截取对象等,可以先看下方详细的注释 21 | # 'video' 表示视频检测,可调用摄像头或者视频进行检测,详情查看下方注释。 22 | # 'fps' 表示测试fps,使用的图片是img里面的street.jpg,详情查看下方注释。 23 | # 'dir_predict' 表示遍历文件夹进行检测并保存。默认遍历img文件夹,保存img_out文件夹,详情查看下方注释。 24 | # 'export_onnx' 表示将模型导出为onnx,需要pytorch1.7.1以上。 25 | #----------------------------------------------------------------------------------------------------------# 26 | mode = "dir_predict" 27 | #-------------------------------------------------------------------------# 28 | # count 指定了是否进行目标的像素点计数(即面积)与比例计算 29 | # name_classes 区分的种类,和json_to_dataset里面的一样,用于打印种类和数量 30 | # 31 | # count、name_classes仅在mode='predict'时有效 32 | #-------------------------------------------------------------------------# 33 | count = False 34 | name_classes = ["background","aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"] 35 | # name_classes = ["background","cat","dog"] 36 | #----------------------------------------------------------------------------------------------------------# 37 | # video_path 用于指定视频的路径,当video_path=0时表示检测摄像头 38 | # 想要检测视频,则设置如video_path = "xxx.mp4"即可,代表读取出根目录下的xxx.mp4文件。 39 | # video_save_path 表示视频保存的路径,当video_save_path=""时表示不保存 40 | # 想要保存视频,则设置如video_save_path = "yyy.mp4"即可,代表保存为根目录下的yyy.mp4文件。 41 | # video_fps 用于保存的视频的fps 42 | # 43 | # video_path、video_save_path和video_fps仅在mode='video'时有效 44 | # 保存视频时需要ctrl+c退出或者运行到最后一帧才会完成完整的保存步骤。 45 | #----------------------------------------------------------------------------------------------------------# 46 | video_path = 0 47 | video_save_path = "" 48 | video_fps = 25.0 49 | #----------------------------------------------------------------------------------------------------------# 50 | # test_interval 用于指定测量fps的时候,图片检测的次数。理论上test_interval越大,fps越准确。 51 | # fps_image_path 用于指定测试的fps图片 52 | # 53 | # test_interval和fps_image_path仅在mode='fps'有效 54 | #----------------------------------------------------------------------------------------------------------# 55 | test_interval = 100 56 | fps_image_path = "img/street.jpg" 57 | #-------------------------------------------------------------------------# 58 | # dir_origin_path 指定了用于检测的图片的文件夹路径 59 | # dir_save_path 指定了检测完图片的保存路径 60 | # 61 | # dir_origin_path和dir_save_path仅在mode='dir_predict'时有效 62 | #-------------------------------------------------------------------------# 63 | dir_origin_path = "img/" 64 | dir_save_path = "img_out/" 65 | #-------------------------------------------------------------------------# 66 | # simplify 使用Simplify onnx 67 | # onnx_save_path 指定了onnx的保存路径 68 | #-------------------------------------------------------------------------# 69 | simplify = True 70 | onnx_save_path = "model_data/models.onnx" 71 | 72 | if mode == "predict": 73 | ''' 74 | predict.py有几个注意点 75 | 1、该代码无法直接进行批量预测,如果想要批量预测,可以利用os.listdir()遍历文件夹,利用Image.open打开图片文件进行预测。 76 | 具体流程可以参考get_miou_prediction.py,在get_miou_prediction.py即实现了遍历。 77 | 2、如果想要保存,利用r_image.save("img.jpg")即可保存。 78 | 3、如果想要原图和分割图不混合,可以把blend参数设置成False。 79 | 4、如果想根据mask获取对应的区域,可以参考detect_image函数中,利用预测结果绘图的部分,判断每一个像素点的种类,然后根据种类获取对应的部分。 80 | seg_img = np.zeros((np.shape(pr)[0],np.shape(pr)[1],3)) 81 | for c in range(self.num_classes): 82 | seg_img[:, :, 0] += ((pr == c)*( self.colors[c][0] )).astype('uint8') 83 | seg_img[:, :, 1] += ((pr == c)*( self.colors[c][1] )).astype('uint8') 84 | seg_img[:, :, 2] += ((pr == c)*( self.colors[c][2] )).astype('uint8') 85 | ''' 86 | while True: 87 | img = input('Input image filename:') 88 | try: 89 | image = Image.open(img) 90 | except: 91 | print('Open Error! Try again!') 92 | continue 93 | else: 94 | r_image = segformer.detect_image(image, count=count, name_classes=name_classes) 95 | r_image.show() 96 | 97 | elif mode == "video": 98 | capture=cv2.VideoCapture(video_path) 99 | if video_save_path!="": 100 | fourcc = cv2.VideoWriter_fourcc(*'XVID') 101 | size = (int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)), int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))) 102 | out = cv2.VideoWriter(video_save_path, fourcc, video_fps, size) 103 | 104 | ref, frame = capture.read() 105 | if not ref: 106 | raise ValueError("未能正确读取摄像头(视频),请注意是否正确安装摄像头(是否正确填写视频路径)。") 107 | 108 | fps = 0.0 109 | while(True): 110 | t1 = time.time() 111 | # 读取某一帧 112 | ref, frame = capture.read() 113 | if not ref: 114 | break 115 | # 格式转变,BGRtoRGB 116 | frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB) 117 | # 转变成Image 118 | frame = Image.fromarray(np.uint8(frame)) 119 | # 进行检测 120 | frame = np.array(segformer.detect_image(frame)) 121 | # RGBtoBGR满足opencv显示格式 122 | frame = cv2.cvtColor(frame,cv2.COLOR_RGB2BGR) 123 | 124 | fps = ( fps + (1./(time.time()-t1)) ) / 2 125 | print("fps= %.2f"%(fps)) 126 | frame = cv2.putText(frame, "fps= %.2f"%(fps), (0, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2) 127 | 128 | cv2.imshow("video",frame) 129 | c= cv2.waitKey(1) & 0xff 130 | if video_save_path!="": 131 | out.write(frame) 132 | 133 | if c==27: 134 | capture.release() 135 | break 136 | print("Video Detection Done!") 137 | capture.release() 138 | if video_save_path!="": 139 | print("Save processed video to the path :" + video_save_path) 140 | out.release() 141 | cv2.destroyAllWindows() 142 | 143 | elif mode == "fps": 144 | img = Image.open(fps_image_path) 145 | tact_time = segformer.get_FPS(img, test_interval) 146 | print(str(tact_time) + ' seconds, ' + str(1/tact_time) + 'FPS, @batch_size 1') 147 | 148 | elif mode == "dir_predict": 149 | import os 150 | from tqdm import tqdm 151 | 152 | img_names = os.listdir(dir_origin_path) 153 | for img_name in tqdm(img_names): 154 | if img_name.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')): 155 | image_path = os.path.join(dir_origin_path, img_name) 156 | image = Image.open(image_path) 157 | r_image = segformer.detect_image(image) 158 | if not os.path.exists(dir_save_path): 159 | os.makedirs(dir_save_path) 160 | r_image.save(os.path.join(dir_save_path, img_name)) 161 | 162 | elif mode == "export_onnx": 163 | segformer.convert_to_onnx(simplify, onnx_save_path) 164 | 165 | else: 166 | raise AssertionError("Please specify the correct mode: 'predict', 'video', 'fps' or 'dir_predict'.") 167 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.6.0 2 | torchvision==0.7.0 3 | -------------------------------------------------------------------------------- /segformer.py: -------------------------------------------------------------------------------- 1 | import colorsys 2 | import copy 3 | import time 4 | 5 | import cv2 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | from PIL import Image 10 | from torch import nn 11 | 12 | from nets.segformer import SegFormer 13 | from utils.utils import cvtColor, preprocess_input, resize_image, show_config 14 | 15 | 16 | #-----------------------------------------------------------------------------------# 17 | # 使用自己训练好的模型预测需要修改3个参数 18 | # model_path、backbone和num_classes都需要修改! 19 | # 如果出现shape不匹配,一定要注意训练时的model_path、backbone和num_classes的修改 20 | #-----------------------------------------------------------------------------------# 21 | class SegFormer_Segmentation(object): 22 | _defaults = { 23 | #-------------------------------------------------------------------# 24 | # model_path指向logs文件夹下的权值文件 25 | # 训练好后logs文件夹下存在多个权值文件,选择验证集损失较低的即可。 26 | # 验证集损失较低不代表miou较高,仅代表该权值在验证集上泛化性能较好。 27 | #-------------------------------------------------------------------# 28 | "model_path" : "logs/best_epoch_weights.pth", 29 | #----------------------------------------# 30 | # 所需要区分的类的个数+1 31 | #----------------------------------------# 32 | "num_classes" : 2, 33 | #----------------------------------------# 34 | # 所使用的的主干网络: 35 | # b0、b1、b2、b3、b4、b5 36 | #----------------------------------------# 37 | "phi" : "b5", 38 | #----------------------------------------# 39 | # 输入图片的大小 40 | #----------------------------------------# 41 | "input_shape" : [256, 256], 42 | #-------------------------------------------------# 43 | # mix_type参数用于控制检测结果的可视化方式 44 | # 45 | # mix_type = 0的时候代表原图与生成的图进行混合 46 | # mix_type = 1的时候代表仅保留生成的图 47 | # mix_type = 2的时候代表仅扣去背景,仅保留原图中的目标 48 | #-------------------------------------------------# 49 | "mix_type" : 1, 50 | #-------------------------------# 51 | # 是否使用Cuda 52 | # 没有GPU可以设置成False 53 | #-------------------------------# 54 | "cuda" : True, 55 | } 56 | 57 | #---------------------------------------------------# 58 | # 初始化SegFormer 59 | #---------------------------------------------------# 60 | def __init__(self, **kwargs): 61 | self.__dict__.update(self._defaults) 62 | for name, value in kwargs.items(): 63 | setattr(self, name, value) 64 | #---------------------------------------------------# 65 | # 画框设置不同的颜色 66 | #---------------------------------------------------# 67 | if self.num_classes <= 21: 68 | self.colors = [ (0, 0, 0), (255, 255, 255), (0, 128, 0), (128, 128, 0), (0, 0, 128), (128, 0, 128), (0, 128, 128), 69 | (128, 128, 128), (64, 0, 0), (192, 0, 0), (64, 128, 0), (192, 128, 0), (64, 0, 128), (192, 0, 128), 70 | (64, 128, 128), (192, 128, 128), (0, 64, 0), (128, 64, 0), (0, 192, 0), (128, 192, 0), (0, 64, 128), 71 | (128, 64, 12)] 72 | else: 73 | hsv_tuples = [(x / self.num_classes, 1., 1.) for x in range(self.num_classes)] 74 | self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples)) 75 | self.colors = list(map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), self.colors)) 76 | #---------------------------------------------------# 77 | # 获得模型 78 | #---------------------------------------------------# 79 | self.generate() 80 | 81 | show_config(**self._defaults) 82 | 83 | #---------------------------------------------------# 84 | # 获得所有的分类 85 | #---------------------------------------------------# 86 | def generate(self, onnx=False): 87 | #-------------------------------# 88 | # 载入模型与权值 89 | #-------------------------------# 90 | self.net = SegFormer(num_classes=self.num_classes, phi=self.phi, pretrained=False) 91 | 92 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 93 | self.net.load_state_dict(torch.load(self.model_path, map_location=device)) 94 | self.net = self.net.eval() 95 | print('{} model, and classes loaded.'.format(self.model_path)) 96 | if not onnx: 97 | if self.cuda: 98 | self.net = nn.DataParallel(self.net) 99 | self.net = self.net.cuda() 100 | 101 | #---------------------------------------------------# 102 | # 检测图片 103 | #---------------------------------------------------# 104 | def detect_image(self, image, count=False, name_classes=None): 105 | #---------------------------------------------------------# 106 | # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。 107 | # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB 108 | #---------------------------------------------------------# 109 | image = cvtColor(image) 110 | #---------------------------------------------------# 111 | # 对输入图像进行一个备份,后面用于绘图 112 | #---------------------------------------------------# 113 | old_img = copy.deepcopy(image) 114 | orininal_h = np.array(image).shape[0] 115 | orininal_w = np.array(image).shape[1] 116 | #---------------------------------------------------------# 117 | # 给图像增加灰条,实现不失真的resize 118 | # 也可以直接resize进行识别 119 | #---------------------------------------------------------# 120 | image_data, nw, nh = resize_image(image, (self.input_shape[1],self.input_shape[0])) 121 | #---------------------------------------------------------# 122 | # 添加上batch_size维度 123 | #---------------------------------------------------------# 124 | image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, np.float32)), (2, 0, 1)), 0) 125 | 126 | with torch.no_grad(): 127 | images = torch.from_numpy(image_data) 128 | if self.cuda: 129 | images = images.cuda() 130 | 131 | #---------------------------------------------------# 132 | # 图片传入网络进行预测 133 | #---------------------------------------------------# 134 | pr = self.net(images)[0] 135 | #---------------------------------------------------# 136 | # 取出每一个像素点的种类 137 | #---------------------------------------------------# 138 | pr = F.softmax(pr.permute(1,2,0),dim = -1).cpu().numpy() 139 | #--------------------------------------# 140 | # 将灰条部分截取掉 141 | #--------------------------------------# 142 | pr = pr[int((self.input_shape[0] - nh) // 2) : int((self.input_shape[0] - nh) // 2 + nh), \ 143 | int((self.input_shape[1] - nw) // 2) : int((self.input_shape[1] - nw) // 2 + nw)] 144 | #---------------------------------------------------# 145 | # 进行图片的resize 146 | #---------------------------------------------------# 147 | pr = cv2.resize(pr, (orininal_w, orininal_h), interpolation = cv2.INTER_LINEAR) 148 | #---------------------------------------------------# 149 | # 取出每一个像素点的种类 150 | #---------------------------------------------------# 151 | pr = pr.argmax(axis=-1) 152 | 153 | #---------------------------------------------------------# 154 | # 计数 155 | #---------------------------------------------------------# 156 | if count: 157 | classes_nums = np.zeros([self.num_classes]) 158 | total_points_num = orininal_h * orininal_w 159 | print('-' * 63) 160 | print("|%25s | %15s | %15s|"%("Key", "Value", "Ratio")) 161 | print('-' * 63) 162 | for i in range(self.num_classes): 163 | num = np.sum(pr == i) 164 | ratio = num / total_points_num * 100 165 | if num > 0: 166 | print("|%25s | %15s | %14.2f%%|"%(str(name_classes[i]), str(num), ratio)) 167 | print('-' * 63) 168 | classes_nums[i] = num 169 | print("classes_nums:", classes_nums) 170 | 171 | if self.mix_type == 0: 172 | # seg_img = np.zeros((np.shape(pr)[0], np.shape(pr)[1], 3)) 173 | # for c in range(self.num_classes): 174 | # seg_img[:, :, 0] += ((pr[:, :] == c ) * self.colors[c][0]).astype('uint8') 175 | # seg_img[:, :, 1] += ((pr[:, :] == c ) * self.colors[c][1]).astype('uint8') 176 | # seg_img[:, :, 2] += ((pr[:, :] == c ) * self.colors[c][2]).astype('uint8') 177 | seg_img = np.reshape(np.array(self.colors, np.uint8)[np.reshape(pr, [-1])], [orininal_h, orininal_w, -1]) 178 | #------------------------------------------------# 179 | # 将新图片转换成Image的形式 180 | #------------------------------------------------# 181 | image = Image.fromarray(np.uint8(seg_img)) 182 | #------------------------------------------------# 183 | # 将新图与原图及进行混合 184 | #------------------------------------------------# 185 | image = Image.blend(old_img, image, 0.7) 186 | 187 | elif self.mix_type == 1: 188 | # seg_img = np.zeros((np.shape(pr)[0], np.shape(pr)[1], 3)) 189 | # for c in range(self.num_classes): 190 | # seg_img[:, :, 0] += ((pr[:, :] == c ) * self.colors[c][0]).astype('uint8') 191 | # seg_img[:, :, 1] += ((pr[:, :] == c ) * self.colors[c][1]).astype('uint8') 192 | # seg_img[:, :, 2] += ((pr[:, :] == c ) * self.colors[c][2]).astype('uint8') 193 | seg_img = np.reshape(np.array(self.colors, np.uint8)[np.reshape(pr, [-1])], [orininal_h, orininal_w, -1]) 194 | #------------------------------------------------# 195 | # 将新图片转换成Image的形式 196 | #------------------------------------------------# 197 | image = Image.fromarray(np.uint8(seg_img)) 198 | 199 | elif self.mix_type == 2: 200 | seg_img = (np.expand_dims(pr != 0, -1) * np.array(old_img, np.float32)).astype('uint8') 201 | #------------------------------------------------# 202 | # 将新图片转换成Image的形式 203 | #------------------------------------------------# 204 | image = Image.fromarray(np.uint8(seg_img)) 205 | 206 | return image 207 | 208 | def get_FPS(self, image, test_interval): 209 | #---------------------------------------------------------# 210 | # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。 211 | # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB 212 | #---------------------------------------------------------# 213 | image = cvtColor(image) 214 | #---------------------------------------------------------# 215 | # 给图像增加灰条,实现不失真的resize 216 | # 也可以直接resize进行识别 217 | #---------------------------------------------------------# 218 | image_data, nw, nh = resize_image(image, (self.input_shape[1],self.input_shape[0])) 219 | #---------------------------------------------------------# 220 | # 添加上batch_size维度 221 | #---------------------------------------------------------# 222 | image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, np.float32)), (2, 0, 1)), 0) 223 | 224 | with torch.no_grad(): 225 | images = torch.from_numpy(image_data) 226 | if self.cuda: 227 | images = images.cuda() 228 | 229 | #---------------------------------------------------# 230 | # 图片传入网络进行预测 231 | #---------------------------------------------------# 232 | pr = self.net(images)[0] 233 | #---------------------------------------------------# 234 | # 取出每一个像素点的种类 235 | #---------------------------------------------------# 236 | pr = F.softmax(pr.permute(1,2,0),dim = -1).cpu().numpy().argmax(axis=-1) 237 | #--------------------------------------# 238 | # 将灰条部分截取掉 239 | #--------------------------------------# 240 | pr = pr[int((self.input_shape[0] - nh) // 2) : int((self.input_shape[0] - nh) // 2 + nh), \ 241 | int((self.input_shape[1] - nw) // 2) : int((self.input_shape[1] - nw) // 2 + nw)] 242 | 243 | t1 = time.time() 244 | for _ in range(test_interval): 245 | with torch.no_grad(): 246 | #---------------------------------------------------# 247 | # 图片传入网络进行预测 248 | #---------------------------------------------------# 249 | pr = self.net(images)[0] 250 | #---------------------------------------------------# 251 | # 取出每一个像素点的种类 252 | #---------------------------------------------------# 253 | pr = F.softmax(pr.permute(1,2,0),dim = -1).cpu().numpy().argmax(axis=-1) 254 | #--------------------------------------# 255 | # 将灰条部分截取掉 256 | #--------------------------------------# 257 | pr = pr[int((self.input_shape[0] - nh) // 2) : int((self.input_shape[0] - nh) // 2 + nh), \ 258 | int((self.input_shape[1] - nw) // 2) : int((self.input_shape[1] - nw) // 2 + nw)] 259 | t2 = time.time() 260 | tact_time = (t2 - t1) / test_interval 261 | return tact_time 262 | 263 | def convert_to_onnx(self, simplify, model_path): 264 | import onnx 265 | self.generate(onnx=True) 266 | 267 | im = torch.zeros(1, 3, *self.input_shape).to('cpu') # image size(1, 3, 512, 512) BCHW 268 | input_layer_names = ["images"] 269 | output_layer_names = ["output"] 270 | 271 | # Export the model 272 | print(f'Starting export with onnx {onnx.__version__}.') 273 | torch.onnx.export(self.net, 274 | im, 275 | f = model_path, 276 | verbose = False, 277 | opset_version = 12, 278 | training = torch.onnx.TrainingMode.EVAL, 279 | do_constant_folding = True, 280 | input_names = input_layer_names, 281 | output_names = output_layer_names, 282 | dynamic_axes = None) 283 | 284 | # Checks 285 | model_onnx = onnx.load(model_path) # load onnx model 286 | onnx.checker.check_model(model_onnx) # check onnx model 287 | 288 | # Simplify onnx 289 | if simplify: 290 | import onnxsim 291 | print(f'Simplifying with onnx-simplifier {onnxsim.__version__}.') 292 | model_onnx, check = onnxsim.simplify( 293 | model_onnx, 294 | dynamic_input_shape=False, 295 | input_shapes=None) 296 | assert check, 'assert check failed' 297 | onnx.save(model_onnx, model_path) 298 | 299 | print('Onnx model save as {}'.format(model_path)) 300 | 301 | def get_miou_png(self, image): 302 | #---------------------------------------------------------# 303 | # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。 304 | # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB 305 | #---------------------------------------------------------# 306 | image = cvtColor(image) 307 | orininal_h = np.array(image).shape[0] 308 | orininal_w = np.array(image).shape[1] 309 | #---------------------------------------------------------# 310 | # 给图像增加灰条,实现不失真的resize 311 | # 也可以直接resize进行识别 312 | #---------------------------------------------------------# 313 | image_data, nw, nh = resize_image(image, (self.input_shape[1],self.input_shape[0])) 314 | #---------------------------------------------------------# 315 | # 添加上batch_size维度 316 | #---------------------------------------------------------# 317 | image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, np.float32)), (2, 0, 1)), 0) 318 | 319 | with torch.no_grad(): 320 | images = torch.from_numpy(image_data) 321 | if self.cuda: 322 | images = images.cuda() 323 | 324 | #---------------------------------------------------# 325 | # 图片传入网络进行预测 326 | #---------------------------------------------------# 327 | pr = self.net(images)[0] 328 | #---------------------------------------------------# 329 | # 取出每一个像素点的种类 330 | #---------------------------------------------------# 331 | pr = F.softmax(pr.permute(1,2,0),dim = -1).cpu().numpy() 332 | #--------------------------------------# 333 | # 将灰条部分截取掉 334 | #--------------------------------------# 335 | pr = pr[int((self.input_shape[0] - nh) // 2) : int((self.input_shape[0] - nh) // 2 + nh), \ 336 | int((self.input_shape[1] - nw) // 2) : int((self.input_shape[1] - nw) // 2 + nw)] 337 | #---------------------------------------------------# 338 | # 进行图片的resize 339 | #---------------------------------------------------# 340 | pr = cv2.resize(pr, (orininal_w, orininal_h), interpolation = cv2.INTER_LINEAR) 341 | #---------------------------------------------------# 342 | # 取出每一个像素点的种类 343 | #---------------------------------------------------# 344 | pr = pr.argmax(axis=-1) 345 | 346 | image = Image.fromarray(np.uint8(pr)) 347 | return image 348 | -------------------------------------------------------------------------------- /summary.py: -------------------------------------------------------------------------------- 1 | #--------------------------------------------# 2 | # 该部分代码用于看网络结构 3 | #--------------------------------------------# 4 | import torch 5 | from thop import clever_format, profile 6 | from torchsummary import summary 7 | 8 | from nets.segformer import SegFormer 9 | 10 | if __name__ == "__main__": 11 | input_shape = [256, 256] 12 | num_classes = 2 13 | phi = 'b5' 14 | 15 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 16 | model = SegFormer(num_classes = num_classes, phi = phi, pretrained=False).to(device) 17 | summary(model, (3, input_shape[0], input_shape[1])) 18 | 19 | dummy_input = torch.randn(1, 3, input_shape[0], input_shape[1]).to(device) 20 | flops, params = profile(model.to(device), (dummy_input, ), verbose=False) 21 | #--------------------------------------------------------# 22 | # flops * 2是因为profile没有将卷积作为两个operations 23 | # 有些论文将卷积算乘法、加法两个operations。此时乘2 24 | # 有些论文只考虑乘法的运算次数,忽略加法。此时不乘2 25 | # 本代码选择乘2,参考YOLOX。 26 | #--------------------------------------------------------# 27 | flops = flops * 2 28 | flops, params = clever_format([flops, params], "%.3f") 29 | print('Total GFLOPS: %s' % (flops)) 30 | print('Total params: %s' % (params)) 31 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import datetime 3 | 4 | import numpy as np 5 | import torch 6 | import torch.backends.cudnn as cudnn 7 | import torch.distributed as dist 8 | import torch.optim as optim 9 | from torch.utils.data import DataLoader 10 | 11 | from nets.segformer import SegFormer 12 | from nets.segformer_training import (get_lr_scheduler, set_optimizer_lr, 13 | weights_init) 14 | from utils.callbacks import LossHistory, EvalCallback 15 | from utils.dataloader import SegmentationDataset, seg_dataset_collate 16 | from utils.utils import download_weights, show_config 17 | from utils.utils_fit import fit_one_epoch 18 | 19 | 20 | if __name__ == "__main__": 21 | #---------------------------------# 22 | # Cuda 是否使用Cuda 23 | # 没有GPU可以设置成False 24 | #---------------------------------# 25 | Cuda = True 26 | #---------------------------------------------------------------------# 27 | # distributed 用于指定是否使用单机多卡分布式运行 28 | # 终端指令仅支持Ubuntu。CUDA_VISIBLE_DEVICES用于在Ubuntu下指定显卡。 29 | # Windows系统下默认使用DP模式调用所有显卡,不支持DDP。 30 | #---------------------------------------------------------------------# 31 | distributed = False 32 | #---------------------------------------------------------------------# 33 | # sync_bn 是否使用sync_bn,DDP模式多卡可用 34 | #---------------------------------------------------------------------# 35 | sync_bn = False 36 | #---------------------------------------------------------------------# 37 | # fp16 是否使用混合精度训练 38 | # 可减少约一半的显存、需要pytorch1.7.1以上 39 | #---------------------------------------------------------------------# 40 | fp16 = False 41 | #-----------------------------------------------------# 42 | # num_classes 类别数 43 | # 自己需要的分类个数+1,如2+1 44 | #-----------------------------------------------------# 45 | num_classes = 2 46 | #-------------------------------------------------------------------# 47 | # 所使用的的主干网络: 48 | # b0、b1、b2、b3、b4、b5 49 | #-------------------------------------------------------------------# 50 | phi = "b5" 51 | #----------------------------------------------------------------------------------------------------------------------------# 52 | # pretrained 是否使用主干网络的预训练权重。 53 | #----------------------------------------------------------------------------------------------------------------------------# 54 | pretrained = True 55 | #----------------------------------------------------------------------------------------------------------------------------# 56 | # 当model_path = ''的时候不加载整个模型的权值。 57 | #----------------------------------------------------------------------------------------------------------------------------# 58 | model_path = './model_data/segformer_b5_backbone_weights.pth' 59 | #------------------------------# 60 | # 输入图片的大小 61 | #------------------------------# 62 | input_shape = [256, 256] 63 | #------------------------------------------------------------------# 64 | # 冻结阶段训练参数 65 | # 此时模型的主干被冻结了,特征提取网络不发生改变 66 | # 占用的显存较小,仅对网络进行微调 67 | # Init_Epoch 模型当前开始的训练世代,其值可以大于Freeze_Epoch,如设置: 68 | # Init_Epoch = 60、Freeze_Epoch = 50、UnFreeze_Epoch = 100 69 | # 会跳过冻结阶段,直接从60代开始,并调整对应的学习率。 70 | # (断点续练时使用) 71 | # Freeze_Epoch 模型冻结训练的Freeze_Epoch 72 | # (当Freeze_Train=False时失效) 73 | # Freeze_batch_size 模型冻结训练的batch_size 74 | # (当Freeze_Train=False时失效) 75 | #------------------------------------------------------------------# 76 | Init_Epoch = 0 77 | Freeze_Epoch = 0 78 | Freeze_batch_size = 30 79 | #------------------------------------------------------------------# 80 | # 解冻阶段训练参数 81 | # 此时模型的主干不被冻结了,特征提取网络会发生改变 82 | # 占用的显存较大,网络所有的参数都会发生改变 83 | # UnFreeze_Epoch 模型总共训练的epoch 84 | # Unfreeze_batch_size 模型在解冻后的batch_size 85 | #------------------------------------------------------------------# 86 | UnFreeze_Epoch = 100 87 | Unfreeze_batch_size = 16 88 | #------------------------------------------------------------------# 89 | # Freeze_Train 是否进行冻结训练 90 | # 默认先冻结主干训练后解冻训练。 91 | #------------------------------------------------------------------# 92 | Freeze_Train = False 93 | 94 | #------------------------------------------------------------------# 95 | # 其它训练参数:学习率、优化器、学习率下降有关 96 | #------------------------------------------------------------------# 97 | #------------------------------------------------------------------# 98 | # Init_lr 模型的最大学习率 99 | # 当使用Adam优化器时建议设置 Init_lr=1e-4 100 | # 当使用AdamW优化器时建议设置 Init_lr=1e-4 101 | # Transformer系列不建议使用SGD 102 | # Min_lr 模型的最小学习率,默认为最大学习率的0.01 103 | #------------------------------------------------------------------# 104 | Init_lr = 1e-4 105 | Min_lr = Init_lr * 0.01 106 | #------------------------------------------------------------------# 107 | # optimizer_type 使用到的优化器种类,可选的有adam、adamw、sgd 108 | # momentum 优化器内部使用到的momentum参数 109 | # weight_decay 权值衰减,可防止过拟合 110 | # adam会导致weight_decay错误,使用adam时建议设置为0。 111 | #------------------------------------------------------------------# 112 | optimizer_type = "adamw" 113 | momentum = 0.9 114 | weight_decay = 1e-2 115 | #------------------------------------------------------------------# 116 | # lr_decay_type 使用到的学习率下降方式,可选的有'step'、'cos' 117 | #------------------------------------------------------------------# 118 | lr_decay_type = 'cos' 119 | #------------------------------------------------------------------# 120 | # save_period 多少个epoch保存一次权值 121 | #------------------------------------------------------------------# 122 | save_period = 100 123 | #------------------------------------------------------------------# 124 | # save_dir 权值与日志文件保存的文件夹 125 | #------------------------------------------------------------------# 126 | save_dir = 'logs' 127 | #------------------------------------------------------------------# 128 | # eval_flag 是否在训练时进行评估,评估对象为验证集 129 | # eval_period 代表多少个epoch评估一次,不建议频繁的评估 130 | # 评估需要消耗较多的时间,频繁评估会导致训练非常慢 131 | # 此处获得的mAP会与get_map.py获得的会有所不同,原因有二: 132 | # (一)此处获得的mAP为验证集的mAP。 133 | # (二)此处设置评估参数较为保守,目的是加快评估速度。 134 | #------------------------------------------------------------------# 135 | eval_flag = False 136 | eval_period = 10 137 | 138 | #------------------------------------------------------------------# 139 | # VOCdevkit_path 数据集路径 140 | #------------------------------------------------------------------# 141 | VOCdevkit_path = './VOCdevkit' 142 | #------------------------------------------------------------------# 143 | # 建议选项: 144 | # 种类少(几类)时,设置为True 145 | # 种类多(十几类)时,如果batch_size比较大(10以上),那么设置为True 146 | # 种类多(十几类)时,如果batch_size比较小(10以下),那么设置为False 147 | #------------------------------------------------------------------# 148 | dice_loss = True 149 | #------------------------------------------------------------------# 150 | # 是否使用focal loss来防止正负样本不平衡 151 | #------------------------------------------------------------------# 152 | focal_loss = True 153 | #------------------------------------------------------------------# 154 | # 是否给不同种类赋予不同的损失权值,默认是平衡的。 155 | # 设置的话,注意设置成numpy形式的,长度和num_classes一样。 156 | # 如: 157 | # num_classes = 3 158 | # cls_weights = np.array([1, 2, 3], np.float32) 159 | #------------------------------------------------------------------# 160 | cls_weights = np.ones([num_classes], np.float32) 161 | #------------------------------------------------------------------# 162 | # num_workers 用于设置是否使用多线程读取数据,1代表关闭多线程 163 | # 开启后会加快数据读取速度,但是会占用更多内存 164 | # keras里开启多线程有些时候速度反而慢了许多 165 | # 在IO为瓶颈的时候再开启多线程,即GPU运算速度远大于读取图片的速度。 166 | #------------------------------------------------------------------# 167 | num_workers = 4 168 | 169 | #------------------------------------------------------# 170 | # 设置用到的显卡 171 | #------------------------------------------------------# 172 | ngpus_per_node = torch.cuda.device_count() 173 | if distributed: 174 | dist.init_process_group(backend="nccl") 175 | local_rank = int(os.environ["LOCAL_RANK"]) 176 | rank = int(os.environ["RANK"]) 177 | device = torch.device("cuda", local_rank) 178 | if local_rank == 0: 179 | print(f"[{os.getpid()}] (rank = {rank}, local_rank = {local_rank}) training...") 180 | print("Gpu Device Count : ", ngpus_per_node) 181 | else: 182 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 183 | local_rank = 0 184 | 185 | #----------------------------------------------------# 186 | # 下载预训练权重 187 | #----------------------------------------------------# 188 | if pretrained: 189 | if distributed: 190 | if local_rank == 0: 191 | download_weights(phi) 192 | dist.barrier() 193 | else: 194 | download_weights(phi) 195 | 196 | model = SegFormer(num_classes=num_classes, phi=phi, pretrained=pretrained) 197 | if not pretrained: 198 | weights_init(model) 199 | if model_path != '': 200 | #------------------------------------------------------# 201 | # 权值文件请看README,百度网盘下载 202 | #------------------------------------------------------# 203 | if local_rank == 0: 204 | print('Load weights {}.'.format(model_path)) 205 | 206 | #------------------------------------------------------# 207 | # 根据预训练权重的Key和模型的Key进行加载 208 | #------------------------------------------------------# 209 | model_dict = model.state_dict() 210 | pretrained_dict = torch.load(model_path, map_location = device) 211 | load_key, no_load_key, temp_dict = [], [], {} 212 | for k, v in pretrained_dict.items(): 213 | if k in model_dict.keys() and np.shape(model_dict[k]) == np.shape(v): 214 | temp_dict[k] = v 215 | load_key.append(k) 216 | else: 217 | no_load_key.append(k) 218 | model_dict.update(temp_dict) 219 | model.load_state_dict(model_dict) 220 | #------------------------------------------------------# 221 | # 显示没有匹配上的Key 222 | #------------------------------------------------------# 223 | if local_rank == 0: 224 | print("\nSuccessful Load Key:", str(load_key)[:500], "……\nSuccessful Load Key Num:", len(load_key)) 225 | print("\nFail To Load Key:", str(no_load_key)[:500], "……\nFail To Load Key num:", len(no_load_key)) 226 | print("\n\033[1;33;44m温馨提示,head部分没有载入是正常现象,Backbone部分没有载入是错误的。\033[0m") 227 | 228 | #----------------------# 229 | # 记录Loss 230 | #----------------------# 231 | if local_rank == 0: 232 | time_str = datetime.datetime.strftime(datetime.datetime.now(),'%Y_%m_%d_%H_%M_%S') 233 | log_dir = os.path.join(save_dir, "loss_" + str(time_str)) 234 | loss_history = LossHistory(log_dir, model, input_shape=input_shape) 235 | else: 236 | loss_history = None 237 | 238 | #------------------------------------------------------------------# 239 | # torch 1.2不支持amp,建议使用torch 1.7.1及以上正确使用fp16 240 | # 因此torch1.2这里显示"could not be resolve" 241 | #------------------------------------------------------------------# 242 | if fp16: 243 | from torch.cuda.amp import GradScaler as GradScaler 244 | scaler = GradScaler() 245 | else: 246 | scaler = None 247 | 248 | model_train = model.train() 249 | #----------------------------# 250 | # 多卡同步Bn 251 | #----------------------------# 252 | if sync_bn and ngpus_per_node > 1 and distributed: 253 | model_train = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model_train) 254 | elif sync_bn: 255 | print("Sync_bn is not support in one gpu or not distributed.") 256 | 257 | if Cuda: 258 | if distributed: 259 | #----------------------------# 260 | # 多卡平行运行 261 | #----------------------------# 262 | model_train = model_train.cuda(local_rank) 263 | model_train = torch.nn.parallel.DistributedDataParallel(model_train, device_ids=[local_rank], find_unused_parameters=True) 264 | else: 265 | model_train = torch.nn.DataParallel(model) 266 | cudnn.benchmark = True 267 | model_train = model_train.cuda() 268 | 269 | #---------------------------# 270 | # 读取数据集对应的txt 271 | #---------------------------# 272 | with open(os.path.join(VOCdevkit_path, "VOC2007/ImageSets/Segmentation/train.txt"),"r") as f: 273 | train_lines = f.readlines() 274 | with open(os.path.join(VOCdevkit_path, "VOC2007/ImageSets/Segmentation/val.txt"),"r") as f: 275 | val_lines = f.readlines() 276 | num_train = len(train_lines) 277 | num_val = len(val_lines) 278 | 279 | if local_rank == 0: 280 | show_config( 281 | num_classes = num_classes, phi = phi, model_path = model_path, input_shape = input_shape, \ 282 | Init_Epoch = Init_Epoch, Freeze_Epoch = Freeze_Epoch, UnFreeze_Epoch = UnFreeze_Epoch, Freeze_batch_size = Freeze_batch_size, Unfreeze_batch_size = Unfreeze_batch_size, Freeze_Train = Freeze_Train, \ 283 | Init_lr = Init_lr, Min_lr = Min_lr, optimizer_type = optimizer_type, momentum = momentum, lr_decay_type = lr_decay_type, \ 284 | save_period = save_period, save_dir = save_dir, num_workers = num_workers, num_train = num_train, num_val = num_val 285 | ) 286 | #---------------------------------------------------------# 287 | # 总训练世代指的是遍历全部数据的总次数 288 | # 总训练步长指的是梯度下降的总次数 289 | # 每个训练世代包含若干训练步长,每个训练步长进行一次梯度下降。 290 | # 此处仅建议最低训练世代,上不封顶,计算时只考虑了解冻部分 291 | #----------------------------------------------------------# 292 | wanted_step = 1.5e4 if optimizer_type == "adamw" else 0.5e4 293 | total_step = num_train // Unfreeze_batch_size * UnFreeze_Epoch 294 | if total_step <= wanted_step: 295 | if num_train // Unfreeze_batch_size == 0: 296 | raise ValueError('数据集过小,无法进行训练,请扩充数据集。') 297 | wanted_epoch = wanted_step // (num_train // Unfreeze_batch_size) + 1 298 | print("\n\033[1;33;44m[Warning] 使用%s优化器时,建议将训练总步长设置到%d以上。\033[0m"%(optimizer_type, wanted_step)) 299 | print("\033[1;33;44m[Warning] 本次运行的总训练数据量为%d,Unfreeze_batch_size为%d,共训练%d个Epoch,计算出总训练步长为%d。\033[0m"%(num_train, Unfreeze_batch_size, UnFreeze_Epoch, total_step)) 300 | print("\033[1;33;44m[Warning] 由于总训练步长为%d,小于建议总步长%d,建议设置总世代为%d。\033[0m"%(total_step, wanted_step, wanted_epoch)) 301 | 302 | #------------------------------------------------------# 303 | # 主干特征提取网络特征通用,冻结训练可以加快训练速度 304 | # 也可以在训练初期防止权值被破坏。 305 | # Init_Epoch为起始世代 306 | # Interval_Epoch为冻结训练的世代 307 | # Epoch总训练世代 308 | # 提示OOM或者显存不足请调小Batch_size 309 | #------------------------------------------------------# 310 | if True: 311 | UnFreeze_flag = False 312 | #------------------------------------# 313 | # 冻结一定部分训练 314 | #------------------------------------# 315 | if Freeze_Train: 316 | for param in model.backbone.parameters(): 317 | param.requires_grad = False 318 | 319 | #-------------------------------------------------------------------# 320 | # 如果不冻结训练的话,直接设置batch_size为Unfreeze_batch_size 321 | #-------------------------------------------------------------------# 322 | batch_size = Freeze_batch_size if Freeze_Train else Unfreeze_batch_size 323 | 324 | #-------------------------------------------------------------------# 325 | # 判断当前batch_size,自适应调整学习率 326 | #-------------------------------------------------------------------# 327 | nbs = 16 328 | lr_limit_max = 1e-4 if optimizer_type in ['adam', 'adamw'] else 5e-2 329 | lr_limit_min = 3e-5 if optimizer_type in ['adam', 'adamw'] else 5e-4 330 | Init_lr_fit = min(max(batch_size / nbs * Init_lr, lr_limit_min), lr_limit_max) 331 | Min_lr_fit = min(max(batch_size / nbs * Min_lr, lr_limit_min * 1e-2), lr_limit_max * 1e-2) 332 | 333 | #---------------------------------------# 334 | # 根据optimizer_type选择优化器 335 | #---------------------------------------# 336 | optimizer = { 337 | 'adam' : optim.Adam(model.parameters(), Init_lr_fit, betas = (momentum, 0.999), weight_decay = weight_decay), 338 | 'adamw' : optim.AdamW(model.parameters(), Init_lr_fit, betas = (momentum, 0.999), weight_decay = weight_decay), 339 | 'sgd' : optim.SGD(model.parameters(), Init_lr_fit, momentum = momentum, nesterov=True, weight_decay = weight_decay) 340 | }[optimizer_type] 341 | 342 | #---------------------------------------# 343 | # 获得学习率下降的公式 344 | #---------------------------------------# 345 | lr_scheduler_func = get_lr_scheduler(lr_decay_type, Init_lr_fit, Min_lr_fit, UnFreeze_Epoch) 346 | 347 | #---------------------------------------# 348 | # 判断每一个世代的长度 349 | #---------------------------------------# 350 | epoch_step = num_train // batch_size 351 | epoch_step_val = num_val // batch_size 352 | 353 | if epoch_step == 0 or epoch_step_val == 0: 354 | raise ValueError("数据集过小,无法继续进行训练,请扩充数据集。") 355 | 356 | train_dataset = SegmentationDataset(train_lines, input_shape, num_classes, True, VOCdevkit_path) 357 | val_dataset = SegmentationDataset(val_lines, input_shape, num_classes, False, VOCdevkit_path) 358 | 359 | if distributed: 360 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=True,) 361 | val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False,) 362 | batch_size = batch_size // ngpus_per_node 363 | shuffle = False 364 | else: 365 | train_sampler = None 366 | val_sampler = None 367 | shuffle = True 368 | 369 | gen = DataLoader(train_dataset, shuffle = shuffle, batch_size = batch_size, num_workers = num_workers, pin_memory=True, 370 | drop_last = True, collate_fn = seg_dataset_collate, sampler=train_sampler) 371 | gen_val = DataLoader(val_dataset , shuffle = shuffle, batch_size = batch_size, num_workers = num_workers, pin_memory=True, 372 | drop_last = True, collate_fn = seg_dataset_collate, sampler=val_sampler) 373 | 374 | #----------------------# 375 | # 记录eval的map曲线 376 | #----------------------# 377 | if local_rank == 0: 378 | eval_callback = EvalCallback(model, input_shape, num_classes, val_lines, VOCdevkit_path, log_dir, Cuda, \ 379 | eval_flag=eval_flag, period=eval_period) 380 | else: 381 | eval_callback = None 382 | 383 | #---------------------------------------# 384 | # 开始模型训练 385 | #---------------------------------------# 386 | for epoch in range(Init_Epoch, UnFreeze_Epoch): 387 | #---------------------------------------# 388 | # 如果模型有冻结学习部分 389 | # 则解冻,并设置参数 390 | #---------------------------------------# 391 | if epoch >= Freeze_Epoch and not UnFreeze_flag and Freeze_Train: 392 | batch_size = Unfreeze_batch_size 393 | 394 | #-------------------------------------------------------------------# 395 | # 判断当前batch_size,自适应调整学习率 396 | #-------------------------------------------------------------------# 397 | nbs = 16 398 | lr_limit_max = 1e-4 if optimizer_type in ['adam', 'adamw'] else 5e-2 399 | lr_limit_min = 3e-5 if optimizer_type in ['adam', 'adamw'] else 5e-4 400 | Init_lr_fit = min(max(batch_size / nbs * Init_lr, lr_limit_min), lr_limit_max) 401 | Min_lr_fit = min(max(batch_size / nbs * Min_lr, lr_limit_min * 1e-2), lr_limit_max * 1e-2) 402 | #---------------------------------------# 403 | # 获得学习率下降的公式 404 | #---------------------------------------# 405 | lr_scheduler_func = get_lr_scheduler(lr_decay_type, Init_lr_fit, Min_lr_fit, UnFreeze_Epoch) 406 | 407 | for param in model.backbone.parameters(): 408 | param.requires_grad = True 409 | 410 | epoch_step = num_train // batch_size 411 | epoch_step_val = num_val // batch_size 412 | 413 | if epoch_step == 0 or epoch_step_val == 0: 414 | raise ValueError("数据集过小,无法继续进行训练,请扩充数据集。") 415 | 416 | gen = DataLoader(train_dataset, shuffle = shuffle, batch_size = batch_size, num_workers = num_workers, pin_memory=True, 417 | drop_last = True, collate_fn = seg_dataset_collate, sampler=train_sampler) 418 | gen_val = DataLoader(val_dataset , shuffle = shuffle, batch_size = batch_size, num_workers = num_workers, pin_memory=True, 419 | drop_last = True, collate_fn = seg_dataset_collate, sampler=val_sampler) 420 | 421 | UnFreeze_flag = True 422 | 423 | if distributed: 424 | train_sampler.set_epoch(epoch) 425 | 426 | set_optimizer_lr(optimizer, lr_scheduler_func, epoch) 427 | 428 | fit_one_epoch(model_train, model, loss_history, eval_callback, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val, UnFreeze_Epoch, Cuda, \ 429 | dice_loss, focal_loss, cls_weights, num_classes, fp16, scaler, save_period, save_dir, local_rank) 430 | 431 | if distributed: 432 | dist.barrier() 433 | 434 | if local_rank == 0: 435 | loss_history.writer.close() 436 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tiany-zhang/MF-SegFormer/009901d87c3911b51d3c3b6b9b7f976c6f8f3d72/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/callbacks.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tiany-zhang/MF-SegFormer/009901d87c3911b51d3c3b6b9b7f976c6f8f3d72/utils/__pycache__/callbacks.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/dataloader.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tiany-zhang/MF-SegFormer/009901d87c3911b51d3c3b6b9b7f976c6f8f3d72/utils/__pycache__/dataloader.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tiany-zhang/MF-SegFormer/009901d87c3911b51d3c3b6b9b7f976c6f8f3d72/utils/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils_fit.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tiany-zhang/MF-SegFormer/009901d87c3911b51d3c3b6b9b7f976c6f8f3d72/utils/__pycache__/utils_fit.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils_metrics.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tiany-zhang/MF-SegFormer/009901d87c3911b51d3c3b6b9b7f976c6f8f3d72/utils/__pycache__/utils_metrics.cpython-38.pyc -------------------------------------------------------------------------------- /utils/callbacks.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import matplotlib 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | matplotlib.use('Agg') 8 | from matplotlib import pyplot as plt 9 | import scipy.signal 10 | 11 | import cv2 12 | import shutil 13 | import numpy as np 14 | 15 | from PIL import Image 16 | from tqdm import tqdm 17 | from torch.utils.tensorboard import SummaryWriter 18 | from .utils import cvtColor, preprocess_input, resize_image 19 | from .utils_metrics import compute_mIoU 20 | 21 | 22 | class LossHistory(): 23 | def __init__(self, log_dir, model, input_shape): 24 | self.log_dir = log_dir 25 | self.losses = [] 26 | self.val_loss = [] 27 | 28 | os.makedirs(self.log_dir) 29 | self.writer = SummaryWriter(self.log_dir) 30 | try: 31 | dummy_input = torch.randn(2, 3, input_shape[0], input_shape[1]) 32 | self.writer.add_graph(model, dummy_input) 33 | except: 34 | pass 35 | 36 | def append_loss(self, epoch, loss, val_loss): 37 | if not os.path.exists(self.log_dir): 38 | os.makedirs(self.log_dir) 39 | 40 | self.losses.append(loss) 41 | self.val_loss.append(val_loss) 42 | 43 | with open(os.path.join(self.log_dir, "epoch_loss.txt"), 'a') as f: 44 | f.write(str(loss)) 45 | f.write("\n") 46 | with open(os.path.join(self.log_dir, "epoch_val_loss.txt"), 'a') as f: 47 | f.write(str(val_loss)) 48 | f.write("\n") 49 | 50 | self.writer.add_scalar('loss', loss, epoch) 51 | self.writer.add_scalar('val_loss', val_loss, epoch) 52 | self.loss_plot() 53 | 54 | def loss_plot(self): 55 | iters = range(len(self.losses)) 56 | 57 | plt.figure() 58 | plt.plot(iters, self.losses, 'red', linewidth = 2, label='train loss') 59 | plt.plot(iters, self.val_loss, 'coral', linewidth = 2, label='val loss') 60 | try: 61 | if len(self.losses) < 25: 62 | num = 5 63 | else: 64 | num = 15 65 | 66 | plt.plot(iters, scipy.signal.savgol_filter(self.losses, num, 3), 'green', linestyle = '--', linewidth = 2, label='smooth train loss') 67 | plt.plot(iters, scipy.signal.savgol_filter(self.val_loss, num, 3), '#8B4513', linestyle = '--', linewidth = 2, label='smooth val loss') 68 | except: 69 | pass 70 | 71 | plt.grid(True) 72 | plt.xlabel('Epoch') 73 | plt.ylabel('Loss') 74 | plt.legend(loc="upper right") 75 | 76 | plt.savefig(os.path.join(self.log_dir, "epoch_loss.png")) 77 | 78 | plt.cla() 79 | plt.close("all") 80 | 81 | class EvalCallback(): 82 | def __init__(self, net, input_shape, num_classes, image_ids, dataset_path, log_dir, cuda, \ 83 | miou_out_path=".temp_miou_out", eval_flag=True, period=1): 84 | super(EvalCallback, self).__init__() 85 | 86 | self.net = net 87 | self.input_shape = input_shape 88 | self.num_classes = num_classes 89 | self.image_ids = image_ids 90 | self.dataset_path = dataset_path 91 | self.log_dir = log_dir 92 | self.cuda = cuda 93 | self.miou_out_path = miou_out_path 94 | self.eval_flag = eval_flag 95 | self.period = period 96 | 97 | self.image_ids = [image_id.split()[0] for image_id in image_ids] 98 | self.mious = [0] 99 | self.epoches = [0] 100 | if self.eval_flag: 101 | with open(os.path.join(self.log_dir, "epoch_miou.txt"), 'a') as f: 102 | f.write(str(0)) 103 | f.write("\n") 104 | 105 | def get_miou_png(self, image): 106 | #---------------------------------------------------------# 107 | # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。 108 | # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB 109 | #---------------------------------------------------------# 110 | image = cvtColor(image) 111 | orininal_h = np.array(image).shape[0] 112 | orininal_w = np.array(image).shape[1] 113 | #---------------------------------------------------------# 114 | # 给图像增加灰条,实现不失真的resize 115 | # 也可以直接resize进行识别 116 | #---------------------------------------------------------# 117 | image_data, nw, nh = resize_image(image, (self.input_shape[1],self.input_shape[0])) 118 | #---------------------------------------------------------# 119 | # 添加上batch_size维度 120 | #---------------------------------------------------------# 121 | image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, np.float32)), (2, 0, 1)), 0) 122 | 123 | with torch.no_grad(): 124 | images = torch.from_numpy(image_data) 125 | if self.cuda: 126 | images = images.cuda() 127 | 128 | #---------------------------------------------------# 129 | # 图片传入网络进行预测 130 | #---------------------------------------------------# 131 | pr = self.net(images)[0] 132 | #---------------------------------------------------# 133 | # 取出每一个像素点的种类 134 | #---------------------------------------------------# 135 | pr = F.softmax(pr.permute(1,2,0),dim = -1).cpu().numpy() 136 | #--------------------------------------# 137 | # 将灰条部分截取掉 138 | #--------------------------------------# 139 | pr = pr[int((self.input_shape[0] - nh) // 2) : int((self.input_shape[0] - nh) // 2 + nh), \ 140 | int((self.input_shape[1] - nw) // 2) : int((self.input_shape[1] - nw) // 2 + nw)] 141 | #---------------------------------------------------# 142 | # 进行图片的resize 143 | #---------------------------------------------------# 144 | pr = cv2.resize(pr, (orininal_w, orininal_h), interpolation = cv2.INTER_LINEAR) 145 | #---------------------------------------------------# 146 | # 取出每一个像素点的种类 147 | #---------------------------------------------------# 148 | pr = pr.argmax(axis=-1) 149 | 150 | image = Image.fromarray(np.uint8(pr)) 151 | return image 152 | 153 | def on_epoch_end(self, epoch, model_eval): 154 | if epoch % self.period == 0 and self.eval_flag: 155 | self.net = model_eval 156 | gt_dir = os.path.join(self.dataset_path, "VOC2007/SegmentationClass/") 157 | pred_dir = os.path.join(self.miou_out_path, 'detection-results') 158 | if not os.path.exists(self.miou_out_path): 159 | os.makedirs(self.miou_out_path) 160 | if not os.path.exists(pred_dir): 161 | os.makedirs(pred_dir) 162 | print("Get miou.") 163 | for image_id in tqdm(self.image_ids): 164 | #-------------------------------# 165 | # 从文件中读取图像 166 | #-------------------------------# 167 | image_path = os.path.join(self.dataset_path, "VOC2007/JPEGImages/"+image_id+".png") 168 | image = Image.open(image_path) 169 | #------------------------------# 170 | # 获得预测txt 171 | #------------------------------# 172 | image = self.get_miou_png(image) 173 | image.save(os.path.join(pred_dir, image_id + ".png")) 174 | 175 | print("Calculate miou.") 176 | _, IoUs, _, _ = compute_mIoU(gt_dir, pred_dir, self.image_ids, self.num_classes, None) # 执行计算mIoU的函数 177 | temp_miou = np.nanmean(IoUs) * 100 178 | 179 | self.mious.append(temp_miou) 180 | self.epoches.append(epoch) 181 | 182 | with open(os.path.join(self.log_dir, "epoch_miou.txt"), 'a') as f: 183 | f.write(str(temp_miou)) 184 | f.write("\n") 185 | 186 | plt.figure() 187 | plt.plot(self.epoches, self.mious, 'red', linewidth = 2, label='train miou') 188 | 189 | plt.grid(True) 190 | plt.xlabel('Epoch') 191 | plt.ylabel('Miou') 192 | plt.title('A Miou Curve') 193 | plt.legend(loc="upper right") 194 | 195 | plt.savefig(os.path.join(self.log_dir, "epoch_miou.png")) 196 | plt.cla() 197 | plt.close("all") 198 | 199 | print("Get miou done.") 200 | shutil.rmtree(self.miou_out_path) 201 | -------------------------------------------------------------------------------- /utils/dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import cv2 4 | import numpy as np 5 | import torch 6 | from PIL import Image 7 | from torch.utils.data.dataset import Dataset 8 | from utils.utils import preprocess_input, cvtColor 9 | 10 | 11 | class SegmentationDataset(Dataset): 12 | def __init__(self, annotation_lines, input_shape, num_classes, train, dataset_path): 13 | super(SegmentationDataset, self).__init__() 14 | self.annotation_lines = annotation_lines 15 | self.length = len(annotation_lines) 16 | self.input_shape = input_shape 17 | self.num_classes = num_classes 18 | self.train = train 19 | self.dataset_path = dataset_path 20 | 21 | def __len__(self): 22 | return self.length 23 | 24 | def __getitem__(self, index): 25 | annotation_line = self.annotation_lines[index] 26 | name = annotation_line.split()[0] 27 | 28 | #-------------------------------# 29 | # 从文件中读取图像 30 | #-------------------------------# 31 | jpg = Image.open(os.path.join(os.path.join(self.dataset_path, "VOC2007/JPEGImages"), name + ".png")) 32 | png = Image.open(os.path.join(os.path.join(self.dataset_path, "VOC2007/SegmentationClass"), name + ".png")) 33 | #-------------------------------# 34 | # 数据增强 35 | #-------------------------------# 36 | jpg, png = self.get_random_data(jpg, png, self.input_shape, random = self.train) 37 | jpg = np.transpose(preprocess_input(np.array(jpg, np.float64)), [2,0,1]) 38 | png = np.array(png) 39 | 40 | 41 | png[png >= self.num_classes] = self.num_classes 42 | #-------------------------------------------------------# 43 | # 转化成one_hot的形式 44 | # 在这里需要+1是因为voc数据集有些标签具有白边部分 45 | # 我们需要将白边部分进行忽略,+1的目的是方便忽略。 46 | #-------------------------------------------------------# 47 | seg_labels = np.eye(self.num_classes + 1)[png.reshape([-1])] 48 | seg_labels = seg_labels.reshape((int(self.input_shape[0]), int(self.input_shape[1]), self.num_classes + 1)) 49 | 50 | return jpg, png, seg_labels 51 | 52 | def rand(self, a=0, b=1): 53 | return np.random.rand() * (b - a) + a 54 | 55 | def get_random_data(self, image, label, input_shape, jitter=.3, hue=.1, sat=0.7, val=0.3, random=True): 56 | image = cvtColor(image) 57 | label = np.array(label) 58 | label = label[:,:,0] 59 | label = label//255.0 60 | label = Image.fromarray(np.array(label)) 61 | #------------------------------# 62 | # 获得图像的高宽与目标高宽 63 | #------------------------------# 64 | iw, ih = image.size 65 | h, w = input_shape 66 | 67 | if not random: 68 | iw, ih = image.size 69 | scale = min(w/iw, h/ih) 70 | nw = int(iw*scale) 71 | nh = int(ih*scale) 72 | 73 | image = image.resize((nw,nh), Image.BICUBIC) 74 | new_image = Image.new('RGB', [w, h], (128,128,128)) 75 | new_image.paste(image, ((w-nw)//2, (h-nh)//2)) 76 | 77 | label = label.resize((nw,nh), Image.NEAREST) 78 | new_label = Image.new('L', [w, h], (0)) 79 | new_label.paste(label, ((w-nw)//2, (h-nh)//2)) 80 | return new_image, new_label 81 | 82 | #------------------------------------------# 83 | # 对图像进行缩放并且进行长和宽的扭曲 84 | #------------------------------------------# 85 | new_ar = iw/ih * self.rand(1-jitter,1+jitter) / self.rand(1-jitter,1+jitter) 86 | scale = self.rand(0.5, 2) 87 | if new_ar < 1: 88 | nh = int(scale*h) 89 | nw = int(nh*new_ar) 90 | else: 91 | nw = int(scale*w) 92 | nh = int(nw/new_ar) 93 | image = image.resize((nw,nh), Image.BICUBIC) 94 | label = label.resize((nw,nh), Image.NEAREST) 95 | 96 | #------------------------------------------# 97 | # 翻转图像 98 | #------------------------------------------# 99 | flip = self.rand()<.5 100 | if flip: 101 | image = image.transpose(Image.FLIP_LEFT_RIGHT) 102 | label = label.transpose(Image.FLIP_LEFT_RIGHT) 103 | 104 | #------------------------------------------# 105 | # 将图像多余的部分加上灰条 106 | #------------------------------------------# 107 | dx = int(self.rand(0, w-nw)) 108 | dy = int(self.rand(0, h-nh)) 109 | new_image = Image.new('RGB', (w,h), (128,128,128)) 110 | new_label = Image.new('L', (w,h), (0)) 111 | new_image.paste(image, (dx, dy)) 112 | new_label.paste(label, (dx, dy)) 113 | image = new_image 114 | label = new_label 115 | 116 | image_data = np.array(image, np.uint8) 117 | #------------------------------------------# 118 | # 高斯模糊 119 | #------------------------------------------# 120 | blur = self.rand() < 0.25 121 | if blur: 122 | image_data = cv2.GaussianBlur(image_data, (5, 5), 0) 123 | 124 | #------------------------------------------# 125 | # 旋转 126 | #------------------------------------------# 127 | rotate = self.rand() < 0.25 128 | if rotate: 129 | center = (w // 2, h // 2) 130 | rotation = np.random.randint(-10, 11) 131 | M = cv2.getRotationMatrix2D(center, -rotation, scale=1) 132 | image_data = cv2.warpAffine(image_data, M, (w, h), flags=cv2.INTER_CUBIC, borderValue=(128,128,128)) 133 | label = cv2.warpAffine(np.array(label, np.uint8), M, (w, h), flags=cv2.INTER_NEAREST, borderValue=(0)) 134 | 135 | #---------------------------------# 136 | # 对图像进行色域变换 137 | # 计算色域变换的参数 138 | #---------------------------------# 139 | r = np.random.uniform(-1, 1, 3) * [hue, sat, val] + 1 140 | #---------------------------------# 141 | # 将图像转到HSV上 142 | #---------------------------------# 143 | hue, sat, val = cv2.split(cv2.cvtColor(image_data, cv2.COLOR_RGB2HSV)) 144 | dtype = image_data.dtype 145 | #---------------------------------# 146 | # 应用变换 147 | #---------------------------------# 148 | x = np.arange(0, 256, dtype=r.dtype) 149 | lut_hue = ((x * r[0]) % 180).astype(dtype) 150 | lut_sat = np.clip(x * r[1], 0, 255).astype(dtype) 151 | lut_val = np.clip(x * r[2], 0, 255).astype(dtype) 152 | 153 | image_data = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val))) 154 | image_data = cv2.cvtColor(image_data, cv2.COLOR_HSV2RGB) 155 | 156 | return image_data, label 157 | 158 | 159 | def seg_dataset_collate(batch): 160 | images = [] 161 | pngs = [] 162 | seg_labels = [] 163 | for img, png, labels in batch: 164 | images.append(img) 165 | pngs.append(png) 166 | seg_labels.append(labels) 167 | images = torch.from_numpy(np.array(images)).type(torch.FloatTensor) 168 | pngs = torch.from_numpy(np.array(pngs)).long() 169 | seg_labels = torch.from_numpy(np.array(seg_labels)).type(torch.FloatTensor) 170 | return images, pngs, seg_labels 171 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | 4 | #---------------------------------------------------------# 5 | # 将图像转换成RGB图像,防止灰度图在预测时报错。 6 | # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB 7 | #---------------------------------------------------------# 8 | def cvtColor(image): 9 | if len(np.shape(image)) == 3 and np.shape(image)[2] == 3: 10 | return image 11 | else: 12 | image = image.convert('RGB') 13 | return image 14 | 15 | #---------------------------------------------------# 16 | # 对输入图像进行resize 17 | #---------------------------------------------------# 18 | def resize_image(image, size): 19 | iw, ih = image.size 20 | w, h = size 21 | 22 | scale = min(w/iw, h/ih) 23 | nw = int(iw*scale) 24 | nh = int(ih*scale) 25 | 26 | image = image.resize((nw,nh), Image.BICUBIC) 27 | new_image = Image.new('RGB', size, (128,128,128)) 28 | new_image.paste(image, ((w-nw)//2, (h-nh)//2)) 29 | 30 | return new_image, nw, nh 31 | 32 | #---------------------------------------------------# 33 | # 获得学习率 34 | #---------------------------------------------------# 35 | def get_lr(optimizer): 36 | for param_group in optimizer.param_groups: 37 | return param_group['lr'] 38 | 39 | def preprocess_input(image): 40 | image -= np.array([123.675, 116.28, 103.53], np.float32) 41 | image /= np.array([58.395, 57.12, 57.375], np.float32) 42 | return image 43 | 44 | def show_config(**kwargs): 45 | print('Configurations:') 46 | print('-' * 70) 47 | print('|%25s | %40s|' % ('keys', 'values')) 48 | print('-' * 70) 49 | for key, value in kwargs.items(): 50 | print('|%25s | %40s|' % (str(key), str(value))) 51 | print('-' * 70) 52 | 53 | def download_weights(phi, model_dir="./model_data"): 54 | import os 55 | from torch.hub import load_state_dict_from_url 56 | 57 | download_urls = { 58 | 'b0' : "https://github.com/bubbliiiing/segformer-pytorch/releases/download/v1.0/segformer_b0_backbone_weights.pth", 59 | 'b1' : "https://github.com/bubbliiiing/segformer-pytorch/releases/download/v1.0/segformer_b1_backbone_weights.pth", 60 | 'b2' : "https://github.com/bubbliiiing/segformer-pytorch/releases/download/v1.0/segformer_b2_backbone_weights.pth", 61 | 'b3' : "https://github.com/bubbliiiing/segformer-pytorch/releases/download/v1.0/segformer_b3_backbone_weights.pth", 62 | 'b4' : "https://github.com/bubbliiiing/segformer-pytorch/releases/download/v1.0/segformer_b4_backbone_weights.pth", 63 | 'b5' : "https://github.com/bubbliiiing/segformer-pytorch/releases/download/v1.0/segformer_b5_backbone_weights.pth", 64 | } 65 | url = download_urls[phi] 66 | 67 | if not os.path.exists(model_dir): 68 | os.makedirs(model_dir) 69 | load_state_dict_from_url(url, model_dir) -------------------------------------------------------------------------------- /utils/utils_fit.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from nets.segformer_training import (CE_Loss, Dice_loss, Focal_Loss, 5 | weights_init) 6 | from tqdm import tqdm 7 | 8 | from utils.utils import get_lr 9 | from utils.utils_metrics import f_score 10 | 11 | 12 | def fit_one_epoch(model_train, model, loss_history, eval_callback, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val, Epoch, cuda, dice_loss, focal_loss, cls_weights, num_classes, fp16, scaler, save_period, save_dir, local_rank=0): 13 | total_loss = 0 14 | total_f_score = 0 15 | 16 | val_loss = 0 17 | val_f_score = 0 18 | 19 | if local_rank == 0: 20 | print('Start Train') 21 | pbar = tqdm(total=epoch_step,desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3) 22 | model_train.train() 23 | for iteration, batch in enumerate(gen): 24 | if iteration >= epoch_step: 25 | break 26 | imgs, pngs, labels = batch 27 | with torch.no_grad(): 28 | weights = torch.from_numpy(cls_weights) 29 | if cuda: 30 | imgs = imgs.cuda(local_rank) 31 | pngs = pngs.cuda(local_rank) 32 | labels = labels.cuda(local_rank) 33 | weights = weights.cuda(local_rank) 34 | 35 | optimizer.zero_grad() 36 | if not fp16: 37 | #----------------------# 38 | # 前向传播 39 | #----------------------# 40 | outputs = model_train(imgs) 41 | #----------------------# 42 | # 计算损失 43 | #----------------------# 44 | if focal_loss: 45 | loss = Focal_Loss(outputs, pngs, weights, num_classes = num_classes) 46 | else: 47 | loss = CE_Loss(outputs, pngs, weights, num_classes = num_classes) 48 | 49 | if dice_loss: 50 | main_dice = Dice_loss(outputs, labels) 51 | loss = loss + main_dice 52 | 53 | with torch.no_grad(): 54 | #-------------------------------# 55 | # 计算f_score 56 | #-------------------------------# 57 | _f_score = f_score(outputs, labels) 58 | 59 | loss.backward() 60 | optimizer.step() 61 | else: 62 | from torch.cuda.amp import autocast 63 | with autocast(): 64 | #----------------------# 65 | # 前向传播 66 | #----------------------# 67 | outputs = model_train(imgs) 68 | #----------------------# 69 | # 计算损失 70 | #----------------------# 71 | if focal_loss: 72 | loss = Focal_Loss(outputs, pngs, weights, num_classes = num_classes) 73 | else: 74 | loss = CE_Loss(outputs, pngs, weights, num_classes = num_classes) 75 | 76 | if dice_loss: 77 | main_dice = Dice_loss(outputs, labels) 78 | loss = loss + main_dice 79 | 80 | with torch.no_grad(): 81 | #-------------------------------# 82 | # 计算f_score 83 | #-------------------------------# 84 | _f_score = f_score(outputs, labels) 85 | 86 | #----------------------# 87 | # 反向传播 88 | #----------------------# 89 | scaler.scale(loss).backward() 90 | scaler.step(optimizer) 91 | scaler.update() 92 | 93 | total_loss += loss.item() 94 | total_f_score += _f_score.item() 95 | 96 | if local_rank == 0: 97 | pbar.set_postfix(**{'total_loss': total_loss / (iteration + 1), 98 | 'f_score' : total_f_score / (iteration + 1), 99 | 'lr' : get_lr(optimizer)}) 100 | pbar.update(1) 101 | 102 | if local_rank == 0: 103 | pbar.close() 104 | print('Finish Train') 105 | print('Start Validation') 106 | pbar = tqdm(total=epoch_step_val, desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3) 107 | 108 | model_train.eval() 109 | for iteration, batch in enumerate(gen_val): 110 | if iteration >= epoch_step_val: 111 | break 112 | imgs, pngs, labels = batch 113 | with torch.no_grad(): 114 | weights = torch.from_numpy(cls_weights) 115 | if cuda: 116 | imgs = imgs.cuda(local_rank) 117 | pngs = pngs.cuda(local_rank) 118 | labels = labels.cuda(local_rank) 119 | weights = weights.cuda(local_rank) 120 | 121 | #----------------------# 122 | # 前向传播 123 | #----------------------# 124 | outputs = model_train(imgs) 125 | #----------------------# 126 | # 损失计算 127 | #----------------------# 128 | if focal_loss: 129 | loss = Focal_Loss(outputs, pngs, weights, num_classes = num_classes) 130 | else: 131 | loss = CE_Loss(outputs, pngs, weights, num_classes = num_classes) 132 | 133 | if dice_loss: 134 | main_dice = Dice_loss(outputs, labels) 135 | loss = loss + main_dice 136 | #-------------------------------# 137 | # 计算f_score 138 | #-------------------------------# 139 | _f_score = f_score(outputs, labels) 140 | 141 | val_loss += loss.item() 142 | val_f_score += _f_score.item() 143 | 144 | if local_rank == 0: 145 | pbar.set_postfix(**{'val_loss' : val_loss / (iteration + 1), 146 | 'f_score' : val_f_score / (iteration + 1), 147 | 'lr' : get_lr(optimizer)}) 148 | pbar.update(1) 149 | 150 | if local_rank == 0: 151 | pbar.close() 152 | print('Finish Validation') 153 | loss_history.append_loss(epoch + 1, total_loss / epoch_step, val_loss / epoch_step_val) 154 | eval_callback.on_epoch_end(epoch + 1, model_train) 155 | print('Epoch:'+ str(epoch + 1) + '/' + str(Epoch)) 156 | print('Total Loss: %.3f || Val Loss: %.3f ' % (total_loss / epoch_step, val_loss / epoch_step_val)) 157 | 158 | #-----------------------------------------------# 159 | # 保存权值 160 | #-----------------------------------------------# 161 | if (epoch + 1) % save_period == 0 or epoch + 1 == Epoch: 162 | torch.save(model.state_dict(), os.path.join(save_dir, 'ep%03d-loss%.3f-val_loss%.3f.pth'%((epoch + 1), total_loss / epoch_step, val_loss / epoch_step_val))) 163 | 164 | if len(loss_history.val_loss) <= 1 or (val_loss / epoch_step_val) <= min(loss_history.val_loss): 165 | print('Save best model to best_epoch_weights.pth') 166 | torch.save(model.state_dict(), os.path.join(save_dir, "best_epoch_weights.pth")) 167 | 168 | torch.save(model.state_dict(), os.path.join(save_dir, "last_epoch_weights.pth")) 169 | -------------------------------------------------------------------------------- /utils/utils_metrics.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import os 3 | from os.path import join 4 | 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | from PIL import Image 10 | 11 | 12 | def f_score(inputs, target, beta=1, smooth = 1e-5, threhold = 0.5): 13 | n, c, h, w = inputs.size() 14 | nt, ht, wt, ct = target.size() 15 | if h != ht and w != wt: 16 | inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True) 17 | 18 | temp_inputs = torch.softmax(inputs.transpose(1, 2).transpose(2, 3).contiguous().view(n, -1, c),-1) 19 | temp_target = target.view(n, -1, ct) 20 | 21 | #--------------------------------------------# 22 | # 计算dice系数 23 | #--------------------------------------------# 24 | temp_inputs = torch.gt(temp_inputs, threhold).float() 25 | tp = torch.sum(temp_target[...,:-1] * temp_inputs, axis=[0,1]) 26 | fp = torch.sum(temp_inputs , axis=[0,1]) - tp 27 | fn = torch.sum(temp_target[...,:-1] , axis=[0,1]) - tp 28 | 29 | score = ((1 + beta ** 2) * tp + smooth) / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + smooth) 30 | score = torch.mean(score) 31 | return score 32 | 33 | # 设标签宽W,长H 34 | def fast_hist(a, b, n): 35 | #--------------------------------------------------------------------------------# 36 | # a是转化成一维数组的标签,形状(H×W,);b是转化成一维数组的预测结果,形状(H×W,) 37 | #--------------------------------------------------------------------------------# 38 | k = (a >= 0) & (a < n) 39 | #--------------------------------------------------------------------------------# 40 | # np.bincount计算了从0到n**2-1这n**2个数中每个数出现的次数,返回值形状(n, n) 41 | # 返回中,写对角线上的为分类正确的像素点 42 | #--------------------------------------------------------------------------------# 43 | return np.bincount(n * a[k].astype(int) + b[k], minlength=n ** 2).reshape(n, n) 44 | 45 | def per_class_iu(hist): 46 | return np.diag(hist) / np.maximum((hist.sum(1) + hist.sum(0) - np.diag(hist)), 1) 47 | 48 | def per_class_PA_Recall(hist): 49 | return np.diag(hist) / np.maximum(hist.sum(1), 1) 50 | 51 | def per_class_Precision(hist): 52 | return np.diag(hist) / np.maximum(hist.sum(0), 1) 53 | 54 | def per_Accuracy(hist): 55 | return np.sum(np.diag(hist)) / np.maximum(np.sum(hist), 1) 56 | 57 | def compute_mIoU(gt_dir, pred_dir, png_name_list, num_classes, name_classes=None): 58 | print('Num classes', num_classes) 59 | #-----------------------------------------# 60 | # 创建一个全是0的矩阵,是一个混淆矩阵 61 | #-----------------------------------------# 62 | hist = np.zeros((num_classes, num_classes)) 63 | 64 | #------------------------------------------------# 65 | # 获得验证集标签路径列表,方便直接读取 66 | # 获得验证集图像分割结果路径列表,方便直接读取 67 | #------------------------------------------------# 68 | gt_imgs = [join(gt_dir, x + ".png") for x in png_name_list] 69 | pred_imgs = [join(pred_dir, x + ".png") for x in png_name_list] 70 | 71 | #------------------------------------------------# 72 | # 读取每一个(图片-标签)对 73 | #------------------------------------------------# 74 | for ind in range(len(gt_imgs)): 75 | #------------------------------------------------# 76 | # 读取一张图像分割结果,转化成numpy数组 77 | #------------------------------------------------# 78 | pred = np.array(Image.open(pred_imgs[ind])) 79 | #------------------------------------------------# 80 | # 读取一张对应的标签,转化成numpy数组 81 | #------------------------------------------------# 82 | label = np.array(Image.open(gt_imgs[ind])) 83 | 84 | # 如果图像分割结果与标签的大小不一样,这张图片就不计算 85 | if len(label.flatten()) != len(pred.flatten()): 86 | print( 87 | 'Skipping: len(gt) = {:d}, len(pred) = {:d}, {:s}, {:s}'.format( 88 | len(label.flatten()), len(pred.flatten()), gt_imgs[ind], 89 | pred_imgs[ind])) 90 | continue 91 | 92 | #------------------------------------------------# 93 | # 对一张图片计算21×21的hist矩阵,并累加 94 | #------------------------------------------------# 95 | hist += fast_hist(label.flatten(), pred.flatten(), num_classes) 96 | # 每计算10张就输出一下目前已计算的图片中所有类别平均的mIoU值 97 | if name_classes is not None and ind > 0 and ind % 10 == 0: 98 | print('{:d} / {:d}: mIou-{:0.2f}%; mPA-{:0.2f}%; Accuracy-{:0.2f}%'.format( 99 | ind, 100 | len(gt_imgs), 101 | 100 * np.nanmean(per_class_iu(hist)), 102 | 100 * np.nanmean(per_class_PA_Recall(hist)), 103 | 100 * per_Accuracy(hist) 104 | ) 105 | ) 106 | #------------------------------------------------# 107 | # 计算所有验证集图片的逐类别mIoU值 108 | #------------------------------------------------# 109 | IoUs = per_class_iu(hist) 110 | PA_Recall = per_class_PA_Recall(hist) 111 | Precision = per_class_Precision(hist) 112 | #------------------------------------------------# 113 | # 逐类别输出一下mIoU值 114 | #------------------------------------------------# 115 | if name_classes is not None: 116 | for ind_class in range(num_classes): 117 | print('===>' + name_classes[ind_class] + ':\tIou-' + str(round(IoUs[ind_class] * 100, 2)) \ 118 | + '; Recall (equal to the PA)-' + str(round(PA_Recall[ind_class] * 100, 2))+ '; Precision-' + str(round(Precision[ind_class] * 100, 2))) 119 | 120 | #-----------------------------------------------------------------# 121 | # 在所有验证集图像上求所有类别平均的mIoU值,计算时忽略NaN值 122 | #-----------------------------------------------------------------# 123 | print('===> mIoU: ' + str(round(np.nanmean(IoUs) * 100, 2)) + '; mPA: ' + str(round(np.nanmean(PA_Recall) * 100, 2)) + '; Accuracy: ' + str(round(per_Accuracy(hist) * 100, 2))) 124 | return np.array(hist, np.int), IoUs, PA_Recall, Precision 125 | 126 | def adjust_axes(r, t, fig, axes): 127 | bb = t.get_window_extent(renderer=r) 128 | text_width_inches = bb.width / fig.dpi 129 | current_fig_width = fig.get_figwidth() 130 | new_fig_width = current_fig_width + text_width_inches 131 | propotion = new_fig_width / current_fig_width 132 | x_lim = axes.get_xlim() 133 | axes.set_xlim([x_lim[0], x_lim[1] * propotion]) 134 | 135 | def draw_plot_func(values, name_classes, plot_title, x_label, output_path, tick_font_size = 12, plt_show = True): 136 | fig = plt.gcf() 137 | axes = plt.gca() 138 | plt.barh(range(len(values)), values, color='royalblue') 139 | plt.title(plot_title, fontsize=tick_font_size + 2) 140 | plt.xlabel(x_label, fontsize=tick_font_size) 141 | plt.yticks(range(len(values)), name_classes, fontsize=tick_font_size) 142 | r = fig.canvas.get_renderer() 143 | for i, val in enumerate(values): 144 | str_val = " " + str(val) 145 | if val < 1.0: 146 | str_val = " {0:.2f}".format(val) 147 | t = plt.text(val, i, str_val, color='royalblue', va='center', fontweight='bold') 148 | if i == (len(values)-1): 149 | adjust_axes(r, t, fig, axes) 150 | 151 | fig.tight_layout() 152 | fig.savefig(output_path) 153 | if plt_show: 154 | plt.show() 155 | plt.close() 156 | 157 | def show_results(miou_out_path, hist, IoUs, PA_Recall, Precision, name_classes, tick_font_size = 12): 158 | draw_plot_func(IoUs, name_classes, "mIoU = {0:.2f}%".format(np.nanmean(IoUs)*100), "Intersection over Union", \ 159 | os.path.join(miou_out_path, "mIoU.png"), tick_font_size = tick_font_size, plt_show = True) 160 | print("Save mIoU out to " + os.path.join(miou_out_path, "mIoU.png")) 161 | 162 | draw_plot_func(PA_Recall, name_classes, "mPA = {0:.2f}%".format(np.nanmean(PA_Recall)*100), "Pixel Accuracy", \ 163 | os.path.join(miou_out_path, "mPA.png"), tick_font_size = tick_font_size, plt_show = False) 164 | print("Save mPA out to " + os.path.join(miou_out_path, "mPA.png")) 165 | 166 | draw_plot_func(PA_Recall, name_classes, "mRecall = {0:.2f}%".format(np.nanmean(PA_Recall)*100), "Recall", \ 167 | os.path.join(miou_out_path, "Recall.png"), tick_font_size = tick_font_size, plt_show = False) 168 | print("Save Recall out to " + os.path.join(miou_out_path, "Recall.png")) 169 | 170 | draw_plot_func(Precision, name_classes, "mPrecision = {0:.2f}%".format(np.nanmean(Precision)*100), "Precision", \ 171 | os.path.join(miou_out_path, "Precision.png"), tick_font_size = tick_font_size, plt_show = False) 172 | print("Save Precision out to " + os.path.join(miou_out_path, "Precision.png")) 173 | 174 | with open(os.path.join(miou_out_path, "confusion_matrix.csv"), 'w', newline='') as f: 175 | writer = csv.writer(f) 176 | writer_list = [] 177 | writer_list.append([' '] + [str(c) for c in name_classes]) 178 | for i in range(len(hist)): 179 | writer_list.append([name_classes[i]] + [str(x) for x in hist[i]]) 180 | writer.writerows(writer_list) 181 | print("Save confusion_matrix out to " + os.path.join(miou_out_path, "confusion_matrix.csv")) 182 | -------------------------------------------------------------------------------- /voc_annotation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import numpy as np 5 | from PIL import Image 6 | from tqdm import tqdm 7 | 8 | #-------------------------------------------------------# 9 | # 想要增加测试集修改trainval_percent 10 | # 修改train_percent用于改变验证集的比例 9:1 11 | # 12 | # 当前该库将测试集当作验证集使用,不单独划分测试集 13 | #-------------------------------------------------------# 14 | trainval_percent = 1 15 | train_percent = 0.9 16 | #-------------------------------------------------------# 17 | # 指向VOC数据集所在的文件夹 18 | # 默认指向根目录下的VOC数据集 19 | #-------------------------------------------------------# 20 | VOCdevkit_path = './VOCdevkit' 21 | 22 | if __name__ == "__main__": 23 | random.seed(0) 24 | print("Generate txt in ImageSets.") 25 | segfilepath = os.path.join(VOCdevkit_path, 'VOC2007/SegmentationClass') 26 | saveBasePath = os.path.join(VOCdevkit_path, 'VOC2007/ImageSets/Segmentation') 27 | 28 | temp_seg = os.listdir(segfilepath) 29 | total_seg = [] 30 | for seg in temp_seg: 31 | if seg.endswith(".png"): 32 | total_seg.append(seg) 33 | 34 | num = len(total_seg) 35 | list = range(num) 36 | tv = int(num*trainval_percent) 37 | tr = int(tv*train_percent) 38 | trainval= random.sample(list,tv) 39 | train = random.sample(trainval,tr) 40 | 41 | print("train and val size",tv) 42 | print("traub suze",tr) 43 | ftrainval = open(os.path.join(saveBasePath,'trainval.txt'), 'w') 44 | ftest = open(os.path.join(saveBasePath,'test.txt'), 'w') 45 | ftrain = open(os.path.join(saveBasePath,'train.txt'), 'w') 46 | fval = open(os.path.join(saveBasePath,'val.txt'), 'w') 47 | 48 | for i in list: 49 | name = total_seg[i][:-4]+'\n' 50 | if i in trainval: 51 | ftrainval.write(name) 52 | if i in train: 53 | ftrain.write(name) 54 | else: 55 | fval.write(name) 56 | else: 57 | ftest.write(name) 58 | 59 | ftrainval.close() 60 | ftrain.close() 61 | fval.close() 62 | ftest.close() 63 | print("Generate txt in ImageSets done.") 64 | 65 | print("Check datasets format, this may take a while.") 66 | print("检查数据集格式是否符合要求,这可能需要一段时间。") 67 | classes_nums = np.zeros([256], np.int) 68 | for i in tqdm(list): 69 | name = total_seg[i] 70 | png_file_name = os.path.join(segfilepath, name) 71 | if not os.path.exists(png_file_name): 72 | raise ValueError("未检测到标签图片%s,请查看具体路径下文件是否存在以及后缀是否为png。"%(png_file_name)) 73 | 74 | png = np.array(Image.open(png_file_name), np.uint8) 75 | if len(np.shape(png)) > 2: 76 | print("标签图片%s的shape为%s,不属于灰度图或者八位彩图,请仔细检查数据集格式。"%(name, str(np.shape(png)))) 77 | print("标签图片需要为灰度图或者八位彩图,标签的每个像素点的值就是这个像素点所属的种类。"%(name, str(np.shape(png)))) 78 | 79 | classes_nums += np.bincount(np.reshape(png, [-1]), minlength=256) 80 | 81 | print("打印像素点的值与数量。") 82 | print('-' * 37) 83 | print("| %15s | %15s |"%("Key", "Value")) 84 | print('-' * 37) 85 | for i in range(256): 86 | if classes_nums[i] > 0: 87 | print("| %15s | %15s |"%(str(i), str(classes_nums[i]))) 88 | print('-' * 37) 89 | 90 | if classes_nums[255] > 0 and classes_nums[0] > 0 and np.sum(classes_nums[1:255]) == 0: 91 | print("检测到标签中像素点的值仅包含0与255,数据格式有误。") 92 | print("二分类问题需要将标签修改为背景的像素点值为0,目标的像素点值为1。") 93 | elif classes_nums[0] > 0 and np.sum(classes_nums[1:]) == 0: 94 | print("检测到标签中仅仅包含背景像素点,数据格式有误,请仔细检查数据集格式。") 95 | 96 | print("JPEGImages中的图片应当为.jpg文件、SegmentationClass中的图片应当为.png文件。") --------------------------------------------------------------------------------