├── nets
├── __init__.py
├── vgg.py
├── unet.py
├── unet_training.py
└── resnet.py
├── utils
├── __init__.py
├── ceshi.py
├── see.py
├── utils.py
├── tile.py
├── fix.py
├── change.py
├── dataloader.py
├── dataloader_medical.py
├── callbacks.py
├── utils_metrics.py
└── utils_fit.py
├── requirements.txt
├── .idea
├── misc.xml
├── inspectionProfiles
│ └── profiles_settings.xml
├── modules.xml
├── UNet_Demo.iml
└── workspace.xml
├── README.md
├── LICENSE
├── voc_annotation_medical.py
├── .gitignore
├── get_miou.py
├── voc_annotation.py
├── predictLabel.py
├── predict.py
├── unet.py
├── train_medical.py
└── train.py
/nets/__init__.py:
--------------------------------------------------------------------------------
1 | #
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | #
--------------------------------------------------------------------------------
/utils/ceshi.py:
--------------------------------------------------------------------------------
1 | import torch
2 | print(torch.__version__)
3 | print(torch.cuda.is_available())
4 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | scipy==1.2.1
2 | numpy==1.17.0
3 | matplotlib==3.1.2
4 | opencv_python==4.1.2.30
5 | torch
6 | torchvision
7 | tqdm==4.60.0
8 | Pillow==8.2.0
9 | h5py==2.10.0
10 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/UNet_Demo.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 糖尿病眼底病变分割和分类
2 | # 1.数据集准备
3 | ## 1.1 数据集下载
4 | 首先准备糖尿病病变分割的数据集,在这里我们选取南开大学的DDR数据集进行模型训练,该数据集包括分类,分割和目标检测三个部分
5 | 下载地址如下:[OIA-DDR数据集](https://news.nankai.edu.cn/ywsd/system/2019/12/07/030036739.shtml)
6 | ## 1.2 数据集预处理
7 | 在这里预处理包含两部分
8 | 首先对于分割数据集,我们需要将四种标签的图片融到一张,用于之后的多标签分割,在这里执行`utils/change.py`中的
9 | 再将分类数据集按照四种类别分别分开,用于数据集的读取
10 | # 2. 模型训练
11 | ## 2.1 模型选取
12 | 在这里,针对眼底病变小样本的特点,我们选取了unet作为网络架构,用来进行病变分割
13 | 特征提取网络选用restnet50,训练所需的权值可在百度网盘中下载。
14 | 链接: https://pan.baidu.com/s/1A22fC5cPRb74gqrpq7O9-A
15 | 提取码: 6n2c
16 | ## 2.2 模型训练
17 | 模型训练执行train.py即可,注意在train.py内设置对应超参,具体见train.py文件夹
18 | ## 2.3 模型预测
19 | 执行那个get_miou.py文件,即可获取分割结果,并生成对应指标,存储分割结果
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Bubbliiiing
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/utils/see.py:
--------------------------------------------------------------------------------
1 |
2 | #showPixelValue.py
3 |
4 | import cv2
5 | import sys
6 |
7 | def getMousePos(imgName):
8 | def onmouse(event, x, y, flags, param):
9 | cv2.imshow("img",img)
10 | #if event==cv2.EVENT_MOUSEMOVE:
11 | #print(img[y,x], " pos: ", x, " x ", y)
12 |
13 | #双击左键,显示鼠标位置
14 | if event == cv2.EVENT_MBUTTONDBLCLK:
15 | strtext = "(%s,%s)"%(x,y)
16 | print(img[y,x])
17 |
18 | cv2.namedWindow("img", cv2.WINDOW_NORMAL)
19 | img= cv2.imread(imgName)
20 | print(img[img>4])
21 | print(img.shape)
22 |
23 | cv2.setMouseCallback("img", onmouse)
24 |
25 | if cv2.waitKey() & 0xFF == 27: #按下‘q'键,退出
26 | cv2.destroyAllWindows()
27 |
28 |
29 | def showPixelValue(imgName):
30 |
31 | img= cv2.imread(imgName)
32 | def onmouse(event, x, y, flags, param):
33 | if event==cv2.EVENT_MOUSEMOVE:
34 | print(img[y,x])
35 |
36 | cv2.namedWindow("img")
37 | cv2.setMouseCallback("img", onmouse)
38 | cv2.imshow("img",img)
39 | if cv2.waitKey() == ord('q'): #按下‘q'键,退出
40 | cv2.destroyAllWindows()
41 |
42 | if __name__ == '__main__':
43 | arg1 = 'E:\\setup\\UNet_Demo\\eyedetect\\NanKai\\train\\label\\num_label\\007-6361-400.png'
44 | #E:\\setup\\UNet_Demo\\eyedetect\\NanKai\\train\\label\\EX_01\\007-1774-100.png
45 | #E:\\setup\\UNet_Demo\\eyedetect\\miou_out\\HE_res\\007-6361-400.png
46 | getMousePos(arg1)
47 |
--------------------------------------------------------------------------------
/voc_annotation_medical.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 |
4 | #----------------------------------------------------------------------#
5 | # 医药数据集的例子没有验证集
6 | #----------------------------------------------------------------------#
7 | trainval_percent = 1
8 | train_percent = 1
9 | #-------------------------------------------------------#
10 | # 指向医药数据集所在的文件夹
11 | # 默认指向根目录下的Medical_Datasets
12 | #-------------------------------------------------------#
13 | VOCdevkit_path = 'Nankai'
14 |
15 | if __name__ == "__main__":
16 | random.seed(0)
17 | print("Generate txt in ImageSets.")
18 | segfilepath = os.path.join(VOCdevkit_path, 'train/label/EX')
19 | saveBasePath = os.path.join(VOCdevkit_path, 'save')
20 |
21 | temp_seg = os.listdir(segfilepath)
22 | total_seg = []
23 | for seg in temp_seg:
24 | if seg.endswith(".jpg"):
25 | total_seg.append(seg)
26 |
27 | num = len(total_seg)
28 | list = range(num)
29 | tv = int(num*trainval_percent)
30 | tr = int(tv*train_percent)
31 | trainval= random.sample(list,tv)
32 | train = random.sample(trainval,tr)
33 |
34 | print("train and val size",tv)
35 | print("traub suze",tr)
36 | ftrainval = open(os.path.join(saveBasePath,'trainval.txt'), 'w')
37 | ftest = open(os.path.join(saveBasePath,'test.txt'), 'w')
38 | ftrain = open(os.path.join(saveBasePath,'train.txt'), 'w')
39 | fval = open(os.path.join(saveBasePath,'val.txt'), 'w')
40 |
41 | for i in list:
42 | name=total_seg[i][:-4]+'\n'
43 | if i in trainval:
44 | ftrainval.write(name)
45 | if i in train:
46 | ftrain.write(name)
47 | else:
48 | fval.write(name)
49 | else:
50 | ftest.write(name)
51 |
52 | ftrainval.close()
53 | ftrain.close()
54 | fval.close()
55 | ftest.close()
56 | print("Generate txt in ImageSets done.")
57 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 |
4 | #---------------------------------------------------------#
5 | # 将图像转换成RGB图像,防止灰度图在预测时报错。
6 | # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
7 | #---------------------------------------------------------#
8 | def cvtColor(image):
9 | if len(np.shape(image)) == 3 and np.shape(image)[2] == 3:
10 | return image
11 | else:
12 | image = image.convert('RGB')
13 | return image
14 |
15 | #---------------------------------------------------#
16 | # 对输入图像进行resize
17 | #---------------------------------------------------#
18 | def resize_image(image, size):
19 | iw, ih = image.size
20 | w, h = size
21 |
22 | scale = min(w/iw, h/ih)
23 | nw = int(iw*scale)
24 | nh = int(ih*scale)
25 |
26 | image = image.resize((nw,nh), Image.BICUBIC)
27 | new_image = Image.new('RGB', size, (128,128,128))
28 | new_image.paste(image, ((w-nw)//2, (h-nh)//2))
29 |
30 | return new_image, nw, nh
31 | def resize_image1(image, size):
32 | iw, ih = image.size
33 | return image, iw, ih
34 | #---------------------------------------------------#
35 | # 获得学习率
36 | #---------------------------------------------------#
37 | def get_lr(optimizer):
38 | for param_group in optimizer.param_groups:
39 | return param_group['lr']
40 |
41 | def preprocess_input(image):
42 | image /= 255.0
43 | return image
44 |
45 | def show_config(**kwargs):
46 | print('Configurations:')
47 | print('-' * 70)
48 | print('|%25s | %40s|' % ('keys', 'values'))
49 | print('-' * 70)
50 | for key, value in kwargs.items():
51 | print('|%25s | %40s|' % (str(key), str(value)))
52 | print('-' * 70)
53 |
54 | def download_weights(backbone, model_dir="./model_data"):
55 | import os
56 | from torch.hub import load_state_dict_from_url
57 |
58 | download_urls = {
59 | 'vgg' : 'https://download.pytorch.org/models/vgg16-397923af.pth',
60 | 'resnet50' : 'https://s3.amazonaws.com/pytorch/models/resnet50-19c8e357.pth'
61 | }
62 | url = download_urls[backbone]
63 |
64 | if not os.path.exists(model_dir):
65 | os.makedirs(model_dir)
66 | load_state_dict_from_url(url, model_dir)
--------------------------------------------------------------------------------
/utils/tile.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import os
4 | from PIL import Image
5 | import re
6 | # 读入图像
7 | input=r'Nankai/train/image'
8 | input_label=r'Nankai/train/label/num_label'
9 | out=r'Nankai/train/image_tile/'
10 | output_label=r'Nankai/train/label/tile_label/'
11 | if not out:
12 | os.makedirs(out)
13 | if not output_label:
14 | os.makedirs(output_label)
15 |
16 | # 获取输入图像的尺寸
17 | def tile(img,outdir,filename,output_label):
18 | height, width, channels = img.shape
19 | # 指定每个图像块的大小
20 | block_size = 512
21 |
22 | # 计算在宽度和高度上应该分成多少个块
23 | num_blocks_wide = int(np.ceil(width / block_size))
24 | num_blocks_high = int(np.ceil(height / block_size))
25 | # 计算宽度和高度上需要填充的像素数
26 | pad_width = ((0, num_blocks_high * block_size - height),
27 | (0, num_blocks_wide * block_size - width),
28 | (0, 0))
29 | # 对图像进行镜像填充
30 | padded_img = np.pad(img, pad_width, mode='reflect')
31 | # 切割图像并保存到磁盘
32 | for i in range(num_blocks_high):
33 | for j in range(num_blocks_wide):
34 | # 计算当前块在填充后图像中的位置
35 | start_x = j * block_size
36 | start_y = i * block_size
37 |
38 | # 切割出当前块
39 | block = padded_img[start_y:start_y+block_size, start_x:start_x+block_size, :]
40 | if len(block[block!=0])<100*100:
41 | continue
42 | #print(outdir+'output_image_{i*num_blocks_wide+j}.jpg')
43 | # 将当前块保存到磁盘
44 | cv2.imwrite(outdir+f'{filename}_{i*num_blocks_wide+j}.jpg', block)
45 | cv2.imwrite(output_label+f'{filename}_{i*num_blocks_wide+j}.png', block)
46 | win_size = 64
47 | step =64
48 | def Spilt():
49 | # 打开待切割的图像文件
50 | img = Image.open('result.jpg')
51 |
52 | # 获取图像大小
53 | width, height = img.size
54 |
55 | # 遍历图像并切割为指定大小的图像块
56 | for i in range(0, height-win_size, step):
57 | for j in range(0, width-win_size, step):
58 | # 切割出当前位置的图像块
59 | box = (j, i, j+win_size, i+win_size)
60 | img_block = img.crop(box)
61 |
62 | # 判断当前图像块中黑色像素占比是否超过一半
63 | img_data = np.array(img_block)
64 | if np.mean(img_data == 0) <= 0.1:
65 | # 将当前图像块保存到指定路径中
66 | img_block.save(os.path.join('Nankai/label', f'block_{i}_{j}.jpg'))
67 | Spilt()
68 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # ignore map, miou, datasets
2 | map_out/
3 | miou_out/
4 | miou_train/
5 | VOCdevkit/
6 | datasets/
7 | Medical_Datasets/
8 | DDR dataset/
9 | logs/
10 | model_data/
11 | .temp_miou_out/
12 | Nankai/
13 | Video/
14 | model_data/
15 | # Byte-compiled / optimized / DLL files
16 | __pycache__/
17 | *.py[cod]
18 | *$py.class
19 | *.pth
20 | # C extensions
21 | *.so
22 |
23 | # Distribution / packaging
24 | .Python
25 | build/
26 | develop-eggs/
27 | dist/
28 | downloads/
29 | eggs/
30 | .eggs/
31 | lib/
32 | lib64/
33 | parts/
34 | sdist/
35 | var/
36 | wheels/
37 | pip-wheel-metadata/
38 | share/python-wheels/
39 | *.egg-info/
40 | .installed.cfg
41 | *.egg
42 | MANIFEST
43 |
44 | # PyInstaller
45 | # Usually these files are written by a python script from a template
46 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
47 | *.manifest
48 | *.spec
49 |
50 | # Installer logs
51 | pip-log.txt
52 | pip-delete-this-directory.txt
53 |
54 | # Unit test / coverage reports
55 | htmlcov/
56 | .tox/
57 | .nox/
58 | .coverage
59 | .coverage.*
60 | .cache
61 | nosetests.xml
62 | coverage.xml
63 | *.cover
64 | *.py,cover
65 | .hypothesis/
66 | .pytest_cache/
67 |
68 | # Translations
69 | *.mo
70 | *.pot
71 |
72 | # Django stuff:
73 | *.log
74 | local_settings.py
75 | db.sqlite3
76 | db.sqlite3-journal
77 |
78 | # Flask stuff:
79 | instance/
80 | .webassets-cache
81 |
82 | # Scrapy stuff:
83 | .scrapy
84 |
85 | # Sphinx documentation
86 | docs/_build/
87 |
88 | # PyBuilder
89 | target/
90 |
91 | # Jupyter Notebook
92 | .ipynb_checkpoints
93 |
94 | # IPython
95 | profile_default/
96 | ipython_config.py
97 |
98 | # pyenv
99 | .python-version
100 |
101 | # pipenv
102 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
103 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
104 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
105 | # install all needed dependencies.
106 | #Pipfile.lock
107 |
108 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
109 | __pypackages__/
110 |
111 | # Celery stuff
112 | celerybeat-schedule
113 | celerybeat.pid
114 |
115 | # SageMath parsed files
116 | *.sage.py
117 |
118 | # Environments
119 | .env
120 | .venv
121 | env/
122 | venv/
123 | ENV/
124 | env.bak/
125 | venv.bak/
126 |
127 | # Spyder project settings
128 | .spyderproject
129 | .spyproject
130 |
131 | # Rope project settings
132 | .ropeproject
133 |
134 | # mkdocs documentation
135 | /site
136 |
137 | # mypy
138 | .mypy_cache/
139 | .dmypy.json
140 | dmypy.json
141 |
142 | # Pyre type checker
143 | .pyre/
144 |
--------------------------------------------------------------------------------
/get_miou.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from PIL import Image
4 | from tqdm import tqdm
5 |
6 | from unet import Unet
7 | from utils.utils_metrics import compute_mIoU, show_results
8 |
9 | '''
10 | 进行指标评估需要注意以下几点:
11 | 1、该文件生成的图为灰度图,因为值比较小,按照JPG形式的图看是没有显示效果的,所以看到近似全黑的图是正常的。
12 | 2、该文件计算的是验证集的miou,当前该库将测试集当作验证集使用,不单独划分测试集
13 | 3、仅有按照VOC格式数据训练的模型可以利用这个文件进行miou的计算。
14 | '''
15 | if __name__ == "__main__":
16 | #---------------------------------------------------------------------------#
17 | # miou_mode用于指定该文件运行时计算的内容
18 | # miou_mode为0代表整个miou计算流程,包括获得预测结果、计算miou。
19 | # miou_mode为1代表仅仅获得预测结果。
20 | # miou_mode为2代表仅仅计算miou。
21 | #---------------------------------------------------------------------------#
22 | miou_mode = 2
23 | #------------------------------#
24 | # 分类个数+1、如2+1
25 | #------------------------------#
26 | num_classes = 5
27 | #--------------------------------------------#
28 | # 区分的种类,和json_to_dataset里面的一样
29 | #--------------------------------------------#
30 | # name_classes = ["background","aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"]
31 | name_classes = ["background","EX", "HE", "MA", "SE"]
32 | #-------------------------------------------------------#
33 | # 指向VOC数据集所在的文件夹
34 | # 默认指向根目录下的VOC数据集
35 | #-------------------------------------------------------#
36 | VOCdevkit_path = 'NanKai'
37 |
38 | image_ids = open(os.path.join(VOCdevkit_path, "save/trainval.txt"),'r').read().splitlines()
39 | gt_dir = os.path.join(VOCdevkit_path, "train/label/num_label/")
40 | miou_out_path = "miou_train"
41 | pred_dir = os.path.join(miou_out_path, 'num_res')
42 |
43 | if miou_mode == 0 or miou_mode == 1:
44 | if not os.path.exists(pred_dir):
45 | os.makedirs(pred_dir)
46 | print("Load model.")
47 | unet = Unet()
48 | print("Load model done.")
49 |
50 | print("Get predict result.")
51 | for image_id in tqdm(image_ids):
52 | image_path = os.path.join(VOCdevkit_path, "train/image/"+image_id+".jpg")
53 | image = Image.open(image_path)
54 | image = unet.get_miou_png(image)
55 | image.save(os.path.join(pred_dir, image_id + ".png"))
56 | print("Get predict result done.")
57 |
58 | if miou_mode == 0 or miou_mode == 2:
59 | print("Get miou.")
60 | hist, IoUs, PA_Recall, Precision = compute_mIoU(gt_dir, pred_dir, image_ids, num_classes, name_classes) # 执行计算mIoU的函数
61 | print("Get miou done.")
62 | show_results(miou_out_path, hist, IoUs, PA_Recall, Precision, name_classes)
--------------------------------------------------------------------------------
/nets/vgg.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from torch.hub import load_state_dict_from_url
3 |
4 |
5 | class VGG(nn.Module):
6 | def __init__(self, features, num_classes=1000):
7 | super(VGG, self).__init__()
8 | self.features = features
9 | self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
10 | self.classifier = nn.Sequential(
11 | nn.Linear(512 * 7 * 7, 4096),
12 | nn.ReLU(True),
13 | nn.Dropout(),
14 | nn.Linear(4096, 4096),
15 | nn.ReLU(True),
16 | nn.Dropout(),
17 | nn.Linear(4096, num_classes),
18 | )
19 | self._initialize_weights()
20 |
21 | def forward(self, x):
22 | # x = self.features(x)
23 | # x = self.avgpool(x)
24 | # x = torch.flatten(x, 1)
25 | # x = self.classifier(x)
26 | feat1 = self.features[ :4 ](x)
27 | feat2 = self.features[4 :9 ](feat1)
28 | feat3 = self.features[9 :16](feat2)
29 | feat4 = self.features[16:23](feat3)
30 | feat5 = self.features[23:-1](feat4)
31 | return [feat1, feat2, feat3, feat4, feat5]
32 |
33 | def _initialize_weights(self):
34 | for m in self.modules():
35 | if isinstance(m, nn.Conv2d):
36 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
37 | if m.bias is not None:
38 | nn.init.constant_(m.bias, 0)
39 | elif isinstance(m, nn.BatchNorm2d):
40 | nn.init.constant_(m.weight, 1)
41 | nn.init.constant_(m.bias, 0)
42 | elif isinstance(m, nn.Linear):
43 | nn.init.normal_(m.weight, 0, 0.01)
44 | nn.init.constant_(m.bias, 0)
45 |
46 |
47 | def make_layers(cfg, batch_norm=False, in_channels = 3):
48 | layers = []
49 | for v in cfg:
50 | if v == 'M':
51 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
52 | else:
53 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
54 | if batch_norm:
55 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
56 | else:
57 | layers += [conv2d, nn.ReLU(inplace=True)]
58 | in_channels = v
59 | return nn.Sequential(*layers)
60 | # 512,512,3 -> 512,512,64 -> 256,256,64 -> 256,256,128 -> 128,128,128 -> 128,128,256 -> 64,64,256
61 | # 64,64,512 -> 32,32,512 -> 32,32,512
62 | cfgs = {
63 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M']
64 | }
65 |
66 |
67 | def VGG16(pretrained, in_channels = 3, **kwargs):
68 | model = VGG(make_layers(cfgs["D"], batch_norm = False, in_channels = in_channels), **kwargs)
69 | if pretrained:
70 | state_dict = load_state_dict_from_url("https://download.pytorch.org/models/vgg16-397923af.pth", model_dir="./model_data")
71 | model.load_state_dict(state_dict)
72 |
73 | del model.avgpool
74 | del model.classifier
75 | return model
76 |
--------------------------------------------------------------------------------
/utils/fix.py:
--------------------------------------------------------------------------------
1 | #--------------------------------------------------------#
2 | # 该文件用于调整标签的格式
3 | #--------------------------------------------------------#
4 | import os
5 |
6 | import numpy as np
7 | from PIL import Image
8 | from tqdm import tqdm
9 |
10 | #-----------------------------------------------------------------------------------#
11 | # Origin_SegmentationClass_path 原始标签所在的路径
12 | # Out_SegmentationClass_path 输出标签所在的路径
13 | # 处理后的标签为灰度图,如果设置的值太小会看不见具体情况。
14 | #-----------------------------------------------------------------------------------#
15 | Origin_SegmentationClass_path = r"E:/setup/UNet_Demo/eyedetect/NanKai/train/label/SE_new"
16 | Out_SegmentationClass_path = r"E:/setup/UNet_Demo/eyedetect/NanKai/train/label/SE_01"
17 |
18 | #-----------------------------------------------------------------------------------#
19 | # Origin_Point_Value 原始标签对应的像素点值
20 | # Out_Point_Value 输出标签对应的像素点值
21 | # Origin_Point_Value需要与Out_Point_Value一一对应。
22 | # 举例如下,当:
23 | # Origin_Point_Value = np.array([0, 255]);Out_Point_Value = np.array([0, 1])
24 | # 代表将原始标签中值为0的像素点,调整为0,将原始标签中值为255的像素点,调整为1。
25 | #
26 | # 示例中仅调整了两个像素点值,实际上可以更多个,如:
27 | # Origin_Point_Value = np.array([0, 128, 255]);Out_Point_Value = np.array([0, 1, 2])
28 | #
29 | # 也可以是数组(当标签值为RGB像素点时),如
30 | # Origin_Point_Value = np.array([[0, 0, 0], [1, 1, 1]]);Out_Point_Value = np.array([0, 1])
31 | #-----------------------------------------------------------------------------------#
32 | Origin_Point_Value = np.array([0, 255])
33 | Out_Point_Value = np.array([0, 1])
34 |
35 | if __name__ == "__main__":
36 | if not os.path.exists(Out_SegmentationClass_path):
37 | os.makedirs(Out_SegmentationClass_path)
38 |
39 | #---------------------------#
40 | # 遍历标签并赋值
41 | #---------------------------#
42 | png_names = os.listdir(Origin_SegmentationClass_path)
43 | print("正在遍历全部标签。")
44 | for png_name in tqdm(png_names):
45 | png = Image.open(os.path.join(Origin_SegmentationClass_path, png_name))
46 | w, h = png.size
47 |
48 | png = np.array(png)
49 | out_png = np.zeros([h, w])
50 | for i in range(len(Origin_Point_Value)):
51 | mask = png[:, :] == Origin_Point_Value[i]
52 | if len(np.shape(mask)) > 2:
53 | mask = mask.all(-1)
54 | out_png[mask] = Out_Point_Value[i]
55 |
56 | out_png = Image.fromarray(np.array(out_png, np.uint8))
57 | out_png.save(os.path.join(Out_SegmentationClass_path, png_name))
58 |
59 | #-------------------------------------#
60 | # 统计输出,各个像素点的值得个数
61 | #-------------------------------------#
62 | print("正在统计输出的图片每个像素点的数量。")
63 | classes_nums = np.zeros([256], np.int)
64 | for png_name in tqdm(png_names):
65 | png_file_name = os.path.join(Out_SegmentationClass_path, png_name)
66 | if not os.path.exists(png_file_name):
67 | raise ValueError("未检测到标签图片%s,请查看具体路径下文件是否存在以及后缀是否为png。"%(png_file_name))
68 |
69 | png = np.array(Image.open(png_file_name), np.uint8)
70 | classes_nums += np.bincount(np.reshape(png, [-1]), minlength=256)
71 |
72 | print("打印像素点的值与数量。")
73 | print('-' * 37)
74 | print("| %15s | %15s |"%("Key", "Value"))
75 | print('-' * 37)
76 | for i in range(256):
77 | if classes_nums[i] > 0:
78 | print("| %15s | %15s |"%(str(i), str(classes_nums[i])))
79 | print('-' * 37)
--------------------------------------------------------------------------------
/nets/unet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from nets.resnet import resnet50
5 | from nets.vgg import VGG16
6 |
7 |
8 | class unetUp(nn.Module):
9 | def __init__(self, in_size, out_size):
10 | super(unetUp, self).__init__()
11 | self.conv1 = nn.Conv2d(in_size, out_size, kernel_size = 3, padding = 1)
12 | self.conv2 = nn.Conv2d(out_size, out_size, kernel_size = 3, padding = 1)
13 | self.up = nn.UpsamplingBilinear2d(scale_factor = 2)
14 | self.relu = nn.ReLU(inplace = True)
15 |
16 | def forward(self, inputs1, inputs2):
17 | outputs = torch.cat([inputs1, self.up(inputs2)], 1)
18 | outputs = self.conv1(outputs)
19 | outputs = self.relu(outputs)
20 | outputs = self.conv2(outputs)
21 | outputs = self.relu(outputs)
22 | return outputs
23 |
24 | class Unet(nn.Module):
25 | def __init__(self, num_classes = 21, pretrained = False, backbone = 'vgg'):
26 | super(Unet, self).__init__()
27 | if backbone == 'vgg':
28 | self.vgg = VGG16(pretrained = pretrained)
29 | in_filters = [192, 384, 768, 1024]
30 | elif backbone == "resnet50":
31 | self.resnet = resnet50(pretrained = pretrained)
32 | in_filters = [192, 512, 1024, 3072]
33 | else:
34 | raise ValueError('Unsupported backbone - `{}`, Use vgg, resnet50.'.format(backbone))
35 | out_filters = [64, 128, 256, 512]
36 |
37 | # upsampling
38 | # 64,64,512
39 | self.up_concat4 = unetUp(in_filters[3], out_filters[3])
40 | # 128,128,256
41 | self.up_concat3 = unetUp(in_filters[2], out_filters[2])
42 | # 256,256,128
43 | self.up_concat2 = unetUp(in_filters[1], out_filters[1])
44 | # 512,512,64
45 | self.up_concat1 = unetUp(in_filters[0], out_filters[0])
46 |
47 | if backbone == 'resnet50':
48 | self.up_conv = nn.Sequential(
49 | nn.UpsamplingBilinear2d(scale_factor = 2),
50 | nn.Conv2d(out_filters[0], out_filters[0], kernel_size = 3, padding = 1),
51 | nn.ReLU(),
52 | nn.Conv2d(out_filters[0], out_filters[0], kernel_size = 3, padding = 1),
53 | nn.ReLU(),
54 | )
55 | else:
56 | self.up_conv = None
57 |
58 | self.final = nn.Conv2d(out_filters[0], num_classes, 1)
59 |
60 | self.backbone = backbone
61 |
62 | def forward(self, inputs):
63 | if self.backbone == "vgg":
64 | [feat1, feat2, feat3, feat4, feat5] = self.vgg.forward(inputs)
65 | elif self.backbone == "resnet50":
66 | [feat1, feat2, feat3, feat4, feat5] = self.resnet.forward(inputs)
67 |
68 | up4 = self.up_concat4(feat4, feat5)
69 | up3 = self.up_concat3(feat3, up4)
70 | up2 = self.up_concat2(feat2, up3)
71 | up1 = self.up_concat1(feat1, up2)
72 |
73 | if self.up_conv != None:
74 | up1 = self.up_conv(up1)
75 |
76 | final = self.final(up1)
77 |
78 | return final
79 |
80 | def freeze_backbone(self):
81 | if self.backbone == "vgg":
82 | for param in self.vgg.parameters():
83 | param.requires_grad = False
84 | elif self.backbone == "resnet50":
85 | for param in self.resnet.parameters():
86 | param.requires_grad = False
87 |
88 | def unfreeze_backbone(self):
89 | if self.backbone == "vgg":
90 | for param in self.vgg.parameters():
91 | param.requires_grad = True
92 | elif self.backbone == "resnet50":
93 | for param in self.resnet.parameters():
94 | param.requires_grad = True
95 |
--------------------------------------------------------------------------------
/voc_annotation.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 |
4 | import numpy as np
5 | from PIL import Image
6 | from tqdm import tqdm
7 |
8 | #-------------------------------------------------------#
9 | # 想要增加测试集修改trainval_percent
10 | # 修改train_percent用于改变验证集的比例 9:1
11 | #
12 | # 当前该库将测试集当作验证集使用,不单独划分测试集
13 | #-------------------------------------------------------#
14 | trainval_percent = 1
15 | train_percent = 0.9
16 | #-------------------------------------------------------#
17 | # 指向VOC数据集所在的文件夹
18 | # 默认指向根目录下的VOC数据集
19 | #-------------------------------------------------------#
20 | VOCdevkit_path = 'Nankai'
21 |
22 | if __name__ == "__main__":
23 | random.seed(0)
24 | print("Generate txt in ImageSets.")
25 | segfilepath = os.path.join(VOCdevkit_path, 'train/label/EX_01')
26 | saveBasePath = os.path.join(VOCdevkit_path, 'save')
27 |
28 | temp_seg = os.listdir(segfilepath)
29 | total_seg = []
30 | for seg in temp_seg:
31 | if seg.endswith(".png"):
32 | total_seg.append(seg)
33 |
34 | num = len(total_seg)
35 | list = range(num)
36 | tv = int(num*trainval_percent)
37 | tr = int(tv*train_percent)
38 | trainval= random.sample(list,tv)
39 | train = random.sample(trainval,tr)
40 |
41 | print("train and val size",tv)
42 | print("traub suze",tr)
43 | ftrainval = open(os.path.join(saveBasePath,'trainval.txt'), 'w')
44 | ftest = open(os.path.join(saveBasePath,'test.txt'), 'w')
45 | ftrain = open(os.path.join(saveBasePath,'train.txt'), 'w')
46 | fval = open(os.path.join(saveBasePath,'val.txt'), 'w')
47 |
48 | for i in list:
49 | name = total_seg[i][:-4]+'\n'
50 | if i in trainval:
51 | ftrainval.write(name)
52 | if i in train:
53 | ftrain.write(name)
54 | else:
55 | fval.write(name)
56 | else:
57 | ftest.write(name)
58 |
59 | ftrainval.close()
60 | ftrain.close()
61 | fval.close()
62 | ftest.close()
63 | print("Generate txt in ImageSets done.")
64 |
65 | print("Check datasets format, this may take a while.")
66 | print("检查数据集格式是否符合要求,这可能需要一段时间。")
67 | classes_nums = np.zeros([256], np.int)
68 | for i in tqdm(list):
69 | name = total_seg[i]
70 | png_file_name = os.path.join(segfilepath, name)
71 | if not os.path.exists(png_file_name):
72 | raise ValueError("未检测到标签图片%s,请查看具体路径下文件是否存在以及后缀是否为png。"%(png_file_name))
73 |
74 | png = np.array(Image.open(png_file_name), np.uint8)
75 | if len(np.shape(png)) > 2:
76 | print("标签图片%s的shape为%s,不属于灰度图或者八位彩图,请仔细检查数据集格式。"%(name, str(np.shape(png))))
77 | print("标签图片需要为灰度图或者八位彩图,标签的每个像素点的值就是这个像素点所属的种类。"%(name, str(np.shape(png))))
78 |
79 | classes_nums += np.bincount(np.reshape(png, [-1]), minlength=256)
80 |
81 | print("打印像素点的值与数量。")
82 | print('-' * 37)
83 | print("| %15s | %15s |"%("Key", "Value"))
84 | print('-' * 37)
85 | for i in range(256):
86 | if classes_nums[i] > 0:
87 | print("| %15s | %15s |"%(str(i), str(classes_nums[i])))
88 | print('-' * 37)
89 |
90 | if classes_nums[255] > 0 and classes_nums[0] > 0 and np.sum(classes_nums[1:255]) == 0:
91 | print("检测到标签中像素点的值仅包含0与255,数据格式有误。")
92 | print("二分类问题需要将标签修改为背景的像素点值为0,目标的像素点值为1。")
93 | elif classes_nums[0] > 0 and np.sum(classes_nums[1:]) == 0:
94 | print("检测到标签中仅仅包含背景像素点,数据格式有误,请仔细检查数据集格式。")
95 |
96 | print("JPEGImages中的图片应当为.jpg文件、SegmentationClass中的图片应当为.png文件。")
97 | print("如果格式有误,参考:")
98 | print("https://github.com/bubbliiiing/segmentation-format-fix")
--------------------------------------------------------------------------------
/utils/change.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | import cv2
4 | import numpy as np
5 | import os
6 | from PIL import Image
7 | def tif_to_png(image_path,save_path):
8 | """
9 | :param image_path: *.tif image path
10 | :param save_path: *.png image path
11 | :return:
12 | """
13 | img = cv2.imread(image_path,cv2.IMREAD_GRAYSCALE)
14 | # print(img)
15 | # print(img.dtype)
16 | filename = image_path.split('/')[-1].split('_')[0]
17 | filename=filename.split('.')[0]
18 | save_path = save_path + '/' + filename + '.png'
19 | cv2.imwrite(save_path,img)
20 | def changePixel():
21 | path = r'E:/setup/UNet_Demo/UNet_Demo/NanKai/train/label/EX_new/'
22 | savedpath = r'E:/setup/UNet_Demo/UNet_Demo/NanKai/train/label/EX_res/'
23 | filelist = os.listdir(path)
24 | for item in filelist:
25 | im = Image.open(path + item) #打开图片
26 | width = im.size[0] #获取宽度
27 | height = im.size[1] #获取长度
28 | for x in range(width):
29 | for y in range(height):
30 | r,g,b = im.getpixel((x,y))
31 | if r+g+b>0:
32 | im.putpixel((x,y),(1,1,1))
33 | im = im.convert('RGB')
34 | im.save(savedpath + item)
35 | def gray():
36 | path = r'E:/setup/UNet_Demo/UNet_Demo/NanKai/train/image/'
37 | savepath = r'E:/setup/UNet_Demo/UNet_Demo/NanKai/train/new_image/'
38 | filelist=os.listdir(path)
39 | for item in filelist:
40 | image = cv2.imread(path+item, cv2.IMREAD_GRAYSCALE)
41 | cv2.imwrite(savepath+item, image)
42 | def hw(strJpgFile, strSaveDir, width=512, height=512):
43 | img_src = Image.open(strJpgFile)
44 | img_dst = img_src.resize((width, height), Image.LANCZOS) # 得到的图像在抗锯齿和保留锐利边缘的效果较好
45 | img_dst.save(os.path.join(strSaveDir, os.path.basename(strJpgFile)))
46 | def mergeData(ex_data,ha_data,se_data,ma_data,pic_name,save_path):
47 | ex=np.array(Image.open(ex_data+pic_name))
48 | ha=np.array(Image.open(ha_data+pic_name))
49 | se=np.array(Image.open(se_data+pic_name))
50 | ma=np.array(Image.open(ma_data+pic_name))
51 | ex=ex+ha*2+se*3+ma*4
52 | ex[ex>4]=4
53 | ex=Image.fromarray(ex)
54 | ex.save(save_path+pic_name)
55 | def merge():
56 | ex_path="E:\\setup\\UNet_Demo\\eyedetect\\NanKai\\train\\label\\EX_01\\"
57 | ha_path="E:\\setup\\UNet_Demo\\eyedetect\\NanKai\\train\\label\\HE_01\\"
58 | se_path="E:\\setup\\UNet_Demo\\eyedetect\\NanKai\\train\\label\\SE_01\\"
59 | ma_path="E:\\setup\\UNet_Demo\\eyedetect\\NanKai\\train\\label\\MA_01\\"
60 | save_path="E:\\setup\\UNet_Demo\\eyedetect\\NanKai\\train\\label\\num_label\\"
61 | image_files=os.listdir(ex_path)
62 | for item in image_files:
63 | mergeData(ex_path,ha_path,se_path,ma_path,item,save_path)
64 | def getPix():
65 | save_path="E:\\setup\\UNet_Demo\\eyedetect\\NanKai\\train\\label\\num_label\\"
66 | image_files=os.listdir(save_path)
67 | res=[0,0,0,0,0]
68 | for item in image_files:
69 | img=Image.open(save_path+item)
70 | img=np.array(img)
71 | res[0]+=len(img[img==0])
72 | res[1]+=len(img[img==1])
73 | res[2]+=len(img[img==2])
74 | res[3]+=len(img[img==3])
75 | res[4]+=len(img[img==4])
76 | print(res)
77 | def getonePix():
78 | save_path="E:\\setup\\UNet_Demo\\eyedetect\\NanKai\\test\\label\\MA\\"
79 | image_files=os.listdir(save_path)
80 | res=[0,0]
81 | for item in image_files:
82 | img=Image.open(save_path+item)
83 | img=np.array(img)
84 | res[0]+=len(img[img==0])
85 | res[1]+=len(img[img==255])
86 | print(res)
87 | if __name__ == '__main__':
88 | root_path = r'E:/setup/UNet_Demo/eyedetect/NanKai/train/label/SE/'
89 | save_path = r'E:/setup/UNet_Demo/eyedetect/NanKai/train/label/SE_new'
90 | image_files = os.listdir(root_path)
91 | getonePix()
92 | #merge()
93 | #for item in image_files:
94 | # tif_to_png(root_path+item,save_path)
95 | #gray()
--------------------------------------------------------------------------------
/nets/unet_training.py:
--------------------------------------------------------------------------------
1 | import math
2 | from functools import partial
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 |
9 | def CE_Loss(inputs, target, cls_weights, num_classes=21):
10 | n, c, h, w = inputs.size()
11 | nt, ht, wt = target.size()
12 | if h != ht and w != wt:
13 | inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True)
14 |
15 | temp_inputs = inputs.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
16 | temp_target = target.view(-1)
17 |
18 | CE_loss = nn.CrossEntropyLoss(weight=cls_weights, ignore_index=num_classes)(temp_inputs, temp_target)
19 | return CE_loss
20 |
21 | def Focal_Loss(inputs, target, cls_weights, num_classes=21, alpha=0.5, gamma=2):
22 | n, c, h, w = inputs.size()
23 | nt, ht, wt = target.size()
24 | if h != ht and w != wt:
25 | inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True)
26 |
27 | temp_inputs = inputs.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
28 | temp_target = target.view(-1)
29 |
30 | logpt = -nn.CrossEntropyLoss(weight=cls_weights, ignore_index=num_classes, reduction='none')(temp_inputs, temp_target)
31 | pt = torch.exp(logpt)
32 | if alpha is not None:
33 | logpt *= alpha
34 | loss = -((1 - pt) ** gamma) * logpt
35 | loss = loss.mean()
36 | return loss
37 |
38 | def Dice_loss(inputs, target, beta=1, smooth = 1e-5):
39 | n, c, h, w = inputs.size()
40 | nt, ht, wt, ct = target.size()
41 | if h != ht and w != wt:
42 | inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True)
43 |
44 | temp_inputs = torch.softmax(inputs.transpose(1, 2).transpose(2, 3).contiguous().view(n, -1, c),-1)
45 | temp_target = target.view(n, -1, ct)
46 |
47 | #--------------------------------------------#
48 | # 计算dice loss
49 | #--------------------------------------------#
50 | tp = torch.sum(temp_target[...,:-1] * temp_inputs, axis=[0,1])
51 | fp = torch.sum(temp_inputs , axis=[0,1]) - tp
52 | fn = torch.sum(temp_target[...,:-1] , axis=[0,1]) - tp
53 |
54 | score = ((1 + beta ** 2) * tp + smooth) / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + smooth)
55 | dice_loss = 1 - torch.mean(score)
56 | return dice_loss
57 |
58 | def weights_init(net, init_type='normal', init_gain=0.02):
59 | def init_func(m):
60 | classname = m.__class__.__name__
61 | if hasattr(m, 'weight') and classname.find('Conv') != -1:
62 | if init_type == 'normal':
63 | torch.nn.init.normal_(m.weight.data, 0.0, init_gain)
64 | elif init_type == 'xavier':
65 | torch.nn.init.xavier_normal_(m.weight.data, gain=init_gain)
66 | elif init_type == 'kaiming':
67 | torch.nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
68 | elif init_type == 'orthogonal':
69 | torch.nn.init.orthogonal_(m.weight.data, gain=init_gain)
70 | else:
71 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
72 | elif classname.find('BatchNorm2d') != -1:
73 | torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
74 | torch.nn.init.constant_(m.bias.data, 0.0)
75 | print('initialize network with %s type' % init_type)
76 | net.apply(init_func)
77 |
78 | def get_lr_scheduler(lr_decay_type, lr, min_lr, total_iters, warmup_iters_ratio = 0.05, warmup_lr_ratio = 0.1, no_aug_iter_ratio = 0.05, step_num = 10):
79 | def yolox_warm_cos_lr(lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter, iters):
80 | if iters <= warmup_total_iters:
81 | # lr = (lr - warmup_lr_start) * iters / float(warmup_total_iters) + warmup_lr_start
82 | lr = (lr - warmup_lr_start) * pow(iters / float(warmup_total_iters), 2) + warmup_lr_start
83 | elif iters >= total_iters - no_aug_iter:
84 | lr = min_lr
85 | else:
86 | lr = min_lr + 0.5 * (lr - min_lr) * (
87 | 1.0 + math.cos(math.pi* (iters - warmup_total_iters) / (total_iters - warmup_total_iters - no_aug_iter))
88 | )
89 | return lr
90 |
91 | def step_lr(lr, decay_rate, step_size, iters):
92 | if step_size < 1:
93 | raise ValueError("step_size must above 1.")
94 | n = iters // step_size
95 | out_lr = lr * decay_rate ** n
96 | return out_lr
97 |
98 | if lr_decay_type == "cos":
99 | warmup_total_iters = min(max(warmup_iters_ratio * total_iters, 1), 3)
100 | warmup_lr_start = max(warmup_lr_ratio * lr, 1e-6)
101 | no_aug_iter = min(max(no_aug_iter_ratio * total_iters, 1), 15)
102 | func = partial(yolox_warm_cos_lr ,lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter)
103 | else:
104 | decay_rate = (min_lr / lr) ** (1 / (step_num - 1))
105 | step_size = total_iters / step_num
106 | func = partial(step_lr, lr, decay_rate, step_size)
107 |
108 | return func
109 |
110 | def set_optimizer_lr(optimizer, lr_scheduler_func, epoch):
111 | lr = lr_scheduler_func(epoch)
112 | for param_group in optimizer.param_groups:
113 | param_group['lr'] = lr
114 |
--------------------------------------------------------------------------------
/utils/dataloader.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import cv2
4 | import numpy as np
5 | import torch
6 | from PIL import Image
7 | from torch.utils.data.dataset import Dataset
8 |
9 | from utils.utils import cvtColor, preprocess_input
10 |
11 |
12 | class UnetDataset(Dataset):
13 | def __init__(self, annotation_lines, input_shape, num_classes, train, dataset_path):
14 | super(UnetDataset, self).__init__()
15 | self.annotation_lines = annotation_lines
16 | self.length = len(annotation_lines)
17 | self.input_shape = input_shape
18 | self.num_classes = num_classes
19 | self.train = train
20 | self.dataset_path = dataset_path
21 |
22 | def __len__(self):
23 | return self.length
24 |
25 | def __getitem__(self, index):
26 | annotation_line = self.annotation_lines[index]
27 | name = annotation_line.split()[0]
28 |
29 | #-------------------------------#
30 | # 从文件中读取图像
31 | #-------------------------------#
32 | jpg = Image.open(os.path.join(os.path.join(self.dataset_path, "train/image"), name + ".jpg"))
33 | png = Image.open(os.path.join(os.path.join(self.dataset_path, "train/label/EX_01"), name + ".png"))
34 | #-------------------------------#
35 | # 数据增强
36 | #-------------------------------#
37 | jpg, png = self.get_random_data(jpg, png, self.input_shape, random = self.train)
38 |
39 | jpg = np.transpose(preprocess_input(np.array(jpg, np.float64)), [2,0,1])
40 | png = np.array(png)
41 | png[png >= self.num_classes] = self.num_classes
42 | #-------------------------------------------------------#
43 | # 转化成one_hot的形式
44 | # 在这里需要+1是因为voc数据集有些标签具有白边部分
45 | # 我们需要将白边部分进行忽略,+1的目的是方便忽略。
46 | #-------------------------------------------------------#
47 | seg_labels = np.eye(self.num_classes + 1)[png.reshape([-1])]
48 | seg_labels = seg_labels.reshape((int(self.input_shape[0]), int(self.input_shape[1]), self.num_classes + 1))
49 |
50 | return jpg, png, seg_labels
51 |
52 | def rand(self, a=0, b=1):
53 | return np.random.rand() * (b - a) + a
54 |
55 | def get_random_data(self, image, label, input_shape, jitter=.3, hue=.1, sat=0.7, val=0.3, random=True):
56 | image = cvtColor(image)
57 | label = Image.fromarray(np.array(label))
58 | #------------------------------#
59 | # 获得图像的高宽与目标高宽
60 | #------------------------------#
61 | iw, ih = image.size
62 | h, w = input_shape
63 |
64 | if not random:
65 | iw, ih = image.size
66 | scale = min(w/iw, h/ih)
67 | nw = int(iw*scale)
68 | nh = int(ih*scale)
69 |
70 | image = image.resize((nw,nh), Image.BICUBIC)
71 | new_image = Image.new('RGB', [w, h], (128,128,128))
72 | new_image.paste(image, ((w-nw)//2, (h-nh)//2))
73 |
74 | label = label.resize((nw,nh), Image.NEAREST)
75 | new_label = Image.new('L', [w, h], (0))
76 | new_label.paste(label, ((w-nw)//2, (h-nh)//2))
77 | return new_image, new_label
78 |
79 | #------------------------------------------#
80 | # 对图像进行缩放并且进行长和宽的扭曲
81 | #------------------------------------------#
82 | new_ar = iw/ih * self.rand(1-jitter,1+jitter) / self.rand(1-jitter,1+jitter)
83 | scale = self.rand(0.25, 2)
84 | if new_ar < 1:
85 | nh = int(scale*h)
86 | nw = int(nh*new_ar)
87 | else:
88 | nw = int(scale*w)
89 | nh = int(nw/new_ar)
90 | image = image.resize((nw,nh), Image.BICUBIC)
91 | label = label.resize((nw,nh), Image.NEAREST)
92 |
93 | #------------------------------------------#
94 | # 翻转图像
95 | #------------------------------------------#
96 | flip = self.rand()<.5
97 | if flip:
98 | image = image.transpose(Image.FLIP_LEFT_RIGHT)
99 | label = label.transpose(Image.FLIP_LEFT_RIGHT)
100 |
101 | #------------------------------------------#
102 | # 将图像多余的部分加上灰条
103 | #------------------------------------------#
104 | dx = int(self.rand(0, w-nw))
105 | dy = int(self.rand(0, h-nh))
106 | new_image = Image.new('RGB', (w,h), (128,128,128))
107 | new_label = Image.new('L', (w,h), (0))
108 | new_image.paste(image, (dx, dy))
109 | new_label.paste(label, (dx, dy))
110 | image = new_image
111 | label = new_label
112 |
113 | image_data = np.array(image, np.uint8)
114 | #---------------------------------#
115 | # 对图像进行色域变换
116 | # 计算色域变换的参数
117 | #---------------------------------#
118 | r = np.random.uniform(-1, 1, 3) * [hue, sat, val] + 1
119 | #---------------------------------#
120 | # 将图像转到HSV上
121 | #---------------------------------#
122 | hue, sat, val = cv2.split(cv2.cvtColor(image_data, cv2.COLOR_RGB2HSV))
123 | dtype = image_data.dtype
124 | #---------------------------------#
125 | # 应用变换
126 | #---------------------------------#
127 | x = np.arange(0, 256, dtype=r.dtype)
128 | lut_hue = ((x * r[0]) % 180).astype(dtype)
129 | lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
130 | lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
131 |
132 | image_data = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))
133 | image_data = cv2.cvtColor(image_data, cv2.COLOR_HSV2RGB)
134 |
135 | return image_data, label
136 |
137 | # DataLoader中collate_fn使用
138 | def unet_dataset_collate(batch):
139 | images = []
140 | pngs = []
141 | seg_labels = []
142 | for img, png, labels in batch:
143 | images.append(img)
144 | pngs.append(png)
145 | seg_labels.append(labels)
146 | images = torch.from_numpy(np.array(images)).type(torch.FloatTensor)
147 | pngs = torch.from_numpy(np.array(pngs)).long()
148 | seg_labels = torch.from_numpy(np.array(seg_labels)).type(torch.FloatTensor)
149 | return images, pngs, seg_labels
150 |
--------------------------------------------------------------------------------
/utils/dataloader_medical.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import cv2
4 | import numpy as np
5 | import torch
6 | from PIL import Image
7 | from torch.utils.data.dataset import Dataset
8 | import matplotlib.pyplot as plt
9 | from utils.utils import cvtColor, preprocess_input
10 |
11 |
12 | class UnetDataset(Dataset):
13 | def __init__(self, annotation_lines, input_shape, num_classes, train, dataset_path):
14 | super(UnetDataset, self).__init__()
15 | self.annotation_lines = annotation_lines
16 | self.length = len(annotation_lines)
17 | self.input_shape = input_shape
18 | self.num_classes = num_classes
19 | self.train = train
20 | self.dataset_path = dataset_path
21 |
22 | def __len__(self):
23 | return self.length
24 |
25 | def __getitem__(self, index):
26 | annotation_line = self.annotation_lines[index]
27 | name = annotation_line.split()[0]
28 |
29 | #-------------------------------#
30 | # 从文件中读取图像
31 | #-------------------------------#
32 | jpg = Image.open(os.path.join(os.path.join(self.dataset_path, "train/new_image"), name + ".jpg"))
33 | png = Image.open(os.path.join(os.path.join(self.dataset_path, "train/label/EX_01"), name + ".png"))
34 | #-------------------------------#
35 | # 数据增强
36 | #-------------------------------#
37 | jpg, png = self.get_random_data(jpg, png, self.input_shape, random = self.train)
38 | #plt.imshow(jpg)
39 | jpg = np.transpose(preprocess_input(np.array(jpg, np.float64)), [2,0,1])
40 | png = np.array(png)
41 | #-------------------------------------------------------#
42 | # 这里的标签处理方式和普通voc的处理方式不同
43 | # 将小于127.5的像素点设置为目标像素点。
44 | #-------------------------------------------------------#
45 | modify_png = np.zeros_like(png)
46 | modify_png[png==0] = 1
47 | seg_labels = modify_png
48 | seg_labels = np.eye(self.num_classes + 1)[seg_labels.reshape([-1])]
49 | seg_labels = seg_labels.reshape((int(self.input_shape[0]), int(self.input_shape[1]), self.num_classes + 1))
50 |
51 | return jpg, modify_png, seg_labels
52 |
53 | def rand(self, a=0, b=1):
54 | return np.random.rand() * (b - a) + a
55 |
56 | def get_random_data(self, image, label, input_shape, jitter=.3, hue=.1, sat=0.7, val=0.3, random=True):
57 | image = cvtColor(image)
58 | label = Image.fromarray(np.array(label))
59 | #------------------------------#
60 | # 获得图像的高宽与目标高宽
61 | #------------------------------#
62 | iw, ih = image.size
63 | h, w = input_shape
64 |
65 | if not random:
66 | iw, ih = image.size
67 | scale = min(w/iw, h/ih)
68 | nw = int(iw*scale)
69 | nh = int(ih*scale)
70 |
71 | image = image.resize((nw,nh), Image.BICUBIC)
72 | new_image = Image.new('RGB', [w, h], (128,128,128))
73 | new_image.paste(image, ((w-nw)//2, (h-nh)//2))
74 |
75 | label = label.resize((nw,nh), Image.NEAREST)
76 | new_label = Image.new('L', [w, h], (0))
77 | new_label.paste(label, ((w-nw)//2, (h-nh)//2))
78 | return new_image, new_label
79 |
80 | #------------------------------------------#
81 | # 对图像进行缩放并且进行长和宽的扭曲
82 | #------------------------------------------#
83 | new_ar = iw/ih * self.rand(1-jitter,1+jitter) / self.rand(1-jitter,1+jitter)
84 | scale = self.rand(0.25, 2)
85 | if new_ar < 1:
86 | nh = int(scale*h)
87 | nw = int(nh*new_ar)
88 | else:
89 | nw = int(scale*w)
90 | nh = int(nw/new_ar)
91 | image = image.resize((nw,nh), Image.BICUBIC)
92 | label = label.resize((nw,nh), Image.NEAREST)
93 |
94 | #------------------------------------------#
95 | # 翻转图像
96 | #------------------------------------------#
97 | flip = self.rand()<.5
98 | if flip:
99 | image = image.transpose(Image.FLIP_LEFT_RIGHT)
100 | label = label.transpose(Image.FLIP_LEFT_RIGHT)
101 |
102 | #------------------------------------------#
103 | # 将图像多余的部分加上灰条
104 | #------------------------------------------#
105 | dx = int(self.rand(0, w-nw))
106 | dy = int(self.rand(0, h-nh))
107 | new_image = Image.new('RGB', (w,h), (128,128,128))
108 | new_label = Image.new('L', (w,h), (0))
109 | new_image.paste(image, (dx, dy))
110 | new_label.paste(label, (dx, dy))
111 | image = new_image
112 | label = new_label
113 |
114 | image_data = np.array(image, np.uint8)
115 | #---------------------------------#
116 | # 对图像进行色域变换
117 | # 计算色域变换的参数
118 | #---------------------------------#
119 | r = np.random.uniform(-1, 1, 3) * [hue, sat, val] + 1
120 | #---------------------------------#
121 | # 将图像转到HSV上
122 | #---------------------------------#
123 | hue, sat, val = cv2.split(cv2.cvtColor(image_data, cv2.COLOR_RGB2HSV))
124 | dtype = image_data.dtype
125 | #---------------------------------#
126 | # 应用变换
127 | #---------------------------------#
128 | x = np.arange(0, 256, dtype=r.dtype)
129 | lut_hue = ((x * r[0]) % 180).astype(dtype)
130 | lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
131 | lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
132 |
133 | image_data = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))
134 | image_data = cv2.cvtColor(image_data, cv2.COLOR_HSV2RGB)
135 |
136 | return image_data, label
137 |
138 | # DataLoader中collate_fn使用
139 | def unet_dataset_collate(batch):
140 | images = []
141 | pngs = []
142 | seg_labels = []
143 | for img, png, labels in batch:
144 | images.append(img)
145 | pngs.append(png)
146 | seg_labels.append(labels)
147 | images = torch.from_numpy(np.array(images)).type(torch.FloatTensor)
148 | pngs = torch.from_numpy(np.array(pngs)).long()
149 | seg_labels = torch.from_numpy(np.array(seg_labels)).type(torch.FloatTensor)
150 | return images, pngs, seg_labels
151 |
--------------------------------------------------------------------------------
/nets/resnet.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch.nn as nn
4 | import torch.utils.model_zoo as model_zoo
5 |
6 |
7 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
8 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
9 | padding=dilation, groups=groups, bias=False, dilation=dilation)
10 |
11 |
12 | def conv1x1(in_planes, out_planes, stride=1):
13 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
14 |
15 |
16 | class BasicBlock(nn.Module):
17 | expansion = 1
18 |
19 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
20 | base_width=64, dilation=1, norm_layer=None):
21 | super(BasicBlock, self).__init__()
22 | if norm_layer is None:
23 | norm_layer = nn.BatchNorm2d
24 | if groups != 1 or base_width != 64:
25 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
26 | if dilation > 1:
27 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
28 | self.conv1 = conv3x3(inplanes, planes, stride)
29 | self.bn1 = norm_layer(planes)
30 | self.relu = nn.ReLU(inplace=True)
31 | self.conv2 = conv3x3(planes, planes)
32 | self.bn2 = norm_layer(planes)
33 | self.downsample = downsample
34 | self.stride = stride
35 |
36 | def forward(self, x):
37 | identity = x
38 |
39 | out = self.conv1(x)
40 | out = self.bn1(out)
41 | out = self.relu(out)
42 |
43 | out = self.conv2(out)
44 | out = self.bn2(out)
45 |
46 | if self.downsample is not None:
47 | identity = self.downsample(x)
48 |
49 | out += identity
50 | out = self.relu(out)
51 |
52 | return out
53 |
54 |
55 | class Bottleneck(nn.Module):
56 | expansion = 4
57 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
58 | base_width=64, dilation=1, norm_layer=None):
59 | super(Bottleneck, self).__init__()
60 | if norm_layer is None:
61 | norm_layer = nn.BatchNorm2d
62 | width = int(planes * (base_width / 64.)) * groups
63 | # 利用1x1卷积下降通道数
64 | self.conv1 = conv1x1(inplanes, width)
65 | self.bn1 = norm_layer(width)
66 | # 利用3x3卷积进行特征提取
67 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
68 | self.bn2 = norm_layer(width)
69 | # 利用1x1卷积上升通道数
70 | self.conv3 = conv1x1(width, planes * self.expansion)
71 | self.bn3 = norm_layer(planes * self.expansion)
72 |
73 | self.relu = nn.ReLU(inplace=True)
74 | self.downsample = downsample
75 | self.stride = stride
76 |
77 | def forward(self, x):
78 | identity = x
79 |
80 | out = self.conv1(x)
81 | out = self.bn1(out)
82 | out = self.relu(out)
83 |
84 | out = self.conv2(out)
85 | out = self.bn2(out)
86 | out = self.relu(out)
87 |
88 | out = self.conv3(out)
89 | out = self.bn3(out)
90 |
91 | if self.downsample is not None:
92 | identity = self.downsample(x)
93 |
94 | out += identity
95 | out = self.relu(out)
96 |
97 | return out
98 |
99 |
100 | class ResNet(nn.Module):
101 | def __init__(self, block, layers, num_classes=1000):
102 | #-----------------------------------------------------------#
103 | # 假设输入图像为600,600,3
104 | # 当我们使用resnet50的时候
105 | #-----------------------------------------------------------#
106 | self.inplanes = 64
107 | super(ResNet, self).__init__()
108 | # 600,600,3 -> 300,300,64
109 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
110 | self.bn1 = nn.BatchNorm2d(64)
111 | self.relu = nn.ReLU(inplace=True)
112 | # 300,300,64 -> 150,150,64
113 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True) # change
114 | # 150,150,64 -> 150,150,256
115 | self.layer1 = self._make_layer(block, 64, layers[0])
116 | # 150,150,256 -> 75,75,512
117 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
118 | # 75,75,512 -> 38,38,1024
119 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
120 | # 38,38,1024 -> 19,19,2048
121 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
122 |
123 | self.avgpool = nn.AvgPool2d(7)
124 | self.fc = nn.Linear(512 * block.expansion, num_classes)
125 |
126 | for m in self.modules():
127 | if isinstance(m, nn.Conv2d):
128 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
129 | m.weight.data.normal_(0, math.sqrt(2. / n))
130 | elif isinstance(m, nn.BatchNorm2d):
131 | m.weight.data.fill_(1)
132 | m.bias.data.zero_()
133 |
134 | def _make_layer(self, block, planes, blocks, stride=1):
135 | downsample = None
136 | if stride != 1 or self.inplanes != planes * block.expansion:
137 | downsample = nn.Sequential(
138 | nn.Conv2d(self.inplanes, planes * block.expansion,
139 | kernel_size=1, stride=stride, bias=False),
140 | nn.BatchNorm2d(planes * block.expansion),
141 | )
142 |
143 | layers = []
144 | layers.append(block(self.inplanes, planes, stride, downsample))
145 | self.inplanes = planes * block.expansion
146 | for i in range(1, blocks):
147 | layers.append(block(self.inplanes, planes))
148 |
149 | return nn.Sequential(*layers)
150 |
151 | def forward(self, x):
152 | # x = self.conv1(x)
153 | # x = self.bn1(x)
154 | # x = self.relu(x)
155 | # x = self.maxpool(x)
156 |
157 | # x = self.layer1(x)
158 | # x = self.layer2(x)
159 | # x = self.layer3(x)
160 | # x = self.layer4(x)
161 |
162 | # x = self.avgpool(x)
163 | # x = x.view(x.size(0), -1)
164 | # x = self.fc(x)
165 |
166 | x = self.conv1(x)
167 | x = self.bn1(x)
168 | feat1 = self.relu(x)
169 |
170 | x = self.maxpool(feat1)
171 | feat2 = self.layer1(x)
172 |
173 | feat3 = self.layer2(feat2)
174 | feat4 = self.layer3(feat3)
175 | feat5 = self.layer4(feat4)
176 | return [feat1, feat2, feat3, feat4, feat5]
177 |
178 | def resnet50(pretrained=False, **kwargs):
179 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
180 | if pretrained:
181 | model.load_state_dict(model_zoo.load_url('https://s3.amazonaws.com/pytorch/models/resnet50-19c8e357.pth', model_dir='model_data'), strict=False)
182 |
183 | del model.avgpool
184 | del model.fc
185 | return model
186 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 | 1664104742342
158 |
159 |
160 | 1664104742342
161 |
162 |
163 |
164 |
--------------------------------------------------------------------------------
/utils/callbacks.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import matplotlib
4 | import torch
5 | import torch.nn.functional as F
6 |
7 | matplotlib.use('Agg')
8 | from matplotlib import pyplot as plt
9 | import scipy.signal
10 |
11 | import cv2
12 | import shutil
13 | import numpy as np
14 |
15 | from PIL import Image
16 | from tqdm import tqdm
17 | from tensorboardX import SummaryWriter
18 | from .utils import cvtColor, preprocess_input, resize_image
19 | from .utils_metrics import compute_mIoU
20 |
21 |
22 | class LossHistory():
23 | def __init__(self, log_dir, model, input_shape, val_loss_flag=True):
24 | self.log_dir = log_dir
25 | self.val_loss_flag = val_loss_flag
26 |
27 | self.losses = []
28 | if self.val_loss_flag:
29 | self.val_loss = []
30 |
31 | os.makedirs(self.log_dir)
32 | self.writer = SummaryWriter(self.log_dir)
33 | try:
34 | dummy_input = torch.randn(2, 3, input_shape[0], input_shape[1])
35 | self.writer.add_graph(model, dummy_input)
36 | except:
37 | pass
38 |
39 | def append_loss(self, epoch, loss, val_loss = None):
40 | if not os.path.exists(self.log_dir):
41 | os.makedirs(self.log_dir)
42 | self.losses.append(loss)
43 | if self.val_loss_flag:
44 | self.val_loss.append(val_loss)
45 |
46 | with open(os.path.join(self.log_dir, "epoch_loss.txt"), 'a') as f:
47 | f.write(str(loss))
48 | f.write("\n")
49 | if self.val_loss_flag:
50 | with open(os.path.join(self.log_dir, "epoch_val_loss.txt"), 'a') as f:
51 | f.write(str(val_loss))
52 | f.write("\n")
53 |
54 | self.writer.add_scalar('loss', loss, epoch)
55 | if self.val_loss_flag:
56 | self.writer.add_scalar('val_loss', val_loss, epoch)
57 |
58 | self.loss_plot()
59 |
60 | def loss_plot(self):
61 | iters = range(len(self.losses))
62 |
63 | plt.figure()
64 | plt.plot(iters, self.losses, 'red', linewidth = 2, label='train loss')
65 | if self.val_loss_flag:
66 | plt.plot(iters, self.val_loss, 'coral', linewidth = 2, label='val loss')
67 |
68 | try:
69 | if len(self.losses) < 25:
70 | num = 5
71 | else:
72 | num = 15
73 |
74 | plt.plot(iters, scipy.signal.savgol_filter(self.losses, num, 3), 'green', linestyle = '--', linewidth = 2, label='smooth train loss')
75 | if self.val_loss_flag:
76 | plt.plot(iters, scipy.signal.savgol_filter(self.val_loss, num, 3), '#8B4513', linestyle = '--', linewidth = 2, label='smooth val loss')
77 | except:
78 | pass
79 |
80 | plt.grid(True)
81 | plt.xlabel('Epoch')
82 | plt.ylabel('Loss')
83 | plt.legend(loc="upper right")
84 |
85 | plt.savefig(os.path.join(self.log_dir, "epoch_loss.png"))
86 |
87 | plt.cla()
88 | plt.close("all")
89 |
90 | class EvalCallback():
91 | def __init__(self, net, input_shape, num_classes, image_ids, dataset_path, log_dir, cuda, \
92 | miou_out_path=".temp_miou_out", eval_flag=True, period=1):
93 | super(EvalCallback, self).__init__()
94 |
95 | self.net = net
96 | self.input_shape = input_shape
97 | self.num_classes = num_classes
98 | self.image_ids = image_ids
99 | self.dataset_path = dataset_path
100 | self.log_dir = log_dir
101 | self.cuda = cuda
102 | self.miou_out_path = miou_out_path
103 | self.eval_flag = eval_flag
104 | self.period = period
105 |
106 | self.image_ids = [image_id.split()[0] for image_id in image_ids]
107 | self.mious = [0]
108 | self.epoches = [0]
109 | if self.eval_flag:
110 | with open(os.path.join(self.log_dir, "epoch_miou.txt"), 'a') as f:
111 | f.write(str(0))
112 | f.write("\n")
113 |
114 | def get_miou_png(self, image):
115 | #---------------------------------------------------------#
116 | # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
117 | # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
118 | #---------------------------------------------------------#
119 | image = cvtColor(image)
120 | orininal_h = np.array(image).shape[0]
121 | orininal_w = np.array(image).shape[1]
122 | #---------------------------------------------------------#
123 | # 给图像增加灰条,实现不失真的resize
124 | # 也可以直接resize进行识别
125 | #---------------------------------------------------------#
126 | image_data, nw, nh = resize_image(image, (self.input_shape[1],self.input_shape[0]))
127 | #---------------------------------------------------------#
128 | # 添加上batch_size维度
129 | #---------------------------------------------------------#
130 | image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, np.float32)), (2, 0, 1)), 0)
131 |
132 | with torch.no_grad():
133 | images = torch.from_numpy(image_data)
134 | if self.cuda:
135 | images = images.cuda()
136 |
137 | #---------------------------------------------------#
138 | # 图片传入网络进行预测
139 | #---------------------------------------------------#
140 | pr = self.net(images)[0]
141 | #---------------------------------------------------#
142 | # 取出每一个像素点的种类
143 | #---------------------------------------------------#
144 | pr = F.softmax(pr.permute(1,2,0),dim = -1).cpu().numpy()
145 | #--------------------------------------#
146 | # 将灰条部分截取掉
147 | #--------------------------------------#
148 | pr = pr[int((self.input_shape[0] - nh) // 2) : int((self.input_shape[0] - nh) // 2 + nh), \
149 | int((self.input_shape[1] - nw) // 2) : int((self.input_shape[1] - nw) // 2 + nw)]
150 | #---------------------------------------------------#
151 | # 进行图片的resize
152 | #---------------------------------------------------#
153 | pr = cv2.resize(pr, (orininal_w, orininal_h), interpolation = cv2.INTER_LINEAR)
154 | #---------------------------------------------------#
155 | # 取出每一个像素点的种类
156 | #---------------------------------------------------#
157 | pr = pr.argmax(axis=-1)
158 |
159 | image = Image.fromarray(np.uint8(pr))
160 | return image
161 |
162 | def on_epoch_end(self, epoch, model_eval):
163 | if epoch % self.period == 0 and self.eval_flag:
164 | self.net = model_eval
165 | gt_dir = os.path.join(self.dataset_path, "VOC2007/SegmentationClass/")
166 | pred_dir = os.path.join(self.miou_out_path, 'detection-results')
167 | if not os.path.exists(self.miou_out_path):
168 | os.makedirs(self.miou_out_path)
169 | if not os.path.exists(pred_dir):
170 | os.makedirs(pred_dir)
171 | print("Get miou.")
172 | for image_id in tqdm(self.image_ids):
173 | #-------------------------------#
174 | # 从文件中读取图像
175 | #-------------------------------#
176 | image_path = os.path.join(self.dataset_path, "VOC2007/JPEGImages/"+image_id+".jpg")
177 | image = Image.open(image_path)
178 | #------------------------------#
179 | # 获得预测txt
180 | #------------------------------#
181 | image = self.get_miou_png(image)
182 | image.save(os.path.join(pred_dir, image_id + ".png"))
183 |
184 | print("Calculate miou.")
185 | _, IoUs, _, _ = compute_mIoU(gt_dir, pred_dir, self.image_ids, self.num_classes, None) # 执行计算mIoU的函数
186 | temp_miou = np.nanmean(IoUs) * 100
187 |
188 | self.mious.append(temp_miou)
189 | self.epoches.append(epoch)
190 |
191 | with open(os.path.join(self.log_dir, "epoch_miou.txt"), 'a') as f:
192 | f.write(str(temp_miou))
193 | f.write("\n")
194 |
195 | plt.figure()
196 | plt.plot(self.epoches, self.mious, 'red', linewidth = 2, label='train miou')
197 |
198 | plt.grid(True)
199 | plt.xlabel('Epoch')
200 | plt.ylabel('Miou')
201 | plt.title('A Miou Curve')
202 | plt.legend(loc="upper right")
203 |
204 | plt.savefig(os.path.join(self.log_dir, "epoch_miou.png"))
205 | plt.cla()
206 | plt.close("all")
207 |
208 | print("Get miou done.")
209 | shutil.rmtree(self.miou_out_path)
210 |
--------------------------------------------------------------------------------
/utils/utils_metrics.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import os
3 | from os.path import join
4 |
5 | import matplotlib.pyplot as plt
6 | import numpy as np
7 | import torch
8 | import torch.nn.functional as F
9 | from PIL import Image
10 |
11 |
12 | def f_score(inputs, target, beta=1, smooth = 1e-5, threhold = 0.5):
13 | n, c, h, w = inputs.size()
14 | nt, ht, wt, ct = target.size()
15 | if h != ht and w != wt:
16 | inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True)
17 |
18 | temp_inputs = torch.softmax(inputs.transpose(1, 2).transpose(2, 3).contiguous().view(n, -1, c),-1)
19 | temp_target = target.view(n, -1, ct)
20 |
21 | #--------------------------------------------#
22 | # 计算dice系数
23 | #--------------------------------------------#
24 | temp_inputs = torch.gt(temp_inputs, threhold).float()
25 | tp = torch.sum(temp_target[...,:-1] * temp_inputs, axis=[0,1])
26 | fp = torch.sum(temp_inputs , axis=[0,1]) - tp
27 | fn = torch.sum(temp_target[...,:-1] , axis=[0,1]) - tp
28 |
29 | score = ((1 + beta ** 2) * tp + smooth) / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + smooth)
30 | score = torch.mean(score)
31 | return score
32 |
33 | # 设标签宽W,长H
34 | def fast_hist(a, b, n):
35 | #--------------------------------------------------------------------------------#
36 | # a是转化成一维数组的标签,形状(H×W,);b是转化成一维数组的预测结果,形状(H×W,)
37 | #--------------------------------------------------------------------------------#
38 | k = (a >= 0) & (a < n)
39 | #--------------------------------------------------------------------------------#
40 | # np.bincount计算了从0到n**2-1这n**2个数中每个数出现的次数,返回值形状(n, n)
41 | # 返回中,写对角线上的为分类正确的像素点
42 | #--------------------------------------------------------------------------------#
43 | return np.bincount(n * a[k].astype(int) + b[k], minlength=n ** 2).reshape(n, n)
44 |
45 | def per_class_iu(hist):
46 | return np.diag(hist) / np.maximum((hist.sum(1) + hist.sum(0) - np.diag(hist)), 1)
47 |
48 | def per_class_PA_Recall(hist):
49 | return np.diag(hist) / np.maximum(hist.sum(1), 1)
50 |
51 | def per_class_Precision(hist):
52 | return np.diag(hist) / np.maximum(hist.sum(0), 1)
53 |
54 | def per_Accuracy(hist):
55 | return np.sum(np.diag(hist)) / np.maximum(np.sum(hist), 1)
56 |
57 | def compute_mIoU(gt_dir, pred_dir, png_name_list, num_classes, name_classes=None):
58 | print('Num classes', num_classes)
59 | #-----------------------------------------#
60 | # 创建一个全是0的矩阵,是一个混淆矩阵
61 | #-----------------------------------------#
62 | hist = np.zeros((num_classes, num_classes))
63 |
64 | #------------------------------------------------#
65 | # 获得验证集标签路径列表,方便直接读取
66 | # 获得验证集图像分割结果路径列表,方便直接读取
67 | #------------------------------------------------#
68 | gt_imgs = [join(gt_dir, x + ".png") for x in png_name_list]
69 | pred_imgs = [join(pred_dir, x + ".png") for x in png_name_list]
70 |
71 | #------------------------------------------------#
72 | # 读取每一个(图片-标签)对
73 | #------------------------------------------------#
74 | for ind in range(len(gt_imgs)):
75 | #------------------------------------------------#
76 | # 读取一张图像分割结果,转化成numpy数组
77 | #------------------------------------------------#
78 | pred = np.array(Image.open(pred_imgs[ind]))
79 | #------------------------------------------------#
80 | # 读取一张对应的标签,转化成numpy数组
81 | #------------------------------------------------#
82 | label = np.array(Image.open(gt_imgs[ind]))
83 |
84 | # 如果图像分割结果与标签的大小不一样,这张图片就不计算
85 | if len(label.flatten()) != len(pred.flatten()):
86 | print(
87 | 'Skipping: len(gt) = {:d}, len(pred) = {:d}, {:s}, {:s}'.format(
88 | len(label.flatten()), len(pred.flatten()), gt_imgs[ind],
89 | pred_imgs[ind]))
90 | continue
91 |
92 | #------------------------------------------------#
93 | # 对一张图片计算21×21的hist矩阵,并累加
94 | #------------------------------------------------#
95 | hist += fast_hist(label.flatten(), pred.flatten(), num_classes)
96 | # 每计算10张就输出一下目前已计算的图片中所有类别平均的mIoU值
97 | if name_classes is not None and ind > 0 and ind % 10 == 0:
98 | print('{:d} / {:d}: mIou-{:0.2f}%; mPA-{:0.2f}%; Accuracy-{:0.2f}%'.format(
99 | ind,
100 | len(gt_imgs),
101 | 100 * np.nanmean(per_class_iu(hist)),
102 | 100 * np.nanmean(per_class_PA_Recall(hist)),
103 | 100 * per_Accuracy(hist)
104 | )
105 | )
106 | #------------------------------------------------#
107 | # 计算所有验证集图片的逐类别mIoU值
108 | #------------------------------------------------#
109 | IoUs = per_class_iu(hist)
110 | PA_Recall = per_class_PA_Recall(hist)
111 | Precision = per_class_Precision(hist)
112 | #------------------------------------------------#
113 | # 逐类别输出一下mIoU值
114 | #------------------------------------------------#
115 | if name_classes is not None:
116 | for ind_class in range(num_classes):
117 | print('===>' + name_classes[ind_class] + ':\tIou-' + str(round(IoUs[ind_class] * 100, 2)) \
118 | + '; Recall (equal to the PA)-' + str(round(PA_Recall[ind_class] * 100, 2))+ '; Precision-' + str(round(Precision[ind_class] * 100, 2)))
119 |
120 | #-----------------------------------------------------------------#
121 | # 在所有验证集图像上求所有类别平均的mIoU值,计算时忽略NaN值
122 | #-----------------------------------------------------------------#
123 | print('===> mIoU: ' + str(round(np.nanmean(IoUs) * 100, 2)) + '; mPA: ' + str(round(np.nanmean(PA_Recall) * 100, 2)) + '; Accuracy: ' + str(round(per_Accuracy(hist) * 100, 2)))
124 | return np.array(hist, np.int), IoUs, PA_Recall, Precision
125 |
126 | def adjust_axes(r, t, fig, axes):
127 | bb = t.get_window_extent(renderer=r)
128 | text_width_inches = bb.width / fig.dpi
129 | current_fig_width = fig.get_figwidth()
130 | new_fig_width = current_fig_width + text_width_inches
131 | propotion = new_fig_width / current_fig_width
132 | x_lim = axes.get_xlim()
133 | axes.set_xlim([x_lim[0], x_lim[1] * propotion])
134 |
135 | def draw_plot_func(values, name_classes, plot_title, x_label, output_path, tick_font_size = 12, plt_show = True):
136 | fig = plt.gcf()
137 | axes = plt.gca()
138 | print(values)
139 | plt.barh(range(len(values)), values, color='royalblue')
140 | plt.title("All data", fontsize=tick_font_size + 2)
141 | plt.xlabel(x_label, fontsize=tick_font_size)
142 | plt.yticks(range(len(values)), name_classes, fontsize=tick_font_size)
143 | r = fig.canvas.get_renderer()
144 | for i, val in enumerate(values):
145 | str_val = " " + str(val)
146 | if val < 1.0:
147 | str_val = " {0:.2f}".format(val)
148 | t = plt.text(val, i, str_val, color='royalblue', va='center', fontweight='bold')
149 | if i == (len(values)-1):
150 | adjust_axes(r, t, fig, axes)
151 |
152 | fig.tight_layout()
153 | fig.savefig(output_path)
154 | if plt_show:
155 | plt.show()
156 | plt.close()
157 |
158 | def show_results(miou_out_path, hist, IoUs, PA_Recall, Precision, name_classes, tick_font_size = 12):
159 | print(IoUs)
160 | classes=['IoUs','PA_Recall','Precision']
161 | value=[]
162 | value.append(IoUs[1])
163 | value.append(PA_Recall[1])
164 | value.append(Precision[1])
165 | draw_plot_func(value, classes, "mIoU = {0:.3f}%".format(np.nanmean(IoUs)*100), "Intersection over Union", \
166 | os.path.join(miou_out_path, "All.png"), tick_font_size = tick_font_size, plt_show = True)
167 | draw_plot_func(IoUs, name_classes, "mIoU = {0:.3f}%".format(np.nanmean(IoUs)*100), "Intersection over Union", \
168 | os.path.join(miou_out_path, "mIoU.png"), tick_font_size = tick_font_size, plt_show = True)
169 | print("Save mIoU out to " + os.path.join(miou_out_path, "mIoU.png"))
170 |
171 | draw_plot_func(PA_Recall, name_classes, "mPA = {0:.3f}%".format(np.nanmean(PA_Recall)*100.0), "Pixel Accuracy", \
172 | os.path.join(miou_out_path, "mPA.png"), tick_font_size = tick_font_size, plt_show = False)
173 | print("Save mPA out to " + os.path.join(miou_out_path, "mPA.png"))
174 |
175 | draw_plot_func(PA_Recall, name_classes, "mRecall = {0:.3f}%".format(np.nanmean(PA_Recall)*100.0), "Recall", \
176 | os.path.join(miou_out_path, "Recall.png"), tick_font_size = tick_font_size, plt_show = False)
177 | print("Save Recall out to " + os.path.join(miou_out_path, "Recall.png"))
178 |
179 | draw_plot_func(Precision, name_classes, "mPrecision = {0:.3f}%".format(np.nanmean(Precision)*100.0), "Precision", \
180 | os.path.join(miou_out_path, "Precision.png"), tick_font_size = tick_font_size, plt_show = False)
181 | print("Save Precision out to " + os.path.join(miou_out_path, "Precision.png"))
182 |
183 | with open(os.path.join(miou_out_path, "confusion_matrix.csv"), 'w', newline='') as f:
184 | writer = csv.writer(f)
185 | writer_list = []
186 | writer_list.append([' '] + [str(c) for c in name_classes])
187 | for i in range(len(hist)):
188 | writer_list.append([name_classes[i]] + [str(x) for x in hist[i]])
189 | writer.writerows(writer_list)
190 | print("Save confusion_matrix out to " + os.path.join(miou_out_path, "confusion_matrix.csv"))
191 |
--------------------------------------------------------------------------------
/predictLabel.py:
--------------------------------------------------------------------------------
1 | #----------------------------------------------------#
2 | # 将单张图片预测、摄像头检测和FPS测试功能
3 | # 整合到了一个py文件中,通过指定mode进行模式的修改。
4 | #----------------------------------------------------#
5 | import time
6 |
7 | import cv2
8 | import numpy as np
9 | from PIL import Image
10 |
11 | from unet import Unet
12 |
13 | def predict(file_name):
14 | #-------------------------------------------------------------------------#
15 | # 如果想要修改对应种类的颜色,到__init__函数里修改self.colors即可
16 | #-------------------------------------------------------------------------#
17 | unet = Unet()
18 | #----------------------------------------------------------------------------------------------------------#
19 | # mode用于指定测试的模式:
20 | # 'predict' 表示单张图片预测,如果想对预测过程进行修改,如保存图片,截取对象等,可以先看下方详细的注释
21 | # 'video' 表示视频检测,可调用摄像头或者视频进行检测,详情查看下方注释。
22 | # 'fps' 表示测试fps,使用的图片是img里面的street.jpg,详情查看下方注释。
23 | # 'dir_predict' 表示遍历文件夹进行检测并保存。默认遍历img文件夹,保存img_out文件夹,详情查看下方注释。
24 | # 'export_onnx' 表示将模型导出为onnx,需要pytorch1.7.1以上。
25 | #----------------------------------------------------------------------------------------------------------#
26 | mode = "predict" ### 检测图片的
27 | # mode = "video" ## 检测视频的
28 | # mode = "video"
29 | #-------------------------------------------------------------------------#
30 | # count 指定了是否进行目标的像素点计数(即面积)与比例计算
31 | # name_classes 区分的种类,和json_to_dataset里面的一样,用于打印种类和数量
32 | #
33 | # count、name_classes仅在mode='predict'时有效
34 | #-------------------------------------------------------------------------#
35 | count = False
36 | # name_classes = ["background","aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"]
37 | #name_classes = ["_background_", "person", "car", "motorbike", "dustbin", "chair", "fire_hydrant", "tricycle", "bicycle","stone"]
38 | name_classes = ["background","EX", "HE", "MA", "SE"]
39 | #----------------------------------------------------------------------------------------------------------#
40 | # video_path 用于指定视频的路径,当video_path=0时表示检测摄像头
41 | # 想要检测视频,则设置如video_path = "xxx.mp4"即可,代表读取出根目录下的xxx.mp4文件。
42 | # video_save_path 表示视频保存的路径,当video_save_path=""时表示不保存
43 | # 想要保存视频,则设置如video_save_path = "yyy.mp4"即可,代表保存为根目录下的yyy.mp4文件。
44 | # video_fps 用于保存的视频的fps
45 | #
46 | # video_path、video_save_path和video_fps仅在mode='video'时有效
47 | # 保存视频时需要ctrl+c退出或者运行到最后一帧才会完成完整的保存步骤。
48 | #----------------------------------------------------------------------------------------------------------#
49 | video_path = 0
50 | video_save_path = ""
51 | video_fps = 25.0
52 | #----------------------------------------------------------------------------------------------------------#
53 | # test_interval 用于指定测量fps的时候,图片检测的次数。理论上test_interval越大,fps越准确。
54 | # fps_image_path 用于指定测试的fps图片
55 | #
56 | # test_interval和fps_image_path仅在mode='fps'有效
57 | #----------------------------------------------------------------------------------------------------------#
58 | test_interval = 100
59 | fps_image_path = "E:/setup/UNet_Demo/UNet_Demo/NanKai/train/image/007-3939-200.jpg"
60 | #-------------------------------------------------------------------------#
61 | # dir_origin_path 指定了用于检测的图片的文件夹路径
62 | # dir_save_path 指定了检测完图片的保存路径
63 | #
64 | # dir_origin_path和dir_save_path仅在mode='dir_predict'时有效
65 | #-------------------------------------------------------------------------#
66 | dir_origin_path = "img/"
67 | dir_save_path = "img_out/"
68 | #-------------------------------------------------------------------------#
69 | # simplify 使用Simplify onnx
70 | # onnx_save_path 指定了onnx的保存路径
71 | #-------------------------------------------------------------------------#
72 | simplify = True
73 | onnx_save_path = "model_data/models.onnx"
74 |
75 | if mode == "predict":
76 | '''
77 | predict.py有几个注意点
78 | 1、该代码无法直接进行批量预测,如果想要批量预测,可以利用os.listdir()遍历文件夹,利用Image.open打开图片文件进行预测。
79 | 具体流程可以参考get_miou_prediction.py,在get_miou_prediction.py即实现了遍历。
80 | 2、如果想要保存,利用r_image.save("img.jpg")即可保存。
81 | 3、如果想要原图和分割图不混合,可以把blend参数设置成False。
82 | 4、如果想根据mask获取对应的区域,可以参考detect_image函数中,利用预测结果绘图的部分,判断每一个像素点的种类,然后根据种类获取对应的部分。
83 | seg_img = np.zeros((np.shape(pr)[0],np.shape(pr)[1],3))
84 | for c in range(self.num_classes):
85 | seg_img[:, :, 0] += ((pr == c)*( self.colors[c][0] )).astype('uint8')
86 | seg_img[:, :, 1] += ((pr == c)*( self.colors[c][1] )).astype('uint8')
87 | seg_img[:, :, 2] += ((pr == c)*( self.colors[c][2] )).astype('uint8')
88 | '''
89 | img = 'static/result/'+file_name
90 | try:
91 | image = Image.open(img)
92 | except:
93 | print('Open Error! Try again!')
94 | else:
95 | r_image = unet.detect_image(image, count=count, name_classes=name_classes)
96 | r_image.save('static/label/'+file_name)
97 |
98 |
99 | ## 下面是检测电脑自带摄像头的
100 | elif mode == "video_camera":
101 | capture=cv2.VideoCapture(video_path)
102 | if video_save_path!="":
103 | fourcc = cv2.VideoWriter_fourcc(*'XVID')
104 | size = (int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)), int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT)))
105 | out = cv2.VideoWriter(video_save_path, fourcc, video_fps, size)
106 |
107 | ref, frame = capture.read()
108 | if not ref:
109 | raise ValueError("未能正确读取摄像头(视频),请注意是否正确安装摄像头(是否正确填写视频路径)。")
110 |
111 | fps = 0.0
112 | while(True):
113 | t1 = time.time()
114 | # 读取某一帧
115 | ref, frame = capture.read()
116 | if not ref:
117 | break
118 | # 格式转变,BGRtoRGB
119 | frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB)
120 | # 转变成Image
121 | frame = Image.fromarray(np.uint8(frame))
122 | # 进行检测
123 | frame = np.array(unet.detect_image(frame))
124 | # RGBtoBGR满足opencv显示格式
125 | frame = cv2.cvtColor(frame,cv2.COLOR_RGB2BGR)
126 |
127 | fps = ( fps + (1./(time.time()-t1)) ) / 2
128 | print("fps= %.2f"%(fps))
129 | frame = cv2.putText(frame, "fps= %.2f"%(fps), (0, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
130 |
131 | cv2.imshow("video",frame)
132 | c= cv2.waitKey(1) & 0xff
133 | if video_save_path!="":
134 | out.write(frame)
135 |
136 | if c==27:
137 | capture.release()
138 | break
139 | print("Video Detection Done!")
140 | capture.release()
141 | if video_save_path!="":
142 | print("Save processed video to the path :" + video_save_path)
143 | out.release()
144 | cv2.destroyAllWindows()
145 |
146 | elif mode == "fps":
147 | img = Image.open('img/street.jpg')
148 | tact_time = unet.get_FPS(img, test_interval)
149 | print(str(tact_time) + ' seconds, ' + str(1/tact_time) + 'FPS, @batch_size 1')
150 |
151 | elif mode == "dir_predict":
152 | import os
153 | from tqdm import tqdm
154 |
155 | img_names = os.listdir(dir_origin_path)
156 | for img_name in tqdm(img_names):
157 | if img_name.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')):
158 | image_path = os.path.join(dir_origin_path, img_name)
159 | image = Image.open(image_path)
160 | r_image = unet.detect_image(image)
161 | if not os.path.exists(dir_save_path):
162 | os.makedirs(dir_save_path)
163 | r_image.save(os.path.join(dir_save_path, img_name))
164 | elif mode == "export_onnx":
165 | unet.convert_to_onnx(simplify, onnx_save_path)
166 |
167 |
168 | ## 下面是检测电脑硬盘中的视频文件
169 | elif mode == "video":
170 |
171 | while True:
172 | video = input("Input video filename:")
173 | try:
174 | capture = cv2.VideoCapture(video)
175 | except:
176 | print("Open Error! Try again!")
177 | continue
178 | else:
179 | # capture = cv2.VideoCapture(video_path)
180 | if video_save_path != "":
181 | fourcc = cv2.VideoWriter_fourcc(*'XVID')
182 | size = (int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)), int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT)))
183 | out = cv2.VideoWriter(video_save_path, fourcc, video_fps, size)
184 |
185 | ref, frame = capture.read()
186 | if not ref:
187 | raise ValueError("未能正确读取摄像头(视频),请注意是否正确安装摄像头(是否正确填写视频路径)。")
188 |
189 | fps = 0.0
190 | while (True):
191 | t1 = time.time()
192 | # 读取某一帧
193 | ref, frame = capture.read()
194 | if not ref:
195 | break
196 | # 格式转变,BGRtoRGB
197 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
198 | # 转变成Image
199 | frame = Image.fromarray(np.uint8(frame))
200 | # 进行检测
201 | frame = np.array(unet.detect_image(frame))
202 | # RGBtoBGR满足opencv显示格式
203 | frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
204 |
205 | fps = (fps + (1. / (time.time() - t1))) / 2
206 | print("fps= %.2f" % (fps))
207 | frame = cv2.putText(frame, "fps= %.2f" % (fps), (0, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
208 |
209 | cv2.imshow("video", frame)
210 | c = cv2.waitKey(1) & 0xff
211 | if video_save_path != "":
212 | out.write(frame)
213 |
214 | if c == 27:
215 | capture.release()
216 | break
217 | print("Video Detection Done!")
218 | capture.release()
219 | if video_save_path != "":
220 | print("Save processed video to the path :" + video_save_path)
221 | out.release()
222 | cv2.destroyAllWindows()
223 |
224 | else:
225 | raise AssertionError("Please specify the correct mode: 'predict', 'video', 'fps' or 'dir_predict'.")
226 |
227 |
228 |
229 |
230 |
--------------------------------------------------------------------------------
/predict.py:
--------------------------------------------------------------------------------
1 | #----------------------------------------------------#
2 | # 将单张图片预测、摄像头检测和FPS测试功能
3 | # 整合到了一个py文件中,通过指定mode进行模式的修改。
4 | #----------------------------------------------------#
5 | import time
6 |
7 | import cv2
8 | import numpy as np
9 | from PIL import Image
10 |
11 | from unet import Unet
12 |
13 | if __name__ == "__main__":
14 | #-------------------------------------------------------------------------#
15 | # 如果想要修改对应种类的颜色,到__init__函数里修改self.colors即可
16 | #-------------------------------------------------------------------------#
17 | unet = Unet()
18 | #----------------------------------------------------------------------------------------------------------#
19 | # mode用于指定测试的模式:
20 | # 'predict' 表示单张图片预测,如果想对预测过程进行修改,如保存图片,截取对象等,可以先看下方详细的注释
21 | # 'video' 表示视频检测,可调用摄像头或者视频进行检测,详情查看下方注释。
22 | # 'fps' 表示测试fps,使用的图片是img里面的street.jpg,详情查看下方注释。
23 | # 'dir_predict' 表示遍历文件夹进行检测并保存。默认遍历img文件夹,保存img_out文件夹,详情查看下方注释。
24 | # 'export_onnx' 表示将模型导出为onnx,需要pytorch1.7.1以上。
25 | #----------------------------------------------------------------------------------------------------------#
26 | mode = "predict" ### 检测图片的
27 | # mode = "video" ## 检测视频的
28 | # mode = "video"
29 | #-------------------------------------------------------------------------#
30 | # count 指定了是否进行目标的像素点计数(即面积)与比例计算
31 | # name_classes 区分的种类,和json_to_dataset里面的一样,用于打印种类和数量
32 | #
33 | # count、name_classes仅在mode='predict'时有效
34 | #-------------------------------------------------------------------------#
35 | count = False
36 | # name_classes = ["background","aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"]
37 | #name_classes = ["_background_", "person", "car", "motorbike", "dustbin", "chair", "fire_hydrant", "tricycle", "bicycle","stone"]
38 | name_classes = ["background","cat"]
39 | #----------------------------------------------------------------------------------------------------------#
40 | # video_path 用于指定视频的路径,当video_path=0时表示检测摄像头
41 | # 想要检测视频,则设置如video_path = "xxx.mp4"即可,代表读取出根目录下的xxx.mp4文件。
42 | # video_save_path 表示视频保存的路径,当video_save_path=""时表示不保存
43 | # 想要保存视频,则设置如video_save_path = "yyy.mp4"即可,代表保存为根目录下的yyy.mp4文件。
44 | # video_fps 用于保存的视频的fps
45 | #
46 | # video_path、video_save_path和video_fps仅在mode='video'时有效
47 | # 保存视频时需要ctrl+c退出或者运行到最后一帧才会完成完整的保存步骤。
48 | #----------------------------------------------------------------------------------------------------------#
49 | video_path = 0
50 | video_save_path = ""
51 | video_fps = 25.0
52 | #----------------------------------------------------------------------------------------------------------#
53 | # test_interval 用于指定测量fps的时候,图片检测的次数。理论上test_interval越大,fps越准确。
54 | # fps_image_path 用于指定测试的fps图片
55 | #
56 | # test_interval和fps_image_path仅在mode='fps'有效
57 | #----------------------------------------------------------------------------------------------------------#
58 | test_interval = 100
59 | fps_image_path = "E:/setup/UNet_Demo/UNet_Demo/NanKai/train/image/007-3939-200.jpg"
60 | #-------------------------------------------------------------------------#
61 | # dir_origin_path 指定了用于检测的图片的文件夹路径
62 | # dir_save_path 指定了检测完图片的保存路径
63 | #
64 | # dir_origin_path和dir_save_path仅在mode='dir_predict'时有效
65 | #-------------------------------------------------------------------------#
66 | dir_origin_path = "img/"
67 | dir_save_path = "img_out/"
68 | #-------------------------------------------------------------------------#
69 | # simplify 使用Simplify onnx
70 | # onnx_save_path 指定了onnx的保存路径
71 | #-------------------------------------------------------------------------#
72 | simplify = True
73 | onnx_save_path = "model_data/models.onnx"
74 |
75 | if mode == "predict":
76 | '''
77 | predict.py有几个注意点
78 | 1、该代码无法直接进行批量预测,如果想要批量预测,可以利用os.listdir()遍历文件夹,利用Image.open打开图片文件进行预测。
79 | 具体流程可以参考get_miou_prediction.py,在get_miou_prediction.py即实现了遍历。
80 | 2、如果想要保存,利用r_image.save("img.jpg")即可保存。
81 | 3、如果想要原图和分割图不混合,可以把blend参数设置成False。
82 | 4、如果想根据mask获取对应的区域,可以参考detect_image函数中,利用预测结果绘图的部分,判断每一个像素点的种类,然后根据种类获取对应的部分。
83 | seg_img = np.zeros((np.shape(pr)[0],np.shape(pr)[1],3))
84 | for c in range(self.num_classes):
85 | seg_img[:, :, 0] += ((pr == c)*( self.colors[c][0] )).astype('uint8')
86 | seg_img[:, :, 1] += ((pr == c)*( self.colors[c][1] )).astype('uint8')
87 | seg_img[:, :, 2] += ((pr == c)*( self.colors[c][2] )).astype('uint8')
88 | '''
89 | while True:
90 | img = input('Input image filename:')
91 | try:
92 | image = Image.open(img)
93 | except:
94 | print('Open Error! Try again!')
95 | continue
96 | else:
97 | r_image = unet.detect_image(image, count=count, name_classes=name_classes)
98 | r_image.save('backend/static/label')
99 |
100 |
101 | ## 下面是检测电脑自带摄像头的
102 | elif mode == "video_camera":
103 | capture=cv2.VideoCapture(video_path)
104 | if video_save_path!="":
105 | fourcc = cv2.VideoWriter_fourcc(*'XVID')
106 | size = (int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)), int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT)))
107 | out = cv2.VideoWriter(video_save_path, fourcc, video_fps, size)
108 |
109 | ref, frame = capture.read()
110 | if not ref:
111 | raise ValueError("未能正确读取摄像头(视频),请注意是否正确安装摄像头(是否正确填写视频路径)。")
112 |
113 | fps = 0.0
114 | while(True):
115 | t1 = time.time()
116 | # 读取某一帧
117 | ref, frame = capture.read()
118 | if not ref:
119 | break
120 | # 格式转变,BGRtoRGB
121 | frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB)
122 | # 转变成Image
123 | frame = Image.fromarray(np.uint8(frame))
124 | # 进行检测
125 | frame = np.array(unet.detect_image(frame))
126 | # RGBtoBGR满足opencv显示格式
127 | frame = cv2.cvtColor(frame,cv2.COLOR_RGB2BGR)
128 |
129 | fps = ( fps + (1./(time.time()-t1)) ) / 2
130 | print("fps= %.2f"%(fps))
131 | frame = cv2.putText(frame, "fps= %.2f"%(fps), (0, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
132 |
133 | cv2.imshow("video",frame)
134 | c= cv2.waitKey(1) & 0xff
135 | if video_save_path!="":
136 | out.write(frame)
137 |
138 | if c==27:
139 | capture.release()
140 | break
141 | print("Video Detection Done!")
142 | capture.release()
143 | if video_save_path!="":
144 | print("Save processed video to the path :" + video_save_path)
145 | out.release()
146 | cv2.destroyAllWindows()
147 |
148 | elif mode == "fps":
149 | img = Image.open('img/street.jpg')
150 | tact_time = unet.get_FPS(img, test_interval)
151 | print(str(tact_time) + ' seconds, ' + str(1/tact_time) + 'FPS, @batch_size 1')
152 |
153 | elif mode == "dir_predict":
154 | import os
155 | from tqdm import tqdm
156 |
157 | img_names = os.listdir(dir_origin_path)
158 | for img_name in tqdm(img_names):
159 | if img_name.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')):
160 | image_path = os.path.join(dir_origin_path, img_name)
161 | image = Image.open(image_path)
162 | r_image = unet.detect_image(image)
163 | if not os.path.exists(dir_save_path):
164 | os.makedirs(dir_save_path)
165 | r_image.save(os.path.join(dir_save_path, img_name))
166 | elif mode == "export_onnx":
167 | unet.convert_to_onnx(simplify, onnx_save_path)
168 |
169 |
170 | ## 下面是检测电脑硬盘中的视频文件
171 | elif mode == "video":
172 |
173 | while True:
174 | video = input("Input video filename:")
175 | try:
176 | capture = cv2.VideoCapture(video)
177 | except:
178 | print("Open Error! Try again!")
179 | continue
180 | else:
181 | # capture = cv2.VideoCapture(video_path)
182 | if video_save_path != "":
183 | fourcc = cv2.VideoWriter_fourcc(*'XVID')
184 | size = (int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)), int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT)))
185 | out = cv2.VideoWriter(video_save_path, fourcc, video_fps, size)
186 |
187 | ref, frame = capture.read()
188 | if not ref:
189 | raise ValueError("未能正确读取摄像头(视频),请注意是否正确安装摄像头(是否正确填写视频路径)。")
190 |
191 | fps = 0.0
192 | while (True):
193 | t1 = time.time()
194 | # 读取某一帧
195 | ref, frame = capture.read()
196 | if not ref:
197 | break
198 | # 格式转变,BGRtoRGB
199 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
200 | # 转变成Image
201 | frame = Image.fromarray(np.uint8(frame))
202 | # 进行检测
203 | frame = np.array(unet.detect_image(frame))
204 | # RGBtoBGR满足opencv显示格式
205 | frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
206 |
207 | fps = (fps + (1. / (time.time() - t1))) / 2
208 | print("fps= %.2f" % (fps))
209 | frame = cv2.putText(frame, "fps= %.2f" % (fps), (0, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
210 |
211 | cv2.imshow("video", frame)
212 | c = cv2.waitKey(1) & 0xff
213 | if video_save_path != "":
214 | out.write(frame)
215 |
216 | if c == 27:
217 | capture.release()
218 | break
219 | print("Video Detection Done!")
220 | capture.release()
221 | if video_save_path != "":
222 | print("Save processed video to the path :" + video_save_path)
223 | out.release()
224 | cv2.destroyAllWindows()
225 |
226 | else:
227 | raise AssertionError("Please specify the correct mode: 'predict', 'video', 'fps' or 'dir_predict'.")
228 |
229 |
230 |
231 |
232 |
--------------------------------------------------------------------------------
/utils/utils_fit.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | from nets.unet_training import CE_Loss, Dice_loss, Focal_Loss
5 | from tqdm import tqdm
6 |
7 | from utils.utils import get_lr
8 | from utils.utils_metrics import f_score
9 |
10 |
11 | def fit_one_epoch(model_train, model, loss_history, eval_callback, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val, Epoch, cuda, dice_loss, focal_loss, cls_weights, num_classes, fp16, scaler, save_period, save_dir, local_rank=0):
12 | total_loss = 0
13 | total_f_score = 0
14 |
15 | val_loss = 0
16 | val_f_score = 0
17 |
18 | if local_rank == 0:
19 | print('Start Train')
20 | pbar = tqdm(total=epoch_step,desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3)
21 | model_train.train()
22 | for iteration, batch in enumerate(gen):
23 | if iteration >= epoch_step:
24 | break
25 | imgs, pngs, labels = batch
26 | with torch.no_grad():
27 | weights = torch.from_numpy(cls_weights)
28 | if cuda:
29 | imgs = imgs.cuda(local_rank)
30 | pngs = pngs.cuda(local_rank)
31 | labels = labels.cuda(local_rank)
32 | weights = weights.cuda(local_rank)
33 |
34 | optimizer.zero_grad()
35 | if not fp16:
36 | #----------------------#
37 | # 前向传播
38 | #----------------------#
39 | outputs = model_train(imgs)
40 | #----------------------#
41 | # 损失计算
42 | #----------------------#
43 | if focal_loss:
44 | loss = Focal_Loss(outputs, pngs, weights, num_classes = num_classes)
45 | else:
46 | loss = CE_Loss(outputs, pngs, weights, num_classes = num_classes)
47 |
48 | if dice_loss:
49 | main_dice = Dice_loss(outputs, labels)
50 | loss = loss + main_dice
51 |
52 | with torch.no_grad():
53 | #-------------------------------#
54 | # 计算f_score
55 | #-------------------------------#
56 | _f_score = f_score(outputs, labels)
57 |
58 | loss.backward()
59 | optimizer.step()
60 | else:
61 | from torch.cuda.amp import autocast
62 | with autocast():
63 | #----------------------#
64 | # 前向传播
65 | #----------------------#
66 | outputs = model_train(imgs)
67 | #----------------------#
68 | # 损失计算
69 | #----------------------#
70 | if focal_loss:
71 | loss = Focal_Loss(outputs, pngs, weights, num_classes = num_classes)
72 | else:
73 | loss = CE_Loss(outputs, pngs, weights, num_classes = num_classes)
74 |
75 | if dice_loss:
76 | main_dice = Dice_loss(outputs, labels)
77 | loss = loss + main_dice
78 |
79 | with torch.no_grad():
80 | #-------------------------------#
81 | # 计算f_score
82 | #-------------------------------#
83 | _f_score = f_score(outputs, labels)
84 |
85 | #----------------------#
86 | # 反向传播
87 | #----------------------#
88 | scaler.scale(loss).backward()
89 | scaler.step(optimizer)
90 | scaler.update()
91 |
92 | total_loss += loss.item()
93 | total_f_score += _f_score.item()
94 |
95 | if local_rank == 0:
96 | pbar.set_postfix(**{'total_loss': total_loss / (iteration + 1),
97 | 'f_score' : total_f_score / (iteration + 1),
98 | 'lr' : get_lr(optimizer)})
99 | pbar.update(1)
100 |
101 | if local_rank == 0:
102 | pbar.close()
103 | print('Finish Train')
104 | print('Start Validation')
105 | pbar = tqdm(total=epoch_step_val, desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3)
106 |
107 | model_train.eval()
108 | for iteration, batch in enumerate(gen_val):
109 | if iteration >= epoch_step_val:
110 | break
111 | imgs, pngs, labels = batch
112 | with torch.no_grad():
113 | weights = torch.from_numpy(cls_weights)
114 | if cuda:
115 | imgs = imgs.cuda(local_rank)
116 | pngs = pngs.cuda(local_rank)
117 | labels = labels.cuda(local_rank)
118 | weights = weights.cuda(local_rank)
119 |
120 | #----------------------#
121 | # 前向传播
122 | #----------------------#
123 | outputs = model_train(imgs)
124 | #----------------------#
125 | # 损失计算
126 | #----------------------#
127 | if focal_loss:
128 | loss = Focal_Loss(outputs, pngs, weights, num_classes = num_classes)
129 | else:
130 | loss = CE_Loss(outputs, pngs, weights, num_classes = num_classes)
131 |
132 | if dice_loss:
133 | main_dice = Dice_loss(outputs, labels)
134 | loss = loss + main_dice
135 | #-------------------------------#
136 | # 计算f_score
137 | #-------------------------------#
138 | _f_score = f_score(outputs, labels)
139 |
140 | val_loss += loss.item()
141 | val_f_score += _f_score.item()
142 |
143 | if local_rank == 0:
144 | pbar.set_postfix(**{'val_loss' : val_loss / (iteration + 1),
145 | 'f_score' : val_f_score / (iteration + 1),
146 | 'lr' : get_lr(optimizer)})
147 | pbar.update(1)
148 |
149 | if local_rank == 0:
150 | pbar.close()
151 | print('Finish Validation')
152 | loss_history.append_loss(epoch + 1, total_loss/ epoch_step, val_loss/ epoch_step_val)
153 | eval_callback.on_epoch_end(epoch + 1, model_train)
154 | print('Epoch:'+ str(epoch+1) + '/' + str(Epoch))
155 | print('Total Loss: %.3f || Val Loss: %.3f ' % (total_loss / epoch_step, val_loss / epoch_step_val))
156 |
157 | #-----------------------------------------------#
158 | # 保存权值
159 | #-----------------------------------------------#
160 | if (epoch + 1) % save_period == 0 or epoch + 1 == Epoch:
161 | torch.save(model.state_dict(), os.path.join(save_dir, 'ep%03d-loss%.3f-val_loss%.3f.pth'%((epoch + 1), total_loss / epoch_step, val_loss / epoch_step_val)))
162 |
163 | if len(loss_history.val_loss) <= 1 or (val_loss / epoch_step_val) <= min(loss_history.val_loss):
164 | print('Save best model to best_epoch_weights.pth')
165 | torch.save(model.state_dict(), os.path.join(save_dir, "best_epoch_weights.pth"))
166 |
167 | torch.save(model.state_dict(), os.path.join(save_dir, "last_epoch_weights.pth"))
168 |
169 | def fit_one_epoch_no_val(model_train, model, loss_history, optimizer, epoch, epoch_step, gen, Epoch, cuda, dice_loss, focal_loss, cls_weights, num_classes, fp16, scaler, save_period, save_dir, local_rank=0):
170 | total_loss = 0
171 | total_f_score = 0
172 |
173 | if local_rank == 0:
174 | print('Start Train')
175 | pbar = tqdm(total=epoch_step,desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3)
176 | model_train.train()
177 | for iteration, batch in enumerate(gen):
178 | if iteration >= epoch_step:
179 | break
180 | imgs, pngs, labels = batch
181 | with torch.no_grad():
182 | weights = torch.from_numpy(cls_weights)
183 | if cuda:
184 | imgs = imgs.cuda(local_rank)
185 | pngs = pngs.cuda(local_rank)
186 | labels = labels.cuda(local_rank)
187 | weights = weights.cuda(local_rank)
188 |
189 | optimizer.zero_grad()
190 | if not fp16:
191 | #----------------------#
192 | # 前向传播
193 | #----------------------#
194 | outputs = model_train(imgs)
195 | #----------------------#
196 | # 损失计算
197 | #----------------------#
198 | if focal_loss:
199 | loss = Focal_Loss(outputs, pngs, weights, num_classes = num_classes)
200 | else:
201 | loss = CE_Loss(outputs, pngs, weights, num_classes = num_classes)
202 |
203 | if dice_loss:
204 | main_dice = Dice_loss(outputs, labels)
205 | loss = loss + main_dice
206 |
207 | with torch.no_grad():
208 | #-------------------------------#
209 | # 计算f_score
210 | #-------------------------------#
211 | _f_score = f_score(outputs, labels)
212 |
213 | loss.backward()
214 | optimizer.step()
215 | else:
216 | from torch.cuda.amp import autocast
217 | with autocast():
218 | #----------------------#
219 | # 前向传播
220 | #----------------------#
221 | outputs = model_train(imgs)
222 | #----------------------#
223 | # 损失计算
224 | #----------------------#
225 | if focal_loss:
226 | loss = Focal_Loss(outputs, pngs, weights, num_classes = num_classes)
227 | else:
228 | loss = CE_Loss(outputs, pngs, weights, num_classes = num_classes)
229 |
230 | if dice_loss:
231 | main_dice = Dice_loss(outputs, labels)
232 | loss = loss + main_dice
233 |
234 | with torch.no_grad():
235 | #-------------------------------#
236 | # 计算f_score
237 | #-------------------------------#
238 | _f_score = f_score(outputs, labels)
239 |
240 | #----------------------#
241 | # 反向传播
242 | #----------------------#
243 | scaler.scale(loss).backward()
244 | scaler.step(optimizer)
245 | scaler.update()
246 |
247 | total_loss += loss.item()
248 | total_f_score += _f_score.item()
249 |
250 | if local_rank == 0:
251 | pbar.set_postfix(**{'total_loss': total_loss / (iteration + 1),
252 | 'f_score' : total_f_score / (iteration + 1),
253 | 'lr' : get_lr(optimizer)})
254 | pbar.update(1)
255 |
256 | if local_rank == 0:
257 | pbar.close()
258 | loss_history.append_loss(epoch + 1, total_loss/ epoch_step)
259 | print('Epoch:'+ str(epoch + 1) + '/' + str(Epoch))
260 | print('Total Loss: %.3f' % (total_loss / epoch_step))
261 |
262 | #-----------------------------------------------#
263 | # 保存权值
264 | #-----------------------------------------------#
265 | if (epoch + 1) % save_period == 0 or epoch + 1 == Epoch:
266 | torch.save(model.state_dict(), os.path.join(save_dir, 'ep%03d-loss%.3f.pth'%((epoch + 1), total_loss / epoch_step)))
267 |
268 | if len(loss_history.losses) <= 1 or (total_loss / epoch_step) <= min(loss_history.losses):
269 | print('Save best model to best_epoch_weights.pth')
270 | torch.save(model.state_dict(), os.path.join(save_dir, "best_epoch_weights.pth"))
271 |
272 | torch.save(model.state_dict(), os.path.join(save_dir, "last_epoch_weights.pth"))
--------------------------------------------------------------------------------
/unet.py:
--------------------------------------------------------------------------------
1 | import colorsys
2 | import copy
3 | import time
4 |
5 | import cv2
6 | import numpy as np
7 | import torch
8 | import torch.nn.functional as F
9 | from PIL import Image
10 | from torch import nn
11 |
12 | from nets.unet import Unet as unet
13 | from utils.utils import cvtColor, preprocess_input, resize_image, show_config
14 |
15 |
16 | #--------------------------------------------#
17 | # 使用自己训练好的模型预测需要修改2个参数
18 | # model_path和num_classes都需要修改!
19 | # 如果出现shape不匹配
20 | # 一定要注意训练时的model_path和num_classes数的修改
21 | #--------------------------------------------#
22 | class Unet(object):
23 | _defaults = {
24 | #-------------------------------------------------------------------#
25 | # model_path指向logs文件夹下的权值文件
26 | # 训练好后logs文件夹下存在多个权值文件,选择验证集损失较低的即可。
27 | # 验证集损失较低不代表miou较高,仅代表该权值在验证集上泛化性能较好。
28 | #-------------------------------------------------------------------#
29 | "model_path" : '../model_data/best_epoch_weights.pth',
30 | #--------------------------------#
31 | # 所需要区分的类的个数+1
32 | #--------------------------------#
33 | "num_classes" : 5,
34 | #--------------------------------#
35 | # 所使用的的主干网络:vgg、resnet50
36 | #--------------------------------#
37 | "backbone" : "resnet50",
38 | #--------------------------------#
39 | # 输入图片的大小
40 | #--------------------------------#
41 | "input_shape" : [512, 512],
42 | #-------------------------------------------------#
43 | # mix_type参数用于控制检测结果的可视化方式
44 | #
45 | # mix_type = 0的时候代表原图与生成的图进行混合
46 | # mix_type = 1的时候代表仅保留生成的图
47 | # mix_type = 2的时候代表仅扣去背景,仅保留原图中的目标
48 | #-------------------------------------------------#
49 | "mix_type" : 0,
50 | #--------------------------------#
51 | # 是否使用Cuda
52 | # 没有GPU可以设置成False
53 | #--------------------------------#
54 | "cuda" : False,
55 | }
56 |
57 | #---------------------------------------------------#
58 | # 初始化UNET
59 | #---------------------------------------------------#
60 | def __init__(self, **kwargs):
61 | self.__dict__.update(self._defaults)
62 | for name, value in kwargs.items():
63 | setattr(self, name, value)
64 | #---------------------------------------------------#
65 | # 画框设置不同的颜色
66 | #---------------------------------------------------#
67 | if self.num_classes <= 21:
68 | self.colors = [ (255, 255, 255), (0,0,0), (0, 128, 0), (128, 128, 0), (0, 0, 128), (128, 0, 128), (0, 128, 128),
69 | (128, 128, 128), (64, 0, 0), (192, 0, 0), (64, 128, 0), (192, 128, 0), (64, 0, 128), (192, 0, 128),
70 | (64, 128, 128), (192, 128, 128), (0, 64, 0), (128, 64, 0), (0, 192, 0), (128, 192, 0), (0, 64, 128),
71 | (128, 64, 12)]
72 | else:
73 | hsv_tuples = [(x / self.num_classes, 1., 1.) for x in range(self.num_classes)]
74 | self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
75 | self.colors = list(map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), self.colors))
76 | #---------------------------------------------------#
77 | # 获得模型
78 | #---------------------------------------------------#
79 | self.generate()
80 |
81 | show_config(**self._defaults)
82 |
83 | #---------------------------------------------------#
84 | # 获得所有的分类
85 | #---------------------------------------------------#
86 | def generate(self, onnx=False):
87 | self.net = unet(num_classes = self.num_classes, backbone=self.backbone)
88 |
89 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
90 | self.net.load_state_dict(torch.load(self.model_path, map_location=device))
91 | self.net = self.net.eval()
92 | print('{} model, and classes loaded.'.format(self.model_path))
93 | if not onnx:
94 | if self.cuda:
95 | self.net = nn.DataParallel(self.net)
96 | self.net = self.net.cuda()
97 |
98 | #---------------------------------------------------#
99 | # 检测图片
100 | #---------------------------------------------------#
101 | def detect_image(self, image, count=False, name_classes=None):
102 | #---------------------------------------------------------#
103 | # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
104 | # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
105 | #---------------------------------------------------------#
106 | image = cvtColor(image)
107 | #---------------------------------------------------#
108 | # 对输入图像进行一个备份,后面用于绘图
109 | #---------------------------------------------------#
110 | old_img = copy.deepcopy(image)
111 | orininal_h = np.array(image).shape[0]
112 | orininal_w = np.array(image).shape[1]
113 | #---------------------------------------------------------#
114 | # 给图像增加灰条,实现不失真的resize
115 | # 也可以直接resize进行识别
116 | #---------------------------------------------------------#
117 | image_data, nw, nh = resize_image(image, (self.input_shape[1],self.input_shape[0]))
118 | #---------------------------------------------------------#
119 | # 添加上batch_size维度
120 | #---------------------------------------------------------#
121 | image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, np.float32)), (2, 0, 1)), 0)
122 |
123 | with torch.no_grad():
124 | images = torch.from_numpy(image_data)
125 | if self.cuda:
126 | images = images.cuda()
127 |
128 | #---------------------------------------------------#
129 | # 图片传入网络进行预测
130 | #---------------------------------------------------#
131 | pr = self.net(images)[0]
132 | #---------------------------------------------------#
133 | # 取出每一个像素点的种类
134 | #---------------------------------------------------#
135 | pr = F.softmax(pr.permute(1,2,0),dim = -1).cpu().numpy()
136 | #--------------------------------------#
137 | # 将灰条部分截取掉
138 | #--------------------------------------#
139 | pr = pr[int((self.input_shape[0] - nh) // 2) : int((self.input_shape[0] - nh) // 2 + nh), \
140 | int((self.input_shape[1] - nw) // 2) : int((self.input_shape[1] - nw) // 2 + nw)]
141 | #---------------------------------------------------#
142 | # 进行图片的resize
143 | #---------------------------------------------------#
144 | pr = cv2.resize(pr, (orininal_w, orininal_h), interpolation = cv2.INTER_LINEAR)
145 | #---------------------------------------------------#
146 | # 取出每一个像素点的种类
147 | #---------------------------------------------------#
148 | pr = pr.argmax(axis=-1)
149 |
150 | #---------------------------------------------------------#
151 | # 计数
152 | #---------------------------------------------------------#
153 | if count:
154 | classes_nums = np.zeros([self.num_classes])
155 | total_points_num = orininal_h * orininal_w
156 | print('-' * 63)
157 | print("|%25s | %15s | %15s|"%("Key", "Value", "Ratio"))
158 | print('-' * 63)
159 | for i in range(self.num_classes):
160 | num = np.sum(pr == i)
161 | ratio = num / total_points_num * 100
162 | if num > 0:
163 | print("|%25s | %15s | %14.2f%%|"%(str(name_classes[i]), str(num), ratio))
164 | print('-' * 63)
165 | classes_nums[i] = num
166 | print("classes_nums:", classes_nums)
167 |
168 | if self.mix_type == 0:
169 | # seg_img = np.zeros((np.shape(pr)[0], np.shape(pr)[1], 3))
170 | # for c in range(self.num_classes):
171 | # seg_img[:, :, 0] += ((pr[:, :] == c ) * self.colors[c][0]).astype('uint8')
172 | # seg_img[:, :, 1] += ((pr[:, :] == c ) * self.colors[c][1]).astype('uint8')
173 | # seg_img[:, :, 2] += ((pr[:, :] == c ) * self.colors[c][2]).astype('uint8')
174 | seg_img = np.reshape(np.array(self.colors, np.uint8)[np.reshape(pr, [-1])], [orininal_h, orininal_w, -1])
175 | #------------------------------------------------#
176 | # 将新图片转换成Image的形式
177 | #------------------------------------------------#
178 | image = Image.fromarray(np.uint8(seg_img))
179 | #------------------------------------------------#
180 | # 将新图与原图及进行混合
181 | #------------------------------------------------#
182 | image = Image.blend(old_img, image, 0.7)
183 |
184 | elif self.mix_type == 1:
185 | # seg_img = np.zeros((np.shape(pr)[0], np.shape(pr)[1], 3))
186 | # for c in range(self.num_classes):
187 | # seg_img[:, :, 0] += ((pr[:, :] == c ) * self.colors[c][0]).astype('uint8')
188 | # seg_img[:, :, 1] += ((pr[:, :] == c ) * self.colors[c][1]).astype('uint8')
189 | # seg_img[:, :, 2] += ((pr[:, :] == c ) * self.colors[c][2]).astype('uint8')
190 | seg_img = np.reshape(np.array(self.colors, np.uint8)[np.reshape(pr, [-1])], [orininal_h, orininal_w, -1])
191 | #------------------------------------------------#
192 | # 将新图片转换成Image的形式
193 | #------------------------------------------------#
194 | image = Image.fromarray(np.uint8(seg_img))
195 |
196 | elif self.mix_type == 2:
197 | seg_img = (np.expand_dims(pr != 0, -1) * np.array(old_img, np.float32)).astype('uint8')
198 | #------------------------------------------------#
199 | # 将新图片转换成Image的形式
200 | #------------------------------------------------#
201 | image = Image.fromarray(np.uint8(seg_img))
202 |
203 | return image
204 |
205 | def get_FPS(self, image, test_interval):
206 | #---------------------------------------------------------#
207 | # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
208 | # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
209 | #---------------------------------------------------------#
210 | image = cvtColor(image)
211 | #---------------------------------------------------------#
212 | # 给图像增加灰条,实现不失真的resize
213 | # 也可以直接resize进行识别
214 | #---------------------------------------------------------#
215 | image_data, nw, nh = resize_image(image, (self.input_shape[1],self.input_shape[0]))
216 | #---------------------------------------------------------#
217 | # 添加上batch_size维度
218 | #---------------------------------------------------------#
219 | image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, np.float32)), (2, 0, 1)), 0)
220 |
221 | with torch.no_grad():
222 | images = torch.from_numpy(image_data)
223 | if self.cuda:
224 | images = images.cuda()
225 |
226 | #---------------------------------------------------#
227 | # 图片传入网络进行预测
228 | #---------------------------------------------------#
229 | pr = self.net(images)[0]
230 | #---------------------------------------------------#
231 | # 取出每一个像素点的种类
232 | #---------------------------------------------------#
233 | pr = F.softmax(pr.permute(1,2,0),dim = -1).cpu().numpy().argmax(axis=-1)
234 | #--------------------------------------#
235 | # 将灰条部分截取掉
236 | #--------------------------------------#
237 | pr = pr[int((self.input_shape[0] - nh) // 2) : int((self.input_shape[0] - nh) // 2 + nh), \
238 | int((self.input_shape[1] - nw) // 2) : int((self.input_shape[1] - nw) // 2 + nw)]
239 |
240 | t1 = time.time()
241 | for _ in range(test_interval):
242 | with torch.no_grad():
243 | #---------------------------------------------------#
244 | # 图片传入网络进行预测
245 | #---------------------------------------------------#
246 | pr = self.net(images)[0]
247 | #---------------------------------------------------#
248 | # 取出每一个像素点的种类
249 | #---------------------------------------------------#
250 | pr = F.softmax(pr.permute(1,2,0),dim = -1).cpu().numpy().argmax(axis=-1)
251 | #--------------------------------------#
252 | # 将灰条部分截取掉
253 | #--------------------------------------#
254 | pr = pr[int((self.input_shape[0] - nh) // 2) : int((self.input_shape[0] - nh) // 2 + nh), \
255 | int((self.input_shape[1] - nw) // 2) : int((self.input_shape[1] - nw) // 2 + nw)]
256 | t2 = time.time()
257 | tact_time = (t2 - t1) / test_interval
258 | return tact_time
259 |
260 | def convert_to_onnx(self, simplify, model_path):
261 | import onnx
262 | self.generate(onnx=True)
263 |
264 | im = torch.zeros(1, 3, *self.input_shape).to('cpu') # image size(1, 3, 512, 512) BCHW
265 | input_layer_names = ["images"]
266 | output_layer_names = ["output"]
267 |
268 | # Export the model
269 | print(f'Starting export with onnx {onnx.__version__}.')
270 | torch.onnx.export(self.net,
271 | im,
272 | f = model_path,
273 | verbose = False,
274 | opset_version = 12,
275 | training = torch.onnx.TrainingMode.EVAL,
276 | do_constant_folding = True,
277 | input_names = input_layer_names,
278 | output_names = output_layer_names,
279 | dynamic_axes = None)
280 |
281 | # Checks
282 | model_onnx = onnx.load(model_path) # load onnx model
283 | onnx.checker.check_model(model_onnx) # check onnx model
284 |
285 | # Simplify onnx
286 | if simplify:
287 | import onnxsim
288 | print(f'Simplifying with onnx-simplifier {onnxsim.__version__}.')
289 | model_onnx, check = onnxsim.simplify(
290 | model_onnx,
291 | dynamic_input_shape=False,
292 | input_shapes=None)
293 | assert check, 'assert check failed'
294 | onnx.save(model_onnx, model_path)
295 |
296 | print('Onnx model save as {}'.format(model_path))
297 |
298 | def get_miou_png(self, image):
299 | #---------------------------------------------------------#
300 | # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
301 | # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
302 | #---------------------------------------------------------#
303 | image = cvtColor(image)
304 | orininal_h = np.array(image).shape[0]
305 | orininal_w = np.array(image).shape[1]
306 | #---------------------------------------------------------#
307 | # 给图像增加灰条,实现不失真的resize
308 | # 也可以直接resize进行识别
309 | #---------------------------------------------------------#
310 | image_data, nw, nh = resize_image(image, (self.input_shape[1],self.input_shape[0]))
311 | #---------------------------------------------------------#
312 | # 添加上batch_size维度
313 | #---------------------------------------------------------#
314 | image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, np.float32)), (2, 0, 1)), 0)
315 |
316 | with torch.no_grad():
317 | images = torch.from_numpy(image_data)
318 | if self.cuda:
319 | images = images.cuda()
320 |
321 | #---------------------------------------------------#
322 | # 图片传入网络进行预测
323 | #---------------------------------------------------#
324 | pr = self.net(images)[0]
325 | #---------------------------------------------------#
326 | # 取出每一个像素点的种类
327 | #---------------------------------------------------#
328 | pr = F.softmax(pr.permute(1,2,0),dim = -1).cpu().numpy()
329 | #--------------------------------------#
330 | # 将灰条部分截取掉
331 | #--------------------------------------#
332 | pr = pr[int((self.input_shape[0] - nh) // 2) : int((self.input_shape[0] - nh) // 2 + nh), \
333 | int((self.input_shape[1] - nw) // 2) : int((self.input_shape[1] - nw) // 2 + nw)]
334 | #---------------------------------------------------#
335 | # 进行图片的resize
336 | #---------------------------------------------------#
337 | pr = cv2.resize(pr, (orininal_w, orininal_h), interpolation = cv2.INTER_LINEAR)
338 | #---------------------------------------------------#
339 | # 取出每一个像素点的种类
340 | #---------------------------------------------------#
341 | pr = pr.argmax(axis=-1)
342 |
343 | image = Image.fromarray(np.uint8(pr))
344 | return image
345 |
--------------------------------------------------------------------------------
/train_medical.py:
--------------------------------------------------------------------------------
1 | import os
2 | import datetime
3 |
4 | import numpy as np
5 | import torch
6 | import torch.backends.cudnn as cudnn
7 | import torch.distributed as dist
8 | import torch.optim as optim
9 | from torch.utils.data import DataLoader
10 |
11 | from nets.unet import Unet
12 | from nets.unet_training import get_lr_scheduler, set_optimizer_lr, weights_init
13 | from utils.callbacks import LossHistory
14 | from utils.dataloader_medical import UnetDataset, unet_dataset_collate
15 | from utils.utils import download_weights, show_config
16 | from utils.utils_fit import fit_one_epoch_no_val
17 |
18 | '''
19 | 训练自己的语义分割模型一定需要注意以下几点:
20 | 1、该数据集是我根据网上找到的医药数据集特殊建立的训练文件,只是一个例子,用于展示数据集不是voc格式时要如何进行训练。
21 |
22 | 不可以计算miou等性能指标。只用于观看医药数据集的训练效果。
23 | 不可以计算miou等性能指标。
24 | 不可以计算miou等性能指标。
25 |
26 | 如果大家有自己的医药数据集需要训练,可以分为两种情况:
27 | a、没有标签的医药数据集:
28 | 请按照视频里面的数据集标注教程,首先利用labelme标注图片,转换成VOC格式后利用train.py进行训练。
29 | b、有标签的医药数据集:
30 | 将文件的标签格式进行转换,标签的每个像素点的值就是这个像素点所属的种类。
31 | 因此数据集的标签需要改成,背景的像素点值为0,目标的像素点值为1。
32 | 参考:https://github.com/bubbliiiing/segmentation-format-fix
33 |
34 | 2、损失值的大小用于判断是否收敛,比较重要的是有收敛的趋势,即验证集损失不断下降,如果验证集损失基本上不改变的话,模型基本上就收敛了。
35 | 损失值的具体大小并没有什么意义,大和小只在于损失的计算方式,并不是接近于0才好。如果想要让损失好看点,可以直接到对应的损失函数里面除上10000。
36 | 训练过程中的损失值会保存在logs文件夹下的loss_%Y_%m_%d_%H_%M_%S文件夹中
37 |
38 | 3、训练好的权值文件保存在logs文件夹中,每个训练世代(Epoch)包含若干训练步长(Step),每个训练步长(Step)进行一次梯度下降。
39 | 如果只是训练了几个Step是不会保存的,Epoch和Step的概念要捋清楚一下。
40 | '''
41 | if __name__ == "__main__":
42 | #---------------------------------#
43 | # Cuda 是否使用Cuda
44 | # 没有GPU可以设置成False
45 | #---------------------------------#
46 | Cuda = True
47 | #---------------------------------------------------------------------#
48 | # distributed 用于指定是否使用单机多卡分布式运行
49 | # 终端指令仅支持Ubuntu。CUDA_VISIBLE_DEVICES用于在Ubuntu下指定显卡。
50 | # Windows系统下默认使用DP模式调用所有显卡,不支持DDP。
51 | # DP模式:
52 | # 设置 distributed = False
53 | # 在终端中输入 CUDA_VISIBLE_DEVICES=0,1 python train_medical.py
54 | # DDP模式:
55 | # 设置 distributed = True
56 | # 在终端中输入 CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 train_medical.py
57 | #---------------------------------------------------------------------#
58 | distributed = False
59 | #---------------------------------------------------------------------#
60 | # sync_bn 是否使用sync_bn,DDP模式多卡可用
61 | #---------------------------------------------------------------------#
62 | sync_bn = False
63 | #---------------------------------------------------------------------#
64 | # fp16 是否使用混合精度训练
65 | # 可减少约一半的显存、需要pytorch1.7.1以上
66 | #---------------------------------------------------------------------#
67 | fp16 = False
68 | #-----------------------------------------------------#
69 | # num_classes 训练自己的数据集必须要修改的
70 | # 自己需要的分类个数+1,如2+1
71 | #-----------------------------------------------------#
72 | num_classes = 4+1
73 | #-----------------------------------------------------#
74 | # 主干网络选择
75 | # vgg
76 | # resnet50
77 | #-----------------------------------------------------#
78 | backbone = "resnet50"
79 | #----------------------------------------------------------------------------------------------------------------------------#
80 | # pretrained 是否使用主干网络的预训练权重,此处使用的是主干的权重,因此是在模型构建的时候进行加载的。
81 | # 如果设置了model_path,则主干的权值无需加载,pretrained的值无意义。
82 | # 如果不设置model_path,pretrained = True,此时仅加载主干开始训练。
83 | # 如果不设置model_path,pretrained = False,Freeze_Train = Fasle,此时从0开始训练,且没有冻结主干的过程。
84 | #----------------------------------------------------------------------------------------------------------------------------#
85 | pretrained = True
86 | #----------------------------------------------------------------------------------------------------------------------------#
87 | # 权值文件的下载请看README,可以通过网盘下载。模型的 预训练权重 对不同数据集是通用的,因为特征是通用的。
88 | # 模型的 预训练权重 比较重要的部分是 主干特征提取网络的权值部分,用于进行特征提取。
89 | # 预训练权重对于99%的情况都必须要用,不用的话主干部分的权值太过随机,特征提取效果不明显,网络训练的结果也不会好
90 | # 训练自己的数据集时提示维度不匹配正常,预测的东西都不一样了自然维度不匹配
91 | #
92 | # 如果训练过程中存在中断训练的操作,可以将model_path设置成logs文件夹下的权值文件,将已经训练了一部分的权值再次载入。
93 | # 同时修改下方的 冻结阶段 或者 解冻阶段 的参数,来保证模型epoch的连续性。
94 | #
95 | # 当model_path = ''的时候不加载整个模型的权值。
96 | #
97 | # 此处使用的是整个模型的权重,因此是在train.py进行加载的,pretrain不影响此处的权值加载。
98 | # 如果想要让模型从主干的预训练权值开始训练,则设置model_path = '',pretrain = True,此时仅加载主干。
99 | # 如果想要让模型从0开始训练,则设置model_path = '',pretrain = Fasle,Freeze_Train = Fasle,此时从0开始训练,且没有冻结主干的过程。
100 | #
101 | # 一般来讲,网络从0开始的训练效果会很差,因为权值太过随机,特征提取效果不明显,因此非常、非常、非常不建议大家从0开始训练!
102 | # 如果一定要从0开始,可以了解imagenet数据集,首先训练分类模型,获得网络的主干部分权值,分类模型的 主干部分 和该模型通用,基于此进行训练。
103 | #----------------------------------------------------------------------------------------------------------------------------#
104 | model_path = "model_data/unet_resnet_voc.pth"
105 | #-----------------------------------------------------#
106 | # input_shape 输入图片的大小,32的倍数
107 | #-----------------------------------------------------#
108 | input_shape = [512, 512]
109 |
110 | #----------------------------------------------------------------------------------------------------------------------------#
111 | # 训练分为两个阶段,分别是冻结阶段和解冻阶段。设置冻结阶段是为了满足机器性能不足的同学的训练需求。
112 | # 冻结训练需要的显存较小,显卡非常差的情况下,可设置Freeze_Epoch等于UnFreeze_Epoch,此时仅仅进行冻结训练。
113 | #
114 | # 在此提供若干参数设置建议,各位训练者根据自己的需求进行灵活调整:
115 | # (一)从整个模型的预训练权重开始训练:
116 | # Adam:
117 | # Init_Epoch = 0,Freeze_Epoch = 50,UnFreeze_Epoch = 100,Freeze_Train = True,optimizer_type = 'adam',Init_lr = 1e-4,weight_decay = 0。(冻结)
118 | # Init_Epoch = 0,UnFreeze_Epoch = 100,Freeze_Train = False,optimizer_type = 'adam',Init_lr = 1e-4,weight_decay = 0。(不冻结)
119 | # SGD:
120 | # Init_Epoch = 0,Freeze_Epoch = 50,UnFreeze_Epoch = 100,Freeze_Train = True,optimizer_type = 'sgd',Init_lr = 1e-2,weight_decay = 1e-4。(冻结)
121 | # Init_Epoch = 0,UnFreeze_Epoch = 100,Freeze_Train = False,optimizer_type = 'sgd',Init_lr = 1e-2,weight_decay = 1e-4。(不冻结)
122 | # 其中:UnFreeze_Epoch可以在100-300之间调整。
123 | # (二)从主干网络的预训练权重开始训练:
124 | # Adam:
125 | # Init_Epoch = 0,Freeze_Epoch = 50,UnFreeze_Epoch = 100,Freeze_Train = True,optimizer_type = 'adam',Init_lr = 1e-4,weight_decay = 0。(冻结)
126 | # Init_Epoch = 0,UnFreeze_Epoch = 100,Freeze_Train = False,optimizer_type = 'adam',Init_lr = 1e-4,weight_decay = 0。(不冻结)
127 | # SGD:
128 | # Init_Epoch = 0,Freeze_Epoch = 50,UnFreeze_Epoch = 120,Freeze_Train = True,optimizer_type = 'sgd',Init_lr = 1e-2,weight_decay = 1e-4。(冻结)
129 | # Init_Epoch = 0,UnFreeze_Epoch = 120,Freeze_Train = False,optimizer_type = 'sgd',Init_lr = 1e-2,weight_decay = 1e-4。(不冻结)
130 | # 其中:由于从主干网络的预训练权重开始训练,主干的权值不一定适合语义分割,需要更多的训练跳出局部最优解。
131 | # UnFreeze_Epoch可以在120-300之间调整。
132 | # Adam相较于SGD收敛的快一些。因此UnFreeze_Epoch理论上可以小一点,但依然推荐更多的Epoch。
133 | # (三)batch_size的设置:
134 | # 在显卡能够接受的范围内,以大为好。显存不足与数据集大小无关,提示显存不足(OOM或者CUDA out of memory)请调小batch_size。
135 | # 由于resnet50中有BatchNormalization层
136 | # 当主干为resnet50的时候batch_size不可为1
137 | # 正常情况下Freeze_batch_size建议为Unfreeze_batch_size的1-2倍。不建议设置的差距过大,因为关系到学习率的自动调整。
138 | #----------------------------------------------------------------------------------------------------------------------------#
139 | #------------------------------------------------------------------#
140 | # 冻结阶段训练参数
141 | # 此时模型的主干被冻结了,特征提取网络不发生改变
142 | # 占用的显存较小,仅对网络进行微调
143 | # Init_Epoch 模型当前开始的训练世代,其值可以大于Freeze_Epoch,如设置:
144 | # Init_Epoch = 60、Freeze_Epoch = 50、UnFreeze_Epoch = 100
145 | # 会跳过冻结阶段,直接从60代开始,并调整对应的学习率。
146 | # (断点续练时使用)
147 | # Freeze_Epoch 模型冻结训练的Freeze_Epoch
148 | # (当Freeze_Train=False时失效)
149 | # Freeze_batch_size 模型冻结训练的batch_size
150 | # (当Freeze_Train=False时失效)
151 | #------------------------------------------------------------------#
152 | Init_Epoch = 0
153 | Freeze_Epoch = 10
154 | Freeze_batch_size = 1
155 | #------------------------------------------------------------------#
156 | # 解冻阶段训练参数
157 | # 此时模型的主干不被冻结了,特征提取网络会发生改变
158 | # 占用的显存较大,网络所有的参数都会发生改变
159 | # UnFreeze_Epoch 模型总共训练的epoch
160 | # Unfreeze_batch_size 模型在解冻后的batch_size
161 | #------------------------------------------------------------------#
162 | UnFreeze_Epoch = 20
163 | Unfreeze_batch_size = 1
164 | #------------------------------------------------------------------#
165 | # Freeze_Train 是否进行冻结训练
166 | # 默认先冻结主干训练后解冻训练。
167 | #------------------------------------------------------------------#
168 | Freeze_Train = True
169 |
170 | #------------------------------------------------------------------#
171 | # 其它训练参数:学习率、优化器、学习率下降有关
172 | #------------------------------------------------------------------#
173 | #------------------------------------------------------------------#
174 | # Init_lr 模型的最大学习率
175 | # 当使用Adam优化器时建议设置 Init_lr=1e-4
176 | # 当使用SGD优化器时建议设置 Init_lr=1e-2
177 | # Min_lr 模型的最小学习率,默认为最大学习率的0.01
178 | #------------------------------------------------------------------#
179 | Init_lr = 1e-4
180 | Min_lr = Init_lr * 0.01
181 | #------------------------------------------------------------------#
182 | # optimizer_type 使用到的优化器种类,可选的有adam、sgd
183 | # 当使用Adam优化器时建议设置 Init_lr=1e-4
184 | # 当使用SGD优化器时建议设置 Init_lr=1e-2
185 | # momentum 优化器内部使用到的momentum参数
186 | # weight_decay 权值衰减,可防止过拟合
187 | # adam会导致weight_decay错误,使用adam时建议设置为0。
188 | #------------------------------------------------------------------#
189 | optimizer_type = "adam"
190 | momentum = 0.9
191 | weight_decay = 0
192 | #------------------------------------------------------------------#
193 | # lr_decay_type 使用到的学习率下降方式,可选的有'step'、'cos'
194 | #------------------------------------------------------------------#
195 | lr_decay_type = 'cos'
196 | #------------------------------------------------------------------#
197 | # save_period 多少个epoch保存一次权值
198 | #------------------------------------------------------------------#
199 | save_period = 5
200 | #------------------------------------------------------------------#
201 | # save_dir 权值与日志文件保存的文件夹
202 | #------------------------------------------------------------------#
203 | save_dir = 'logs'
204 |
205 | #------------------------------#
206 | # 数据集路径
207 | #------------------------------#
208 | VOCdevkit_path = 'NanKai'
209 | #------------------------------------------------------------------#
210 | # 建议选项:
211 | # 种类少(几类)时,设置为True
212 | # 种类多(十几类)时,如果batch_size比较大(10以上),那么设置为True
213 | # 种类多(十几类)时,如果batch_size比较小(10以下),那么设置为False
214 | #------------------------------------------------------------------#
215 | dice_loss = True
216 | #------------------------------------------------------------------#
217 | # 是否使用focal loss来防止正负样本不平衡
218 | #------------------------------------------------------------------#
219 | focal_loss = [1,3,3,3,3]
220 | #------------------------------------------------------------------#
221 | # 是否给不同种类赋予不同的损失权值,默认是平衡的。
222 | # 设置的话,注意设置成numpy形式的,长度和num_classes一样。
223 | # 如:
224 | # num_classes = 3
225 | # cls_weights = np.array([1, 2, 3], np.float32)
226 | #------------------------------------------------------------------#
227 | cls_weights = np.ones([num_classes], np.float32)
228 | #------------------------------------------------------------------#
229 | # num_workers 用于设置是否使用多线程读取数据,1代表关闭多线程
230 | # 开启后会加快数据读取速度,但是会占用更多内存
231 | # keras里开启多线程有些时候速度反而慢了许多
232 | # 在IO为瓶颈的时候再开启多线程,即GPU运算速度远大于读取图片的速度。
233 | #------------------------------------------------------------------#
234 | num_workers = 4
235 |
236 | #------------------------------------------------------#
237 | # 设置用到的显卡
238 | #------------------------------------------------------#
239 | ngpus_per_node = torch.cuda.device_count()
240 | if distributed:
241 | dist.init_process_group(backend="nccl")
242 | local_rank = int(os.environ["LOCAL_RANK"])
243 | rank = int(os.environ["RANK"])
244 | device = torch.device("cuda", local_rank)
245 | if local_rank == 0:
246 | print(f"[{os.getpid()}] (rank = {rank}, local_rank = {local_rank}) training...")
247 | print("Gpu Device Count : ", ngpus_per_node)
248 | else:
249 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
250 | local_rank = 0
251 |
252 | #----------------------------------------------------#
253 | # 下载预训练权重
254 | #----------------------------------------------------#
255 | if pretrained:
256 | if distributed:
257 | if local_rank == 0:
258 | download_weights(backbone)
259 | dist.barrier()
260 | else:
261 | download_weights(backbone)
262 |
263 | model = Unet(num_classes=num_classes, pretrained=pretrained, backbone=backbone).train()
264 | if not pretrained:
265 | weights_init(model)
266 | if model_path != '':
267 | #------------------------------------------------------#
268 | # 权值文件请看README,百度网盘下载
269 | #------------------------------------------------------#
270 | if local_rank == 0:
271 | print('Load weights {}.'.format(model_path))
272 |
273 | #------------------------------------------------------#
274 | # 根据预训练权重的Key和模型的Key进行加载
275 | #------------------------------------------------------#
276 | model_dict = model.state_dict()
277 | pretrained_dict = torch.load(model_path, map_location = device)
278 | load_key, no_load_key, temp_dict = [], [], {}
279 | for k, v in pretrained_dict.items():
280 | if k in model_dict.keys() and np.shape(model_dict[k]) == np.shape(v):
281 | temp_dict[k] = v
282 | load_key.append(k)
283 | else:
284 | no_load_key.append(k)
285 | model_dict.update(temp_dict)
286 | model.load_state_dict(model_dict)
287 | #------------------------------------------------------#
288 | # 显示没有匹配上的Key
289 | #------------------------------------------------------#
290 | if local_rank == 0:
291 | print("\nSuccessful Load Key:", str(load_key)[:500], "……\nSuccessful Load Key Num:", len(load_key))
292 | print("\nFail To Load Key:", str(no_load_key)[:500], "……\nFail To Load Key num:", len(no_load_key))
293 | print("\n\033[1;33;44m温馨提示,head部分没有载入是正常现象,Backbone部分没有载入是错误的。\033[0m")
294 |
295 | #----------------------#
296 | # 记录Loss
297 | #----------------------#
298 | if local_rank == 0:
299 | time_str = datetime.datetime.strftime(datetime.datetime.now(),'%Y_%m_%d_%H_%M_%S')
300 | log_dir = os.path.join(save_dir, "loss_" + str(time_str))
301 | loss_history = LossHistory(log_dir, model, input_shape=input_shape, val_loss_flag = False)
302 | else:
303 | loss_history = None
304 |
305 | #------------------------------------------------------------------#
306 | # torch 1.2不支持amp,建议使用torch 1.7.1及以上正确使用fp16
307 | # 因此torch1.2这里显示"could not be resolve"
308 | #------------------------------------------------------------------#
309 | if fp16:
310 | from torch.cuda.amp import GradScaler as GradScaler
311 | scaler = GradScaler()
312 | else:
313 | scaler = None
314 |
315 | model_train = model.train()
316 | #----------------------------#
317 | # 多卡同步Bn
318 | #----------------------------#
319 | if sync_bn and ngpus_per_node > 1 and distributed:
320 | model_train = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model_train)
321 | elif sync_bn:
322 | print("Sync_bn is not support in one gpu or not distributed.")
323 |
324 | if Cuda:
325 | if distributed:
326 | #----------------------------#
327 | # 多卡平行运行
328 | #----------------------------#
329 | model_train = model_train.cuda(local_rank)
330 | model_train = torch.nn.parallel.DistributedDataParallel(model_train, device_ids=[local_rank], find_unused_parameters=True)
331 | else:
332 | model_train = torch.nn.DataParallel(model)
333 | cudnn.benchmark = True
334 | model_train = model_train.cuda()
335 |
336 | #---------------------------#
337 | # 读取数据集对应的txt
338 | #---------------------------#
339 | with open(os.path.join(VOCdevkit_path, "save/train.txt"),"r") as f:
340 | train_lines = f.readlines()
341 | num_train = len(train_lines)
342 |
343 | if local_rank == 0:
344 | show_config(
345 | num_classes = num_classes, backbone = backbone, model_path = model_path, input_shape = input_shape, \
346 | Init_Epoch = Init_Epoch, Freeze_Epoch = Freeze_Epoch, UnFreeze_Epoch = UnFreeze_Epoch, Freeze_batch_size = Freeze_batch_size, Unfreeze_batch_size = Unfreeze_batch_size, Freeze_Train = Freeze_Train, \
347 | Init_lr = Init_lr, Min_lr = Min_lr, optimizer_type = optimizer_type, momentum = momentum, lr_decay_type = lr_decay_type, \
348 | save_period = save_period, save_dir = save_dir, num_workers = num_workers, num_train = num_train
349 | )
350 | #------------------------------------------------------#
351 | # 主干特征提取网络特征通用,冻结训练可以加快训练速度
352 | # 也可以在训练初期防止权值被破坏。
353 | # Init_Epoch为起始世代
354 | # Interval_Epoch为冻结训练的世代
355 | # Epoch总训练世代
356 | # 提示OOM或者显存不足请调小Batch_size
357 | #------------------------------------------------------#
358 | if True:
359 | UnFreeze_flag = False
360 | #------------------------------------#
361 | # 冻结一定部分训练
362 | #------------------------------------#
363 | if Freeze_Train:
364 | model.freeze_backbone()
365 |
366 | #-------------------------------------------------------------------#
367 | # 如果不冻结训练的话,直接设置batch_size为Unfreeze_batch_size
368 | #-------------------------------------------------------------------#
369 | batch_size = Freeze_batch_size if Freeze_Train else Unfreeze_batch_size
370 |
371 | #-------------------------------------------------------------------#
372 | # 判断当前batch_size,自适应调整学习率
373 | #-------------------------------------------------------------------#
374 | nbs = 16
375 | lr_limit_max = 1e-4 if optimizer_type == 'adam' else 1e-1
376 | lr_limit_min = 1e-4 if optimizer_type == 'adam' else 5e-4
377 | Init_lr_fit = min(max(batch_size / nbs * Init_lr, lr_limit_min), lr_limit_max)
378 | Min_lr_fit = min(max(batch_size / nbs * Min_lr, lr_limit_min * 1e-2), lr_limit_max * 1e-2)
379 |
380 | #---------------------------------------#
381 | # 根据optimizer_type选择优化器
382 | #---------------------------------------#
383 | optimizer = {
384 | 'adam' : optim.Adam(model.parameters(), Init_lr_fit, betas = (momentum, 0.999), weight_decay = weight_decay),
385 | 'sgd' : optim.SGD(model.parameters(), Init_lr_fit, momentum = momentum, nesterov=True, weight_decay = weight_decay)
386 | }[optimizer_type]
387 |
388 | #---------------------------------------#
389 | # 获得学习率下降的公式
390 | #---------------------------------------#
391 | lr_scheduler_func = get_lr_scheduler(lr_decay_type, Init_lr_fit, Min_lr_fit, UnFreeze_Epoch)
392 |
393 | #---------------------------------------#
394 | # 判断每一个世代的长度
395 | #---------------------------------------#
396 | epoch_step = num_train // batch_size
397 |
398 | if epoch_step == 0:
399 | raise ValueError("数据集过小,无法继续进行训练,请扩充数据集。")
400 |
401 | train_dataset = UnetDataset(train_lines, input_shape, num_classes, True, VOCdevkit_path)
402 |
403 | if distributed:
404 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=True,)
405 | batch_size = batch_size // ngpus_per_node
406 | shuffle = False
407 | else:
408 | train_sampler = None
409 | shuffle = True
410 |
411 | gen = DataLoader(train_dataset, shuffle = shuffle, batch_size = batch_size, num_workers = num_workers, pin_memory=True,
412 | drop_last = True, collate_fn = unet_dataset_collate, sampler=train_sampler)
413 |
414 | #---------------------------------------#
415 | # 开始模型训练
416 | #---------------------------------------#
417 | for epoch in range(Init_Epoch, UnFreeze_Epoch):
418 | #---------------------------------------#
419 | # 如果模型有冻结学习部分
420 | # 则解冻,并设置参数
421 | #---------------------------------------#
422 | if epoch >= Freeze_Epoch and not UnFreeze_flag and Freeze_Train:
423 | batch_size = Unfreeze_batch_size
424 |
425 | #-------------------------------------------------------------------#
426 | # 判断当前batch_size,自适应调整学习率
427 | #-------------------------------------------------------------------#
428 | nbs = 16
429 | lr_limit_max = 1e-4 if optimizer_type == 'adam' else 1e-1
430 | lr_limit_min = 1e-4 if optimizer_type == 'adam' else 5e-4
431 | Init_lr_fit = min(max(batch_size / nbs * Init_lr, lr_limit_min), lr_limit_max)
432 | Min_lr_fit = min(max(batch_size / nbs * Min_lr, lr_limit_min * 1e-2), lr_limit_max * 1e-2)
433 | #---------------------------------------#
434 | # 获得学习率下降的公式
435 | #---------------------------------------#
436 | lr_scheduler_func = get_lr_scheduler(lr_decay_type, Init_lr_fit, Min_lr_fit, UnFreeze_Epoch)
437 |
438 | model.unfreeze_backbone()
439 |
440 | epoch_step = num_train // batch_size
441 |
442 | if epoch_step == 0:
443 | raise ValueError("数据集过小,无法继续进行训练,请扩充数据集。")
444 |
445 | if distributed:
446 | batch_size = batch_size // ngpus_per_node
447 |
448 | gen = DataLoader(train_dataset, shuffle = shuffle, batch_size = batch_size, num_workers = num_workers, pin_memory=True,
449 | drop_last = True, collate_fn = unet_dataset_collate, sampler=train_sampler)
450 |
451 | UnFreeze_flag = True
452 |
453 | if distributed:
454 | train_sampler.set_epoch(epoch)
455 |
456 | set_optimizer_lr(optimizer, lr_scheduler_func, epoch)
457 |
458 | fit_one_epoch_no_val(model_train, model, loss_history, optimizer, epoch, epoch_step, gen, UnFreeze_Epoch, Cuda, dice_loss, focal_loss, cls_weights, num_classes, fp16, scaler, save_period, save_dir, local_rank)
459 |
460 | if distributed:
461 | dist.barrier()
462 |
463 | if local_rank == 0:
464 | loss_history.writer.close()
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import datetime
3 |
4 | import numpy as np
5 | import torch
6 | import torch.backends.cudnn as cudnn
7 | import torch.distributed as dist
8 | import torch.optim as optim
9 | from torch.utils.data import DataLoader
10 |
11 | from nets.unet import Unet
12 | from nets.unet_training import get_lr_scheduler, set_optimizer_lr, weights_init
13 | from utils.callbacks import LossHistory, EvalCallback
14 | from utils.dataloader import UnetDataset, unet_dataset_collate
15 | from utils.utils import download_weights, show_config
16 | from utils.utils_fit import fit_one_epoch
17 |
18 | '''
19 | 训练自己的语义分割模型一定需要注意以下几点:
20 | 1、训练前仔细检查自己的格式是否满足要求,该库要求数据集格式为VOC格式,需要准备好的内容有输入图片和标签
21 | 输入图片为.jpg图片,无需固定大小,传入训练前会自动进行resize。
22 | 灰度图会自动转成RGB图片进行训练,无需自己修改。
23 | 输入图片如果后缀非jpg,需要自己批量转成jpg后再开始训练。
24 |
25 | 标签为png图片,无需固定大小,传入训练前会自动进行resize。
26 | 由于许多同学的数据集是网络上下载的,标签格式并不符合,需要再度处理。一定要注意!标签的每个像素点的值就是这个像素点所属的种类。
27 | 网上常见的数据集总共对输入图片分两类,背景的像素点值为0,目标的像素点值为255。这样的数据集可以正常运行但是预测是没有效果的!
28 | 需要改成,背景的像素点值为0,目标的像素点值为1。
29 | 如果格式有误,参考:https://github.com/bubbliiiing/segmentation-format-fix
30 |
31 | 2、损失值的大小用于判断是否收敛,比较重要的是有收敛的趋势,即验证集损失不断下降,如果验证集损失基本上不改变的话,模型基本上就收敛了。
32 | 损失值的具体大小并没有什么意义,大和小只在于损失的计算方式,并不是接近于0才好。如果想要让损失好看点,可以直接到对应的损失函数里面除上10000。
33 | 训练过程中的损失值会保存在logs文件夹下的loss_%Y_%m_%d_%H_%M_%S文件夹中
34 |
35 | 3、训练好的权值文件保存在logs文件夹中,每个训练世代(Epoch)包含若干训练步长(Step),每个训练步长(Step)进行一次梯度下降。
36 | 如果只是训练了几个Step是不会保存的,Epoch和Step的概念要捋清楚一下。
37 | '''
38 | if __name__ == "__main__":
39 | #---------------------------------#
40 | # Cuda 是否使用Cuda
41 | # 没有GPU可以设置成False
42 | #---------------------------------#
43 | Cuda = True
44 | #---------------------------------------------------------------------#
45 | # distributed 用于指定是否使用单机多卡分布式运行
46 | # 终端指令仅支持Ubuntu。CUDA_VISIBLE_DEVICES用于在Ubuntu下指定显卡。
47 | # Windows系统下默认使用DP模式调用所有显卡,不支持DDP。
48 | # DP模式:
49 | # 设置 distributed = False
50 | # 在终端中输入 CUDA_VISIBLE_DEVICES=0,1 python train.py
51 | # DDP模式:
52 | # 设置 distributed = True
53 | # 在终端中输入 CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 train.py
54 | #---------------------------------------------------------------------#
55 | distributed = False
56 | #---------------------------------------------------------------------#
57 | # sync_bn 是否使用sync_bn,DDP模式多卡可用
58 | #---------------------------------------------------------------------#
59 | sync_bn = False
60 | #---------------------------------------------------------------------#
61 | # fp16 是否使用混合精度训练
62 | # 可减少约一半的显存、需要pytorch1.7.1以上
63 | #---------------------------------------------------------------------#
64 | fp16 = False
65 | #-----------------------------------------------------#
66 | # num_classes 训练自己的数据集必须要修改的
67 | # 自己需要的分类个数+1,如2+1
68 | #-----------------------------------------------------#
69 | num_classes = 2
70 | #-----------------------------------------------------#
71 | # 主干网络选择
72 | # vgg
73 | # resnet50
74 | #-----------------------------------------------------#
75 | backbone = "vgg"
76 | #----------------------------------------------------------------------------------------------------------------------------#
77 | # pretrained 是否使用主干网络的预训练权重,此处使用的是主干的权重,因此是在模型构建的时候进行加载的。
78 | # 如果设置了model_path,则主干的权值无需加载,pretrained的值无意义。
79 | # 如果不设置model_path,pretrained = True,此时仅加载主干开始训练。
80 | # 如果不设置model_path,pretrained = False,Freeze_Train = Fasle,此时从0开始训练,且没有冻结主干的过程。
81 | #----------------------------------------------------------------------------------------------------------------------------#
82 | pretrained = False
83 | #----------------------------------------------------------------------------------------------------------------------------#
84 | # 权值文件的下载请看README,可以通过网盘下载。模型的 预训练权重 对不同数据集是通用的,因为特征是通用的。
85 | # 模型的 预训练权重 比较重要的部分是 主干特征提取网络的权值部分,用于进行特征提取。
86 | # 预训练权重对于99%的情况都必须要用,不用的话主干部分的权值太过随机,特征提取效果不明显,网络训练的结果也不会好
87 | # 训练自己的数据集时提示维度不匹配正常,预测的东西都不一样了自然维度不匹配
88 | #
89 | # 如果训练过程中存在中断训练的操作,可以将model_path设置成logs文件夹下的权值文件,将已经训练了一部分的权值再次载入。
90 | # 同时修改下方的 冻结阶段 或者 解冻阶段 的参数,来保证模型epoch的连续性。
91 | #
92 | # 当model_path = ''的时候不加载整个模型的权值。
93 | #
94 | # 此处使用的是整个模型的权重,因此是在train.py进行加载的,pretrain不影响此处的权值加载。
95 | # 如果想要让模型从主干的预训练权值开始训练,则设置model_path = '',pretrain = True,此时仅加载主干。
96 | # 如果想要让模型从0开始训练,则设置model_path = '',pretrain = Fasle,Freeze_Train = Fasle,此时从0开始训练,且没有冻结主干的过程。
97 | #
98 | # 一般来讲,网络从0开始的训练效果会很差,因为权值太过随机,特征提取效果不明显,因此非常、非常、非常不建议大家从0开始训练!
99 | # 如果一定要从0开始,可以了解imagenet数据集,首先训练分类模型,获得网络的主干部分权值,分类模型的 主干部分 和该模型通用,基于此进行训练。
100 | #----------------------------------------------------------------------------------------------------------------------------#
101 | model_path = "model_data/unet_vgg_voc.pth"
102 | #-----------------------------------------------------#
103 | # input_shape 输入图片的大小,32的倍数
104 | #-----------------------------------------------------#
105 | input_shape = [512, 512]
106 |
107 | #----------------------------------------------------------------------------------------------------------------------------#
108 | # 训练分为两个阶段,分别是冻结阶段和解冻阶段。设置冻结阶段是为了满足机器性能不足的同学的训练需求。
109 | # 冻结训练需要的显存较小,显卡非常差的情况下,可设置Freeze_Epoch等于UnFreeze_Epoch,此时仅仅进行冻结训练。
110 | #
111 | # 在此提供若干参数设置建议,各位训练者根据自己的需求进行灵活调整:
112 | # (一)从整个模型的预训练权重开始训练:
113 | # Adam:
114 | # Init_Epoch = 0,Freeze_Epoch = 50,UnFreeze_Epoch = 100,Freeze_Train = True,optimizer_type = 'adam',Init_lr = 1e-4,weight_decay = 0。(冻结)
115 | # Init_Epoch = 0,UnFreeze_Epoch = 100,Freeze_Train = False,optimizer_type = 'adam',Init_lr = 1e-4,weight_decay = 0。(不冻结)
116 | # SGD:
117 | # Init_Epoch = 0,Freeze_Epoch = 50,UnFreeze_Epoch = 100,Freeze_Train = True,optimizer_type = 'sgd',Init_lr = 1e-2,weight_decay = 1e-4。(冻结)
118 | # Init_Epoch = 0,UnFreeze_Epoch = 100,Freeze_Train = False,optimizer_type = 'sgd',Init_lr = 1e-2,weight_decay = 1e-4。(不冻结)
119 | # 其中:UnFreeze_Epoch可以在100-300之间调整。
120 | # (二)从主干网络的预训练权重开始训练:
121 | # Adam:
122 | # Init_Epoch = 0,Freeze_Epoch = 50,UnFreeze_Epoch = 100,Freeze_Train = True,optimizer_type = 'adam',Init_lr = 1e-4,weight_decay = 0。(冻结)
123 | # Init_Epoch = 0,UnFreeze_Epoch = 100,Freeze_Train = False,optimizer_type = 'adam',Init_lr = 1e-4,weight_decay = 0。(不冻结)
124 | # SGD:
125 | # Init_Epoch = 0,Freeze_Epoch = 50,UnFreeze_Epoch = 120,Freeze_Train = True,optimizer_type = 'sgd',Init_lr = 1e-2,weight_decay = 1e-4。(冻结)
126 | # Init_Epoch = 0,UnFreeze_Epoch = 120,Freeze_Train = False,optimizer_type = 'sgd',Init_lr = 1e-2,weight_decay = 1e-4。(不冻结)
127 | # 其中:由于从主干网络的预训练权重开始训练,主干的权值不一定适合语义分割,需要更多的训练跳出局部最优解。
128 | # UnFreeze_Epoch可以在120-300之间调整。
129 | # Adam相较于SGD收敛的快一些。因此UnFreeze_Epoch理论上可以小一点,但依然推荐更多的Epoch。
130 | # (三)batch_size的设置:
131 | # 在显卡能够接受的范围内,以大为好。显存不足与数据集大小无关,提示显存不足(OOM或者CUDA out of memory)请调小batch_size。
132 | # 由于resnet50中有BatchNormalization层
133 | # 当主干为resnet50的时候batch_size不可为1
134 | # 正常情况下Freeze_batch_size建议为Unfreeze_batch_size的1-2倍。不建议设置的差距过大,因为关系到学习率的自动调整。
135 | #----------------------------------------------------------------------------------------------------------------------------#
136 | #------------------------------------------------------------------#
137 | # 冻结阶段训练参数
138 | # 此时模型的主干被冻结了,特征提取网络不发生改变
139 | # 占用的显存较小,仅对网络进行微调
140 | # Init_Epoch 模型当前开始的训练世代,其值可以大于Freeze_Epoch,如设置:
141 | # Init_Epoch = 60、Freeze_Epoch = 50、UnFreeze_Epoch = 100
142 | # 会跳过冻结阶段,直接从60代开始,并调整对应的学习率。
143 | # (断点续练时使用)
144 | # Freeze_Epoch 模型冻结训练的Freeze_Epoch
145 | # (当Freeze_Train=False时失效)
146 | # Freeze_batch_size 模型冻结训练的batch_size
147 | # (当Freeze_Train=False时失效)
148 | #------------------------------------------------------------------#
149 | Init_Epoch = 0
150 | Freeze_Epoch = 100
151 | Freeze_batch_size = 2
152 | #------------------------------------------------------------------#
153 | # 解冻阶段训练参数
154 | # 此时模型的主干不被冻结了,特征提取网络会发生改变
155 | # 占用的显存较大,网络所有的参数都会发生改变
156 | # UnFreeze_Epoch 模型总共训练的epoch
157 | # Unfreeze_batch_size 模型在解冻后的batch_size
158 | #------------------------------------------------------------------#
159 | UnFreeze_Epoch = 300
160 | Unfreeze_batch_size = 2
161 | #------------------------------------------------------------------#
162 | # Freeze_Train 是否进行冻结训练
163 | # 默认先冻结主干训练后解冻训练。
164 | #------------------------------------------------------------------#
165 | Freeze_Train = True
166 |
167 | #------------------------------------------------------------------#
168 | # 其它训练参数:学习率、优化器、学习率下降有关
169 | #------------------------------------------------------------------#
170 | #------------------------------------------------------------------#
171 | # Init_lr 模型的最大学习率
172 | # 当使用Adam优化器时建议设置 Init_lr=1e-4
173 | # 当使用SGD优化器时建议设置 Init_lr=1e-2
174 | # Min_lr 模型的最小学习率,默认为最大学习率的0.01
175 | #------------------------------------------------------------------#
176 | Init_lr = 1e-4
177 | Min_lr = Init_lr * 0.01
178 | #------------------------------------------------------------------#
179 | # optimizer_type 使用到的优化器种类,可选的有adam、sgd
180 | # 当使用Adam优化器时建议设置 Init_lr=1e-4
181 | # 当使用SGD优化器时建议设置 Init_lr=1e-2
182 | # momentum 优化器内部使用到的momentum参数
183 | # weight_decay 权值衰减,可防止过拟合
184 | # adam会导致weight_decay错误,使用adam时建议设置为0。
185 | #------------------------------------------------------------------#
186 | optimizer_type = "adam"
187 | momentum = 0.9
188 | weight_decay = 0
189 | #------------------------------------------------------------------#
190 | # lr_decay_type 使用到的学习率下降方式,可选的有'step'、'cos'
191 | #------------------------------------------------------------------#
192 | lr_decay_type = 'cos'
193 | #------------------------------------------------------------------#
194 | # save_period 多少个epoch保存一次权值
195 | #------------------------------------------------------------------#
196 | save_period = 5
197 | #------------------------------------------------------------------#
198 | # save_dir 权值与日志文件保存的文件夹
199 | #------------------------------------------------------------------#
200 | save_dir = 'logs'
201 | #------------------------------------------------------------------#
202 | # eval_flag 是否在训练时进行评估,评估对象为验证集
203 | # eval_period 代表多少个epoch评估一次,不建议频繁的评估
204 | # 评估需要消耗较多的时间,频繁评估会导致训练非常慢
205 | # 此处获得的mAP会与get_map.py获得的会有所不同,原因有二:
206 | # (一)此处获得的mAP为验证集的mAP。
207 | # (二)此处设置评估参数较为保守,目的是加快评估速度。
208 | #------------------------------------------------------------------#
209 | eval_flag = True
210 | eval_period = 5
211 |
212 | #------------------------------#
213 | # 数据集路径
214 | #------------------------------#
215 | VOCdevkit_path = 'Nankai'
216 | #------------------------------------------------------------------#
217 | # 建议选项:
218 | # 种类少(几类)时,设置为True
219 | # 种类多(十几类)时,如果batch_size比较大(10以上),那么设置为True
220 | # 种类多(十几类)时,如果batch_size比较小(10以下),那么设置为False
221 | #------------------------------------------------------------------#
222 | dice_loss = False
223 | #------------------------------------------------------------------#
224 | # 是否使用focal loss来防止正负样本不平衡
225 | #------------------------------------------------------------------#
226 | focal_loss = False
227 | #------------------------------------------------------------------#
228 | # 是否给不同种类赋予不同的损失权值,默认是平衡的。
229 | # 设置的话,注意设置成numpy形式的,长度和num_classes一样。
230 | # 如:
231 | # num_classes = 3
232 | # cls_weights = np.array([1, 2, 3], np.float32)
233 | #------------------------------------------------------------------#
234 | cls_weights = np.ones([num_classes], np.float32)
235 | #------------------------------------------------------------------#
236 | # num_workers 用于设置是否使用多线程读取数据,1代表关闭多线程
237 | # 开启后会加快数据读取速度,但是会占用更多内存
238 | # keras里开启多线程有些时候速度反而慢了许多
239 | # 在IO为瓶颈的时候再开启多线程,即GPU运算速度远大于读取图片的速度。
240 | #------------------------------------------------------------------#
241 | num_workers = 4
242 |
243 | #------------------------------------------------------#
244 | # 设置用到的显卡
245 | #------------------------------------------------------#
246 | ngpus_per_node = torch.cuda.device_count()
247 | if distributed:
248 | dist.init_process_group(backend="nccl")
249 | local_rank = int(os.environ["LOCAL_RANK"])
250 | rank = int(os.environ["RANK"])
251 | device = torch.device("cuda", local_rank)
252 | if local_rank == 0:
253 | print(f"[{os.getpid()}] (rank = {rank}, local_rank = {local_rank}) training...")
254 | print("Gpu Device Count : ", ngpus_per_node)
255 | else:
256 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
257 | local_rank = 0
258 |
259 | #----------------------------------------------------#
260 | # 下载预训练权重
261 | #----------------------------------------------------#
262 | if pretrained:
263 | if distributed:
264 | if local_rank == 0:
265 | download_weights(backbone)
266 | dist.barrier()
267 | else:
268 | download_weights(backbone)
269 |
270 | model = Unet(num_classes=num_classes, pretrained=pretrained, backbone=backbone).train()
271 | if not pretrained:
272 | weights_init(model)
273 | if model_path != '':
274 | #------------------------------------------------------#
275 | # 权值文件请看README,百度网盘下载
276 | #------------------------------------------------------#
277 | if local_rank == 0:
278 | print('Load weights {}.'.format(model_path))
279 |
280 | #------------------------------------------------------#
281 | # 根据预训练权重的Key和模型的Key进行加载
282 | #------------------------------------------------------#
283 | model_dict = model.state_dict()
284 | pretrained_dict = torch.load(model_path, map_location = device)
285 | load_key, no_load_key, temp_dict = [], [], {}
286 | for k, v in pretrained_dict.items():
287 | if k in model_dict.keys() and np.shape(model_dict[k]) == np.shape(v):
288 | temp_dict[k] = v
289 | load_key.append(k)
290 | else:
291 | no_load_key.append(k)
292 | model_dict.update(temp_dict)
293 | model.load_state_dict(model_dict)
294 | #------------------------------------------------------#
295 | # 显示没有匹配上的Key
296 | #------------------------------------------------------#
297 | if local_rank == 0:
298 | print("\nSuccessful Load Key:", str(load_key)[:500], "……\nSuccessful Load Key Num:", len(load_key))
299 | print("\nFail To Load Key:", str(no_load_key)[:500], "……\nFail To Load Key num:", len(no_load_key))
300 | print("\n\033[1;33;44m温馨提示,head部分没有载入是正常现象,Backbone部分没有载入是错误的。\033[0m")
301 |
302 | #----------------------#
303 | # 记录Loss
304 | #----------------------#
305 | if local_rank == 0:
306 | time_str = datetime.datetime.strftime(datetime.datetime.now(),'%Y_%m_%d_%H_%M_%S')
307 | log_dir = os.path.join(save_dir, "loss_" + str(time_str))
308 | loss_history = LossHistory(log_dir, model, input_shape=input_shape)
309 | else:
310 | loss_history = None
311 |
312 | #------------------------------------------------------------------#
313 | # torch 1.2不支持amp,建议使用torch 1.7.1及以上正确使用fp16
314 | # 因此torch1.2这里显示"could not be resolve"
315 | #------------------------------------------------------------------#
316 | if fp16:
317 | from torch.cuda.amp import GradScaler as GradScaler
318 | scaler = GradScaler()
319 | else:
320 | scaler = None
321 |
322 | model_train = model.train()
323 | #----------------------------#
324 | # 多卡同步Bn
325 | #----------------------------#
326 | if sync_bn and ngpus_per_node > 1 and distributed:
327 | model_train = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model_train)
328 | elif sync_bn:
329 | print("Sync_bn is not support in one gpu or not distributed.")
330 |
331 | if Cuda:
332 | if distributed:
333 | #----------------------------#
334 | # 多卡平行运行
335 | #----------------------------#
336 | model_train = model_train.cuda(local_rank)
337 | model_train = torch.nn.parallel.DistributedDataParallel(model_train, device_ids=[local_rank], find_unused_parameters=True)
338 | else:
339 | model_train = torch.nn.DataParallel(model)
340 | cudnn.benchmark = True
341 | model_train = model_train.cuda()
342 |
343 | #---------------------------#
344 | # 读取数据集对应的txt
345 | #---------------------------#
346 | with open(os.path.join(VOCdevkit_path, "save/train.txt"),"r") as f:
347 | train_lines = f.readlines()
348 | with open(os.path.join(VOCdevkit_path, "save/val.txt"),"r") as f:
349 | val_lines = f.readlines()
350 | num_train = len(train_lines)
351 | num_val = len(val_lines)
352 |
353 | if local_rank == 0:
354 | show_config(
355 | num_classes = num_classes, backbone = backbone, model_path = model_path, input_shape = input_shape, \
356 | Init_Epoch = Init_Epoch, Freeze_Epoch = Freeze_Epoch, UnFreeze_Epoch = UnFreeze_Epoch, Freeze_batch_size = Freeze_batch_size, Unfreeze_batch_size = Unfreeze_batch_size, Freeze_Train = Freeze_Train, \
357 | Init_lr = Init_lr, Min_lr = Min_lr, optimizer_type = optimizer_type, momentum = momentum, lr_decay_type = lr_decay_type, \
358 | save_period = save_period, save_dir = save_dir, num_workers = num_workers, num_train = num_train, num_val = num_val
359 | )
360 | #------------------------------------------------------#
361 | # 主干特征提取网络特征通用,冻结训练可以加快训练速度
362 | # 也可以在训练初期防止权值被破坏。
363 | # Init_Epoch为起始世代
364 | # Interval_Epoch为冻结训练的世代
365 | # Epoch总训练世代
366 | # 提示OOM或者显存不足请调小Batch_size
367 | #------------------------------------------------------#
368 | if True:
369 | UnFreeze_flag = False
370 | #------------------------------------#
371 | # 冻结一定部分训练
372 | #------------------------------------#
373 | if Freeze_Train:
374 | model.freeze_backbone()
375 |
376 | #-------------------------------------------------------------------#
377 | # 如果不冻结训练的话,直接设置batch_size为Unfreeze_batch_size
378 | #-------------------------------------------------------------------#
379 | batch_size = Freeze_batch_size if Freeze_Train else Unfreeze_batch_size
380 |
381 | #-------------------------------------------------------------------#
382 | # 判断当前batch_size,自适应调整学习率
383 | #-------------------------------------------------------------------#
384 | nbs = 16
385 | lr_limit_max = 1e-4 if optimizer_type == 'adam' else 1e-1
386 | lr_limit_min = 1e-4 if optimizer_type == 'adam' else 5e-4
387 | Init_lr_fit = min(max(batch_size / nbs * Init_lr, lr_limit_min), lr_limit_max)
388 | Min_lr_fit = min(max(batch_size / nbs * Min_lr, lr_limit_min * 1e-2), lr_limit_max * 1e-2)
389 |
390 | #---------------------------------------#
391 | # 根据optimizer_type选择优化器
392 | #---------------------------------------#
393 | optimizer = {
394 | 'adam' : optim.Adam(model.parameters(), Init_lr_fit, betas = (momentum, 0.999), weight_decay = weight_decay),
395 | 'sgd' : optim.SGD(model.parameters(), Init_lr_fit, momentum = momentum, nesterov=True, weight_decay = weight_decay)
396 | }[optimizer_type]
397 |
398 | #---------------------------------------#
399 | # 获得学习率下降的公式
400 | #---------------------------------------#
401 | lr_scheduler_func = get_lr_scheduler(lr_decay_type, Init_lr_fit, Min_lr_fit, UnFreeze_Epoch)
402 |
403 | #---------------------------------------#
404 | # 判断每一个世代的长度
405 | #---------------------------------------#
406 | epoch_step = num_train // batch_size
407 | epoch_step_val = num_val // batch_size
408 |
409 | if epoch_step == 0 or epoch_step_val == 0:
410 | raise ValueError("数据集过小,无法继续进行训练,请扩充数据集。")
411 |
412 | train_dataset = UnetDataset(train_lines, input_shape, num_classes, True, VOCdevkit_path)
413 | val_dataset = UnetDataset(val_lines, input_shape, num_classes, False, VOCdevkit_path)
414 |
415 | if distributed:
416 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=True,)
417 | val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False,)
418 | batch_size = batch_size // ngpus_per_node
419 | shuffle = False
420 | else:
421 | train_sampler = None
422 | val_sampler = None
423 | shuffle = True
424 |
425 | gen = DataLoader(train_dataset, shuffle = shuffle, batch_size = batch_size, num_workers = num_workers, pin_memory=True,
426 | drop_last = True, collate_fn = unet_dataset_collate, sampler=train_sampler)
427 | gen_val = DataLoader(val_dataset , shuffle = shuffle, batch_size = batch_size, num_workers = num_workers, pin_memory=True,
428 | drop_last = True, collate_fn = unet_dataset_collate, sampler=val_sampler)
429 |
430 | #----------------------#
431 | # 记录eval的map曲线
432 | #----------------------#
433 | if local_rank == 0:
434 | eval_callback = EvalCallback(model, input_shape, num_classes, val_lines, VOCdevkit_path, log_dir, Cuda, \
435 | eval_flag=eval_flag, period=eval_period)
436 | else:
437 | eval_callback = None
438 |
439 | #---------------------------------------#
440 | # 开始模型训练
441 | #---------------------------------------#
442 | for epoch in range(Init_Epoch, UnFreeze_Epoch):
443 | #---------------------------------------#
444 | # 如果模型有冻结学习部分
445 | # 则解冻,并设置参数
446 | #---------------------------------------#
447 | if epoch >= Freeze_Epoch and not UnFreeze_flag and Freeze_Train:
448 | batch_size = Unfreeze_batch_size
449 |
450 | #-------------------------------------------------------------------#
451 | # 判断当前batch_size,自适应调整学习率
452 | #-------------------------------------------------------------------#
453 | nbs = 16
454 | lr_limit_max = 1e-4 if optimizer_type == 'adam' else 1e-1
455 | lr_limit_min = 1e-4 if optimizer_type == 'adam' else 5e-4
456 | Init_lr_fit = min(max(batch_size / nbs * Init_lr, lr_limit_min), lr_limit_max)
457 | Min_lr_fit = min(max(batch_size / nbs * Min_lr, lr_limit_min * 1e-2), lr_limit_max * 1e-2)
458 | #---------------------------------------#
459 | # 获得学习率下降的公式
460 | #---------------------------------------#
461 | lr_scheduler_func = get_lr_scheduler(lr_decay_type, Init_lr_fit, Min_lr_fit, UnFreeze_Epoch)
462 |
463 | model.unfreeze_backbone()
464 |
465 | epoch_step = num_train // batch_size
466 | epoch_step_val = num_val // batch_size
467 |
468 | if epoch_step == 0 or epoch_step_val == 0:
469 | raise ValueError("数据集过小,无法继续进行训练,请扩充数据集。")
470 |
471 | if distributed:
472 | batch_size = batch_size // ngpus_per_node
473 |
474 | gen = DataLoader(train_dataset, shuffle = shuffle, batch_size = batch_size, num_workers = num_workers, pin_memory=True,
475 | drop_last = True, collate_fn = unet_dataset_collate, sampler=train_sampler)
476 | gen_val = DataLoader(val_dataset , shuffle = shuffle, batch_size = batch_size, num_workers = num_workers, pin_memory=True,
477 | drop_last = True, collate_fn = unet_dataset_collate, sampler=val_sampler)
478 |
479 | UnFreeze_flag = True
480 |
481 | if distributed:
482 | train_sampler.set_epoch(epoch)
483 |
484 | set_optimizer_lr(optimizer, lr_scheduler_func, epoch)
485 |
486 | fit_one_epoch(model_train, model, loss_history, eval_callback, optimizer, epoch,
487 | epoch_step, epoch_step_val, gen, gen_val, UnFreeze_Epoch, Cuda, dice_loss, focal_loss, cls_weights, num_classes, fp16, scaler, save_period, save_dir, local_rank)
488 |
489 | if distributed:
490 | dist.barrier()
491 |
492 | if local_rank == 0:
493 | loss_history.writer.close()
494 |
--------------------------------------------------------------------------------