├── .gitignore
├── .idea
├── .gitignore
├── inspectionProfiles
│ ├── Project_Default.xml
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
├── pytorch-unet.iml
└── vcs.xml
├── LICENSE
├── README.en.md
├── README.md
├── data.py
├── data
├── JPEGImages
│ └── 000799.png
├── SegmentationClass
│ └── 000799.png
├── image
│ ├── 000799.json
│ └── 000799.png
└── make_mask_data.py
├── evaluation
└── get_evaluation.py
├── net.py
├── params
└── README.md
├── result
├── README.md
└── result.png
├── test.py
├── train.py
├── train_image
├── 0.png
└── README.md
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
98 | __pypackages__/
99 |
100 | # Celery stuff
101 | celerybeat-schedule
102 | celerybeat.pid
103 |
104 | # SageMath parsed files
105 | *.sage.py
106 |
107 | # Environments
108 | .env
109 | .venv
110 | env/
111 | venv/
112 | ENV/
113 | env.bak/
114 | venv.bak/
115 |
116 | # Spyder project settings
117 | .spyderproject
118 | .spyproject
119 |
120 | # Rope project settings
121 | .ropeproject
122 |
123 | # mkdocs documentation
124 | /site
125 |
126 | # mypy
127 | .mypy_cache/
128 | .dmypy.json
129 | dmypy.json
130 |
131 | # Pyre type checker
132 | .pyre/
133 |
134 | # pytype static type analyzer
135 | .pytype/
136 |
137 | # Cython debug symbols
138 | cython_debug/
139 |
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
27 |
28 |
29 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/pytorch-unet.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | 仅供学习,禁止用于任何违法行为上!
--------------------------------------------------------------------------------
/README.en.md:
--------------------------------------------------------------------------------
1 | # pytorch-unet
2 |
3 | #### Description
4 | pytorch搭建自己的unet网络,训练自己的数据集。
5 |
6 | #### Software Architecture
7 | Software architecture description
8 |
9 | #### Installation
10 |
11 | 1. xxxx
12 | 2. xxxx
13 | 3. xxxx
14 |
15 | #### Instructions
16 |
17 | 1. xxxx
18 | 2. xxxx
19 | 3. xxxx
20 |
21 | #### Contribution
22 |
23 | 1. Fork the repository
24 | 2. Create Feat_xxx branch
25 | 3. Commit your code
26 | 4. Create Pull Request
27 |
28 |
29 | #### Gitee Feature
30 |
31 | 1. You can use Readme\_XXX.md to support different languages, such as Readme\_en.md, Readme\_zh.md
32 | 2. Gitee blog [blog.gitee.com](https://blog.gitee.com)
33 | 3. Explore open source project [https://gitee.com/explore](https://gitee.com/explore)
34 | 4. The most valuable open source project [GVP](https://gitee.com/gvp)
35 | 5. The manual of Gitee [https://gitee.com/help](https://gitee.com/help)
36 | 6. The most popular members [https://gitee.com/gitee-stars/](https://gitee.com/gitee-stars/)
37 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # pytorch-unet
2 |
3 | #### 介绍
4 | pytorch搭建自己的unet网络,训练自己的数据集。
5 | #### 软件架构
6 | pythorch
7 |
8 |
9 | #### 安装教程
10 |
11 | 1. 下载最新版本的pytorch就可以
12 |
13 | #### 使用说明
14 |
15 | 1. 数据集原图存放地址:data/JPEGImages mask存放地址:data/SegmentationClass
16 | 2. 直接运行train.py,其中train_image文件夹存储的是训练过程中的效果图
17 | 3. params文件夹保存权重
18 | 4. 测试test.py,用来测试图片,测试结果存储在result文件夹中
19 |
20 |
21 | #### 视频地址
22 | B站:https://www.bilibili.com/video/BV11341127iK?spm_id_from=333.999.0.0
23 |
--------------------------------------------------------------------------------
/data.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | import torch
5 | from torch.utils.data import Dataset
6 | from utils import *
7 | from torchvision import transforms
8 |
9 | transform = transforms.Compose([
10 | transforms.ToTensor()
11 | ])
12 |
13 |
14 | class MyDataset(Dataset):
15 | def __init__(self, path):
16 | self.path = path
17 | self.name = os.listdir(os.path.join(path, 'SegmentationClass'))
18 |
19 | def __len__(self):
20 | return len(self.name)
21 |
22 | def __getitem__(self, index):
23 | segment_name = self.name[index] # xx.png
24 | segment_path = os.path.join(self.path, 'SegmentationClass', segment_name)
25 | image_path = os.path.join(self.path, 'JPEGImages', segment_name)
26 | segment_image = keep_image_size_open(segment_path)
27 | image = keep_image_size_open_rgb(image_path)
28 | return transform(image), torch.Tensor(np.array(segment_image))
29 |
30 |
31 | if __name__ == '__main__':
32 | from torch.nn.functional import one_hot
33 | data = MyDataset('data')
34 | print(data[0][0].shape)
35 | print(data[0][1].shape)
36 | out=one_hot(data[0][1].long())
37 | print(out.shape)
38 |
--------------------------------------------------------------------------------
/data/JPEGImages/000799.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/qiaofengsheng/pytorch-UNet/d4d649eb357fc25896c3fd15080e507e13f17571/data/JPEGImages/000799.png
--------------------------------------------------------------------------------
/data/SegmentationClass/000799.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/qiaofengsheng/pytorch-UNet/d4d649eb357fc25896c3fd15080e507e13f17571/data/SegmentationClass/000799.png
--------------------------------------------------------------------------------
/data/image/000799.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/qiaofengsheng/pytorch-UNet/d4d649eb357fc25896c3fd15080e507e13f17571/data/image/000799.png
--------------------------------------------------------------------------------
/data/make_mask_data.py:
--------------------------------------------------------------------------------
1 | '''
2 | ==================板块功能描述====================
3 | @Time :2022/4/9 15:34
4 | @Author : qiaofengsheng
5 | @File :make_mask_data.py
6 | @Software :PyCharm
7 | @description:
8 | ================================================
9 | '''
10 | import os
11 |
12 | import cv2
13 | import numpy as np
14 | from PIL import Image, ImageDraw
15 | import json
16 |
17 | CLASS_NAMES = ['horse', 'person']
18 |
19 |
20 | def make_mask(image_dir, save_dir):
21 | data = os.listdir(image_dir)
22 | temp_data = []
23 | for i in data:
24 | if i.split('.')[1] == 'json':
25 | temp_data.append(i)
26 | else:
27 | continue
28 | for js in temp_data:
29 | json_data = json.load(open(os.path.join(image_dir, js), 'r'))
30 | shapes_ = json_data['shapes']
31 | mask = Image.new('P', Image.open(os.path.join(image_dir, js.replace('json', 'png'))).size)
32 | for shape_ in shapes_:
33 | label = shape_['label']
34 | points = shape_['points']
35 | points = tuple(tuple(i) for i in points)
36 | mask_draw = ImageDraw.Draw(mask)
37 | mask_draw.polygon(points, fill=CLASS_NAMES.index(label) + 1)
38 | mask.save(os.path.join(save_dir, js.replace('json', 'png')))
39 |
40 |
41 | def vis_label(img):
42 | img=Image.open(img)
43 | img=np.array(img)
44 | print(set(img.reshape(-1).tolist()))
45 |
46 |
47 |
48 | if __name__ == '__main__':
49 | # make_mask('image', 'SegmentationClass')
50 | vis_label('SegmentationClass/000799.png')
51 | # img=Image.open('SegmentationClass/000019.png')
52 | # print(np.array(img).shape)
53 | # out=np.array(img).reshape(-1)
54 | # print(set(out.tolist()))
55 |
--------------------------------------------------------------------------------
/evaluation/get_evaluation.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 | from sklearn.metrics import confusion_matrix
4 | import cv2
5 |
6 | def keep_image_size_open_label(path, size=(256, 256)):
7 | img = Image.open(path)
8 | temp = max(img.size)
9 | mask = Image.new('P', (temp, temp))
10 | mask.paste(img, (0, 0))
11 | mask = mask.resize(size)
12 | mask = np.array(mask)
13 | mask[mask!=255]=0
14 | mask[mask==255]=1
15 | mask = Image.fromarray(mask)
16 | return mask
17 |
18 | def keep_image_size_open_predict(path, size=(256, 256)):
19 | img = Image.open(path)
20 | temp = max(img.size)
21 | mask = Image.new('P', (temp, temp))
22 | mask.paste(img, (0, 0))
23 | mask = mask.resize(size)
24 | mask = np.array(mask)
25 | mask = Image.fromarray(mask)
26 | return mask
27 |
28 | def compute_iou(seg_pred, seg_gt, num_classes):
29 | ious = []
30 | for c in range(num_classes):
31 | pred_inds = seg_pred == c
32 | target_inds = seg_gt == c
33 | intersection = np.logical_and(pred_inds, target_inds).sum()
34 | union = np.logical_or(pred_inds, target_inds).sum()
35 | if union == 0:
36 | ious.append(float('nan'))
37 | else:
38 | ious.append(float(intersection) / float(union))
39 | return ious
40 |
41 | def compute_miou(seg_preds, seg_gts, num_classes):
42 | ious = []
43 | for i in range(len(seg_preds)):
44 | ious.append(compute_iou(seg_preds[i], seg_gts[i], num_classes))
45 | ious = np.array(ious, dtype=np.float32)
46 | miou = np.nanmean(ious, axis=0)
47 | return miou
48 |
49 | if __name__ == '__main__':
50 | from PIL import Image
51 | import os
52 |
53 | label_path = "data/val/SegmentationClass" # 标签的文件夹位置
54 |
55 | predict_path = "data/val/predict" # 预测结果的文件夹位置
56 |
57 | res_miou = []
58 | for pred_im in os.listdir(predict_path):
59 | label = keep_image_size_open_label(os.path.join(label_path,pred_im))
60 | pred = keep_image_size_open_predict(os.path.join(predict_path,pred_im))
61 | l, p = np.array(label).astype(int), np.array(pred).astype(int)
62 | print(set(l.reshape(-1).tolist()),set(p.reshape(-1).tolist()))
63 | miou = compute_miou(p,l,2)
64 | res_miou.append(miou)
65 | print(np.array(res_miou).mean(axis=0))
66 |
67 |
68 |
--------------------------------------------------------------------------------
/net.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 |
5 | class Conv_Block(nn.Module):
6 | def __init__(self,in_channel,out_channel):
7 | super(Conv_Block, self).__init__()
8 | self.layer=nn.Sequential(
9 | nn.Conv2d(in_channel,out_channel,3,1,1,padding_mode='reflect',bias=False),
10 | nn.BatchNorm2d(out_channel),
11 | nn.Dropout2d(0.3),
12 | nn.LeakyReLU(),
13 | nn.Conv2d(out_channel, out_channel, 3, 1, 1, padding_mode='reflect', bias=False),
14 | nn.BatchNorm2d(out_channel),
15 | nn.Dropout2d(0.3),
16 | nn.LeakyReLU()
17 | )
18 | def forward(self,x):
19 | return self.layer(x)
20 |
21 |
22 | class DownSample(nn.Module):
23 | def __init__(self,channel):
24 | super(DownSample, self).__init__()
25 | self.layer=nn.Sequential(
26 | nn.Conv2d(channel,channel,3,2,1,padding_mode='reflect',bias=False),
27 | nn.BatchNorm2d(channel),
28 | nn.LeakyReLU()
29 | )
30 | def forward(self,x):
31 | return self.layer(x)
32 |
33 |
34 | class UpSample(nn.Module):
35 | def __init__(self,channel):
36 | super(UpSample, self).__init__()
37 | self.layer=nn.Conv2d(channel,channel//2,1,1)
38 | def forward(self,x,feature_map):
39 | up=F.interpolate(x,scale_factor=2,mode='nearest')
40 | out=self.layer(up)
41 | return torch.cat((out,feature_map),dim=1)
42 |
43 |
44 | class UNet(nn.Module):
45 | def __init__(self,num_classes):
46 | super(UNet, self).__init__()
47 | self.c1=Conv_Block(3,64)
48 | self.d1=DownSample(64)
49 | self.c2=Conv_Block(64,128)
50 | self.d2=DownSample(128)
51 | self.c3=Conv_Block(128,256)
52 | self.d3=DownSample(256)
53 | self.c4=Conv_Block(256,512)
54 | self.d4=DownSample(512)
55 | self.c5=Conv_Block(512,1024)
56 | self.u1=UpSample(1024)
57 | self.c6=Conv_Block(1024,512)
58 | self.u2 = UpSample(512)
59 | self.c7 = Conv_Block(512, 256)
60 | self.u3 = UpSample(256)
61 | self.c8 = Conv_Block(256, 128)
62 | self.u4 = UpSample(128)
63 | self.c9 = Conv_Block(128, 64)
64 | self.out=nn.Conv2d(64,num_classes,3,1,1)
65 |
66 | def forward(self,x):
67 | R1=self.c1(x)
68 | R2=self.c2(self.d1(R1))
69 | R3 = self.c3(self.d2(R2))
70 | R4 = self.c4(self.d3(R3))
71 | R5 = self.c5(self.d4(R4))
72 | O1=self.c6(self.u1(R5,R4))
73 | O2 = self.c7(self.u2(O1, R3))
74 | O3 = self.c8(self.u3(O2, R2))
75 | O4 = self.c9(self.u4(O3, R1))
76 |
77 | return self.out(O4)
78 |
79 | if __name__ == '__main__':
80 | x=torch.randn(2,3,256,256)
81 | net=UNet()
82 | print(net(x).shape)
--------------------------------------------------------------------------------
/params/README.md:
--------------------------------------------------------------------------------
1 | 存放权重文件
--------------------------------------------------------------------------------
/result/README.md:
--------------------------------------------------------------------------------
1 | 存放结果文件
2 |
--------------------------------------------------------------------------------
/result/result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/qiaofengsheng/pytorch-UNet/d4d649eb357fc25896c3fd15080e507e13f17571/result/result.png
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import cv2
4 | import numpy as np
5 | import torch
6 |
7 | from net import *
8 | from utils import *
9 | from data import *
10 | from torchvision.utils import save_image
11 | from PIL import Image
12 | net=UNet(3).cuda()
13 |
14 | weights='params/unet.pth'
15 | if os.path.exists(weights):
16 | net.load_state_dict(torch.load(weights))
17 | print('successfully')
18 | else:
19 | print('no loading')
20 |
21 | _input=input('please input JPEGImages path:')
22 |
23 | img=keep_image_size_open_rgb(_input)
24 | img_data=transform(img).cuda()
25 | img_data=torch.unsqueeze(img_data,dim=0)
26 | net.eval()
27 | out=net(img_data)
28 | out=torch.argmax(out,dim=1)
29 | out=torch.squeeze(out,dim=0)
30 | out=out.unsqueeze(dim=0)
31 | print(set((out).reshape(-1).tolist()))
32 | out=(out).permute((1,2,0)).cpu().detach().numpy()
33 | cv2.imwrite('result/result.png',out)
34 | cv2.imshow('out',out*255.0)
35 | cv2.waitKey(0)
36 |
37 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import tqdm
4 | from torch import nn, optim
5 | import torch
6 | from torch.utils.data import DataLoader
7 | from data import *
8 | from net import *
9 | from torchvision.utils import save_image
10 |
11 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12 | weight_path = 'params/unet.pth'
13 | data_path = r'data'
14 | save_path = 'train_image'
15 | if __name__ == '__main__':
16 | num_classes = 2 + 1 # +1是背景也为一类
17 | data_loader = DataLoader(MyDataset(data_path), batch_size=1, shuffle=True)
18 | net = UNet(num_classes).to(device)
19 | if os.path.exists(weight_path):
20 | net.load_state_dict(torch.load(weight_path))
21 | print('successful load weight!')
22 | else:
23 | print('not successful load weight')
24 |
25 | opt = optim.Adam(net.parameters())
26 | loss_fun = nn.CrossEntropyLoss()
27 |
28 | epoch = 1
29 | while epoch < 200:
30 | for i, (image, segment_image) in enumerate(tqdm.tqdm(data_loader)):
31 | image, segment_image = image.to(device), segment_image.to(device)
32 | out_image = net(image)
33 | train_loss = loss_fun(out_image, segment_image.long())
34 | opt.zero_grad()
35 | train_loss.backward()
36 | opt.step()
37 |
38 | if i % 1 == 0:
39 | print(f'{epoch}-{i}-train_loss===>>{train_loss.item()}')
40 |
41 | _image = image[0]
42 | _segment_image = torch.unsqueeze(segment_image[0], 0) * 255
43 | _out_image = torch.argmax(out_image[0], dim=0).unsqueeze(0) * 255
44 |
45 | img = torch.stack([_segment_image, _out_image], dim=0)
46 | save_image(img, f'{save_path}/{i}.png')
47 | if epoch % 20 == 0:
48 | torch.save(net.state_dict(), weight_path)
49 | print('save successfully!')
50 | epoch += 1
51 |
--------------------------------------------------------------------------------
/train_image/0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/qiaofengsheng/pytorch-UNet/d4d649eb357fc25896c3fd15080e507e13f17571/train_image/0.png
--------------------------------------------------------------------------------
/train_image/README.md:
--------------------------------------------------------------------------------
1 | 训练时效果图
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 |
3 |
4 | def keep_image_size_open(path, size=(256, 256)):
5 | img = Image.open(path)
6 | temp = max(img.size)
7 | mask = Image.new('P', (temp, temp))
8 | mask.paste(img, (0, 0))
9 | mask = mask.resize(size)
10 | return mask
11 | def keep_image_size_open_rgb(path, size=(256, 256)):
12 | img = Image.open(path)
13 | temp = max(img.size)
14 | mask = Image.new('RGB', (temp, temp))
15 | mask.paste(img, (0, 0))
16 | mask = mask.resize(size)
17 | return mask
18 |
--------------------------------------------------------------------------------