├── nets ├── __init__.py ├── vgg.py ├── unet.py ├── unet_training.py └── resnet.py ├── utils ├── __init__.py ├── ceshi.py ├── see.py ├── utils.py ├── tile.py ├── fix.py ├── change.py ├── dataloader.py ├── dataloader_medical.py ├── callbacks.py ├── utils_metrics.py └── utils_fit.py ├── requirements.txt ├── .idea ├── misc.xml ├── inspectionProfiles │ └── profiles_settings.xml ├── modules.xml ├── UNet_Demo.iml └── workspace.xml ├── README.md ├── LICENSE ├── voc_annotation_medical.py ├── .gitignore ├── get_miou.py ├── voc_annotation.py ├── predictLabel.py ├── predict.py ├── unet.py ├── train_medical.py └── train.py /nets/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------------------------------- /utils/ceshi.py: -------------------------------------------------------------------------------- 1 | import torch 2 | print(torch.__version__) 3 | print(torch.cuda.is_available()) 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | scipy==1.2.1 2 | numpy==1.17.0 3 | matplotlib==3.1.2 4 | opencv_python==4.1.2.30 5 | torch 6 | torchvision 7 | tqdm==4.60.0 8 | Pillow==8.2.0 9 | h5py==2.10.0 10 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/UNet_Demo.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 13 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 糖尿病眼底病变分割和分类 2 | # 1.数据集准备 3 | ## 1.1 数据集下载 4 | 首先准备糖尿病病变分割的数据集,在这里我们选取南开大学的DDR数据集进行模型训练,该数据集包括分类,分割和目标检测三个部分 5 | 下载地址如下:[OIA-DDR数据集](https://news.nankai.edu.cn/ywsd/system/2019/12/07/030036739.shtml) 6 | ## 1.2 数据集预处理 7 | 在这里预处理包含两部分 8 | 首先对于分割数据集,我们需要将四种标签的图片融到一张,用于之后的多标签分割,在这里执行`utils/change.py`中的 9 | 再将分类数据集按照四种类别分别分开,用于数据集的读取 10 | # 2. 模型训练 11 | ## 2.1 模型选取 12 | 在这里,针对眼底病变小样本的特点,我们选取了unet作为网络架构,用来进行病变分割 13 | 特征提取网络选用restnet50,训练所需的权值可在百度网盘中下载。 14 | 链接: https://pan.baidu.com/s/1A22fC5cPRb74gqrpq7O9-A 15 | 提取码: 6n2c 16 | ## 2.2 模型训练 17 | 模型训练执行train.py即可,注意在train.py内设置对应超参,具体见train.py文件夹 18 | ## 2.3 模型预测 19 | 执行那个get_miou.py文件,即可获取分割结果,并生成对应指标,存储分割结果 -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Bubbliiiing 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /utils/see.py: -------------------------------------------------------------------------------- 1 | 2 | #showPixelValue.py 3 | 4 | import cv2 5 | import sys 6 | 7 | def getMousePos(imgName): 8 | def onmouse(event, x, y, flags, param): 9 | cv2.imshow("img",img) 10 | #if event==cv2.EVENT_MOUSEMOVE: 11 | #print(img[y,x], " pos: ", x, " x ", y) 12 | 13 | #双击左键,显示鼠标位置 14 | if event == cv2.EVENT_MBUTTONDBLCLK: 15 | strtext = "(%s,%s)"%(x,y) 16 | print(img[y,x]) 17 | 18 | cv2.namedWindow("img", cv2.WINDOW_NORMAL) 19 | img= cv2.imread(imgName) 20 | print(img[img>4]) 21 | print(img.shape) 22 | 23 | cv2.setMouseCallback("img", onmouse) 24 | 25 | if cv2.waitKey() & 0xFF == 27: #按下‘q'键,退出 26 | cv2.destroyAllWindows() 27 | 28 | 29 | def showPixelValue(imgName): 30 | 31 | img= cv2.imread(imgName) 32 | def onmouse(event, x, y, flags, param): 33 | if event==cv2.EVENT_MOUSEMOVE: 34 | print(img[y,x]) 35 | 36 | cv2.namedWindow("img") 37 | cv2.setMouseCallback("img", onmouse) 38 | cv2.imshow("img",img) 39 | if cv2.waitKey() == ord('q'): #按下‘q'键,退出 40 | cv2.destroyAllWindows() 41 | 42 | if __name__ == '__main__': 43 | arg1 = 'E:\\setup\\UNet_Demo\\eyedetect\\NanKai\\train\\label\\num_label\\007-6361-400.png' 44 | #E:\\setup\\UNet_Demo\\eyedetect\\NanKai\\train\\label\\EX_01\\007-1774-100.png 45 | #E:\\setup\\UNet_Demo\\eyedetect\\miou_out\\HE_res\\007-6361-400.png 46 | getMousePos(arg1) 47 | -------------------------------------------------------------------------------- /voc_annotation_medical.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | #----------------------------------------------------------------------# 5 | # 医药数据集的例子没有验证集 6 | #----------------------------------------------------------------------# 7 | trainval_percent = 1 8 | train_percent = 1 9 | #-------------------------------------------------------# 10 | # 指向医药数据集所在的文件夹 11 | # 默认指向根目录下的Medical_Datasets 12 | #-------------------------------------------------------# 13 | VOCdevkit_path = 'Nankai' 14 | 15 | if __name__ == "__main__": 16 | random.seed(0) 17 | print("Generate txt in ImageSets.") 18 | segfilepath = os.path.join(VOCdevkit_path, 'train/label/EX') 19 | saveBasePath = os.path.join(VOCdevkit_path, 'save') 20 | 21 | temp_seg = os.listdir(segfilepath) 22 | total_seg = [] 23 | for seg in temp_seg: 24 | if seg.endswith(".jpg"): 25 | total_seg.append(seg) 26 | 27 | num = len(total_seg) 28 | list = range(num) 29 | tv = int(num*trainval_percent) 30 | tr = int(tv*train_percent) 31 | trainval= random.sample(list,tv) 32 | train = random.sample(trainval,tr) 33 | 34 | print("train and val size",tv) 35 | print("traub suze",tr) 36 | ftrainval = open(os.path.join(saveBasePath,'trainval.txt'), 'w') 37 | ftest = open(os.path.join(saveBasePath,'test.txt'), 'w') 38 | ftrain = open(os.path.join(saveBasePath,'train.txt'), 'w') 39 | fval = open(os.path.join(saveBasePath,'val.txt'), 'w') 40 | 41 | for i in list: 42 | name=total_seg[i][:-4]+'\n' 43 | if i in trainval: 44 | ftrainval.write(name) 45 | if i in train: 46 | ftrain.write(name) 47 | else: 48 | fval.write(name) 49 | else: 50 | ftest.write(name) 51 | 52 | ftrainval.close() 53 | ftrain.close() 54 | fval.close() 55 | ftest.close() 56 | print("Generate txt in ImageSets done.") 57 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | 4 | #---------------------------------------------------------# 5 | # 将图像转换成RGB图像,防止灰度图在预测时报错。 6 | # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB 7 | #---------------------------------------------------------# 8 | def cvtColor(image): 9 | if len(np.shape(image)) == 3 and np.shape(image)[2] == 3: 10 | return image 11 | else: 12 | image = image.convert('RGB') 13 | return image 14 | 15 | #---------------------------------------------------# 16 | # 对输入图像进行resize 17 | #---------------------------------------------------# 18 | def resize_image(image, size): 19 | iw, ih = image.size 20 | w, h = size 21 | 22 | scale = min(w/iw, h/ih) 23 | nw = int(iw*scale) 24 | nh = int(ih*scale) 25 | 26 | image = image.resize((nw,nh), Image.BICUBIC) 27 | new_image = Image.new('RGB', size, (128,128,128)) 28 | new_image.paste(image, ((w-nw)//2, (h-nh)//2)) 29 | 30 | return new_image, nw, nh 31 | def resize_image1(image, size): 32 | iw, ih = image.size 33 | return image, iw, ih 34 | #---------------------------------------------------# 35 | # 获得学习率 36 | #---------------------------------------------------# 37 | def get_lr(optimizer): 38 | for param_group in optimizer.param_groups: 39 | return param_group['lr'] 40 | 41 | def preprocess_input(image): 42 | image /= 255.0 43 | return image 44 | 45 | def show_config(**kwargs): 46 | print('Configurations:') 47 | print('-' * 70) 48 | print('|%25s | %40s|' % ('keys', 'values')) 49 | print('-' * 70) 50 | for key, value in kwargs.items(): 51 | print('|%25s | %40s|' % (str(key), str(value))) 52 | print('-' * 70) 53 | 54 | def download_weights(backbone, model_dir="./model_data"): 55 | import os 56 | from torch.hub import load_state_dict_from_url 57 | 58 | download_urls = { 59 | 'vgg' : 'https://download.pytorch.org/models/vgg16-397923af.pth', 60 | 'resnet50' : 'https://s3.amazonaws.com/pytorch/models/resnet50-19c8e357.pth' 61 | } 62 | url = download_urls[backbone] 63 | 64 | if not os.path.exists(model_dir): 65 | os.makedirs(model_dir) 66 | load_state_dict_from_url(url, model_dir) -------------------------------------------------------------------------------- /utils/tile.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import os 4 | from PIL import Image 5 | import re 6 | # 读入图像 7 | input=r'Nankai/train/image' 8 | input_label=r'Nankai/train/label/num_label' 9 | out=r'Nankai/train/image_tile/' 10 | output_label=r'Nankai/train/label/tile_label/' 11 | if not out: 12 | os.makedirs(out) 13 | if not output_label: 14 | os.makedirs(output_label) 15 | 16 | # 获取输入图像的尺寸 17 | def tile(img,outdir,filename,output_label): 18 | height, width, channels = img.shape 19 | # 指定每个图像块的大小 20 | block_size = 512 21 | 22 | # 计算在宽度和高度上应该分成多少个块 23 | num_blocks_wide = int(np.ceil(width / block_size)) 24 | num_blocks_high = int(np.ceil(height / block_size)) 25 | # 计算宽度和高度上需要填充的像素数 26 | pad_width = ((0, num_blocks_high * block_size - height), 27 | (0, num_blocks_wide * block_size - width), 28 | (0, 0)) 29 | # 对图像进行镜像填充 30 | padded_img = np.pad(img, pad_width, mode='reflect') 31 | # 切割图像并保存到磁盘 32 | for i in range(num_blocks_high): 33 | for j in range(num_blocks_wide): 34 | # 计算当前块在填充后图像中的位置 35 | start_x = j * block_size 36 | start_y = i * block_size 37 | 38 | # 切割出当前块 39 | block = padded_img[start_y:start_y+block_size, start_x:start_x+block_size, :] 40 | if len(block[block!=0])<100*100: 41 | continue 42 | #print(outdir+'output_image_{i*num_blocks_wide+j}.jpg') 43 | # 将当前块保存到磁盘 44 | cv2.imwrite(outdir+f'{filename}_{i*num_blocks_wide+j}.jpg', block) 45 | cv2.imwrite(output_label+f'{filename}_{i*num_blocks_wide+j}.png', block) 46 | win_size = 64 47 | step =64 48 | def Spilt(): 49 | # 打开待切割的图像文件 50 | img = Image.open('result.jpg') 51 | 52 | # 获取图像大小 53 | width, height = img.size 54 | 55 | # 遍历图像并切割为指定大小的图像块 56 | for i in range(0, height-win_size, step): 57 | for j in range(0, width-win_size, step): 58 | # 切割出当前位置的图像块 59 | box = (j, i, j+win_size, i+win_size) 60 | img_block = img.crop(box) 61 | 62 | # 判断当前图像块中黑色像素占比是否超过一半 63 | img_data = np.array(img_block) 64 | if np.mean(img_data == 0) <= 0.1: 65 | # 将当前图像块保存到指定路径中 66 | img_block.save(os.path.join('Nankai/label', f'block_{i}_{j}.jpg')) 67 | Spilt() 68 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # ignore map, miou, datasets 2 | map_out/ 3 | miou_out/ 4 | miou_train/ 5 | VOCdevkit/ 6 | datasets/ 7 | Medical_Datasets/ 8 | DDR dataset/ 9 | logs/ 10 | model_data/ 11 | .temp_miou_out/ 12 | Nankai/ 13 | Video/ 14 | model_data/ 15 | # Byte-compiled / optimized / DLL files 16 | __pycache__/ 17 | *.py[cod] 18 | *$py.class 19 | *.pth 20 | # C extensions 21 | *.so 22 | 23 | # Distribution / packaging 24 | .Python 25 | build/ 26 | develop-eggs/ 27 | dist/ 28 | downloads/ 29 | eggs/ 30 | .eggs/ 31 | lib/ 32 | lib64/ 33 | parts/ 34 | sdist/ 35 | var/ 36 | wheels/ 37 | pip-wheel-metadata/ 38 | share/python-wheels/ 39 | *.egg-info/ 40 | .installed.cfg 41 | *.egg 42 | MANIFEST 43 | 44 | # PyInstaller 45 | # Usually these files are written by a python script from a template 46 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 47 | *.manifest 48 | *.spec 49 | 50 | # Installer logs 51 | pip-log.txt 52 | pip-delete-this-directory.txt 53 | 54 | # Unit test / coverage reports 55 | htmlcov/ 56 | .tox/ 57 | .nox/ 58 | .coverage 59 | .coverage.* 60 | .cache 61 | nosetests.xml 62 | coverage.xml 63 | *.cover 64 | *.py,cover 65 | .hypothesis/ 66 | .pytest_cache/ 67 | 68 | # Translations 69 | *.mo 70 | *.pot 71 | 72 | # Django stuff: 73 | *.log 74 | local_settings.py 75 | db.sqlite3 76 | db.sqlite3-journal 77 | 78 | # Flask stuff: 79 | instance/ 80 | .webassets-cache 81 | 82 | # Scrapy stuff: 83 | .scrapy 84 | 85 | # Sphinx documentation 86 | docs/_build/ 87 | 88 | # PyBuilder 89 | target/ 90 | 91 | # Jupyter Notebook 92 | .ipynb_checkpoints 93 | 94 | # IPython 95 | profile_default/ 96 | ipython_config.py 97 | 98 | # pyenv 99 | .python-version 100 | 101 | # pipenv 102 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 103 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 104 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 105 | # install all needed dependencies. 106 | #Pipfile.lock 107 | 108 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 109 | __pypackages__/ 110 | 111 | # Celery stuff 112 | celerybeat-schedule 113 | celerybeat.pid 114 | 115 | # SageMath parsed files 116 | *.sage.py 117 | 118 | # Environments 119 | .env 120 | .venv 121 | env/ 122 | venv/ 123 | ENV/ 124 | env.bak/ 125 | venv.bak/ 126 | 127 | # Spyder project settings 128 | .spyderproject 129 | .spyproject 130 | 131 | # Rope project settings 132 | .ropeproject 133 | 134 | # mkdocs documentation 135 | /site 136 | 137 | # mypy 138 | .mypy_cache/ 139 | .dmypy.json 140 | dmypy.json 141 | 142 | # Pyre type checker 143 | .pyre/ 144 | -------------------------------------------------------------------------------- /get_miou.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from PIL import Image 4 | from tqdm import tqdm 5 | 6 | from unet import Unet 7 | from utils.utils_metrics import compute_mIoU, show_results 8 | 9 | ''' 10 | 进行指标评估需要注意以下几点: 11 | 1、该文件生成的图为灰度图,因为值比较小,按照JPG形式的图看是没有显示效果的,所以看到近似全黑的图是正常的。 12 | 2、该文件计算的是验证集的miou,当前该库将测试集当作验证集使用,不单独划分测试集 13 | 3、仅有按照VOC格式数据训练的模型可以利用这个文件进行miou的计算。 14 | ''' 15 | if __name__ == "__main__": 16 | #---------------------------------------------------------------------------# 17 | # miou_mode用于指定该文件运行时计算的内容 18 | # miou_mode为0代表整个miou计算流程,包括获得预测结果、计算miou。 19 | # miou_mode为1代表仅仅获得预测结果。 20 | # miou_mode为2代表仅仅计算miou。 21 | #---------------------------------------------------------------------------# 22 | miou_mode = 2 23 | #------------------------------# 24 | # 分类个数+1、如2+1 25 | #------------------------------# 26 | num_classes = 5 27 | #--------------------------------------------# 28 | # 区分的种类,和json_to_dataset里面的一样 29 | #--------------------------------------------# 30 | # name_classes = ["background","aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"] 31 | name_classes = ["background","EX", "HE", "MA", "SE"] 32 | #-------------------------------------------------------# 33 | # 指向VOC数据集所在的文件夹 34 | # 默认指向根目录下的VOC数据集 35 | #-------------------------------------------------------# 36 | VOCdevkit_path = 'NanKai' 37 | 38 | image_ids = open(os.path.join(VOCdevkit_path, "save/trainval.txt"),'r').read().splitlines() 39 | gt_dir = os.path.join(VOCdevkit_path, "train/label/num_label/") 40 | miou_out_path = "miou_train" 41 | pred_dir = os.path.join(miou_out_path, 'num_res') 42 | 43 | if miou_mode == 0 or miou_mode == 1: 44 | if not os.path.exists(pred_dir): 45 | os.makedirs(pred_dir) 46 | print("Load model.") 47 | unet = Unet() 48 | print("Load model done.") 49 | 50 | print("Get predict result.") 51 | for image_id in tqdm(image_ids): 52 | image_path = os.path.join(VOCdevkit_path, "train/image/"+image_id+".jpg") 53 | image = Image.open(image_path) 54 | image = unet.get_miou_png(image) 55 | image.save(os.path.join(pred_dir, image_id + ".png")) 56 | print("Get predict result done.") 57 | 58 | if miou_mode == 0 or miou_mode == 2: 59 | print("Get miou.") 60 | hist, IoUs, PA_Recall, Precision = compute_mIoU(gt_dir, pred_dir, image_ids, num_classes, name_classes) # 执行计算mIoU的函数 61 | print("Get miou done.") 62 | show_results(miou_out_path, hist, IoUs, PA_Recall, Precision, name_classes) -------------------------------------------------------------------------------- /nets/vgg.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.hub import load_state_dict_from_url 3 | 4 | 5 | class VGG(nn.Module): 6 | def __init__(self, features, num_classes=1000): 7 | super(VGG, self).__init__() 8 | self.features = features 9 | self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) 10 | self.classifier = nn.Sequential( 11 | nn.Linear(512 * 7 * 7, 4096), 12 | nn.ReLU(True), 13 | nn.Dropout(), 14 | nn.Linear(4096, 4096), 15 | nn.ReLU(True), 16 | nn.Dropout(), 17 | nn.Linear(4096, num_classes), 18 | ) 19 | self._initialize_weights() 20 | 21 | def forward(self, x): 22 | # x = self.features(x) 23 | # x = self.avgpool(x) 24 | # x = torch.flatten(x, 1) 25 | # x = self.classifier(x) 26 | feat1 = self.features[ :4 ](x) 27 | feat2 = self.features[4 :9 ](feat1) 28 | feat3 = self.features[9 :16](feat2) 29 | feat4 = self.features[16:23](feat3) 30 | feat5 = self.features[23:-1](feat4) 31 | return [feat1, feat2, feat3, feat4, feat5] 32 | 33 | def _initialize_weights(self): 34 | for m in self.modules(): 35 | if isinstance(m, nn.Conv2d): 36 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 37 | if m.bias is not None: 38 | nn.init.constant_(m.bias, 0) 39 | elif isinstance(m, nn.BatchNorm2d): 40 | nn.init.constant_(m.weight, 1) 41 | nn.init.constant_(m.bias, 0) 42 | elif isinstance(m, nn.Linear): 43 | nn.init.normal_(m.weight, 0, 0.01) 44 | nn.init.constant_(m.bias, 0) 45 | 46 | 47 | def make_layers(cfg, batch_norm=False, in_channels = 3): 48 | layers = [] 49 | for v in cfg: 50 | if v == 'M': 51 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 52 | else: 53 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 54 | if batch_norm: 55 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 56 | else: 57 | layers += [conv2d, nn.ReLU(inplace=True)] 58 | in_channels = v 59 | return nn.Sequential(*layers) 60 | # 512,512,3 -> 512,512,64 -> 256,256,64 -> 256,256,128 -> 128,128,128 -> 128,128,256 -> 64,64,256 61 | # 64,64,512 -> 32,32,512 -> 32,32,512 62 | cfgs = { 63 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'] 64 | } 65 | 66 | 67 | def VGG16(pretrained, in_channels = 3, **kwargs): 68 | model = VGG(make_layers(cfgs["D"], batch_norm = False, in_channels = in_channels), **kwargs) 69 | if pretrained: 70 | state_dict = load_state_dict_from_url("https://download.pytorch.org/models/vgg16-397923af.pth", model_dir="./model_data") 71 | model.load_state_dict(state_dict) 72 | 73 | del model.avgpool 74 | del model.classifier 75 | return model 76 | -------------------------------------------------------------------------------- /utils/fix.py: -------------------------------------------------------------------------------- 1 | #--------------------------------------------------------# 2 | # 该文件用于调整标签的格式 3 | #--------------------------------------------------------# 4 | import os 5 | 6 | import numpy as np 7 | from PIL import Image 8 | from tqdm import tqdm 9 | 10 | #-----------------------------------------------------------------------------------# 11 | # Origin_SegmentationClass_path 原始标签所在的路径 12 | # Out_SegmentationClass_path 输出标签所在的路径 13 | # 处理后的标签为灰度图,如果设置的值太小会看不见具体情况。 14 | #-----------------------------------------------------------------------------------# 15 | Origin_SegmentationClass_path = r"E:/setup/UNet_Demo/eyedetect/NanKai/train/label/SE_new" 16 | Out_SegmentationClass_path = r"E:/setup/UNet_Demo/eyedetect/NanKai/train/label/SE_01" 17 | 18 | #-----------------------------------------------------------------------------------# 19 | # Origin_Point_Value 原始标签对应的像素点值 20 | # Out_Point_Value 输出标签对应的像素点值 21 | # Origin_Point_Value需要与Out_Point_Value一一对应。 22 | # 举例如下,当: 23 | # Origin_Point_Value = np.array([0, 255]);Out_Point_Value = np.array([0, 1]) 24 | # 代表将原始标签中值为0的像素点,调整为0,将原始标签中值为255的像素点,调整为1。 25 | # 26 | # 示例中仅调整了两个像素点值,实际上可以更多个,如: 27 | # Origin_Point_Value = np.array([0, 128, 255]);Out_Point_Value = np.array([0, 1, 2]) 28 | # 29 | # 也可以是数组(当标签值为RGB像素点时),如 30 | # Origin_Point_Value = np.array([[0, 0, 0], [1, 1, 1]]);Out_Point_Value = np.array([0, 1]) 31 | #-----------------------------------------------------------------------------------# 32 | Origin_Point_Value = np.array([0, 255]) 33 | Out_Point_Value = np.array([0, 1]) 34 | 35 | if __name__ == "__main__": 36 | if not os.path.exists(Out_SegmentationClass_path): 37 | os.makedirs(Out_SegmentationClass_path) 38 | 39 | #---------------------------# 40 | # 遍历标签并赋值 41 | #---------------------------# 42 | png_names = os.listdir(Origin_SegmentationClass_path) 43 | print("正在遍历全部标签。") 44 | for png_name in tqdm(png_names): 45 | png = Image.open(os.path.join(Origin_SegmentationClass_path, png_name)) 46 | w, h = png.size 47 | 48 | png = np.array(png) 49 | out_png = np.zeros([h, w]) 50 | for i in range(len(Origin_Point_Value)): 51 | mask = png[:, :] == Origin_Point_Value[i] 52 | if len(np.shape(mask)) > 2: 53 | mask = mask.all(-1) 54 | out_png[mask] = Out_Point_Value[i] 55 | 56 | out_png = Image.fromarray(np.array(out_png, np.uint8)) 57 | out_png.save(os.path.join(Out_SegmentationClass_path, png_name)) 58 | 59 | #-------------------------------------# 60 | # 统计输出,各个像素点的值得个数 61 | #-------------------------------------# 62 | print("正在统计输出的图片每个像素点的数量。") 63 | classes_nums = np.zeros([256], np.int) 64 | for png_name in tqdm(png_names): 65 | png_file_name = os.path.join(Out_SegmentationClass_path, png_name) 66 | if not os.path.exists(png_file_name): 67 | raise ValueError("未检测到标签图片%s,请查看具体路径下文件是否存在以及后缀是否为png。"%(png_file_name)) 68 | 69 | png = np.array(Image.open(png_file_name), np.uint8) 70 | classes_nums += np.bincount(np.reshape(png, [-1]), minlength=256) 71 | 72 | print("打印像素点的值与数量。") 73 | print('-' * 37) 74 | print("| %15s | %15s |"%("Key", "Value")) 75 | print('-' * 37) 76 | for i in range(256): 77 | if classes_nums[i] > 0: 78 | print("| %15s | %15s |"%(str(i), str(classes_nums[i]))) 79 | print('-' * 37) -------------------------------------------------------------------------------- /nets/unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from nets.resnet import resnet50 5 | from nets.vgg import VGG16 6 | 7 | 8 | class unetUp(nn.Module): 9 | def __init__(self, in_size, out_size): 10 | super(unetUp, self).__init__() 11 | self.conv1 = nn.Conv2d(in_size, out_size, kernel_size = 3, padding = 1) 12 | self.conv2 = nn.Conv2d(out_size, out_size, kernel_size = 3, padding = 1) 13 | self.up = nn.UpsamplingBilinear2d(scale_factor = 2) 14 | self.relu = nn.ReLU(inplace = True) 15 | 16 | def forward(self, inputs1, inputs2): 17 | outputs = torch.cat([inputs1, self.up(inputs2)], 1) 18 | outputs = self.conv1(outputs) 19 | outputs = self.relu(outputs) 20 | outputs = self.conv2(outputs) 21 | outputs = self.relu(outputs) 22 | return outputs 23 | 24 | class Unet(nn.Module): 25 | def __init__(self, num_classes = 21, pretrained = False, backbone = 'vgg'): 26 | super(Unet, self).__init__() 27 | if backbone == 'vgg': 28 | self.vgg = VGG16(pretrained = pretrained) 29 | in_filters = [192, 384, 768, 1024] 30 | elif backbone == "resnet50": 31 | self.resnet = resnet50(pretrained = pretrained) 32 | in_filters = [192, 512, 1024, 3072] 33 | else: 34 | raise ValueError('Unsupported backbone - `{}`, Use vgg, resnet50.'.format(backbone)) 35 | out_filters = [64, 128, 256, 512] 36 | 37 | # upsampling 38 | # 64,64,512 39 | self.up_concat4 = unetUp(in_filters[3], out_filters[3]) 40 | # 128,128,256 41 | self.up_concat3 = unetUp(in_filters[2], out_filters[2]) 42 | # 256,256,128 43 | self.up_concat2 = unetUp(in_filters[1], out_filters[1]) 44 | # 512,512,64 45 | self.up_concat1 = unetUp(in_filters[0], out_filters[0]) 46 | 47 | if backbone == 'resnet50': 48 | self.up_conv = nn.Sequential( 49 | nn.UpsamplingBilinear2d(scale_factor = 2), 50 | nn.Conv2d(out_filters[0], out_filters[0], kernel_size = 3, padding = 1), 51 | nn.ReLU(), 52 | nn.Conv2d(out_filters[0], out_filters[0], kernel_size = 3, padding = 1), 53 | nn.ReLU(), 54 | ) 55 | else: 56 | self.up_conv = None 57 | 58 | self.final = nn.Conv2d(out_filters[0], num_classes, 1) 59 | 60 | self.backbone = backbone 61 | 62 | def forward(self, inputs): 63 | if self.backbone == "vgg": 64 | [feat1, feat2, feat3, feat4, feat5] = self.vgg.forward(inputs) 65 | elif self.backbone == "resnet50": 66 | [feat1, feat2, feat3, feat4, feat5] = self.resnet.forward(inputs) 67 | 68 | up4 = self.up_concat4(feat4, feat5) 69 | up3 = self.up_concat3(feat3, up4) 70 | up2 = self.up_concat2(feat2, up3) 71 | up1 = self.up_concat1(feat1, up2) 72 | 73 | if self.up_conv != None: 74 | up1 = self.up_conv(up1) 75 | 76 | final = self.final(up1) 77 | 78 | return final 79 | 80 | def freeze_backbone(self): 81 | if self.backbone == "vgg": 82 | for param in self.vgg.parameters(): 83 | param.requires_grad = False 84 | elif self.backbone == "resnet50": 85 | for param in self.resnet.parameters(): 86 | param.requires_grad = False 87 | 88 | def unfreeze_backbone(self): 89 | if self.backbone == "vgg": 90 | for param in self.vgg.parameters(): 91 | param.requires_grad = True 92 | elif self.backbone == "resnet50": 93 | for param in self.resnet.parameters(): 94 | param.requires_grad = True 95 | -------------------------------------------------------------------------------- /voc_annotation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import numpy as np 5 | from PIL import Image 6 | from tqdm import tqdm 7 | 8 | #-------------------------------------------------------# 9 | # 想要增加测试集修改trainval_percent 10 | # 修改train_percent用于改变验证集的比例 9:1 11 | # 12 | # 当前该库将测试集当作验证集使用,不单独划分测试集 13 | #-------------------------------------------------------# 14 | trainval_percent = 1 15 | train_percent = 0.9 16 | #-------------------------------------------------------# 17 | # 指向VOC数据集所在的文件夹 18 | # 默认指向根目录下的VOC数据集 19 | #-------------------------------------------------------# 20 | VOCdevkit_path = 'Nankai' 21 | 22 | if __name__ == "__main__": 23 | random.seed(0) 24 | print("Generate txt in ImageSets.") 25 | segfilepath = os.path.join(VOCdevkit_path, 'train/label/EX_01') 26 | saveBasePath = os.path.join(VOCdevkit_path, 'save') 27 | 28 | temp_seg = os.listdir(segfilepath) 29 | total_seg = [] 30 | for seg in temp_seg: 31 | if seg.endswith(".png"): 32 | total_seg.append(seg) 33 | 34 | num = len(total_seg) 35 | list = range(num) 36 | tv = int(num*trainval_percent) 37 | tr = int(tv*train_percent) 38 | trainval= random.sample(list,tv) 39 | train = random.sample(trainval,tr) 40 | 41 | print("train and val size",tv) 42 | print("traub suze",tr) 43 | ftrainval = open(os.path.join(saveBasePath,'trainval.txt'), 'w') 44 | ftest = open(os.path.join(saveBasePath,'test.txt'), 'w') 45 | ftrain = open(os.path.join(saveBasePath,'train.txt'), 'w') 46 | fval = open(os.path.join(saveBasePath,'val.txt'), 'w') 47 | 48 | for i in list: 49 | name = total_seg[i][:-4]+'\n' 50 | if i in trainval: 51 | ftrainval.write(name) 52 | if i in train: 53 | ftrain.write(name) 54 | else: 55 | fval.write(name) 56 | else: 57 | ftest.write(name) 58 | 59 | ftrainval.close() 60 | ftrain.close() 61 | fval.close() 62 | ftest.close() 63 | print("Generate txt in ImageSets done.") 64 | 65 | print("Check datasets format, this may take a while.") 66 | print("检查数据集格式是否符合要求,这可能需要一段时间。") 67 | classes_nums = np.zeros([256], np.int) 68 | for i in tqdm(list): 69 | name = total_seg[i] 70 | png_file_name = os.path.join(segfilepath, name) 71 | if not os.path.exists(png_file_name): 72 | raise ValueError("未检测到标签图片%s,请查看具体路径下文件是否存在以及后缀是否为png。"%(png_file_name)) 73 | 74 | png = np.array(Image.open(png_file_name), np.uint8) 75 | if len(np.shape(png)) > 2: 76 | print("标签图片%s的shape为%s,不属于灰度图或者八位彩图,请仔细检查数据集格式。"%(name, str(np.shape(png)))) 77 | print("标签图片需要为灰度图或者八位彩图,标签的每个像素点的值就是这个像素点所属的种类。"%(name, str(np.shape(png)))) 78 | 79 | classes_nums += np.bincount(np.reshape(png, [-1]), minlength=256) 80 | 81 | print("打印像素点的值与数量。") 82 | print('-' * 37) 83 | print("| %15s | %15s |"%("Key", "Value")) 84 | print('-' * 37) 85 | for i in range(256): 86 | if classes_nums[i] > 0: 87 | print("| %15s | %15s |"%(str(i), str(classes_nums[i]))) 88 | print('-' * 37) 89 | 90 | if classes_nums[255] > 0 and classes_nums[0] > 0 and np.sum(classes_nums[1:255]) == 0: 91 | print("检测到标签中像素点的值仅包含0与255,数据格式有误。") 92 | print("二分类问题需要将标签修改为背景的像素点值为0,目标的像素点值为1。") 93 | elif classes_nums[0] > 0 and np.sum(classes_nums[1:]) == 0: 94 | print("检测到标签中仅仅包含背景像素点,数据格式有误,请仔细检查数据集格式。") 95 | 96 | print("JPEGImages中的图片应当为.jpg文件、SegmentationClass中的图片应当为.png文件。") 97 | print("如果格式有误,参考:") 98 | print("https://github.com/bubbliiiing/segmentation-format-fix") -------------------------------------------------------------------------------- /utils/change.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import cv2 4 | import numpy as np 5 | import os 6 | from PIL import Image 7 | def tif_to_png(image_path,save_path): 8 | """ 9 | :param image_path: *.tif image path 10 | :param save_path: *.png image path 11 | :return: 12 | """ 13 | img = cv2.imread(image_path,cv2.IMREAD_GRAYSCALE) 14 | # print(img) 15 | # print(img.dtype) 16 | filename = image_path.split('/')[-1].split('_')[0] 17 | filename=filename.split('.')[0] 18 | save_path = save_path + '/' + filename + '.png' 19 | cv2.imwrite(save_path,img) 20 | def changePixel(): 21 | path = r'E:/setup/UNet_Demo/UNet_Demo/NanKai/train/label/EX_new/' 22 | savedpath = r'E:/setup/UNet_Demo/UNet_Demo/NanKai/train/label/EX_res/' 23 | filelist = os.listdir(path) 24 | for item in filelist: 25 | im = Image.open(path + item) #打开图片 26 | width = im.size[0] #获取宽度 27 | height = im.size[1] #获取长度 28 | for x in range(width): 29 | for y in range(height): 30 | r,g,b = im.getpixel((x,y)) 31 | if r+g+b>0: 32 | im.putpixel((x,y),(1,1,1)) 33 | im = im.convert('RGB') 34 | im.save(savedpath + item) 35 | def gray(): 36 | path = r'E:/setup/UNet_Demo/UNet_Demo/NanKai/train/image/' 37 | savepath = r'E:/setup/UNet_Demo/UNet_Demo/NanKai/train/new_image/' 38 | filelist=os.listdir(path) 39 | for item in filelist: 40 | image = cv2.imread(path+item, cv2.IMREAD_GRAYSCALE) 41 | cv2.imwrite(savepath+item, image) 42 | def hw(strJpgFile, strSaveDir, width=512, height=512): 43 | img_src = Image.open(strJpgFile) 44 | img_dst = img_src.resize((width, height), Image.LANCZOS) # 得到的图像在抗锯齿和保留锐利边缘的效果较好 45 | img_dst.save(os.path.join(strSaveDir, os.path.basename(strJpgFile))) 46 | def mergeData(ex_data,ha_data,se_data,ma_data,pic_name,save_path): 47 | ex=np.array(Image.open(ex_data+pic_name)) 48 | ha=np.array(Image.open(ha_data+pic_name)) 49 | se=np.array(Image.open(se_data+pic_name)) 50 | ma=np.array(Image.open(ma_data+pic_name)) 51 | ex=ex+ha*2+se*3+ma*4 52 | ex[ex>4]=4 53 | ex=Image.fromarray(ex) 54 | ex.save(save_path+pic_name) 55 | def merge(): 56 | ex_path="E:\\setup\\UNet_Demo\\eyedetect\\NanKai\\train\\label\\EX_01\\" 57 | ha_path="E:\\setup\\UNet_Demo\\eyedetect\\NanKai\\train\\label\\HE_01\\" 58 | se_path="E:\\setup\\UNet_Demo\\eyedetect\\NanKai\\train\\label\\SE_01\\" 59 | ma_path="E:\\setup\\UNet_Demo\\eyedetect\\NanKai\\train\\label\\MA_01\\" 60 | save_path="E:\\setup\\UNet_Demo\\eyedetect\\NanKai\\train\\label\\num_label\\" 61 | image_files=os.listdir(ex_path) 62 | for item in image_files: 63 | mergeData(ex_path,ha_path,se_path,ma_path,item,save_path) 64 | def getPix(): 65 | save_path="E:\\setup\\UNet_Demo\\eyedetect\\NanKai\\train\\label\\num_label\\" 66 | image_files=os.listdir(save_path) 67 | res=[0,0,0,0,0] 68 | for item in image_files: 69 | img=Image.open(save_path+item) 70 | img=np.array(img) 71 | res[0]+=len(img[img==0]) 72 | res[1]+=len(img[img==1]) 73 | res[2]+=len(img[img==2]) 74 | res[3]+=len(img[img==3]) 75 | res[4]+=len(img[img==4]) 76 | print(res) 77 | def getonePix(): 78 | save_path="E:\\setup\\UNet_Demo\\eyedetect\\NanKai\\test\\label\\MA\\" 79 | image_files=os.listdir(save_path) 80 | res=[0,0] 81 | for item in image_files: 82 | img=Image.open(save_path+item) 83 | img=np.array(img) 84 | res[0]+=len(img[img==0]) 85 | res[1]+=len(img[img==255]) 86 | print(res) 87 | if __name__ == '__main__': 88 | root_path = r'E:/setup/UNet_Demo/eyedetect/NanKai/train/label/SE/' 89 | save_path = r'E:/setup/UNet_Demo/eyedetect/NanKai/train/label/SE_new' 90 | image_files = os.listdir(root_path) 91 | getonePix() 92 | #merge() 93 | #for item in image_files: 94 | # tif_to_png(root_path+item,save_path) 95 | #gray() -------------------------------------------------------------------------------- /nets/unet_training.py: -------------------------------------------------------------------------------- 1 | import math 2 | from functools import partial 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | def CE_Loss(inputs, target, cls_weights, num_classes=21): 10 | n, c, h, w = inputs.size() 11 | nt, ht, wt = target.size() 12 | if h != ht and w != wt: 13 | inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True) 14 | 15 | temp_inputs = inputs.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c) 16 | temp_target = target.view(-1) 17 | 18 | CE_loss = nn.CrossEntropyLoss(weight=cls_weights, ignore_index=num_classes)(temp_inputs, temp_target) 19 | return CE_loss 20 | 21 | def Focal_Loss(inputs, target, cls_weights, num_classes=21, alpha=0.5, gamma=2): 22 | n, c, h, w = inputs.size() 23 | nt, ht, wt = target.size() 24 | if h != ht and w != wt: 25 | inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True) 26 | 27 | temp_inputs = inputs.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c) 28 | temp_target = target.view(-1) 29 | 30 | logpt = -nn.CrossEntropyLoss(weight=cls_weights, ignore_index=num_classes, reduction='none')(temp_inputs, temp_target) 31 | pt = torch.exp(logpt) 32 | if alpha is not None: 33 | logpt *= alpha 34 | loss = -((1 - pt) ** gamma) * logpt 35 | loss = loss.mean() 36 | return loss 37 | 38 | def Dice_loss(inputs, target, beta=1, smooth = 1e-5): 39 | n, c, h, w = inputs.size() 40 | nt, ht, wt, ct = target.size() 41 | if h != ht and w != wt: 42 | inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True) 43 | 44 | temp_inputs = torch.softmax(inputs.transpose(1, 2).transpose(2, 3).contiguous().view(n, -1, c),-1) 45 | temp_target = target.view(n, -1, ct) 46 | 47 | #--------------------------------------------# 48 | # 计算dice loss 49 | #--------------------------------------------# 50 | tp = torch.sum(temp_target[...,:-1] * temp_inputs, axis=[0,1]) 51 | fp = torch.sum(temp_inputs , axis=[0,1]) - tp 52 | fn = torch.sum(temp_target[...,:-1] , axis=[0,1]) - tp 53 | 54 | score = ((1 + beta ** 2) * tp + smooth) / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + smooth) 55 | dice_loss = 1 - torch.mean(score) 56 | return dice_loss 57 | 58 | def weights_init(net, init_type='normal', init_gain=0.02): 59 | def init_func(m): 60 | classname = m.__class__.__name__ 61 | if hasattr(m, 'weight') and classname.find('Conv') != -1: 62 | if init_type == 'normal': 63 | torch.nn.init.normal_(m.weight.data, 0.0, init_gain) 64 | elif init_type == 'xavier': 65 | torch.nn.init.xavier_normal_(m.weight.data, gain=init_gain) 66 | elif init_type == 'kaiming': 67 | torch.nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 68 | elif init_type == 'orthogonal': 69 | torch.nn.init.orthogonal_(m.weight.data, gain=init_gain) 70 | else: 71 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 72 | elif classname.find('BatchNorm2d') != -1: 73 | torch.nn.init.normal_(m.weight.data, 1.0, 0.02) 74 | torch.nn.init.constant_(m.bias.data, 0.0) 75 | print('initialize network with %s type' % init_type) 76 | net.apply(init_func) 77 | 78 | def get_lr_scheduler(lr_decay_type, lr, min_lr, total_iters, warmup_iters_ratio = 0.05, warmup_lr_ratio = 0.1, no_aug_iter_ratio = 0.05, step_num = 10): 79 | def yolox_warm_cos_lr(lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter, iters): 80 | if iters <= warmup_total_iters: 81 | # lr = (lr - warmup_lr_start) * iters / float(warmup_total_iters) + warmup_lr_start 82 | lr = (lr - warmup_lr_start) * pow(iters / float(warmup_total_iters), 2) + warmup_lr_start 83 | elif iters >= total_iters - no_aug_iter: 84 | lr = min_lr 85 | else: 86 | lr = min_lr + 0.5 * (lr - min_lr) * ( 87 | 1.0 + math.cos(math.pi* (iters - warmup_total_iters) / (total_iters - warmup_total_iters - no_aug_iter)) 88 | ) 89 | return lr 90 | 91 | def step_lr(lr, decay_rate, step_size, iters): 92 | if step_size < 1: 93 | raise ValueError("step_size must above 1.") 94 | n = iters // step_size 95 | out_lr = lr * decay_rate ** n 96 | return out_lr 97 | 98 | if lr_decay_type == "cos": 99 | warmup_total_iters = min(max(warmup_iters_ratio * total_iters, 1), 3) 100 | warmup_lr_start = max(warmup_lr_ratio * lr, 1e-6) 101 | no_aug_iter = min(max(no_aug_iter_ratio * total_iters, 1), 15) 102 | func = partial(yolox_warm_cos_lr ,lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter) 103 | else: 104 | decay_rate = (min_lr / lr) ** (1 / (step_num - 1)) 105 | step_size = total_iters / step_num 106 | func = partial(step_lr, lr, decay_rate, step_size) 107 | 108 | return func 109 | 110 | def set_optimizer_lr(optimizer, lr_scheduler_func, epoch): 111 | lr = lr_scheduler_func(epoch) 112 | for param_group in optimizer.param_groups: 113 | param_group['lr'] = lr 114 | -------------------------------------------------------------------------------- /utils/dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import cv2 4 | import numpy as np 5 | import torch 6 | from PIL import Image 7 | from torch.utils.data.dataset import Dataset 8 | 9 | from utils.utils import cvtColor, preprocess_input 10 | 11 | 12 | class UnetDataset(Dataset): 13 | def __init__(self, annotation_lines, input_shape, num_classes, train, dataset_path): 14 | super(UnetDataset, self).__init__() 15 | self.annotation_lines = annotation_lines 16 | self.length = len(annotation_lines) 17 | self.input_shape = input_shape 18 | self.num_classes = num_classes 19 | self.train = train 20 | self.dataset_path = dataset_path 21 | 22 | def __len__(self): 23 | return self.length 24 | 25 | def __getitem__(self, index): 26 | annotation_line = self.annotation_lines[index] 27 | name = annotation_line.split()[0] 28 | 29 | #-------------------------------# 30 | # 从文件中读取图像 31 | #-------------------------------# 32 | jpg = Image.open(os.path.join(os.path.join(self.dataset_path, "train/image"), name + ".jpg")) 33 | png = Image.open(os.path.join(os.path.join(self.dataset_path, "train/label/EX_01"), name + ".png")) 34 | #-------------------------------# 35 | # 数据增强 36 | #-------------------------------# 37 | jpg, png = self.get_random_data(jpg, png, self.input_shape, random = self.train) 38 | 39 | jpg = np.transpose(preprocess_input(np.array(jpg, np.float64)), [2,0,1]) 40 | png = np.array(png) 41 | png[png >= self.num_classes] = self.num_classes 42 | #-------------------------------------------------------# 43 | # 转化成one_hot的形式 44 | # 在这里需要+1是因为voc数据集有些标签具有白边部分 45 | # 我们需要将白边部分进行忽略,+1的目的是方便忽略。 46 | #-------------------------------------------------------# 47 | seg_labels = np.eye(self.num_classes + 1)[png.reshape([-1])] 48 | seg_labels = seg_labels.reshape((int(self.input_shape[0]), int(self.input_shape[1]), self.num_classes + 1)) 49 | 50 | return jpg, png, seg_labels 51 | 52 | def rand(self, a=0, b=1): 53 | return np.random.rand() * (b - a) + a 54 | 55 | def get_random_data(self, image, label, input_shape, jitter=.3, hue=.1, sat=0.7, val=0.3, random=True): 56 | image = cvtColor(image) 57 | label = Image.fromarray(np.array(label)) 58 | #------------------------------# 59 | # 获得图像的高宽与目标高宽 60 | #------------------------------# 61 | iw, ih = image.size 62 | h, w = input_shape 63 | 64 | if not random: 65 | iw, ih = image.size 66 | scale = min(w/iw, h/ih) 67 | nw = int(iw*scale) 68 | nh = int(ih*scale) 69 | 70 | image = image.resize((nw,nh), Image.BICUBIC) 71 | new_image = Image.new('RGB', [w, h], (128,128,128)) 72 | new_image.paste(image, ((w-nw)//2, (h-nh)//2)) 73 | 74 | label = label.resize((nw,nh), Image.NEAREST) 75 | new_label = Image.new('L', [w, h], (0)) 76 | new_label.paste(label, ((w-nw)//2, (h-nh)//2)) 77 | return new_image, new_label 78 | 79 | #------------------------------------------# 80 | # 对图像进行缩放并且进行长和宽的扭曲 81 | #------------------------------------------# 82 | new_ar = iw/ih * self.rand(1-jitter,1+jitter) / self.rand(1-jitter,1+jitter) 83 | scale = self.rand(0.25, 2) 84 | if new_ar < 1: 85 | nh = int(scale*h) 86 | nw = int(nh*new_ar) 87 | else: 88 | nw = int(scale*w) 89 | nh = int(nw/new_ar) 90 | image = image.resize((nw,nh), Image.BICUBIC) 91 | label = label.resize((nw,nh), Image.NEAREST) 92 | 93 | #------------------------------------------# 94 | # 翻转图像 95 | #------------------------------------------# 96 | flip = self.rand()<.5 97 | if flip: 98 | image = image.transpose(Image.FLIP_LEFT_RIGHT) 99 | label = label.transpose(Image.FLIP_LEFT_RIGHT) 100 | 101 | #------------------------------------------# 102 | # 将图像多余的部分加上灰条 103 | #------------------------------------------# 104 | dx = int(self.rand(0, w-nw)) 105 | dy = int(self.rand(0, h-nh)) 106 | new_image = Image.new('RGB', (w,h), (128,128,128)) 107 | new_label = Image.new('L', (w,h), (0)) 108 | new_image.paste(image, (dx, dy)) 109 | new_label.paste(label, (dx, dy)) 110 | image = new_image 111 | label = new_label 112 | 113 | image_data = np.array(image, np.uint8) 114 | #---------------------------------# 115 | # 对图像进行色域变换 116 | # 计算色域变换的参数 117 | #---------------------------------# 118 | r = np.random.uniform(-1, 1, 3) * [hue, sat, val] + 1 119 | #---------------------------------# 120 | # 将图像转到HSV上 121 | #---------------------------------# 122 | hue, sat, val = cv2.split(cv2.cvtColor(image_data, cv2.COLOR_RGB2HSV)) 123 | dtype = image_data.dtype 124 | #---------------------------------# 125 | # 应用变换 126 | #---------------------------------# 127 | x = np.arange(0, 256, dtype=r.dtype) 128 | lut_hue = ((x * r[0]) % 180).astype(dtype) 129 | lut_sat = np.clip(x * r[1], 0, 255).astype(dtype) 130 | lut_val = np.clip(x * r[2], 0, 255).astype(dtype) 131 | 132 | image_data = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val))) 133 | image_data = cv2.cvtColor(image_data, cv2.COLOR_HSV2RGB) 134 | 135 | return image_data, label 136 | 137 | # DataLoader中collate_fn使用 138 | def unet_dataset_collate(batch): 139 | images = [] 140 | pngs = [] 141 | seg_labels = [] 142 | for img, png, labels in batch: 143 | images.append(img) 144 | pngs.append(png) 145 | seg_labels.append(labels) 146 | images = torch.from_numpy(np.array(images)).type(torch.FloatTensor) 147 | pngs = torch.from_numpy(np.array(pngs)).long() 148 | seg_labels = torch.from_numpy(np.array(seg_labels)).type(torch.FloatTensor) 149 | return images, pngs, seg_labels 150 | -------------------------------------------------------------------------------- /utils/dataloader_medical.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import cv2 4 | import numpy as np 5 | import torch 6 | from PIL import Image 7 | from torch.utils.data.dataset import Dataset 8 | import matplotlib.pyplot as plt 9 | from utils.utils import cvtColor, preprocess_input 10 | 11 | 12 | class UnetDataset(Dataset): 13 | def __init__(self, annotation_lines, input_shape, num_classes, train, dataset_path): 14 | super(UnetDataset, self).__init__() 15 | self.annotation_lines = annotation_lines 16 | self.length = len(annotation_lines) 17 | self.input_shape = input_shape 18 | self.num_classes = num_classes 19 | self.train = train 20 | self.dataset_path = dataset_path 21 | 22 | def __len__(self): 23 | return self.length 24 | 25 | def __getitem__(self, index): 26 | annotation_line = self.annotation_lines[index] 27 | name = annotation_line.split()[0] 28 | 29 | #-------------------------------# 30 | # 从文件中读取图像 31 | #-------------------------------# 32 | jpg = Image.open(os.path.join(os.path.join(self.dataset_path, "train/new_image"), name + ".jpg")) 33 | png = Image.open(os.path.join(os.path.join(self.dataset_path, "train/label/EX_01"), name + ".png")) 34 | #-------------------------------# 35 | # 数据增强 36 | #-------------------------------# 37 | jpg, png = self.get_random_data(jpg, png, self.input_shape, random = self.train) 38 | #plt.imshow(jpg) 39 | jpg = np.transpose(preprocess_input(np.array(jpg, np.float64)), [2,0,1]) 40 | png = np.array(png) 41 | #-------------------------------------------------------# 42 | # 这里的标签处理方式和普通voc的处理方式不同 43 | # 将小于127.5的像素点设置为目标像素点。 44 | #-------------------------------------------------------# 45 | modify_png = np.zeros_like(png) 46 | modify_png[png==0] = 1 47 | seg_labels = modify_png 48 | seg_labels = np.eye(self.num_classes + 1)[seg_labels.reshape([-1])] 49 | seg_labels = seg_labels.reshape((int(self.input_shape[0]), int(self.input_shape[1]), self.num_classes + 1)) 50 | 51 | return jpg, modify_png, seg_labels 52 | 53 | def rand(self, a=0, b=1): 54 | return np.random.rand() * (b - a) + a 55 | 56 | def get_random_data(self, image, label, input_shape, jitter=.3, hue=.1, sat=0.7, val=0.3, random=True): 57 | image = cvtColor(image) 58 | label = Image.fromarray(np.array(label)) 59 | #------------------------------# 60 | # 获得图像的高宽与目标高宽 61 | #------------------------------# 62 | iw, ih = image.size 63 | h, w = input_shape 64 | 65 | if not random: 66 | iw, ih = image.size 67 | scale = min(w/iw, h/ih) 68 | nw = int(iw*scale) 69 | nh = int(ih*scale) 70 | 71 | image = image.resize((nw,nh), Image.BICUBIC) 72 | new_image = Image.new('RGB', [w, h], (128,128,128)) 73 | new_image.paste(image, ((w-nw)//2, (h-nh)//2)) 74 | 75 | label = label.resize((nw,nh), Image.NEAREST) 76 | new_label = Image.new('L', [w, h], (0)) 77 | new_label.paste(label, ((w-nw)//2, (h-nh)//2)) 78 | return new_image, new_label 79 | 80 | #------------------------------------------# 81 | # 对图像进行缩放并且进行长和宽的扭曲 82 | #------------------------------------------# 83 | new_ar = iw/ih * self.rand(1-jitter,1+jitter) / self.rand(1-jitter,1+jitter) 84 | scale = self.rand(0.25, 2) 85 | if new_ar < 1: 86 | nh = int(scale*h) 87 | nw = int(nh*new_ar) 88 | else: 89 | nw = int(scale*w) 90 | nh = int(nw/new_ar) 91 | image = image.resize((nw,nh), Image.BICUBIC) 92 | label = label.resize((nw,nh), Image.NEAREST) 93 | 94 | #------------------------------------------# 95 | # 翻转图像 96 | #------------------------------------------# 97 | flip = self.rand()<.5 98 | if flip: 99 | image = image.transpose(Image.FLIP_LEFT_RIGHT) 100 | label = label.transpose(Image.FLIP_LEFT_RIGHT) 101 | 102 | #------------------------------------------# 103 | # 将图像多余的部分加上灰条 104 | #------------------------------------------# 105 | dx = int(self.rand(0, w-nw)) 106 | dy = int(self.rand(0, h-nh)) 107 | new_image = Image.new('RGB', (w,h), (128,128,128)) 108 | new_label = Image.new('L', (w,h), (0)) 109 | new_image.paste(image, (dx, dy)) 110 | new_label.paste(label, (dx, dy)) 111 | image = new_image 112 | label = new_label 113 | 114 | image_data = np.array(image, np.uint8) 115 | #---------------------------------# 116 | # 对图像进行色域变换 117 | # 计算色域变换的参数 118 | #---------------------------------# 119 | r = np.random.uniform(-1, 1, 3) * [hue, sat, val] + 1 120 | #---------------------------------# 121 | # 将图像转到HSV上 122 | #---------------------------------# 123 | hue, sat, val = cv2.split(cv2.cvtColor(image_data, cv2.COLOR_RGB2HSV)) 124 | dtype = image_data.dtype 125 | #---------------------------------# 126 | # 应用变换 127 | #---------------------------------# 128 | x = np.arange(0, 256, dtype=r.dtype) 129 | lut_hue = ((x * r[0]) % 180).astype(dtype) 130 | lut_sat = np.clip(x * r[1], 0, 255).astype(dtype) 131 | lut_val = np.clip(x * r[2], 0, 255).astype(dtype) 132 | 133 | image_data = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val))) 134 | image_data = cv2.cvtColor(image_data, cv2.COLOR_HSV2RGB) 135 | 136 | return image_data, label 137 | 138 | # DataLoader中collate_fn使用 139 | def unet_dataset_collate(batch): 140 | images = [] 141 | pngs = [] 142 | seg_labels = [] 143 | for img, png, labels in batch: 144 | images.append(img) 145 | pngs.append(png) 146 | seg_labels.append(labels) 147 | images = torch.from_numpy(np.array(images)).type(torch.FloatTensor) 148 | pngs = torch.from_numpy(np.array(pngs)).long() 149 | seg_labels = torch.from_numpy(np.array(seg_labels)).type(torch.FloatTensor) 150 | return images, pngs, seg_labels 151 | -------------------------------------------------------------------------------- /nets/resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch.nn as nn 4 | import torch.utils.model_zoo as model_zoo 5 | 6 | 7 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 8 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 9 | padding=dilation, groups=groups, bias=False, dilation=dilation) 10 | 11 | 12 | def conv1x1(in_planes, out_planes, stride=1): 13 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 14 | 15 | 16 | class BasicBlock(nn.Module): 17 | expansion = 1 18 | 19 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 20 | base_width=64, dilation=1, norm_layer=None): 21 | super(BasicBlock, self).__init__() 22 | if norm_layer is None: 23 | norm_layer = nn.BatchNorm2d 24 | if groups != 1 or base_width != 64: 25 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 26 | if dilation > 1: 27 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 28 | self.conv1 = conv3x3(inplanes, planes, stride) 29 | self.bn1 = norm_layer(planes) 30 | self.relu = nn.ReLU(inplace=True) 31 | self.conv2 = conv3x3(planes, planes) 32 | self.bn2 = norm_layer(planes) 33 | self.downsample = downsample 34 | self.stride = stride 35 | 36 | def forward(self, x): 37 | identity = x 38 | 39 | out = self.conv1(x) 40 | out = self.bn1(out) 41 | out = self.relu(out) 42 | 43 | out = self.conv2(out) 44 | out = self.bn2(out) 45 | 46 | if self.downsample is not None: 47 | identity = self.downsample(x) 48 | 49 | out += identity 50 | out = self.relu(out) 51 | 52 | return out 53 | 54 | 55 | class Bottleneck(nn.Module): 56 | expansion = 4 57 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 58 | base_width=64, dilation=1, norm_layer=None): 59 | super(Bottleneck, self).__init__() 60 | if norm_layer is None: 61 | norm_layer = nn.BatchNorm2d 62 | width = int(planes * (base_width / 64.)) * groups 63 | # 利用1x1卷积下降通道数 64 | self.conv1 = conv1x1(inplanes, width) 65 | self.bn1 = norm_layer(width) 66 | # 利用3x3卷积进行特征提取 67 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 68 | self.bn2 = norm_layer(width) 69 | # 利用1x1卷积上升通道数 70 | self.conv3 = conv1x1(width, planes * self.expansion) 71 | self.bn3 = norm_layer(planes * self.expansion) 72 | 73 | self.relu = nn.ReLU(inplace=True) 74 | self.downsample = downsample 75 | self.stride = stride 76 | 77 | def forward(self, x): 78 | identity = x 79 | 80 | out = self.conv1(x) 81 | out = self.bn1(out) 82 | out = self.relu(out) 83 | 84 | out = self.conv2(out) 85 | out = self.bn2(out) 86 | out = self.relu(out) 87 | 88 | out = self.conv3(out) 89 | out = self.bn3(out) 90 | 91 | if self.downsample is not None: 92 | identity = self.downsample(x) 93 | 94 | out += identity 95 | out = self.relu(out) 96 | 97 | return out 98 | 99 | 100 | class ResNet(nn.Module): 101 | def __init__(self, block, layers, num_classes=1000): 102 | #-----------------------------------------------------------# 103 | # 假设输入图像为600,600,3 104 | # 当我们使用resnet50的时候 105 | #-----------------------------------------------------------# 106 | self.inplanes = 64 107 | super(ResNet, self).__init__() 108 | # 600,600,3 -> 300,300,64 109 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 110 | self.bn1 = nn.BatchNorm2d(64) 111 | self.relu = nn.ReLU(inplace=True) 112 | # 300,300,64 -> 150,150,64 113 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True) # change 114 | # 150,150,64 -> 150,150,256 115 | self.layer1 = self._make_layer(block, 64, layers[0]) 116 | # 150,150,256 -> 75,75,512 117 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 118 | # 75,75,512 -> 38,38,1024 119 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 120 | # 38,38,1024 -> 19,19,2048 121 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 122 | 123 | self.avgpool = nn.AvgPool2d(7) 124 | self.fc = nn.Linear(512 * block.expansion, num_classes) 125 | 126 | for m in self.modules(): 127 | if isinstance(m, nn.Conv2d): 128 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 129 | m.weight.data.normal_(0, math.sqrt(2. / n)) 130 | elif isinstance(m, nn.BatchNorm2d): 131 | m.weight.data.fill_(1) 132 | m.bias.data.zero_() 133 | 134 | def _make_layer(self, block, planes, blocks, stride=1): 135 | downsample = None 136 | if stride != 1 or self.inplanes != planes * block.expansion: 137 | downsample = nn.Sequential( 138 | nn.Conv2d(self.inplanes, planes * block.expansion, 139 | kernel_size=1, stride=stride, bias=False), 140 | nn.BatchNorm2d(planes * block.expansion), 141 | ) 142 | 143 | layers = [] 144 | layers.append(block(self.inplanes, planes, stride, downsample)) 145 | self.inplanes = planes * block.expansion 146 | for i in range(1, blocks): 147 | layers.append(block(self.inplanes, planes)) 148 | 149 | return nn.Sequential(*layers) 150 | 151 | def forward(self, x): 152 | # x = self.conv1(x) 153 | # x = self.bn1(x) 154 | # x = self.relu(x) 155 | # x = self.maxpool(x) 156 | 157 | # x = self.layer1(x) 158 | # x = self.layer2(x) 159 | # x = self.layer3(x) 160 | # x = self.layer4(x) 161 | 162 | # x = self.avgpool(x) 163 | # x = x.view(x.size(0), -1) 164 | # x = self.fc(x) 165 | 166 | x = self.conv1(x) 167 | x = self.bn1(x) 168 | feat1 = self.relu(x) 169 | 170 | x = self.maxpool(feat1) 171 | feat2 = self.layer1(x) 172 | 173 | feat3 = self.layer2(feat2) 174 | feat4 = self.layer3(feat3) 175 | feat5 = self.layer4(feat4) 176 | return [feat1, feat2, feat3, feat4, feat5] 177 | 178 | def resnet50(pretrained=False, **kwargs): 179 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 180 | if pretrained: 181 | model.load_state_dict(model_zoo.load_url('https://s3.amazonaws.com/pytorch/models/resnet50-19c8e357.pth', model_dir='model_data'), strict=False) 182 | 183 | del model.avgpool 184 | del model.fc 185 | return model 186 | -------------------------------------------------------------------------------- /.idea/workspace.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 11 | 12 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 34 | 35 | 36 | 37 | 38 | 57 | 58 | 59 | 78 | 79 | 80 | 99 | 100 | 101 | 120 | 121 | 122 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 1664104742342 158 | 162 | 163 | 164 | -------------------------------------------------------------------------------- /utils/callbacks.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import matplotlib 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | matplotlib.use('Agg') 8 | from matplotlib import pyplot as plt 9 | import scipy.signal 10 | 11 | import cv2 12 | import shutil 13 | import numpy as np 14 | 15 | from PIL import Image 16 | from tqdm import tqdm 17 | from tensorboardX import SummaryWriter 18 | from .utils import cvtColor, preprocess_input, resize_image 19 | from .utils_metrics import compute_mIoU 20 | 21 | 22 | class LossHistory(): 23 | def __init__(self, log_dir, model, input_shape, val_loss_flag=True): 24 | self.log_dir = log_dir 25 | self.val_loss_flag = val_loss_flag 26 | 27 | self.losses = [] 28 | if self.val_loss_flag: 29 | self.val_loss = [] 30 | 31 | os.makedirs(self.log_dir) 32 | self.writer = SummaryWriter(self.log_dir) 33 | try: 34 | dummy_input = torch.randn(2, 3, input_shape[0], input_shape[1]) 35 | self.writer.add_graph(model, dummy_input) 36 | except: 37 | pass 38 | 39 | def append_loss(self, epoch, loss, val_loss = None): 40 | if not os.path.exists(self.log_dir): 41 | os.makedirs(self.log_dir) 42 | self.losses.append(loss) 43 | if self.val_loss_flag: 44 | self.val_loss.append(val_loss) 45 | 46 | with open(os.path.join(self.log_dir, "epoch_loss.txt"), 'a') as f: 47 | f.write(str(loss)) 48 | f.write("\n") 49 | if self.val_loss_flag: 50 | with open(os.path.join(self.log_dir, "epoch_val_loss.txt"), 'a') as f: 51 | f.write(str(val_loss)) 52 | f.write("\n") 53 | 54 | self.writer.add_scalar('loss', loss, epoch) 55 | if self.val_loss_flag: 56 | self.writer.add_scalar('val_loss', val_loss, epoch) 57 | 58 | self.loss_plot() 59 | 60 | def loss_plot(self): 61 | iters = range(len(self.losses)) 62 | 63 | plt.figure() 64 | plt.plot(iters, self.losses, 'red', linewidth = 2, label='train loss') 65 | if self.val_loss_flag: 66 | plt.plot(iters, self.val_loss, 'coral', linewidth = 2, label='val loss') 67 | 68 | try: 69 | if len(self.losses) < 25: 70 | num = 5 71 | else: 72 | num = 15 73 | 74 | plt.plot(iters, scipy.signal.savgol_filter(self.losses, num, 3), 'green', linestyle = '--', linewidth = 2, label='smooth train loss') 75 | if self.val_loss_flag: 76 | plt.plot(iters, scipy.signal.savgol_filter(self.val_loss, num, 3), '#8B4513', linestyle = '--', linewidth = 2, label='smooth val loss') 77 | except: 78 | pass 79 | 80 | plt.grid(True) 81 | plt.xlabel('Epoch') 82 | plt.ylabel('Loss') 83 | plt.legend(loc="upper right") 84 | 85 | plt.savefig(os.path.join(self.log_dir, "epoch_loss.png")) 86 | 87 | plt.cla() 88 | plt.close("all") 89 | 90 | class EvalCallback(): 91 | def __init__(self, net, input_shape, num_classes, image_ids, dataset_path, log_dir, cuda, \ 92 | miou_out_path=".temp_miou_out", eval_flag=True, period=1): 93 | super(EvalCallback, self).__init__() 94 | 95 | self.net = net 96 | self.input_shape = input_shape 97 | self.num_classes = num_classes 98 | self.image_ids = image_ids 99 | self.dataset_path = dataset_path 100 | self.log_dir = log_dir 101 | self.cuda = cuda 102 | self.miou_out_path = miou_out_path 103 | self.eval_flag = eval_flag 104 | self.period = period 105 | 106 | self.image_ids = [image_id.split()[0] for image_id in image_ids] 107 | self.mious = [0] 108 | self.epoches = [0] 109 | if self.eval_flag: 110 | with open(os.path.join(self.log_dir, "epoch_miou.txt"), 'a') as f: 111 | f.write(str(0)) 112 | f.write("\n") 113 | 114 | def get_miou_png(self, image): 115 | #---------------------------------------------------------# 116 | # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。 117 | # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB 118 | #---------------------------------------------------------# 119 | image = cvtColor(image) 120 | orininal_h = np.array(image).shape[0] 121 | orininal_w = np.array(image).shape[1] 122 | #---------------------------------------------------------# 123 | # 给图像增加灰条,实现不失真的resize 124 | # 也可以直接resize进行识别 125 | #---------------------------------------------------------# 126 | image_data, nw, nh = resize_image(image, (self.input_shape[1],self.input_shape[0])) 127 | #---------------------------------------------------------# 128 | # 添加上batch_size维度 129 | #---------------------------------------------------------# 130 | image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, np.float32)), (2, 0, 1)), 0) 131 | 132 | with torch.no_grad(): 133 | images = torch.from_numpy(image_data) 134 | if self.cuda: 135 | images = images.cuda() 136 | 137 | #---------------------------------------------------# 138 | # 图片传入网络进行预测 139 | #---------------------------------------------------# 140 | pr = self.net(images)[0] 141 | #---------------------------------------------------# 142 | # 取出每一个像素点的种类 143 | #---------------------------------------------------# 144 | pr = F.softmax(pr.permute(1,2,0),dim = -1).cpu().numpy() 145 | #--------------------------------------# 146 | # 将灰条部分截取掉 147 | #--------------------------------------# 148 | pr = pr[int((self.input_shape[0] - nh) // 2) : int((self.input_shape[0] - nh) // 2 + nh), \ 149 | int((self.input_shape[1] - nw) // 2) : int((self.input_shape[1] - nw) // 2 + nw)] 150 | #---------------------------------------------------# 151 | # 进行图片的resize 152 | #---------------------------------------------------# 153 | pr = cv2.resize(pr, (orininal_w, orininal_h), interpolation = cv2.INTER_LINEAR) 154 | #---------------------------------------------------# 155 | # 取出每一个像素点的种类 156 | #---------------------------------------------------# 157 | pr = pr.argmax(axis=-1) 158 | 159 | image = Image.fromarray(np.uint8(pr)) 160 | return image 161 | 162 | def on_epoch_end(self, epoch, model_eval): 163 | if epoch % self.period == 0 and self.eval_flag: 164 | self.net = model_eval 165 | gt_dir = os.path.join(self.dataset_path, "VOC2007/SegmentationClass/") 166 | pred_dir = os.path.join(self.miou_out_path, 'detection-results') 167 | if not os.path.exists(self.miou_out_path): 168 | os.makedirs(self.miou_out_path) 169 | if not os.path.exists(pred_dir): 170 | os.makedirs(pred_dir) 171 | print("Get miou.") 172 | for image_id in tqdm(self.image_ids): 173 | #-------------------------------# 174 | # 从文件中读取图像 175 | #-------------------------------# 176 | image_path = os.path.join(self.dataset_path, "VOC2007/JPEGImages/"+image_id+".jpg") 177 | image = Image.open(image_path) 178 | #------------------------------# 179 | # 获得预测txt 180 | #------------------------------# 181 | image = self.get_miou_png(image) 182 | image.save(os.path.join(pred_dir, image_id + ".png")) 183 | 184 | print("Calculate miou.") 185 | _, IoUs, _, _ = compute_mIoU(gt_dir, pred_dir, self.image_ids, self.num_classes, None) # 执行计算mIoU的函数 186 | temp_miou = np.nanmean(IoUs) * 100 187 | 188 | self.mious.append(temp_miou) 189 | self.epoches.append(epoch) 190 | 191 | with open(os.path.join(self.log_dir, "epoch_miou.txt"), 'a') as f: 192 | f.write(str(temp_miou)) 193 | f.write("\n") 194 | 195 | plt.figure() 196 | plt.plot(self.epoches, self.mious, 'red', linewidth = 2, label='train miou') 197 | 198 | plt.grid(True) 199 | plt.xlabel('Epoch') 200 | plt.ylabel('Miou') 201 | plt.title('A Miou Curve') 202 | plt.legend(loc="upper right") 203 | 204 | plt.savefig(os.path.join(self.log_dir, "epoch_miou.png")) 205 | plt.cla() 206 | plt.close("all") 207 | 208 | print("Get miou done.") 209 | shutil.rmtree(self.miou_out_path) 210 | -------------------------------------------------------------------------------- /utils/utils_metrics.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import os 3 | from os.path import join 4 | 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | from PIL import Image 10 | 11 | 12 | def f_score(inputs, target, beta=1, smooth = 1e-5, threhold = 0.5): 13 | n, c, h, w = inputs.size() 14 | nt, ht, wt, ct = target.size() 15 | if h != ht and w != wt: 16 | inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True) 17 | 18 | temp_inputs = torch.softmax(inputs.transpose(1, 2).transpose(2, 3).contiguous().view(n, -1, c),-1) 19 | temp_target = target.view(n, -1, ct) 20 | 21 | #--------------------------------------------# 22 | # 计算dice系数 23 | #--------------------------------------------# 24 | temp_inputs = torch.gt(temp_inputs, threhold).float() 25 | tp = torch.sum(temp_target[...,:-1] * temp_inputs, axis=[0,1]) 26 | fp = torch.sum(temp_inputs , axis=[0,1]) - tp 27 | fn = torch.sum(temp_target[...,:-1] , axis=[0,1]) - tp 28 | 29 | score = ((1 + beta ** 2) * tp + smooth) / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + smooth) 30 | score = torch.mean(score) 31 | return score 32 | 33 | # 设标签宽W,长H 34 | def fast_hist(a, b, n): 35 | #--------------------------------------------------------------------------------# 36 | # a是转化成一维数组的标签,形状(H×W,);b是转化成一维数组的预测结果,形状(H×W,) 37 | #--------------------------------------------------------------------------------# 38 | k = (a >= 0) & (a < n) 39 | #--------------------------------------------------------------------------------# 40 | # np.bincount计算了从0到n**2-1这n**2个数中每个数出现的次数,返回值形状(n, n) 41 | # 返回中,写对角线上的为分类正确的像素点 42 | #--------------------------------------------------------------------------------# 43 | return np.bincount(n * a[k].astype(int) + b[k], minlength=n ** 2).reshape(n, n) 44 | 45 | def per_class_iu(hist): 46 | return np.diag(hist) / np.maximum((hist.sum(1) + hist.sum(0) - np.diag(hist)), 1) 47 | 48 | def per_class_PA_Recall(hist): 49 | return np.diag(hist) / np.maximum(hist.sum(1), 1) 50 | 51 | def per_class_Precision(hist): 52 | return np.diag(hist) / np.maximum(hist.sum(0), 1) 53 | 54 | def per_Accuracy(hist): 55 | return np.sum(np.diag(hist)) / np.maximum(np.sum(hist), 1) 56 | 57 | def compute_mIoU(gt_dir, pred_dir, png_name_list, num_classes, name_classes=None): 58 | print('Num classes', num_classes) 59 | #-----------------------------------------# 60 | # 创建一个全是0的矩阵,是一个混淆矩阵 61 | #-----------------------------------------# 62 | hist = np.zeros((num_classes, num_classes)) 63 | 64 | #------------------------------------------------# 65 | # 获得验证集标签路径列表,方便直接读取 66 | # 获得验证集图像分割结果路径列表,方便直接读取 67 | #------------------------------------------------# 68 | gt_imgs = [join(gt_dir, x + ".png") for x in png_name_list] 69 | pred_imgs = [join(pred_dir, x + ".png") for x in png_name_list] 70 | 71 | #------------------------------------------------# 72 | # 读取每一个(图片-标签)对 73 | #------------------------------------------------# 74 | for ind in range(len(gt_imgs)): 75 | #------------------------------------------------# 76 | # 读取一张图像分割结果,转化成numpy数组 77 | #------------------------------------------------# 78 | pred = np.array(Image.open(pred_imgs[ind])) 79 | #------------------------------------------------# 80 | # 读取一张对应的标签,转化成numpy数组 81 | #------------------------------------------------# 82 | label = np.array(Image.open(gt_imgs[ind])) 83 | 84 | # 如果图像分割结果与标签的大小不一样,这张图片就不计算 85 | if len(label.flatten()) != len(pred.flatten()): 86 | print( 87 | 'Skipping: len(gt) = {:d}, len(pred) = {:d}, {:s}, {:s}'.format( 88 | len(label.flatten()), len(pred.flatten()), gt_imgs[ind], 89 | pred_imgs[ind])) 90 | continue 91 | 92 | #------------------------------------------------# 93 | # 对一张图片计算21×21的hist矩阵,并累加 94 | #------------------------------------------------# 95 | hist += fast_hist(label.flatten(), pred.flatten(), num_classes) 96 | # 每计算10张就输出一下目前已计算的图片中所有类别平均的mIoU值 97 | if name_classes is not None and ind > 0 and ind % 10 == 0: 98 | print('{:d} / {:d}: mIou-{:0.2f}%; mPA-{:0.2f}%; Accuracy-{:0.2f}%'.format( 99 | ind, 100 | len(gt_imgs), 101 | 100 * np.nanmean(per_class_iu(hist)), 102 | 100 * np.nanmean(per_class_PA_Recall(hist)), 103 | 100 * per_Accuracy(hist) 104 | ) 105 | ) 106 | #------------------------------------------------# 107 | # 计算所有验证集图片的逐类别mIoU值 108 | #------------------------------------------------# 109 | IoUs = per_class_iu(hist) 110 | PA_Recall = per_class_PA_Recall(hist) 111 | Precision = per_class_Precision(hist) 112 | #------------------------------------------------# 113 | # 逐类别输出一下mIoU值 114 | #------------------------------------------------# 115 | if name_classes is not None: 116 | for ind_class in range(num_classes): 117 | print('===>' + name_classes[ind_class] + ':\tIou-' + str(round(IoUs[ind_class] * 100, 2)) \ 118 | + '; Recall (equal to the PA)-' + str(round(PA_Recall[ind_class] * 100, 2))+ '; Precision-' + str(round(Precision[ind_class] * 100, 2))) 119 | 120 | #-----------------------------------------------------------------# 121 | # 在所有验证集图像上求所有类别平均的mIoU值,计算时忽略NaN值 122 | #-----------------------------------------------------------------# 123 | print('===> mIoU: ' + str(round(np.nanmean(IoUs) * 100, 2)) + '; mPA: ' + str(round(np.nanmean(PA_Recall) * 100, 2)) + '; Accuracy: ' + str(round(per_Accuracy(hist) * 100, 2))) 124 | return np.array(hist, np.int), IoUs, PA_Recall, Precision 125 | 126 | def adjust_axes(r, t, fig, axes): 127 | bb = t.get_window_extent(renderer=r) 128 | text_width_inches = bb.width / fig.dpi 129 | current_fig_width = fig.get_figwidth() 130 | new_fig_width = current_fig_width + text_width_inches 131 | propotion = new_fig_width / current_fig_width 132 | x_lim = axes.get_xlim() 133 | axes.set_xlim([x_lim[0], x_lim[1] * propotion]) 134 | 135 | def draw_plot_func(values, name_classes, plot_title, x_label, output_path, tick_font_size = 12, plt_show = True): 136 | fig = plt.gcf() 137 | axes = plt.gca() 138 | print(values) 139 | plt.barh(range(len(values)), values, color='royalblue') 140 | plt.title("All data", fontsize=tick_font_size + 2) 141 | plt.xlabel(x_label, fontsize=tick_font_size) 142 | plt.yticks(range(len(values)), name_classes, fontsize=tick_font_size) 143 | r = fig.canvas.get_renderer() 144 | for i, val in enumerate(values): 145 | str_val = " " + str(val) 146 | if val < 1.0: 147 | str_val = " {0:.2f}".format(val) 148 | t = plt.text(val, i, str_val, color='royalblue', va='center', fontweight='bold') 149 | if i == (len(values)-1): 150 | adjust_axes(r, t, fig, axes) 151 | 152 | fig.tight_layout() 153 | fig.savefig(output_path) 154 | if plt_show: 155 | plt.show() 156 | plt.close() 157 | 158 | def show_results(miou_out_path, hist, IoUs, PA_Recall, Precision, name_classes, tick_font_size = 12): 159 | print(IoUs) 160 | classes=['IoUs','PA_Recall','Precision'] 161 | value=[] 162 | value.append(IoUs[1]) 163 | value.append(PA_Recall[1]) 164 | value.append(Precision[1]) 165 | draw_plot_func(value, classes, "mIoU = {0:.3f}%".format(np.nanmean(IoUs)*100), "Intersection over Union", \ 166 | os.path.join(miou_out_path, "All.png"), tick_font_size = tick_font_size, plt_show = True) 167 | draw_plot_func(IoUs, name_classes, "mIoU = {0:.3f}%".format(np.nanmean(IoUs)*100), "Intersection over Union", \ 168 | os.path.join(miou_out_path, "mIoU.png"), tick_font_size = tick_font_size, plt_show = True) 169 | print("Save mIoU out to " + os.path.join(miou_out_path, "mIoU.png")) 170 | 171 | draw_plot_func(PA_Recall, name_classes, "mPA = {0:.3f}%".format(np.nanmean(PA_Recall)*100.0), "Pixel Accuracy", \ 172 | os.path.join(miou_out_path, "mPA.png"), tick_font_size = tick_font_size, plt_show = False) 173 | print("Save mPA out to " + os.path.join(miou_out_path, "mPA.png")) 174 | 175 | draw_plot_func(PA_Recall, name_classes, "mRecall = {0:.3f}%".format(np.nanmean(PA_Recall)*100.0), "Recall", \ 176 | os.path.join(miou_out_path, "Recall.png"), tick_font_size = tick_font_size, plt_show = False) 177 | print("Save Recall out to " + os.path.join(miou_out_path, "Recall.png")) 178 | 179 | draw_plot_func(Precision, name_classes, "mPrecision = {0:.3f}%".format(np.nanmean(Precision)*100.0), "Precision", \ 180 | os.path.join(miou_out_path, "Precision.png"), tick_font_size = tick_font_size, plt_show = False) 181 | print("Save Precision out to " + os.path.join(miou_out_path, "Precision.png")) 182 | 183 | with open(os.path.join(miou_out_path, "confusion_matrix.csv"), 'w', newline='') as f: 184 | writer = csv.writer(f) 185 | writer_list = [] 186 | writer_list.append([' '] + [str(c) for c in name_classes]) 187 | for i in range(len(hist)): 188 | writer_list.append([name_classes[i]] + [str(x) for x in hist[i]]) 189 | writer.writerows(writer_list) 190 | print("Save confusion_matrix out to " + os.path.join(miou_out_path, "confusion_matrix.csv")) 191 | -------------------------------------------------------------------------------- /predictLabel.py: -------------------------------------------------------------------------------- 1 | #----------------------------------------------------# 2 | # 将单张图片预测、摄像头检测和FPS测试功能 3 | # 整合到了一个py文件中,通过指定mode进行模式的修改。 4 | #----------------------------------------------------# 5 | import time 6 | 7 | import cv2 8 | import numpy as np 9 | from PIL import Image 10 | 11 | from unet import Unet 12 | 13 | def predict(file_name): 14 | #-------------------------------------------------------------------------# 15 | # 如果想要修改对应种类的颜色,到__init__函数里修改self.colors即可 16 | #-------------------------------------------------------------------------# 17 | unet = Unet() 18 | #----------------------------------------------------------------------------------------------------------# 19 | # mode用于指定测试的模式: 20 | # 'predict' 表示单张图片预测,如果想对预测过程进行修改,如保存图片,截取对象等,可以先看下方详细的注释 21 | # 'video' 表示视频检测,可调用摄像头或者视频进行检测,详情查看下方注释。 22 | # 'fps' 表示测试fps,使用的图片是img里面的street.jpg,详情查看下方注释。 23 | # 'dir_predict' 表示遍历文件夹进行检测并保存。默认遍历img文件夹,保存img_out文件夹,详情查看下方注释。 24 | # 'export_onnx' 表示将模型导出为onnx,需要pytorch1.7.1以上。 25 | #----------------------------------------------------------------------------------------------------------# 26 | mode = "predict" ### 检测图片的 27 | # mode = "video" ## 检测视频的 28 | # mode = "video" 29 | #-------------------------------------------------------------------------# 30 | # count 指定了是否进行目标的像素点计数(即面积)与比例计算 31 | # name_classes 区分的种类,和json_to_dataset里面的一样,用于打印种类和数量 32 | # 33 | # count、name_classes仅在mode='predict'时有效 34 | #-------------------------------------------------------------------------# 35 | count = False 36 | # name_classes = ["background","aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"] 37 | #name_classes = ["_background_", "person", "car", "motorbike", "dustbin", "chair", "fire_hydrant", "tricycle", "bicycle","stone"] 38 | name_classes = ["background","EX", "HE", "MA", "SE"] 39 | #----------------------------------------------------------------------------------------------------------# 40 | # video_path 用于指定视频的路径,当video_path=0时表示检测摄像头 41 | # 想要检测视频,则设置如video_path = "xxx.mp4"即可,代表读取出根目录下的xxx.mp4文件。 42 | # video_save_path 表示视频保存的路径,当video_save_path=""时表示不保存 43 | # 想要保存视频,则设置如video_save_path = "yyy.mp4"即可,代表保存为根目录下的yyy.mp4文件。 44 | # video_fps 用于保存的视频的fps 45 | # 46 | # video_path、video_save_path和video_fps仅在mode='video'时有效 47 | # 保存视频时需要ctrl+c退出或者运行到最后一帧才会完成完整的保存步骤。 48 | #----------------------------------------------------------------------------------------------------------# 49 | video_path = 0 50 | video_save_path = "" 51 | video_fps = 25.0 52 | #----------------------------------------------------------------------------------------------------------# 53 | # test_interval 用于指定测量fps的时候,图片检测的次数。理论上test_interval越大,fps越准确。 54 | # fps_image_path 用于指定测试的fps图片 55 | # 56 | # test_interval和fps_image_path仅在mode='fps'有效 57 | #----------------------------------------------------------------------------------------------------------# 58 | test_interval = 100 59 | fps_image_path = "E:/setup/UNet_Demo/UNet_Demo/NanKai/train/image/007-3939-200.jpg" 60 | #-------------------------------------------------------------------------# 61 | # dir_origin_path 指定了用于检测的图片的文件夹路径 62 | # dir_save_path 指定了检测完图片的保存路径 63 | # 64 | # dir_origin_path和dir_save_path仅在mode='dir_predict'时有效 65 | #-------------------------------------------------------------------------# 66 | dir_origin_path = "img/" 67 | dir_save_path = "img_out/" 68 | #-------------------------------------------------------------------------# 69 | # simplify 使用Simplify onnx 70 | # onnx_save_path 指定了onnx的保存路径 71 | #-------------------------------------------------------------------------# 72 | simplify = True 73 | onnx_save_path = "model_data/models.onnx" 74 | 75 | if mode == "predict": 76 | ''' 77 | predict.py有几个注意点 78 | 1、该代码无法直接进行批量预测,如果想要批量预测,可以利用os.listdir()遍历文件夹,利用Image.open打开图片文件进行预测。 79 | 具体流程可以参考get_miou_prediction.py,在get_miou_prediction.py即实现了遍历。 80 | 2、如果想要保存,利用r_image.save("img.jpg")即可保存。 81 | 3、如果想要原图和分割图不混合,可以把blend参数设置成False。 82 | 4、如果想根据mask获取对应的区域,可以参考detect_image函数中,利用预测结果绘图的部分,判断每一个像素点的种类,然后根据种类获取对应的部分。 83 | seg_img = np.zeros((np.shape(pr)[0],np.shape(pr)[1],3)) 84 | for c in range(self.num_classes): 85 | seg_img[:, :, 0] += ((pr == c)*( self.colors[c][0] )).astype('uint8') 86 | seg_img[:, :, 1] += ((pr == c)*( self.colors[c][1] )).astype('uint8') 87 | seg_img[:, :, 2] += ((pr == c)*( self.colors[c][2] )).astype('uint8') 88 | ''' 89 | img = 'static/result/'+file_name 90 | try: 91 | image = Image.open(img) 92 | except: 93 | print('Open Error! Try again!') 94 | else: 95 | r_image = unet.detect_image(image, count=count, name_classes=name_classes) 96 | r_image.save('static/label/'+file_name) 97 | 98 | 99 | ## 下面是检测电脑自带摄像头的 100 | elif mode == "video_camera": 101 | capture=cv2.VideoCapture(video_path) 102 | if video_save_path!="": 103 | fourcc = cv2.VideoWriter_fourcc(*'XVID') 104 | size = (int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)), int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))) 105 | out = cv2.VideoWriter(video_save_path, fourcc, video_fps, size) 106 | 107 | ref, frame = capture.read() 108 | if not ref: 109 | raise ValueError("未能正确读取摄像头(视频),请注意是否正确安装摄像头(是否正确填写视频路径)。") 110 | 111 | fps = 0.0 112 | while(True): 113 | t1 = time.time() 114 | # 读取某一帧 115 | ref, frame = capture.read() 116 | if not ref: 117 | break 118 | # 格式转变,BGRtoRGB 119 | frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB) 120 | # 转变成Image 121 | frame = Image.fromarray(np.uint8(frame)) 122 | # 进行检测 123 | frame = np.array(unet.detect_image(frame)) 124 | # RGBtoBGR满足opencv显示格式 125 | frame = cv2.cvtColor(frame,cv2.COLOR_RGB2BGR) 126 | 127 | fps = ( fps + (1./(time.time()-t1)) ) / 2 128 | print("fps= %.2f"%(fps)) 129 | frame = cv2.putText(frame, "fps= %.2f"%(fps), (0, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2) 130 | 131 | cv2.imshow("video",frame) 132 | c= cv2.waitKey(1) & 0xff 133 | if video_save_path!="": 134 | out.write(frame) 135 | 136 | if c==27: 137 | capture.release() 138 | break 139 | print("Video Detection Done!") 140 | capture.release() 141 | if video_save_path!="": 142 | print("Save processed video to the path :" + video_save_path) 143 | out.release() 144 | cv2.destroyAllWindows() 145 | 146 | elif mode == "fps": 147 | img = Image.open('img/street.jpg') 148 | tact_time = unet.get_FPS(img, test_interval) 149 | print(str(tact_time) + ' seconds, ' + str(1/tact_time) + 'FPS, @batch_size 1') 150 | 151 | elif mode == "dir_predict": 152 | import os 153 | from tqdm import tqdm 154 | 155 | img_names = os.listdir(dir_origin_path) 156 | for img_name in tqdm(img_names): 157 | if img_name.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')): 158 | image_path = os.path.join(dir_origin_path, img_name) 159 | image = Image.open(image_path) 160 | r_image = unet.detect_image(image) 161 | if not os.path.exists(dir_save_path): 162 | os.makedirs(dir_save_path) 163 | r_image.save(os.path.join(dir_save_path, img_name)) 164 | elif mode == "export_onnx": 165 | unet.convert_to_onnx(simplify, onnx_save_path) 166 | 167 | 168 | ## 下面是检测电脑硬盘中的视频文件 169 | elif mode == "video": 170 | 171 | while True: 172 | video = input("Input video filename:") 173 | try: 174 | capture = cv2.VideoCapture(video) 175 | except: 176 | print("Open Error! Try again!") 177 | continue 178 | else: 179 | # capture = cv2.VideoCapture(video_path) 180 | if video_save_path != "": 181 | fourcc = cv2.VideoWriter_fourcc(*'XVID') 182 | size = (int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)), int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))) 183 | out = cv2.VideoWriter(video_save_path, fourcc, video_fps, size) 184 | 185 | ref, frame = capture.read() 186 | if not ref: 187 | raise ValueError("未能正确读取摄像头(视频),请注意是否正确安装摄像头(是否正确填写视频路径)。") 188 | 189 | fps = 0.0 190 | while (True): 191 | t1 = time.time() 192 | # 读取某一帧 193 | ref, frame = capture.read() 194 | if not ref: 195 | break 196 | # 格式转变,BGRtoRGB 197 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 198 | # 转变成Image 199 | frame = Image.fromarray(np.uint8(frame)) 200 | # 进行检测 201 | frame = np.array(unet.detect_image(frame)) 202 | # RGBtoBGR满足opencv显示格式 203 | frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) 204 | 205 | fps = (fps + (1. / (time.time() - t1))) / 2 206 | print("fps= %.2f" % (fps)) 207 | frame = cv2.putText(frame, "fps= %.2f" % (fps), (0, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2) 208 | 209 | cv2.imshow("video", frame) 210 | c = cv2.waitKey(1) & 0xff 211 | if video_save_path != "": 212 | out.write(frame) 213 | 214 | if c == 27: 215 | capture.release() 216 | break 217 | print("Video Detection Done!") 218 | capture.release() 219 | if video_save_path != "": 220 | print("Save processed video to the path :" + video_save_path) 221 | out.release() 222 | cv2.destroyAllWindows() 223 | 224 | else: 225 | raise AssertionError("Please specify the correct mode: 'predict', 'video', 'fps' or 'dir_predict'.") 226 | 227 | 228 | 229 | 230 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | #----------------------------------------------------# 2 | # 将单张图片预测、摄像头检测和FPS测试功能 3 | # 整合到了一个py文件中,通过指定mode进行模式的修改。 4 | #----------------------------------------------------# 5 | import time 6 | 7 | import cv2 8 | import numpy as np 9 | from PIL import Image 10 | 11 | from unet import Unet 12 | 13 | if __name__ == "__main__": 14 | #-------------------------------------------------------------------------# 15 | # 如果想要修改对应种类的颜色,到__init__函数里修改self.colors即可 16 | #-------------------------------------------------------------------------# 17 | unet = Unet() 18 | #----------------------------------------------------------------------------------------------------------# 19 | # mode用于指定测试的模式: 20 | # 'predict' 表示单张图片预测,如果想对预测过程进行修改,如保存图片,截取对象等,可以先看下方详细的注释 21 | # 'video' 表示视频检测,可调用摄像头或者视频进行检测,详情查看下方注释。 22 | # 'fps' 表示测试fps,使用的图片是img里面的street.jpg,详情查看下方注释。 23 | # 'dir_predict' 表示遍历文件夹进行检测并保存。默认遍历img文件夹,保存img_out文件夹,详情查看下方注释。 24 | # 'export_onnx' 表示将模型导出为onnx,需要pytorch1.7.1以上。 25 | #----------------------------------------------------------------------------------------------------------# 26 | mode = "predict" ### 检测图片的 27 | # mode = "video" ## 检测视频的 28 | # mode = "video" 29 | #-------------------------------------------------------------------------# 30 | # count 指定了是否进行目标的像素点计数(即面积)与比例计算 31 | # name_classes 区分的种类,和json_to_dataset里面的一样,用于打印种类和数量 32 | # 33 | # count、name_classes仅在mode='predict'时有效 34 | #-------------------------------------------------------------------------# 35 | count = False 36 | # name_classes = ["background","aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"] 37 | #name_classes = ["_background_", "person", "car", "motorbike", "dustbin", "chair", "fire_hydrant", "tricycle", "bicycle","stone"] 38 | name_classes = ["background","cat"] 39 | #----------------------------------------------------------------------------------------------------------# 40 | # video_path 用于指定视频的路径,当video_path=0时表示检测摄像头 41 | # 想要检测视频,则设置如video_path = "xxx.mp4"即可,代表读取出根目录下的xxx.mp4文件。 42 | # video_save_path 表示视频保存的路径,当video_save_path=""时表示不保存 43 | # 想要保存视频,则设置如video_save_path = "yyy.mp4"即可,代表保存为根目录下的yyy.mp4文件。 44 | # video_fps 用于保存的视频的fps 45 | # 46 | # video_path、video_save_path和video_fps仅在mode='video'时有效 47 | # 保存视频时需要ctrl+c退出或者运行到最后一帧才会完成完整的保存步骤。 48 | #----------------------------------------------------------------------------------------------------------# 49 | video_path = 0 50 | video_save_path = "" 51 | video_fps = 25.0 52 | #----------------------------------------------------------------------------------------------------------# 53 | # test_interval 用于指定测量fps的时候,图片检测的次数。理论上test_interval越大,fps越准确。 54 | # fps_image_path 用于指定测试的fps图片 55 | # 56 | # test_interval和fps_image_path仅在mode='fps'有效 57 | #----------------------------------------------------------------------------------------------------------# 58 | test_interval = 100 59 | fps_image_path = "E:/setup/UNet_Demo/UNet_Demo/NanKai/train/image/007-3939-200.jpg" 60 | #-------------------------------------------------------------------------# 61 | # dir_origin_path 指定了用于检测的图片的文件夹路径 62 | # dir_save_path 指定了检测完图片的保存路径 63 | # 64 | # dir_origin_path和dir_save_path仅在mode='dir_predict'时有效 65 | #-------------------------------------------------------------------------# 66 | dir_origin_path = "img/" 67 | dir_save_path = "img_out/" 68 | #-------------------------------------------------------------------------# 69 | # simplify 使用Simplify onnx 70 | # onnx_save_path 指定了onnx的保存路径 71 | #-------------------------------------------------------------------------# 72 | simplify = True 73 | onnx_save_path = "model_data/models.onnx" 74 | 75 | if mode == "predict": 76 | ''' 77 | predict.py有几个注意点 78 | 1、该代码无法直接进行批量预测,如果想要批量预测,可以利用os.listdir()遍历文件夹,利用Image.open打开图片文件进行预测。 79 | 具体流程可以参考get_miou_prediction.py,在get_miou_prediction.py即实现了遍历。 80 | 2、如果想要保存,利用r_image.save("img.jpg")即可保存。 81 | 3、如果想要原图和分割图不混合,可以把blend参数设置成False。 82 | 4、如果想根据mask获取对应的区域,可以参考detect_image函数中,利用预测结果绘图的部分,判断每一个像素点的种类,然后根据种类获取对应的部分。 83 | seg_img = np.zeros((np.shape(pr)[0],np.shape(pr)[1],3)) 84 | for c in range(self.num_classes): 85 | seg_img[:, :, 0] += ((pr == c)*( self.colors[c][0] )).astype('uint8') 86 | seg_img[:, :, 1] += ((pr == c)*( self.colors[c][1] )).astype('uint8') 87 | seg_img[:, :, 2] += ((pr == c)*( self.colors[c][2] )).astype('uint8') 88 | ''' 89 | while True: 90 | img = input('Input image filename:') 91 | try: 92 | image = Image.open(img) 93 | except: 94 | print('Open Error! Try again!') 95 | continue 96 | else: 97 | r_image = unet.detect_image(image, count=count, name_classes=name_classes) 98 | r_image.save('backend/static/label') 99 | 100 | 101 | ## 下面是检测电脑自带摄像头的 102 | elif mode == "video_camera": 103 | capture=cv2.VideoCapture(video_path) 104 | if video_save_path!="": 105 | fourcc = cv2.VideoWriter_fourcc(*'XVID') 106 | size = (int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)), int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))) 107 | out = cv2.VideoWriter(video_save_path, fourcc, video_fps, size) 108 | 109 | ref, frame = capture.read() 110 | if not ref: 111 | raise ValueError("未能正确读取摄像头(视频),请注意是否正确安装摄像头(是否正确填写视频路径)。") 112 | 113 | fps = 0.0 114 | while(True): 115 | t1 = time.time() 116 | # 读取某一帧 117 | ref, frame = capture.read() 118 | if not ref: 119 | break 120 | # 格式转变,BGRtoRGB 121 | frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB) 122 | # 转变成Image 123 | frame = Image.fromarray(np.uint8(frame)) 124 | # 进行检测 125 | frame = np.array(unet.detect_image(frame)) 126 | # RGBtoBGR满足opencv显示格式 127 | frame = cv2.cvtColor(frame,cv2.COLOR_RGB2BGR) 128 | 129 | fps = ( fps + (1./(time.time()-t1)) ) / 2 130 | print("fps= %.2f"%(fps)) 131 | frame = cv2.putText(frame, "fps= %.2f"%(fps), (0, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2) 132 | 133 | cv2.imshow("video",frame) 134 | c= cv2.waitKey(1) & 0xff 135 | if video_save_path!="": 136 | out.write(frame) 137 | 138 | if c==27: 139 | capture.release() 140 | break 141 | print("Video Detection Done!") 142 | capture.release() 143 | if video_save_path!="": 144 | print("Save processed video to the path :" + video_save_path) 145 | out.release() 146 | cv2.destroyAllWindows() 147 | 148 | elif mode == "fps": 149 | img = Image.open('img/street.jpg') 150 | tact_time = unet.get_FPS(img, test_interval) 151 | print(str(tact_time) + ' seconds, ' + str(1/tact_time) + 'FPS, @batch_size 1') 152 | 153 | elif mode == "dir_predict": 154 | import os 155 | from tqdm import tqdm 156 | 157 | img_names = os.listdir(dir_origin_path) 158 | for img_name in tqdm(img_names): 159 | if img_name.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')): 160 | image_path = os.path.join(dir_origin_path, img_name) 161 | image = Image.open(image_path) 162 | r_image = unet.detect_image(image) 163 | if not os.path.exists(dir_save_path): 164 | os.makedirs(dir_save_path) 165 | r_image.save(os.path.join(dir_save_path, img_name)) 166 | elif mode == "export_onnx": 167 | unet.convert_to_onnx(simplify, onnx_save_path) 168 | 169 | 170 | ## 下面是检测电脑硬盘中的视频文件 171 | elif mode == "video": 172 | 173 | while True: 174 | video = input("Input video filename:") 175 | try: 176 | capture = cv2.VideoCapture(video) 177 | except: 178 | print("Open Error! Try again!") 179 | continue 180 | else: 181 | # capture = cv2.VideoCapture(video_path) 182 | if video_save_path != "": 183 | fourcc = cv2.VideoWriter_fourcc(*'XVID') 184 | size = (int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)), int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))) 185 | out = cv2.VideoWriter(video_save_path, fourcc, video_fps, size) 186 | 187 | ref, frame = capture.read() 188 | if not ref: 189 | raise ValueError("未能正确读取摄像头(视频),请注意是否正确安装摄像头(是否正确填写视频路径)。") 190 | 191 | fps = 0.0 192 | while (True): 193 | t1 = time.time() 194 | # 读取某一帧 195 | ref, frame = capture.read() 196 | if not ref: 197 | break 198 | # 格式转变,BGRtoRGB 199 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 200 | # 转变成Image 201 | frame = Image.fromarray(np.uint8(frame)) 202 | # 进行检测 203 | frame = np.array(unet.detect_image(frame)) 204 | # RGBtoBGR满足opencv显示格式 205 | frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) 206 | 207 | fps = (fps + (1. / (time.time() - t1))) / 2 208 | print("fps= %.2f" % (fps)) 209 | frame = cv2.putText(frame, "fps= %.2f" % (fps), (0, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2) 210 | 211 | cv2.imshow("video", frame) 212 | c = cv2.waitKey(1) & 0xff 213 | if video_save_path != "": 214 | out.write(frame) 215 | 216 | if c == 27: 217 | capture.release() 218 | break 219 | print("Video Detection Done!") 220 | capture.release() 221 | if video_save_path != "": 222 | print("Save processed video to the path :" + video_save_path) 223 | out.release() 224 | cv2.destroyAllWindows() 225 | 226 | else: 227 | raise AssertionError("Please specify the correct mode: 'predict', 'video', 'fps' or 'dir_predict'.") 228 | 229 | 230 | 231 | 232 | -------------------------------------------------------------------------------- /utils/utils_fit.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from nets.unet_training import CE_Loss, Dice_loss, Focal_Loss 5 | from tqdm import tqdm 6 | 7 | from utils.utils import get_lr 8 | from utils.utils_metrics import f_score 9 | 10 | 11 | def fit_one_epoch(model_train, model, loss_history, eval_callback, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val, Epoch, cuda, dice_loss, focal_loss, cls_weights, num_classes, fp16, scaler, save_period, save_dir, local_rank=0): 12 | total_loss = 0 13 | total_f_score = 0 14 | 15 | val_loss = 0 16 | val_f_score = 0 17 | 18 | if local_rank == 0: 19 | print('Start Train') 20 | pbar = tqdm(total=epoch_step,desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3) 21 | model_train.train() 22 | for iteration, batch in enumerate(gen): 23 | if iteration >= epoch_step: 24 | break 25 | imgs, pngs, labels = batch 26 | with torch.no_grad(): 27 | weights = torch.from_numpy(cls_weights) 28 | if cuda: 29 | imgs = imgs.cuda(local_rank) 30 | pngs = pngs.cuda(local_rank) 31 | labels = labels.cuda(local_rank) 32 | weights = weights.cuda(local_rank) 33 | 34 | optimizer.zero_grad() 35 | if not fp16: 36 | #----------------------# 37 | # 前向传播 38 | #----------------------# 39 | outputs = model_train(imgs) 40 | #----------------------# 41 | # 损失计算 42 | #----------------------# 43 | if focal_loss: 44 | loss = Focal_Loss(outputs, pngs, weights, num_classes = num_classes) 45 | else: 46 | loss = CE_Loss(outputs, pngs, weights, num_classes = num_classes) 47 | 48 | if dice_loss: 49 | main_dice = Dice_loss(outputs, labels) 50 | loss = loss + main_dice 51 | 52 | with torch.no_grad(): 53 | #-------------------------------# 54 | # 计算f_score 55 | #-------------------------------# 56 | _f_score = f_score(outputs, labels) 57 | 58 | loss.backward() 59 | optimizer.step() 60 | else: 61 | from torch.cuda.amp import autocast 62 | with autocast(): 63 | #----------------------# 64 | # 前向传播 65 | #----------------------# 66 | outputs = model_train(imgs) 67 | #----------------------# 68 | # 损失计算 69 | #----------------------# 70 | if focal_loss: 71 | loss = Focal_Loss(outputs, pngs, weights, num_classes = num_classes) 72 | else: 73 | loss = CE_Loss(outputs, pngs, weights, num_classes = num_classes) 74 | 75 | if dice_loss: 76 | main_dice = Dice_loss(outputs, labels) 77 | loss = loss + main_dice 78 | 79 | with torch.no_grad(): 80 | #-------------------------------# 81 | # 计算f_score 82 | #-------------------------------# 83 | _f_score = f_score(outputs, labels) 84 | 85 | #----------------------# 86 | # 反向传播 87 | #----------------------# 88 | scaler.scale(loss).backward() 89 | scaler.step(optimizer) 90 | scaler.update() 91 | 92 | total_loss += loss.item() 93 | total_f_score += _f_score.item() 94 | 95 | if local_rank == 0: 96 | pbar.set_postfix(**{'total_loss': total_loss / (iteration + 1), 97 | 'f_score' : total_f_score / (iteration + 1), 98 | 'lr' : get_lr(optimizer)}) 99 | pbar.update(1) 100 | 101 | if local_rank == 0: 102 | pbar.close() 103 | print('Finish Train') 104 | print('Start Validation') 105 | pbar = tqdm(total=epoch_step_val, desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3) 106 | 107 | model_train.eval() 108 | for iteration, batch in enumerate(gen_val): 109 | if iteration >= epoch_step_val: 110 | break 111 | imgs, pngs, labels = batch 112 | with torch.no_grad(): 113 | weights = torch.from_numpy(cls_weights) 114 | if cuda: 115 | imgs = imgs.cuda(local_rank) 116 | pngs = pngs.cuda(local_rank) 117 | labels = labels.cuda(local_rank) 118 | weights = weights.cuda(local_rank) 119 | 120 | #----------------------# 121 | # 前向传播 122 | #----------------------# 123 | outputs = model_train(imgs) 124 | #----------------------# 125 | # 损失计算 126 | #----------------------# 127 | if focal_loss: 128 | loss = Focal_Loss(outputs, pngs, weights, num_classes = num_classes) 129 | else: 130 | loss = CE_Loss(outputs, pngs, weights, num_classes = num_classes) 131 | 132 | if dice_loss: 133 | main_dice = Dice_loss(outputs, labels) 134 | loss = loss + main_dice 135 | #-------------------------------# 136 | # 计算f_score 137 | #-------------------------------# 138 | _f_score = f_score(outputs, labels) 139 | 140 | val_loss += loss.item() 141 | val_f_score += _f_score.item() 142 | 143 | if local_rank == 0: 144 | pbar.set_postfix(**{'val_loss' : val_loss / (iteration + 1), 145 | 'f_score' : val_f_score / (iteration + 1), 146 | 'lr' : get_lr(optimizer)}) 147 | pbar.update(1) 148 | 149 | if local_rank == 0: 150 | pbar.close() 151 | print('Finish Validation') 152 | loss_history.append_loss(epoch + 1, total_loss/ epoch_step, val_loss/ epoch_step_val) 153 | eval_callback.on_epoch_end(epoch + 1, model_train) 154 | print('Epoch:'+ str(epoch+1) + '/' + str(Epoch)) 155 | print('Total Loss: %.3f || Val Loss: %.3f ' % (total_loss / epoch_step, val_loss / epoch_step_val)) 156 | 157 | #-----------------------------------------------# 158 | # 保存权值 159 | #-----------------------------------------------# 160 | if (epoch + 1) % save_period == 0 or epoch + 1 == Epoch: 161 | torch.save(model.state_dict(), os.path.join(save_dir, 'ep%03d-loss%.3f-val_loss%.3f.pth'%((epoch + 1), total_loss / epoch_step, val_loss / epoch_step_val))) 162 | 163 | if len(loss_history.val_loss) <= 1 or (val_loss / epoch_step_val) <= min(loss_history.val_loss): 164 | print('Save best model to best_epoch_weights.pth') 165 | torch.save(model.state_dict(), os.path.join(save_dir, "best_epoch_weights.pth")) 166 | 167 | torch.save(model.state_dict(), os.path.join(save_dir, "last_epoch_weights.pth")) 168 | 169 | def fit_one_epoch_no_val(model_train, model, loss_history, optimizer, epoch, epoch_step, gen, Epoch, cuda, dice_loss, focal_loss, cls_weights, num_classes, fp16, scaler, save_period, save_dir, local_rank=0): 170 | total_loss = 0 171 | total_f_score = 0 172 | 173 | if local_rank == 0: 174 | print('Start Train') 175 | pbar = tqdm(total=epoch_step,desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3) 176 | model_train.train() 177 | for iteration, batch in enumerate(gen): 178 | if iteration >= epoch_step: 179 | break 180 | imgs, pngs, labels = batch 181 | with torch.no_grad(): 182 | weights = torch.from_numpy(cls_weights) 183 | if cuda: 184 | imgs = imgs.cuda(local_rank) 185 | pngs = pngs.cuda(local_rank) 186 | labels = labels.cuda(local_rank) 187 | weights = weights.cuda(local_rank) 188 | 189 | optimizer.zero_grad() 190 | if not fp16: 191 | #----------------------# 192 | # 前向传播 193 | #----------------------# 194 | outputs = model_train(imgs) 195 | #----------------------# 196 | # 损失计算 197 | #----------------------# 198 | if focal_loss: 199 | loss = Focal_Loss(outputs, pngs, weights, num_classes = num_classes) 200 | else: 201 | loss = CE_Loss(outputs, pngs, weights, num_classes = num_classes) 202 | 203 | if dice_loss: 204 | main_dice = Dice_loss(outputs, labels) 205 | loss = loss + main_dice 206 | 207 | with torch.no_grad(): 208 | #-------------------------------# 209 | # 计算f_score 210 | #-------------------------------# 211 | _f_score = f_score(outputs, labels) 212 | 213 | loss.backward() 214 | optimizer.step() 215 | else: 216 | from torch.cuda.amp import autocast 217 | with autocast(): 218 | #----------------------# 219 | # 前向传播 220 | #----------------------# 221 | outputs = model_train(imgs) 222 | #----------------------# 223 | # 损失计算 224 | #----------------------# 225 | if focal_loss: 226 | loss = Focal_Loss(outputs, pngs, weights, num_classes = num_classes) 227 | else: 228 | loss = CE_Loss(outputs, pngs, weights, num_classes = num_classes) 229 | 230 | if dice_loss: 231 | main_dice = Dice_loss(outputs, labels) 232 | loss = loss + main_dice 233 | 234 | with torch.no_grad(): 235 | #-------------------------------# 236 | # 计算f_score 237 | #-------------------------------# 238 | _f_score = f_score(outputs, labels) 239 | 240 | #----------------------# 241 | # 反向传播 242 | #----------------------# 243 | scaler.scale(loss).backward() 244 | scaler.step(optimizer) 245 | scaler.update() 246 | 247 | total_loss += loss.item() 248 | total_f_score += _f_score.item() 249 | 250 | if local_rank == 0: 251 | pbar.set_postfix(**{'total_loss': total_loss / (iteration + 1), 252 | 'f_score' : total_f_score / (iteration + 1), 253 | 'lr' : get_lr(optimizer)}) 254 | pbar.update(1) 255 | 256 | if local_rank == 0: 257 | pbar.close() 258 | loss_history.append_loss(epoch + 1, total_loss/ epoch_step) 259 | print('Epoch:'+ str(epoch + 1) + '/' + str(Epoch)) 260 | print('Total Loss: %.3f' % (total_loss / epoch_step)) 261 | 262 | #-----------------------------------------------# 263 | # 保存权值 264 | #-----------------------------------------------# 265 | if (epoch + 1) % save_period == 0 or epoch + 1 == Epoch: 266 | torch.save(model.state_dict(), os.path.join(save_dir, 'ep%03d-loss%.3f.pth'%((epoch + 1), total_loss / epoch_step))) 267 | 268 | if len(loss_history.losses) <= 1 or (total_loss / epoch_step) <= min(loss_history.losses): 269 | print('Save best model to best_epoch_weights.pth') 270 | torch.save(model.state_dict(), os.path.join(save_dir, "best_epoch_weights.pth")) 271 | 272 | torch.save(model.state_dict(), os.path.join(save_dir, "last_epoch_weights.pth")) -------------------------------------------------------------------------------- /unet.py: -------------------------------------------------------------------------------- 1 | import colorsys 2 | import copy 3 | import time 4 | 5 | import cv2 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | from PIL import Image 10 | from torch import nn 11 | 12 | from nets.unet import Unet as unet 13 | from utils.utils import cvtColor, preprocess_input, resize_image, show_config 14 | 15 | 16 | #--------------------------------------------# 17 | # 使用自己训练好的模型预测需要修改2个参数 18 | # model_path和num_classes都需要修改! 19 | # 如果出现shape不匹配 20 | # 一定要注意训练时的model_path和num_classes数的修改 21 | #--------------------------------------------# 22 | class Unet(object): 23 | _defaults = { 24 | #-------------------------------------------------------------------# 25 | # model_path指向logs文件夹下的权值文件 26 | # 训练好后logs文件夹下存在多个权值文件,选择验证集损失较低的即可。 27 | # 验证集损失较低不代表miou较高,仅代表该权值在验证集上泛化性能较好。 28 | #-------------------------------------------------------------------# 29 | "model_path" : '../model_data/best_epoch_weights.pth', 30 | #--------------------------------# 31 | # 所需要区分的类的个数+1 32 | #--------------------------------# 33 | "num_classes" : 5, 34 | #--------------------------------# 35 | # 所使用的的主干网络:vgg、resnet50 36 | #--------------------------------# 37 | "backbone" : "resnet50", 38 | #--------------------------------# 39 | # 输入图片的大小 40 | #--------------------------------# 41 | "input_shape" : [512, 512], 42 | #-------------------------------------------------# 43 | # mix_type参数用于控制检测结果的可视化方式 44 | # 45 | # mix_type = 0的时候代表原图与生成的图进行混合 46 | # mix_type = 1的时候代表仅保留生成的图 47 | # mix_type = 2的时候代表仅扣去背景,仅保留原图中的目标 48 | #-------------------------------------------------# 49 | "mix_type" : 0, 50 | #--------------------------------# 51 | # 是否使用Cuda 52 | # 没有GPU可以设置成False 53 | #--------------------------------# 54 | "cuda" : False, 55 | } 56 | 57 | #---------------------------------------------------# 58 | # 初始化UNET 59 | #---------------------------------------------------# 60 | def __init__(self, **kwargs): 61 | self.__dict__.update(self._defaults) 62 | for name, value in kwargs.items(): 63 | setattr(self, name, value) 64 | #---------------------------------------------------# 65 | # 画框设置不同的颜色 66 | #---------------------------------------------------# 67 | if self.num_classes <= 21: 68 | self.colors = [ (255, 255, 255), (0,0,0), (0, 128, 0), (128, 128, 0), (0, 0, 128), (128, 0, 128), (0, 128, 128), 69 | (128, 128, 128), (64, 0, 0), (192, 0, 0), (64, 128, 0), (192, 128, 0), (64, 0, 128), (192, 0, 128), 70 | (64, 128, 128), (192, 128, 128), (0, 64, 0), (128, 64, 0), (0, 192, 0), (128, 192, 0), (0, 64, 128), 71 | (128, 64, 12)] 72 | else: 73 | hsv_tuples = [(x / self.num_classes, 1., 1.) for x in range(self.num_classes)] 74 | self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples)) 75 | self.colors = list(map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), self.colors)) 76 | #---------------------------------------------------# 77 | # 获得模型 78 | #---------------------------------------------------# 79 | self.generate() 80 | 81 | show_config(**self._defaults) 82 | 83 | #---------------------------------------------------# 84 | # 获得所有的分类 85 | #---------------------------------------------------# 86 | def generate(self, onnx=False): 87 | self.net = unet(num_classes = self.num_classes, backbone=self.backbone) 88 | 89 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 90 | self.net.load_state_dict(torch.load(self.model_path, map_location=device)) 91 | self.net = self.net.eval() 92 | print('{} model, and classes loaded.'.format(self.model_path)) 93 | if not onnx: 94 | if self.cuda: 95 | self.net = nn.DataParallel(self.net) 96 | self.net = self.net.cuda() 97 | 98 | #---------------------------------------------------# 99 | # 检测图片 100 | #---------------------------------------------------# 101 | def detect_image(self, image, count=False, name_classes=None): 102 | #---------------------------------------------------------# 103 | # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。 104 | # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB 105 | #---------------------------------------------------------# 106 | image = cvtColor(image) 107 | #---------------------------------------------------# 108 | # 对输入图像进行一个备份,后面用于绘图 109 | #---------------------------------------------------# 110 | old_img = copy.deepcopy(image) 111 | orininal_h = np.array(image).shape[0] 112 | orininal_w = np.array(image).shape[1] 113 | #---------------------------------------------------------# 114 | # 给图像增加灰条,实现不失真的resize 115 | # 也可以直接resize进行识别 116 | #---------------------------------------------------------# 117 | image_data, nw, nh = resize_image(image, (self.input_shape[1],self.input_shape[0])) 118 | #---------------------------------------------------------# 119 | # 添加上batch_size维度 120 | #---------------------------------------------------------# 121 | image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, np.float32)), (2, 0, 1)), 0) 122 | 123 | with torch.no_grad(): 124 | images = torch.from_numpy(image_data) 125 | if self.cuda: 126 | images = images.cuda() 127 | 128 | #---------------------------------------------------# 129 | # 图片传入网络进行预测 130 | #---------------------------------------------------# 131 | pr = self.net(images)[0] 132 | #---------------------------------------------------# 133 | # 取出每一个像素点的种类 134 | #---------------------------------------------------# 135 | pr = F.softmax(pr.permute(1,2,0),dim = -1).cpu().numpy() 136 | #--------------------------------------# 137 | # 将灰条部分截取掉 138 | #--------------------------------------# 139 | pr = pr[int((self.input_shape[0] - nh) // 2) : int((self.input_shape[0] - nh) // 2 + nh), \ 140 | int((self.input_shape[1] - nw) // 2) : int((self.input_shape[1] - nw) // 2 + nw)] 141 | #---------------------------------------------------# 142 | # 进行图片的resize 143 | #---------------------------------------------------# 144 | pr = cv2.resize(pr, (orininal_w, orininal_h), interpolation = cv2.INTER_LINEAR) 145 | #---------------------------------------------------# 146 | # 取出每一个像素点的种类 147 | #---------------------------------------------------# 148 | pr = pr.argmax(axis=-1) 149 | 150 | #---------------------------------------------------------# 151 | # 计数 152 | #---------------------------------------------------------# 153 | if count: 154 | classes_nums = np.zeros([self.num_classes]) 155 | total_points_num = orininal_h * orininal_w 156 | print('-' * 63) 157 | print("|%25s | %15s | %15s|"%("Key", "Value", "Ratio")) 158 | print('-' * 63) 159 | for i in range(self.num_classes): 160 | num = np.sum(pr == i) 161 | ratio = num / total_points_num * 100 162 | if num > 0: 163 | print("|%25s | %15s | %14.2f%%|"%(str(name_classes[i]), str(num), ratio)) 164 | print('-' * 63) 165 | classes_nums[i] = num 166 | print("classes_nums:", classes_nums) 167 | 168 | if self.mix_type == 0: 169 | # seg_img = np.zeros((np.shape(pr)[0], np.shape(pr)[1], 3)) 170 | # for c in range(self.num_classes): 171 | # seg_img[:, :, 0] += ((pr[:, :] == c ) * self.colors[c][0]).astype('uint8') 172 | # seg_img[:, :, 1] += ((pr[:, :] == c ) * self.colors[c][1]).astype('uint8') 173 | # seg_img[:, :, 2] += ((pr[:, :] == c ) * self.colors[c][2]).astype('uint8') 174 | seg_img = np.reshape(np.array(self.colors, np.uint8)[np.reshape(pr, [-1])], [orininal_h, orininal_w, -1]) 175 | #------------------------------------------------# 176 | # 将新图片转换成Image的形式 177 | #------------------------------------------------# 178 | image = Image.fromarray(np.uint8(seg_img)) 179 | #------------------------------------------------# 180 | # 将新图与原图及进行混合 181 | #------------------------------------------------# 182 | image = Image.blend(old_img, image, 0.7) 183 | 184 | elif self.mix_type == 1: 185 | # seg_img = np.zeros((np.shape(pr)[0], np.shape(pr)[1], 3)) 186 | # for c in range(self.num_classes): 187 | # seg_img[:, :, 0] += ((pr[:, :] == c ) * self.colors[c][0]).astype('uint8') 188 | # seg_img[:, :, 1] += ((pr[:, :] == c ) * self.colors[c][1]).astype('uint8') 189 | # seg_img[:, :, 2] += ((pr[:, :] == c ) * self.colors[c][2]).astype('uint8') 190 | seg_img = np.reshape(np.array(self.colors, np.uint8)[np.reshape(pr, [-1])], [orininal_h, orininal_w, -1]) 191 | #------------------------------------------------# 192 | # 将新图片转换成Image的形式 193 | #------------------------------------------------# 194 | image = Image.fromarray(np.uint8(seg_img)) 195 | 196 | elif self.mix_type == 2: 197 | seg_img = (np.expand_dims(pr != 0, -1) * np.array(old_img, np.float32)).astype('uint8') 198 | #------------------------------------------------# 199 | # 将新图片转换成Image的形式 200 | #------------------------------------------------# 201 | image = Image.fromarray(np.uint8(seg_img)) 202 | 203 | return image 204 | 205 | def get_FPS(self, image, test_interval): 206 | #---------------------------------------------------------# 207 | # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。 208 | # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB 209 | #---------------------------------------------------------# 210 | image = cvtColor(image) 211 | #---------------------------------------------------------# 212 | # 给图像增加灰条,实现不失真的resize 213 | # 也可以直接resize进行识别 214 | #---------------------------------------------------------# 215 | image_data, nw, nh = resize_image(image, (self.input_shape[1],self.input_shape[0])) 216 | #---------------------------------------------------------# 217 | # 添加上batch_size维度 218 | #---------------------------------------------------------# 219 | image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, np.float32)), (2, 0, 1)), 0) 220 | 221 | with torch.no_grad(): 222 | images = torch.from_numpy(image_data) 223 | if self.cuda: 224 | images = images.cuda() 225 | 226 | #---------------------------------------------------# 227 | # 图片传入网络进行预测 228 | #---------------------------------------------------# 229 | pr = self.net(images)[0] 230 | #---------------------------------------------------# 231 | # 取出每一个像素点的种类 232 | #---------------------------------------------------# 233 | pr = F.softmax(pr.permute(1,2,0),dim = -1).cpu().numpy().argmax(axis=-1) 234 | #--------------------------------------# 235 | # 将灰条部分截取掉 236 | #--------------------------------------# 237 | pr = pr[int((self.input_shape[0] - nh) // 2) : int((self.input_shape[0] - nh) // 2 + nh), \ 238 | int((self.input_shape[1] - nw) // 2) : int((self.input_shape[1] - nw) // 2 + nw)] 239 | 240 | t1 = time.time() 241 | for _ in range(test_interval): 242 | with torch.no_grad(): 243 | #---------------------------------------------------# 244 | # 图片传入网络进行预测 245 | #---------------------------------------------------# 246 | pr = self.net(images)[0] 247 | #---------------------------------------------------# 248 | # 取出每一个像素点的种类 249 | #---------------------------------------------------# 250 | pr = F.softmax(pr.permute(1,2,0),dim = -1).cpu().numpy().argmax(axis=-1) 251 | #--------------------------------------# 252 | # 将灰条部分截取掉 253 | #--------------------------------------# 254 | pr = pr[int((self.input_shape[0] - nh) // 2) : int((self.input_shape[0] - nh) // 2 + nh), \ 255 | int((self.input_shape[1] - nw) // 2) : int((self.input_shape[1] - nw) // 2 + nw)] 256 | t2 = time.time() 257 | tact_time = (t2 - t1) / test_interval 258 | return tact_time 259 | 260 | def convert_to_onnx(self, simplify, model_path): 261 | import onnx 262 | self.generate(onnx=True) 263 | 264 | im = torch.zeros(1, 3, *self.input_shape).to('cpu') # image size(1, 3, 512, 512) BCHW 265 | input_layer_names = ["images"] 266 | output_layer_names = ["output"] 267 | 268 | # Export the model 269 | print(f'Starting export with onnx {onnx.__version__}.') 270 | torch.onnx.export(self.net, 271 | im, 272 | f = model_path, 273 | verbose = False, 274 | opset_version = 12, 275 | training = torch.onnx.TrainingMode.EVAL, 276 | do_constant_folding = True, 277 | input_names = input_layer_names, 278 | output_names = output_layer_names, 279 | dynamic_axes = None) 280 | 281 | # Checks 282 | model_onnx = onnx.load(model_path) # load onnx model 283 | onnx.checker.check_model(model_onnx) # check onnx model 284 | 285 | # Simplify onnx 286 | if simplify: 287 | import onnxsim 288 | print(f'Simplifying with onnx-simplifier {onnxsim.__version__}.') 289 | model_onnx, check = onnxsim.simplify( 290 | model_onnx, 291 | dynamic_input_shape=False, 292 | input_shapes=None) 293 | assert check, 'assert check failed' 294 | onnx.save(model_onnx, model_path) 295 | 296 | print('Onnx model save as {}'.format(model_path)) 297 | 298 | def get_miou_png(self, image): 299 | #---------------------------------------------------------# 300 | # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。 301 | # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB 302 | #---------------------------------------------------------# 303 | image = cvtColor(image) 304 | orininal_h = np.array(image).shape[0] 305 | orininal_w = np.array(image).shape[1] 306 | #---------------------------------------------------------# 307 | # 给图像增加灰条,实现不失真的resize 308 | # 也可以直接resize进行识别 309 | #---------------------------------------------------------# 310 | image_data, nw, nh = resize_image(image, (self.input_shape[1],self.input_shape[0])) 311 | #---------------------------------------------------------# 312 | # 添加上batch_size维度 313 | #---------------------------------------------------------# 314 | image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, np.float32)), (2, 0, 1)), 0) 315 | 316 | with torch.no_grad(): 317 | images = torch.from_numpy(image_data) 318 | if self.cuda: 319 | images = images.cuda() 320 | 321 | #---------------------------------------------------# 322 | # 图片传入网络进行预测 323 | #---------------------------------------------------# 324 | pr = self.net(images)[0] 325 | #---------------------------------------------------# 326 | # 取出每一个像素点的种类 327 | #---------------------------------------------------# 328 | pr = F.softmax(pr.permute(1,2,0),dim = -1).cpu().numpy() 329 | #--------------------------------------# 330 | # 将灰条部分截取掉 331 | #--------------------------------------# 332 | pr = pr[int((self.input_shape[0] - nh) // 2) : int((self.input_shape[0] - nh) // 2 + nh), \ 333 | int((self.input_shape[1] - nw) // 2) : int((self.input_shape[1] - nw) // 2 + nw)] 334 | #---------------------------------------------------# 335 | # 进行图片的resize 336 | #---------------------------------------------------# 337 | pr = cv2.resize(pr, (orininal_w, orininal_h), interpolation = cv2.INTER_LINEAR) 338 | #---------------------------------------------------# 339 | # 取出每一个像素点的种类 340 | #---------------------------------------------------# 341 | pr = pr.argmax(axis=-1) 342 | 343 | image = Image.fromarray(np.uint8(pr)) 344 | return image 345 | -------------------------------------------------------------------------------- /train_medical.py: -------------------------------------------------------------------------------- 1 | import os 2 | import datetime 3 | 4 | import numpy as np 5 | import torch 6 | import torch.backends.cudnn as cudnn 7 | import torch.distributed as dist 8 | import torch.optim as optim 9 | from torch.utils.data import DataLoader 10 | 11 | from nets.unet import Unet 12 | from nets.unet_training import get_lr_scheduler, set_optimizer_lr, weights_init 13 | from utils.callbacks import LossHistory 14 | from utils.dataloader_medical import UnetDataset, unet_dataset_collate 15 | from utils.utils import download_weights, show_config 16 | from utils.utils_fit import fit_one_epoch_no_val 17 | 18 | ''' 19 | 训练自己的语义分割模型一定需要注意以下几点: 20 | 1、该数据集是我根据网上找到的医药数据集特殊建立的训练文件,只是一个例子,用于展示数据集不是voc格式时要如何进行训练。 21 | 22 | 不可以计算miou等性能指标。只用于观看医药数据集的训练效果。 23 | 不可以计算miou等性能指标。 24 | 不可以计算miou等性能指标。 25 | 26 | 如果大家有自己的医药数据集需要训练,可以分为两种情况: 27 | a、没有标签的医药数据集: 28 | 请按照视频里面的数据集标注教程,首先利用labelme标注图片,转换成VOC格式后利用train.py进行训练。 29 | b、有标签的医药数据集: 30 | 将文件的标签格式进行转换,标签的每个像素点的值就是这个像素点所属的种类。 31 | 因此数据集的标签需要改成,背景的像素点值为0,目标的像素点值为1。 32 | 参考:https://github.com/bubbliiiing/segmentation-format-fix 33 | 34 | 2、损失值的大小用于判断是否收敛,比较重要的是有收敛的趋势,即验证集损失不断下降,如果验证集损失基本上不改变的话,模型基本上就收敛了。 35 | 损失值的具体大小并没有什么意义,大和小只在于损失的计算方式,并不是接近于0才好。如果想要让损失好看点,可以直接到对应的损失函数里面除上10000。 36 | 训练过程中的损失值会保存在logs文件夹下的loss_%Y_%m_%d_%H_%M_%S文件夹中 37 | 38 | 3、训练好的权值文件保存在logs文件夹中,每个训练世代(Epoch)包含若干训练步长(Step),每个训练步长(Step)进行一次梯度下降。 39 | 如果只是训练了几个Step是不会保存的,Epoch和Step的概念要捋清楚一下。 40 | ''' 41 | if __name__ == "__main__": 42 | #---------------------------------# 43 | # Cuda 是否使用Cuda 44 | # 没有GPU可以设置成False 45 | #---------------------------------# 46 | Cuda = True 47 | #---------------------------------------------------------------------# 48 | # distributed 用于指定是否使用单机多卡分布式运行 49 | # 终端指令仅支持Ubuntu。CUDA_VISIBLE_DEVICES用于在Ubuntu下指定显卡。 50 | # Windows系统下默认使用DP模式调用所有显卡,不支持DDP。 51 | # DP模式: 52 | # 设置 distributed = False 53 | # 在终端中输入 CUDA_VISIBLE_DEVICES=0,1 python train_medical.py 54 | # DDP模式: 55 | # 设置 distributed = True 56 | # 在终端中输入 CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 train_medical.py 57 | #---------------------------------------------------------------------# 58 | distributed = False 59 | #---------------------------------------------------------------------# 60 | # sync_bn 是否使用sync_bn,DDP模式多卡可用 61 | #---------------------------------------------------------------------# 62 | sync_bn = False 63 | #---------------------------------------------------------------------# 64 | # fp16 是否使用混合精度训练 65 | # 可减少约一半的显存、需要pytorch1.7.1以上 66 | #---------------------------------------------------------------------# 67 | fp16 = False 68 | #-----------------------------------------------------# 69 | # num_classes 训练自己的数据集必须要修改的 70 | # 自己需要的分类个数+1,如2+1 71 | #-----------------------------------------------------# 72 | num_classes = 4+1 73 | #-----------------------------------------------------# 74 | # 主干网络选择 75 | # vgg 76 | # resnet50 77 | #-----------------------------------------------------# 78 | backbone = "resnet50" 79 | #----------------------------------------------------------------------------------------------------------------------------# 80 | # pretrained 是否使用主干网络的预训练权重,此处使用的是主干的权重,因此是在模型构建的时候进行加载的。 81 | # 如果设置了model_path,则主干的权值无需加载,pretrained的值无意义。 82 | # 如果不设置model_path,pretrained = True,此时仅加载主干开始训练。 83 | # 如果不设置model_path,pretrained = False,Freeze_Train = Fasle,此时从0开始训练,且没有冻结主干的过程。 84 | #----------------------------------------------------------------------------------------------------------------------------# 85 | pretrained = True 86 | #----------------------------------------------------------------------------------------------------------------------------# 87 | # 权值文件的下载请看README,可以通过网盘下载。模型的 预训练权重 对不同数据集是通用的,因为特征是通用的。 88 | # 模型的 预训练权重 比较重要的部分是 主干特征提取网络的权值部分,用于进行特征提取。 89 | # 预训练权重对于99%的情况都必须要用,不用的话主干部分的权值太过随机,特征提取效果不明显,网络训练的结果也不会好 90 | # 训练自己的数据集时提示维度不匹配正常,预测的东西都不一样了自然维度不匹配 91 | # 92 | # 如果训练过程中存在中断训练的操作,可以将model_path设置成logs文件夹下的权值文件,将已经训练了一部分的权值再次载入。 93 | # 同时修改下方的 冻结阶段 或者 解冻阶段 的参数,来保证模型epoch的连续性。 94 | # 95 | # 当model_path = ''的时候不加载整个模型的权值。 96 | # 97 | # 此处使用的是整个模型的权重,因此是在train.py进行加载的,pretrain不影响此处的权值加载。 98 | # 如果想要让模型从主干的预训练权值开始训练,则设置model_path = '',pretrain = True,此时仅加载主干。 99 | # 如果想要让模型从0开始训练,则设置model_path = '',pretrain = Fasle,Freeze_Train = Fasle,此时从0开始训练,且没有冻结主干的过程。 100 | # 101 | # 一般来讲,网络从0开始的训练效果会很差,因为权值太过随机,特征提取效果不明显,因此非常、非常、非常不建议大家从0开始训练! 102 | # 如果一定要从0开始,可以了解imagenet数据集,首先训练分类模型,获得网络的主干部分权值,分类模型的 主干部分 和该模型通用,基于此进行训练。 103 | #----------------------------------------------------------------------------------------------------------------------------# 104 | model_path = "model_data/unet_resnet_voc.pth" 105 | #-----------------------------------------------------# 106 | # input_shape 输入图片的大小,32的倍数 107 | #-----------------------------------------------------# 108 | input_shape = [512, 512] 109 | 110 | #----------------------------------------------------------------------------------------------------------------------------# 111 | # 训练分为两个阶段,分别是冻结阶段和解冻阶段。设置冻结阶段是为了满足机器性能不足的同学的训练需求。 112 | # 冻结训练需要的显存较小,显卡非常差的情况下,可设置Freeze_Epoch等于UnFreeze_Epoch,此时仅仅进行冻结训练。 113 | # 114 | # 在此提供若干参数设置建议,各位训练者根据自己的需求进行灵活调整: 115 | # (一)从整个模型的预训练权重开始训练: 116 | # Adam: 117 | # Init_Epoch = 0,Freeze_Epoch = 50,UnFreeze_Epoch = 100,Freeze_Train = True,optimizer_type = 'adam',Init_lr = 1e-4,weight_decay = 0。(冻结) 118 | # Init_Epoch = 0,UnFreeze_Epoch = 100,Freeze_Train = False,optimizer_type = 'adam',Init_lr = 1e-4,weight_decay = 0。(不冻结) 119 | # SGD: 120 | # Init_Epoch = 0,Freeze_Epoch = 50,UnFreeze_Epoch = 100,Freeze_Train = True,optimizer_type = 'sgd',Init_lr = 1e-2,weight_decay = 1e-4。(冻结) 121 | # Init_Epoch = 0,UnFreeze_Epoch = 100,Freeze_Train = False,optimizer_type = 'sgd',Init_lr = 1e-2,weight_decay = 1e-4。(不冻结) 122 | # 其中:UnFreeze_Epoch可以在100-300之间调整。 123 | # (二)从主干网络的预训练权重开始训练: 124 | # Adam: 125 | # Init_Epoch = 0,Freeze_Epoch = 50,UnFreeze_Epoch = 100,Freeze_Train = True,optimizer_type = 'adam',Init_lr = 1e-4,weight_decay = 0。(冻结) 126 | # Init_Epoch = 0,UnFreeze_Epoch = 100,Freeze_Train = False,optimizer_type = 'adam',Init_lr = 1e-4,weight_decay = 0。(不冻结) 127 | # SGD: 128 | # Init_Epoch = 0,Freeze_Epoch = 50,UnFreeze_Epoch = 120,Freeze_Train = True,optimizer_type = 'sgd',Init_lr = 1e-2,weight_decay = 1e-4。(冻结) 129 | # Init_Epoch = 0,UnFreeze_Epoch = 120,Freeze_Train = False,optimizer_type = 'sgd',Init_lr = 1e-2,weight_decay = 1e-4。(不冻结) 130 | # 其中:由于从主干网络的预训练权重开始训练,主干的权值不一定适合语义分割,需要更多的训练跳出局部最优解。 131 | # UnFreeze_Epoch可以在120-300之间调整。 132 | # Adam相较于SGD收敛的快一些。因此UnFreeze_Epoch理论上可以小一点,但依然推荐更多的Epoch。 133 | # (三)batch_size的设置: 134 | # 在显卡能够接受的范围内,以大为好。显存不足与数据集大小无关,提示显存不足(OOM或者CUDA out of memory)请调小batch_size。 135 | # 由于resnet50中有BatchNormalization层 136 | # 当主干为resnet50的时候batch_size不可为1 137 | # 正常情况下Freeze_batch_size建议为Unfreeze_batch_size的1-2倍。不建议设置的差距过大,因为关系到学习率的自动调整。 138 | #----------------------------------------------------------------------------------------------------------------------------# 139 | #------------------------------------------------------------------# 140 | # 冻结阶段训练参数 141 | # 此时模型的主干被冻结了,特征提取网络不发生改变 142 | # 占用的显存较小,仅对网络进行微调 143 | # Init_Epoch 模型当前开始的训练世代,其值可以大于Freeze_Epoch,如设置: 144 | # Init_Epoch = 60、Freeze_Epoch = 50、UnFreeze_Epoch = 100 145 | # 会跳过冻结阶段,直接从60代开始,并调整对应的学习率。 146 | # (断点续练时使用) 147 | # Freeze_Epoch 模型冻结训练的Freeze_Epoch 148 | # (当Freeze_Train=False时失效) 149 | # Freeze_batch_size 模型冻结训练的batch_size 150 | # (当Freeze_Train=False时失效) 151 | #------------------------------------------------------------------# 152 | Init_Epoch = 0 153 | Freeze_Epoch = 10 154 | Freeze_batch_size = 1 155 | #------------------------------------------------------------------# 156 | # 解冻阶段训练参数 157 | # 此时模型的主干不被冻结了,特征提取网络会发生改变 158 | # 占用的显存较大,网络所有的参数都会发生改变 159 | # UnFreeze_Epoch 模型总共训练的epoch 160 | # Unfreeze_batch_size 模型在解冻后的batch_size 161 | #------------------------------------------------------------------# 162 | UnFreeze_Epoch = 20 163 | Unfreeze_batch_size = 1 164 | #------------------------------------------------------------------# 165 | # Freeze_Train 是否进行冻结训练 166 | # 默认先冻结主干训练后解冻训练。 167 | #------------------------------------------------------------------# 168 | Freeze_Train = True 169 | 170 | #------------------------------------------------------------------# 171 | # 其它训练参数:学习率、优化器、学习率下降有关 172 | #------------------------------------------------------------------# 173 | #------------------------------------------------------------------# 174 | # Init_lr 模型的最大学习率 175 | # 当使用Adam优化器时建议设置 Init_lr=1e-4 176 | # 当使用SGD优化器时建议设置 Init_lr=1e-2 177 | # Min_lr 模型的最小学习率,默认为最大学习率的0.01 178 | #------------------------------------------------------------------# 179 | Init_lr = 1e-4 180 | Min_lr = Init_lr * 0.01 181 | #------------------------------------------------------------------# 182 | # optimizer_type 使用到的优化器种类,可选的有adam、sgd 183 | # 当使用Adam优化器时建议设置 Init_lr=1e-4 184 | # 当使用SGD优化器时建议设置 Init_lr=1e-2 185 | # momentum 优化器内部使用到的momentum参数 186 | # weight_decay 权值衰减,可防止过拟合 187 | # adam会导致weight_decay错误,使用adam时建议设置为0。 188 | #------------------------------------------------------------------# 189 | optimizer_type = "adam" 190 | momentum = 0.9 191 | weight_decay = 0 192 | #------------------------------------------------------------------# 193 | # lr_decay_type 使用到的学习率下降方式,可选的有'step'、'cos' 194 | #------------------------------------------------------------------# 195 | lr_decay_type = 'cos' 196 | #------------------------------------------------------------------# 197 | # save_period 多少个epoch保存一次权值 198 | #------------------------------------------------------------------# 199 | save_period = 5 200 | #------------------------------------------------------------------# 201 | # save_dir 权值与日志文件保存的文件夹 202 | #------------------------------------------------------------------# 203 | save_dir = 'logs' 204 | 205 | #------------------------------# 206 | # 数据集路径 207 | #------------------------------# 208 | VOCdevkit_path = 'NanKai' 209 | #------------------------------------------------------------------# 210 | # 建议选项: 211 | # 种类少(几类)时,设置为True 212 | # 种类多(十几类)时,如果batch_size比较大(10以上),那么设置为True 213 | # 种类多(十几类)时,如果batch_size比较小(10以下),那么设置为False 214 | #------------------------------------------------------------------# 215 | dice_loss = True 216 | #------------------------------------------------------------------# 217 | # 是否使用focal loss来防止正负样本不平衡 218 | #------------------------------------------------------------------# 219 | focal_loss = [1,3,3,3,3] 220 | #------------------------------------------------------------------# 221 | # 是否给不同种类赋予不同的损失权值,默认是平衡的。 222 | # 设置的话,注意设置成numpy形式的,长度和num_classes一样。 223 | # 如: 224 | # num_classes = 3 225 | # cls_weights = np.array([1, 2, 3], np.float32) 226 | #------------------------------------------------------------------# 227 | cls_weights = np.ones([num_classes], np.float32) 228 | #------------------------------------------------------------------# 229 | # num_workers 用于设置是否使用多线程读取数据,1代表关闭多线程 230 | # 开启后会加快数据读取速度,但是会占用更多内存 231 | # keras里开启多线程有些时候速度反而慢了许多 232 | # 在IO为瓶颈的时候再开启多线程,即GPU运算速度远大于读取图片的速度。 233 | #------------------------------------------------------------------# 234 | num_workers = 4 235 | 236 | #------------------------------------------------------# 237 | # 设置用到的显卡 238 | #------------------------------------------------------# 239 | ngpus_per_node = torch.cuda.device_count() 240 | if distributed: 241 | dist.init_process_group(backend="nccl") 242 | local_rank = int(os.environ["LOCAL_RANK"]) 243 | rank = int(os.environ["RANK"]) 244 | device = torch.device("cuda", local_rank) 245 | if local_rank == 0: 246 | print(f"[{os.getpid()}] (rank = {rank}, local_rank = {local_rank}) training...") 247 | print("Gpu Device Count : ", ngpus_per_node) 248 | else: 249 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 250 | local_rank = 0 251 | 252 | #----------------------------------------------------# 253 | # 下载预训练权重 254 | #----------------------------------------------------# 255 | if pretrained: 256 | if distributed: 257 | if local_rank == 0: 258 | download_weights(backbone) 259 | dist.barrier() 260 | else: 261 | download_weights(backbone) 262 | 263 | model = Unet(num_classes=num_classes, pretrained=pretrained, backbone=backbone).train() 264 | if not pretrained: 265 | weights_init(model) 266 | if model_path != '': 267 | #------------------------------------------------------# 268 | # 权值文件请看README,百度网盘下载 269 | #------------------------------------------------------# 270 | if local_rank == 0: 271 | print('Load weights {}.'.format(model_path)) 272 | 273 | #------------------------------------------------------# 274 | # 根据预训练权重的Key和模型的Key进行加载 275 | #------------------------------------------------------# 276 | model_dict = model.state_dict() 277 | pretrained_dict = torch.load(model_path, map_location = device) 278 | load_key, no_load_key, temp_dict = [], [], {} 279 | for k, v in pretrained_dict.items(): 280 | if k in model_dict.keys() and np.shape(model_dict[k]) == np.shape(v): 281 | temp_dict[k] = v 282 | load_key.append(k) 283 | else: 284 | no_load_key.append(k) 285 | model_dict.update(temp_dict) 286 | model.load_state_dict(model_dict) 287 | #------------------------------------------------------# 288 | # 显示没有匹配上的Key 289 | #------------------------------------------------------# 290 | if local_rank == 0: 291 | print("\nSuccessful Load Key:", str(load_key)[:500], "……\nSuccessful Load Key Num:", len(load_key)) 292 | print("\nFail To Load Key:", str(no_load_key)[:500], "……\nFail To Load Key num:", len(no_load_key)) 293 | print("\n\033[1;33;44m温馨提示,head部分没有载入是正常现象,Backbone部分没有载入是错误的。\033[0m") 294 | 295 | #----------------------# 296 | # 记录Loss 297 | #----------------------# 298 | if local_rank == 0: 299 | time_str = datetime.datetime.strftime(datetime.datetime.now(),'%Y_%m_%d_%H_%M_%S') 300 | log_dir = os.path.join(save_dir, "loss_" + str(time_str)) 301 | loss_history = LossHistory(log_dir, model, input_shape=input_shape, val_loss_flag = False) 302 | else: 303 | loss_history = None 304 | 305 | #------------------------------------------------------------------# 306 | # torch 1.2不支持amp,建议使用torch 1.7.1及以上正确使用fp16 307 | # 因此torch1.2这里显示"could not be resolve" 308 | #------------------------------------------------------------------# 309 | if fp16: 310 | from torch.cuda.amp import GradScaler as GradScaler 311 | scaler = GradScaler() 312 | else: 313 | scaler = None 314 | 315 | model_train = model.train() 316 | #----------------------------# 317 | # 多卡同步Bn 318 | #----------------------------# 319 | if sync_bn and ngpus_per_node > 1 and distributed: 320 | model_train = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model_train) 321 | elif sync_bn: 322 | print("Sync_bn is not support in one gpu or not distributed.") 323 | 324 | if Cuda: 325 | if distributed: 326 | #----------------------------# 327 | # 多卡平行运行 328 | #----------------------------# 329 | model_train = model_train.cuda(local_rank) 330 | model_train = torch.nn.parallel.DistributedDataParallel(model_train, device_ids=[local_rank], find_unused_parameters=True) 331 | else: 332 | model_train = torch.nn.DataParallel(model) 333 | cudnn.benchmark = True 334 | model_train = model_train.cuda() 335 | 336 | #---------------------------# 337 | # 读取数据集对应的txt 338 | #---------------------------# 339 | with open(os.path.join(VOCdevkit_path, "save/train.txt"),"r") as f: 340 | train_lines = f.readlines() 341 | num_train = len(train_lines) 342 | 343 | if local_rank == 0: 344 | show_config( 345 | num_classes = num_classes, backbone = backbone, model_path = model_path, input_shape = input_shape, \ 346 | Init_Epoch = Init_Epoch, Freeze_Epoch = Freeze_Epoch, UnFreeze_Epoch = UnFreeze_Epoch, Freeze_batch_size = Freeze_batch_size, Unfreeze_batch_size = Unfreeze_batch_size, Freeze_Train = Freeze_Train, \ 347 | Init_lr = Init_lr, Min_lr = Min_lr, optimizer_type = optimizer_type, momentum = momentum, lr_decay_type = lr_decay_type, \ 348 | save_period = save_period, save_dir = save_dir, num_workers = num_workers, num_train = num_train 349 | ) 350 | #------------------------------------------------------# 351 | # 主干特征提取网络特征通用,冻结训练可以加快训练速度 352 | # 也可以在训练初期防止权值被破坏。 353 | # Init_Epoch为起始世代 354 | # Interval_Epoch为冻结训练的世代 355 | # Epoch总训练世代 356 | # 提示OOM或者显存不足请调小Batch_size 357 | #------------------------------------------------------# 358 | if True: 359 | UnFreeze_flag = False 360 | #------------------------------------# 361 | # 冻结一定部分训练 362 | #------------------------------------# 363 | if Freeze_Train: 364 | model.freeze_backbone() 365 | 366 | #-------------------------------------------------------------------# 367 | # 如果不冻结训练的话,直接设置batch_size为Unfreeze_batch_size 368 | #-------------------------------------------------------------------# 369 | batch_size = Freeze_batch_size if Freeze_Train else Unfreeze_batch_size 370 | 371 | #-------------------------------------------------------------------# 372 | # 判断当前batch_size,自适应调整学习率 373 | #-------------------------------------------------------------------# 374 | nbs = 16 375 | lr_limit_max = 1e-4 if optimizer_type == 'adam' else 1e-1 376 | lr_limit_min = 1e-4 if optimizer_type == 'adam' else 5e-4 377 | Init_lr_fit = min(max(batch_size / nbs * Init_lr, lr_limit_min), lr_limit_max) 378 | Min_lr_fit = min(max(batch_size / nbs * Min_lr, lr_limit_min * 1e-2), lr_limit_max * 1e-2) 379 | 380 | #---------------------------------------# 381 | # 根据optimizer_type选择优化器 382 | #---------------------------------------# 383 | optimizer = { 384 | 'adam' : optim.Adam(model.parameters(), Init_lr_fit, betas = (momentum, 0.999), weight_decay = weight_decay), 385 | 'sgd' : optim.SGD(model.parameters(), Init_lr_fit, momentum = momentum, nesterov=True, weight_decay = weight_decay) 386 | }[optimizer_type] 387 | 388 | #---------------------------------------# 389 | # 获得学习率下降的公式 390 | #---------------------------------------# 391 | lr_scheduler_func = get_lr_scheduler(lr_decay_type, Init_lr_fit, Min_lr_fit, UnFreeze_Epoch) 392 | 393 | #---------------------------------------# 394 | # 判断每一个世代的长度 395 | #---------------------------------------# 396 | epoch_step = num_train // batch_size 397 | 398 | if epoch_step == 0: 399 | raise ValueError("数据集过小,无法继续进行训练,请扩充数据集。") 400 | 401 | train_dataset = UnetDataset(train_lines, input_shape, num_classes, True, VOCdevkit_path) 402 | 403 | if distributed: 404 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=True,) 405 | batch_size = batch_size // ngpus_per_node 406 | shuffle = False 407 | else: 408 | train_sampler = None 409 | shuffle = True 410 | 411 | gen = DataLoader(train_dataset, shuffle = shuffle, batch_size = batch_size, num_workers = num_workers, pin_memory=True, 412 | drop_last = True, collate_fn = unet_dataset_collate, sampler=train_sampler) 413 | 414 | #---------------------------------------# 415 | # 开始模型训练 416 | #---------------------------------------# 417 | for epoch in range(Init_Epoch, UnFreeze_Epoch): 418 | #---------------------------------------# 419 | # 如果模型有冻结学习部分 420 | # 则解冻,并设置参数 421 | #---------------------------------------# 422 | if epoch >= Freeze_Epoch and not UnFreeze_flag and Freeze_Train: 423 | batch_size = Unfreeze_batch_size 424 | 425 | #-------------------------------------------------------------------# 426 | # 判断当前batch_size,自适应调整学习率 427 | #-------------------------------------------------------------------# 428 | nbs = 16 429 | lr_limit_max = 1e-4 if optimizer_type == 'adam' else 1e-1 430 | lr_limit_min = 1e-4 if optimizer_type == 'adam' else 5e-4 431 | Init_lr_fit = min(max(batch_size / nbs * Init_lr, lr_limit_min), lr_limit_max) 432 | Min_lr_fit = min(max(batch_size / nbs * Min_lr, lr_limit_min * 1e-2), lr_limit_max * 1e-2) 433 | #---------------------------------------# 434 | # 获得学习率下降的公式 435 | #---------------------------------------# 436 | lr_scheduler_func = get_lr_scheduler(lr_decay_type, Init_lr_fit, Min_lr_fit, UnFreeze_Epoch) 437 | 438 | model.unfreeze_backbone() 439 | 440 | epoch_step = num_train // batch_size 441 | 442 | if epoch_step == 0: 443 | raise ValueError("数据集过小,无法继续进行训练,请扩充数据集。") 444 | 445 | if distributed: 446 | batch_size = batch_size // ngpus_per_node 447 | 448 | gen = DataLoader(train_dataset, shuffle = shuffle, batch_size = batch_size, num_workers = num_workers, pin_memory=True, 449 | drop_last = True, collate_fn = unet_dataset_collate, sampler=train_sampler) 450 | 451 | UnFreeze_flag = True 452 | 453 | if distributed: 454 | train_sampler.set_epoch(epoch) 455 | 456 | set_optimizer_lr(optimizer, lr_scheduler_func, epoch) 457 | 458 | fit_one_epoch_no_val(model_train, model, loss_history, optimizer, epoch, epoch_step, gen, UnFreeze_Epoch, Cuda, dice_loss, focal_loss, cls_weights, num_classes, fp16, scaler, save_period, save_dir, local_rank) 459 | 460 | if distributed: 461 | dist.barrier() 462 | 463 | if local_rank == 0: 464 | loss_history.writer.close() -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import datetime 3 | 4 | import numpy as np 5 | import torch 6 | import torch.backends.cudnn as cudnn 7 | import torch.distributed as dist 8 | import torch.optim as optim 9 | from torch.utils.data import DataLoader 10 | 11 | from nets.unet import Unet 12 | from nets.unet_training import get_lr_scheduler, set_optimizer_lr, weights_init 13 | from utils.callbacks import LossHistory, EvalCallback 14 | from utils.dataloader import UnetDataset, unet_dataset_collate 15 | from utils.utils import download_weights, show_config 16 | from utils.utils_fit import fit_one_epoch 17 | 18 | ''' 19 | 训练自己的语义分割模型一定需要注意以下几点: 20 | 1、训练前仔细检查自己的格式是否满足要求,该库要求数据集格式为VOC格式,需要准备好的内容有输入图片和标签 21 | 输入图片为.jpg图片,无需固定大小,传入训练前会自动进行resize。 22 | 灰度图会自动转成RGB图片进行训练,无需自己修改。 23 | 输入图片如果后缀非jpg,需要自己批量转成jpg后再开始训练。 24 | 25 | 标签为png图片,无需固定大小,传入训练前会自动进行resize。 26 | 由于许多同学的数据集是网络上下载的,标签格式并不符合,需要再度处理。一定要注意!标签的每个像素点的值就是这个像素点所属的种类。 27 | 网上常见的数据集总共对输入图片分两类,背景的像素点值为0,目标的像素点值为255。这样的数据集可以正常运行但是预测是没有效果的! 28 | 需要改成,背景的像素点值为0,目标的像素点值为1。 29 | 如果格式有误,参考:https://github.com/bubbliiiing/segmentation-format-fix 30 | 31 | 2、损失值的大小用于判断是否收敛,比较重要的是有收敛的趋势,即验证集损失不断下降,如果验证集损失基本上不改变的话,模型基本上就收敛了。 32 | 损失值的具体大小并没有什么意义,大和小只在于损失的计算方式,并不是接近于0才好。如果想要让损失好看点,可以直接到对应的损失函数里面除上10000。 33 | 训练过程中的损失值会保存在logs文件夹下的loss_%Y_%m_%d_%H_%M_%S文件夹中 34 | 35 | 3、训练好的权值文件保存在logs文件夹中,每个训练世代(Epoch)包含若干训练步长(Step),每个训练步长(Step)进行一次梯度下降。 36 | 如果只是训练了几个Step是不会保存的,Epoch和Step的概念要捋清楚一下。 37 | ''' 38 | if __name__ == "__main__": 39 | #---------------------------------# 40 | # Cuda 是否使用Cuda 41 | # 没有GPU可以设置成False 42 | #---------------------------------# 43 | Cuda = True 44 | #---------------------------------------------------------------------# 45 | # distributed 用于指定是否使用单机多卡分布式运行 46 | # 终端指令仅支持Ubuntu。CUDA_VISIBLE_DEVICES用于在Ubuntu下指定显卡。 47 | # Windows系统下默认使用DP模式调用所有显卡,不支持DDP。 48 | # DP模式: 49 | # 设置 distributed = False 50 | # 在终端中输入 CUDA_VISIBLE_DEVICES=0,1 python train.py 51 | # DDP模式: 52 | # 设置 distributed = True 53 | # 在终端中输入 CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 train.py 54 | #---------------------------------------------------------------------# 55 | distributed = False 56 | #---------------------------------------------------------------------# 57 | # sync_bn 是否使用sync_bn,DDP模式多卡可用 58 | #---------------------------------------------------------------------# 59 | sync_bn = False 60 | #---------------------------------------------------------------------# 61 | # fp16 是否使用混合精度训练 62 | # 可减少约一半的显存、需要pytorch1.7.1以上 63 | #---------------------------------------------------------------------# 64 | fp16 = False 65 | #-----------------------------------------------------# 66 | # num_classes 训练自己的数据集必须要修改的 67 | # 自己需要的分类个数+1,如2+1 68 | #-----------------------------------------------------# 69 | num_classes = 2 70 | #-----------------------------------------------------# 71 | # 主干网络选择 72 | # vgg 73 | # resnet50 74 | #-----------------------------------------------------# 75 | backbone = "vgg" 76 | #----------------------------------------------------------------------------------------------------------------------------# 77 | # pretrained 是否使用主干网络的预训练权重,此处使用的是主干的权重,因此是在模型构建的时候进行加载的。 78 | # 如果设置了model_path,则主干的权值无需加载,pretrained的值无意义。 79 | # 如果不设置model_path,pretrained = True,此时仅加载主干开始训练。 80 | # 如果不设置model_path,pretrained = False,Freeze_Train = Fasle,此时从0开始训练,且没有冻结主干的过程。 81 | #----------------------------------------------------------------------------------------------------------------------------# 82 | pretrained = False 83 | #----------------------------------------------------------------------------------------------------------------------------# 84 | # 权值文件的下载请看README,可以通过网盘下载。模型的 预训练权重 对不同数据集是通用的,因为特征是通用的。 85 | # 模型的 预训练权重 比较重要的部分是 主干特征提取网络的权值部分,用于进行特征提取。 86 | # 预训练权重对于99%的情况都必须要用,不用的话主干部分的权值太过随机,特征提取效果不明显,网络训练的结果也不会好 87 | # 训练自己的数据集时提示维度不匹配正常,预测的东西都不一样了自然维度不匹配 88 | # 89 | # 如果训练过程中存在中断训练的操作,可以将model_path设置成logs文件夹下的权值文件,将已经训练了一部分的权值再次载入。 90 | # 同时修改下方的 冻结阶段 或者 解冻阶段 的参数,来保证模型epoch的连续性。 91 | # 92 | # 当model_path = ''的时候不加载整个模型的权值。 93 | # 94 | # 此处使用的是整个模型的权重,因此是在train.py进行加载的,pretrain不影响此处的权值加载。 95 | # 如果想要让模型从主干的预训练权值开始训练,则设置model_path = '',pretrain = True,此时仅加载主干。 96 | # 如果想要让模型从0开始训练,则设置model_path = '',pretrain = Fasle,Freeze_Train = Fasle,此时从0开始训练,且没有冻结主干的过程。 97 | # 98 | # 一般来讲,网络从0开始的训练效果会很差,因为权值太过随机,特征提取效果不明显,因此非常、非常、非常不建议大家从0开始训练! 99 | # 如果一定要从0开始,可以了解imagenet数据集,首先训练分类模型,获得网络的主干部分权值,分类模型的 主干部分 和该模型通用,基于此进行训练。 100 | #----------------------------------------------------------------------------------------------------------------------------# 101 | model_path = "model_data/unet_vgg_voc.pth" 102 | #-----------------------------------------------------# 103 | # input_shape 输入图片的大小,32的倍数 104 | #-----------------------------------------------------# 105 | input_shape = [512, 512] 106 | 107 | #----------------------------------------------------------------------------------------------------------------------------# 108 | # 训练分为两个阶段,分别是冻结阶段和解冻阶段。设置冻结阶段是为了满足机器性能不足的同学的训练需求。 109 | # 冻结训练需要的显存较小,显卡非常差的情况下,可设置Freeze_Epoch等于UnFreeze_Epoch,此时仅仅进行冻结训练。 110 | # 111 | # 在此提供若干参数设置建议,各位训练者根据自己的需求进行灵活调整: 112 | # (一)从整个模型的预训练权重开始训练: 113 | # Adam: 114 | # Init_Epoch = 0,Freeze_Epoch = 50,UnFreeze_Epoch = 100,Freeze_Train = True,optimizer_type = 'adam',Init_lr = 1e-4,weight_decay = 0。(冻结) 115 | # Init_Epoch = 0,UnFreeze_Epoch = 100,Freeze_Train = False,optimizer_type = 'adam',Init_lr = 1e-4,weight_decay = 0。(不冻结) 116 | # SGD: 117 | # Init_Epoch = 0,Freeze_Epoch = 50,UnFreeze_Epoch = 100,Freeze_Train = True,optimizer_type = 'sgd',Init_lr = 1e-2,weight_decay = 1e-4。(冻结) 118 | # Init_Epoch = 0,UnFreeze_Epoch = 100,Freeze_Train = False,optimizer_type = 'sgd',Init_lr = 1e-2,weight_decay = 1e-4。(不冻结) 119 | # 其中:UnFreeze_Epoch可以在100-300之间调整。 120 | # (二)从主干网络的预训练权重开始训练: 121 | # Adam: 122 | # Init_Epoch = 0,Freeze_Epoch = 50,UnFreeze_Epoch = 100,Freeze_Train = True,optimizer_type = 'adam',Init_lr = 1e-4,weight_decay = 0。(冻结) 123 | # Init_Epoch = 0,UnFreeze_Epoch = 100,Freeze_Train = False,optimizer_type = 'adam',Init_lr = 1e-4,weight_decay = 0。(不冻结) 124 | # SGD: 125 | # Init_Epoch = 0,Freeze_Epoch = 50,UnFreeze_Epoch = 120,Freeze_Train = True,optimizer_type = 'sgd',Init_lr = 1e-2,weight_decay = 1e-4。(冻结) 126 | # Init_Epoch = 0,UnFreeze_Epoch = 120,Freeze_Train = False,optimizer_type = 'sgd',Init_lr = 1e-2,weight_decay = 1e-4。(不冻结) 127 | # 其中:由于从主干网络的预训练权重开始训练,主干的权值不一定适合语义分割,需要更多的训练跳出局部最优解。 128 | # UnFreeze_Epoch可以在120-300之间调整。 129 | # Adam相较于SGD收敛的快一些。因此UnFreeze_Epoch理论上可以小一点,但依然推荐更多的Epoch。 130 | # (三)batch_size的设置: 131 | # 在显卡能够接受的范围内,以大为好。显存不足与数据集大小无关,提示显存不足(OOM或者CUDA out of memory)请调小batch_size。 132 | # 由于resnet50中有BatchNormalization层 133 | # 当主干为resnet50的时候batch_size不可为1 134 | # 正常情况下Freeze_batch_size建议为Unfreeze_batch_size的1-2倍。不建议设置的差距过大,因为关系到学习率的自动调整。 135 | #----------------------------------------------------------------------------------------------------------------------------# 136 | #------------------------------------------------------------------# 137 | # 冻结阶段训练参数 138 | # 此时模型的主干被冻结了,特征提取网络不发生改变 139 | # 占用的显存较小,仅对网络进行微调 140 | # Init_Epoch 模型当前开始的训练世代,其值可以大于Freeze_Epoch,如设置: 141 | # Init_Epoch = 60、Freeze_Epoch = 50、UnFreeze_Epoch = 100 142 | # 会跳过冻结阶段,直接从60代开始,并调整对应的学习率。 143 | # (断点续练时使用) 144 | # Freeze_Epoch 模型冻结训练的Freeze_Epoch 145 | # (当Freeze_Train=False时失效) 146 | # Freeze_batch_size 模型冻结训练的batch_size 147 | # (当Freeze_Train=False时失效) 148 | #------------------------------------------------------------------# 149 | Init_Epoch = 0 150 | Freeze_Epoch = 100 151 | Freeze_batch_size = 2 152 | #------------------------------------------------------------------# 153 | # 解冻阶段训练参数 154 | # 此时模型的主干不被冻结了,特征提取网络会发生改变 155 | # 占用的显存较大,网络所有的参数都会发生改变 156 | # UnFreeze_Epoch 模型总共训练的epoch 157 | # Unfreeze_batch_size 模型在解冻后的batch_size 158 | #------------------------------------------------------------------# 159 | UnFreeze_Epoch = 300 160 | Unfreeze_batch_size = 2 161 | #------------------------------------------------------------------# 162 | # Freeze_Train 是否进行冻结训练 163 | # 默认先冻结主干训练后解冻训练。 164 | #------------------------------------------------------------------# 165 | Freeze_Train = True 166 | 167 | #------------------------------------------------------------------# 168 | # 其它训练参数:学习率、优化器、学习率下降有关 169 | #------------------------------------------------------------------# 170 | #------------------------------------------------------------------# 171 | # Init_lr 模型的最大学习率 172 | # 当使用Adam优化器时建议设置 Init_lr=1e-4 173 | # 当使用SGD优化器时建议设置 Init_lr=1e-2 174 | # Min_lr 模型的最小学习率,默认为最大学习率的0.01 175 | #------------------------------------------------------------------# 176 | Init_lr = 1e-4 177 | Min_lr = Init_lr * 0.01 178 | #------------------------------------------------------------------# 179 | # optimizer_type 使用到的优化器种类,可选的有adam、sgd 180 | # 当使用Adam优化器时建议设置 Init_lr=1e-4 181 | # 当使用SGD优化器时建议设置 Init_lr=1e-2 182 | # momentum 优化器内部使用到的momentum参数 183 | # weight_decay 权值衰减,可防止过拟合 184 | # adam会导致weight_decay错误,使用adam时建议设置为0。 185 | #------------------------------------------------------------------# 186 | optimizer_type = "adam" 187 | momentum = 0.9 188 | weight_decay = 0 189 | #------------------------------------------------------------------# 190 | # lr_decay_type 使用到的学习率下降方式,可选的有'step'、'cos' 191 | #------------------------------------------------------------------# 192 | lr_decay_type = 'cos' 193 | #------------------------------------------------------------------# 194 | # save_period 多少个epoch保存一次权值 195 | #------------------------------------------------------------------# 196 | save_period = 5 197 | #------------------------------------------------------------------# 198 | # save_dir 权值与日志文件保存的文件夹 199 | #------------------------------------------------------------------# 200 | save_dir = 'logs' 201 | #------------------------------------------------------------------# 202 | # eval_flag 是否在训练时进行评估,评估对象为验证集 203 | # eval_period 代表多少个epoch评估一次,不建议频繁的评估 204 | # 评估需要消耗较多的时间,频繁评估会导致训练非常慢 205 | # 此处获得的mAP会与get_map.py获得的会有所不同,原因有二: 206 | # (一)此处获得的mAP为验证集的mAP。 207 | # (二)此处设置评估参数较为保守,目的是加快评估速度。 208 | #------------------------------------------------------------------# 209 | eval_flag = True 210 | eval_period = 5 211 | 212 | #------------------------------# 213 | # 数据集路径 214 | #------------------------------# 215 | VOCdevkit_path = 'Nankai' 216 | #------------------------------------------------------------------# 217 | # 建议选项: 218 | # 种类少(几类)时,设置为True 219 | # 种类多(十几类)时,如果batch_size比较大(10以上),那么设置为True 220 | # 种类多(十几类)时,如果batch_size比较小(10以下),那么设置为False 221 | #------------------------------------------------------------------# 222 | dice_loss = False 223 | #------------------------------------------------------------------# 224 | # 是否使用focal loss来防止正负样本不平衡 225 | #------------------------------------------------------------------# 226 | focal_loss = False 227 | #------------------------------------------------------------------# 228 | # 是否给不同种类赋予不同的损失权值,默认是平衡的。 229 | # 设置的话,注意设置成numpy形式的,长度和num_classes一样。 230 | # 如: 231 | # num_classes = 3 232 | # cls_weights = np.array([1, 2, 3], np.float32) 233 | #------------------------------------------------------------------# 234 | cls_weights = np.ones([num_classes], np.float32) 235 | #------------------------------------------------------------------# 236 | # num_workers 用于设置是否使用多线程读取数据,1代表关闭多线程 237 | # 开启后会加快数据读取速度,但是会占用更多内存 238 | # keras里开启多线程有些时候速度反而慢了许多 239 | # 在IO为瓶颈的时候再开启多线程,即GPU运算速度远大于读取图片的速度。 240 | #------------------------------------------------------------------# 241 | num_workers = 4 242 | 243 | #------------------------------------------------------# 244 | # 设置用到的显卡 245 | #------------------------------------------------------# 246 | ngpus_per_node = torch.cuda.device_count() 247 | if distributed: 248 | dist.init_process_group(backend="nccl") 249 | local_rank = int(os.environ["LOCAL_RANK"]) 250 | rank = int(os.environ["RANK"]) 251 | device = torch.device("cuda", local_rank) 252 | if local_rank == 0: 253 | print(f"[{os.getpid()}] (rank = {rank}, local_rank = {local_rank}) training...") 254 | print("Gpu Device Count : ", ngpus_per_node) 255 | else: 256 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 257 | local_rank = 0 258 | 259 | #----------------------------------------------------# 260 | # 下载预训练权重 261 | #----------------------------------------------------# 262 | if pretrained: 263 | if distributed: 264 | if local_rank == 0: 265 | download_weights(backbone) 266 | dist.barrier() 267 | else: 268 | download_weights(backbone) 269 | 270 | model = Unet(num_classes=num_classes, pretrained=pretrained, backbone=backbone).train() 271 | if not pretrained: 272 | weights_init(model) 273 | if model_path != '': 274 | #------------------------------------------------------# 275 | # 权值文件请看README,百度网盘下载 276 | #------------------------------------------------------# 277 | if local_rank == 0: 278 | print('Load weights {}.'.format(model_path)) 279 | 280 | #------------------------------------------------------# 281 | # 根据预训练权重的Key和模型的Key进行加载 282 | #------------------------------------------------------# 283 | model_dict = model.state_dict() 284 | pretrained_dict = torch.load(model_path, map_location = device) 285 | load_key, no_load_key, temp_dict = [], [], {} 286 | for k, v in pretrained_dict.items(): 287 | if k in model_dict.keys() and np.shape(model_dict[k]) == np.shape(v): 288 | temp_dict[k] = v 289 | load_key.append(k) 290 | else: 291 | no_load_key.append(k) 292 | model_dict.update(temp_dict) 293 | model.load_state_dict(model_dict) 294 | #------------------------------------------------------# 295 | # 显示没有匹配上的Key 296 | #------------------------------------------------------# 297 | if local_rank == 0: 298 | print("\nSuccessful Load Key:", str(load_key)[:500], "……\nSuccessful Load Key Num:", len(load_key)) 299 | print("\nFail To Load Key:", str(no_load_key)[:500], "……\nFail To Load Key num:", len(no_load_key)) 300 | print("\n\033[1;33;44m温馨提示,head部分没有载入是正常现象,Backbone部分没有载入是错误的。\033[0m") 301 | 302 | #----------------------# 303 | # 记录Loss 304 | #----------------------# 305 | if local_rank == 0: 306 | time_str = datetime.datetime.strftime(datetime.datetime.now(),'%Y_%m_%d_%H_%M_%S') 307 | log_dir = os.path.join(save_dir, "loss_" + str(time_str)) 308 | loss_history = LossHistory(log_dir, model, input_shape=input_shape) 309 | else: 310 | loss_history = None 311 | 312 | #------------------------------------------------------------------# 313 | # torch 1.2不支持amp,建议使用torch 1.7.1及以上正确使用fp16 314 | # 因此torch1.2这里显示"could not be resolve" 315 | #------------------------------------------------------------------# 316 | if fp16: 317 | from torch.cuda.amp import GradScaler as GradScaler 318 | scaler = GradScaler() 319 | else: 320 | scaler = None 321 | 322 | model_train = model.train() 323 | #----------------------------# 324 | # 多卡同步Bn 325 | #----------------------------# 326 | if sync_bn and ngpus_per_node > 1 and distributed: 327 | model_train = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model_train) 328 | elif sync_bn: 329 | print("Sync_bn is not support in one gpu or not distributed.") 330 | 331 | if Cuda: 332 | if distributed: 333 | #----------------------------# 334 | # 多卡平行运行 335 | #----------------------------# 336 | model_train = model_train.cuda(local_rank) 337 | model_train = torch.nn.parallel.DistributedDataParallel(model_train, device_ids=[local_rank], find_unused_parameters=True) 338 | else: 339 | model_train = torch.nn.DataParallel(model) 340 | cudnn.benchmark = True 341 | model_train = model_train.cuda() 342 | 343 | #---------------------------# 344 | # 读取数据集对应的txt 345 | #---------------------------# 346 | with open(os.path.join(VOCdevkit_path, "save/train.txt"),"r") as f: 347 | train_lines = f.readlines() 348 | with open(os.path.join(VOCdevkit_path, "save/val.txt"),"r") as f: 349 | val_lines = f.readlines() 350 | num_train = len(train_lines) 351 | num_val = len(val_lines) 352 | 353 | if local_rank == 0: 354 | show_config( 355 | num_classes = num_classes, backbone = backbone, model_path = model_path, input_shape = input_shape, \ 356 | Init_Epoch = Init_Epoch, Freeze_Epoch = Freeze_Epoch, UnFreeze_Epoch = UnFreeze_Epoch, Freeze_batch_size = Freeze_batch_size, Unfreeze_batch_size = Unfreeze_batch_size, Freeze_Train = Freeze_Train, \ 357 | Init_lr = Init_lr, Min_lr = Min_lr, optimizer_type = optimizer_type, momentum = momentum, lr_decay_type = lr_decay_type, \ 358 | save_period = save_period, save_dir = save_dir, num_workers = num_workers, num_train = num_train, num_val = num_val 359 | ) 360 | #------------------------------------------------------# 361 | # 主干特征提取网络特征通用,冻结训练可以加快训练速度 362 | # 也可以在训练初期防止权值被破坏。 363 | # Init_Epoch为起始世代 364 | # Interval_Epoch为冻结训练的世代 365 | # Epoch总训练世代 366 | # 提示OOM或者显存不足请调小Batch_size 367 | #------------------------------------------------------# 368 | if True: 369 | UnFreeze_flag = False 370 | #------------------------------------# 371 | # 冻结一定部分训练 372 | #------------------------------------# 373 | if Freeze_Train: 374 | model.freeze_backbone() 375 | 376 | #-------------------------------------------------------------------# 377 | # 如果不冻结训练的话,直接设置batch_size为Unfreeze_batch_size 378 | #-------------------------------------------------------------------# 379 | batch_size = Freeze_batch_size if Freeze_Train else Unfreeze_batch_size 380 | 381 | #-------------------------------------------------------------------# 382 | # 判断当前batch_size,自适应调整学习率 383 | #-------------------------------------------------------------------# 384 | nbs = 16 385 | lr_limit_max = 1e-4 if optimizer_type == 'adam' else 1e-1 386 | lr_limit_min = 1e-4 if optimizer_type == 'adam' else 5e-4 387 | Init_lr_fit = min(max(batch_size / nbs * Init_lr, lr_limit_min), lr_limit_max) 388 | Min_lr_fit = min(max(batch_size / nbs * Min_lr, lr_limit_min * 1e-2), lr_limit_max * 1e-2) 389 | 390 | #---------------------------------------# 391 | # 根据optimizer_type选择优化器 392 | #---------------------------------------# 393 | optimizer = { 394 | 'adam' : optim.Adam(model.parameters(), Init_lr_fit, betas = (momentum, 0.999), weight_decay = weight_decay), 395 | 'sgd' : optim.SGD(model.parameters(), Init_lr_fit, momentum = momentum, nesterov=True, weight_decay = weight_decay) 396 | }[optimizer_type] 397 | 398 | #---------------------------------------# 399 | # 获得学习率下降的公式 400 | #---------------------------------------# 401 | lr_scheduler_func = get_lr_scheduler(lr_decay_type, Init_lr_fit, Min_lr_fit, UnFreeze_Epoch) 402 | 403 | #---------------------------------------# 404 | # 判断每一个世代的长度 405 | #---------------------------------------# 406 | epoch_step = num_train // batch_size 407 | epoch_step_val = num_val // batch_size 408 | 409 | if epoch_step == 0 or epoch_step_val == 0: 410 | raise ValueError("数据集过小,无法继续进行训练,请扩充数据集。") 411 | 412 | train_dataset = UnetDataset(train_lines, input_shape, num_classes, True, VOCdevkit_path) 413 | val_dataset = UnetDataset(val_lines, input_shape, num_classes, False, VOCdevkit_path) 414 | 415 | if distributed: 416 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=True,) 417 | val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False,) 418 | batch_size = batch_size // ngpus_per_node 419 | shuffle = False 420 | else: 421 | train_sampler = None 422 | val_sampler = None 423 | shuffle = True 424 | 425 | gen = DataLoader(train_dataset, shuffle = shuffle, batch_size = batch_size, num_workers = num_workers, pin_memory=True, 426 | drop_last = True, collate_fn = unet_dataset_collate, sampler=train_sampler) 427 | gen_val = DataLoader(val_dataset , shuffle = shuffle, batch_size = batch_size, num_workers = num_workers, pin_memory=True, 428 | drop_last = True, collate_fn = unet_dataset_collate, sampler=val_sampler) 429 | 430 | #----------------------# 431 | # 记录eval的map曲线 432 | #----------------------# 433 | if local_rank == 0: 434 | eval_callback = EvalCallback(model, input_shape, num_classes, val_lines, VOCdevkit_path, log_dir, Cuda, \ 435 | eval_flag=eval_flag, period=eval_period) 436 | else: 437 | eval_callback = None 438 | 439 | #---------------------------------------# 440 | # 开始模型训练 441 | #---------------------------------------# 442 | for epoch in range(Init_Epoch, UnFreeze_Epoch): 443 | #---------------------------------------# 444 | # 如果模型有冻结学习部分 445 | # 则解冻,并设置参数 446 | #---------------------------------------# 447 | if epoch >= Freeze_Epoch and not UnFreeze_flag and Freeze_Train: 448 | batch_size = Unfreeze_batch_size 449 | 450 | #-------------------------------------------------------------------# 451 | # 判断当前batch_size,自适应调整学习率 452 | #-------------------------------------------------------------------# 453 | nbs = 16 454 | lr_limit_max = 1e-4 if optimizer_type == 'adam' else 1e-1 455 | lr_limit_min = 1e-4 if optimizer_type == 'adam' else 5e-4 456 | Init_lr_fit = min(max(batch_size / nbs * Init_lr, lr_limit_min), lr_limit_max) 457 | Min_lr_fit = min(max(batch_size / nbs * Min_lr, lr_limit_min * 1e-2), lr_limit_max * 1e-2) 458 | #---------------------------------------# 459 | # 获得学习率下降的公式 460 | #---------------------------------------# 461 | lr_scheduler_func = get_lr_scheduler(lr_decay_type, Init_lr_fit, Min_lr_fit, UnFreeze_Epoch) 462 | 463 | model.unfreeze_backbone() 464 | 465 | epoch_step = num_train // batch_size 466 | epoch_step_val = num_val // batch_size 467 | 468 | if epoch_step == 0 or epoch_step_val == 0: 469 | raise ValueError("数据集过小,无法继续进行训练,请扩充数据集。") 470 | 471 | if distributed: 472 | batch_size = batch_size // ngpus_per_node 473 | 474 | gen = DataLoader(train_dataset, shuffle = shuffle, batch_size = batch_size, num_workers = num_workers, pin_memory=True, 475 | drop_last = True, collate_fn = unet_dataset_collate, sampler=train_sampler) 476 | gen_val = DataLoader(val_dataset , shuffle = shuffle, batch_size = batch_size, num_workers = num_workers, pin_memory=True, 477 | drop_last = True, collate_fn = unet_dataset_collate, sampler=val_sampler) 478 | 479 | UnFreeze_flag = True 480 | 481 | if distributed: 482 | train_sampler.set_epoch(epoch) 483 | 484 | set_optimizer_lr(optimizer, lr_scheduler_func, epoch) 485 | 486 | fit_one_epoch(model_train, model, loss_history, eval_callback, optimizer, epoch, 487 | epoch_step, epoch_step_val, gen, gen_val, UnFreeze_Epoch, Cuda, dice_loss, focal_loss, cls_weights, num_classes, fp16, scaler, save_period, save_dir, local_rank) 488 | 489 | if distributed: 490 | dist.barrier() 491 | 492 | if local_rank == 0: 493 | loss_history.writer.close() 494 | --------------------------------------------------------------------------------