├── .gitignore ├── .idea ├── .gitignore ├── inspectionProfiles │ ├── Project_Default.xml │ └── profiles_settings.xml ├── misc.xml ├── modules.xml ├── pytorch-unet.iml └── vcs.xml ├── LICENSE ├── README.en.md ├── README.md ├── data.py ├── data ├── JPEGImages │ └── 000799.png ├── SegmentationClass │ └── 000799.png ├── image │ ├── 000799.json │ └── 000799.png └── make_mask_data.py ├── evaluation └── get_evaluation.py ├── net.py ├── params └── README.md ├── result ├── README.md └── result.png ├── test.py ├── train.py ├── train_image ├── 0.png └── README.md └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 29 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/pytorch-unet.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 仅供学习,禁止用于任何违法行为上! -------------------------------------------------------------------------------- /README.en.md: -------------------------------------------------------------------------------- 1 | # pytorch-unet 2 | 3 | #### Description 4 | pytorch搭建自己的unet网络,训练自己的数据集。 5 | 6 | #### Software Architecture 7 | Software architecture description 8 | 9 | #### Installation 10 | 11 | 1. xxxx 12 | 2. xxxx 13 | 3. xxxx 14 | 15 | #### Instructions 16 | 17 | 1. xxxx 18 | 2. xxxx 19 | 3. xxxx 20 | 21 | #### Contribution 22 | 23 | 1. Fork the repository 24 | 2. Create Feat_xxx branch 25 | 3. Commit your code 26 | 4. Create Pull Request 27 | 28 | 29 | #### Gitee Feature 30 | 31 | 1. You can use Readme\_XXX.md to support different languages, such as Readme\_en.md, Readme\_zh.md 32 | 2. Gitee blog [blog.gitee.com](https://blog.gitee.com) 33 | 3. Explore open source project [https://gitee.com/explore](https://gitee.com/explore) 34 | 4. The most valuable open source project [GVP](https://gitee.com/gvp) 35 | 5. The manual of Gitee [https://gitee.com/help](https://gitee.com/help) 36 | 6. The most popular members [https://gitee.com/gitee-stars/](https://gitee.com/gitee-stars/) 37 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-unet 2 | 3 | #### 介绍 4 | pytorch搭建自己的unet网络,训练自己的数据集。 5 | #### 软件架构 6 | pythorch 7 | 8 | 9 | #### 安装教程 10 | 11 | 1. 下载最新版本的pytorch就可以 12 | 13 | #### 使用说明 14 | 15 | 1. 数据集原图存放地址:data/JPEGImages mask存放地址:data/SegmentationClass 16 | 2. 直接运行train.py,其中train_image文件夹存储的是训练过程中的效果图 17 | 3. params文件夹保存权重 18 | 4. 测试test.py,用来测试图片,测试结果存储在result文件夹中 19 | 20 | 21 | #### 视频地址 22 | B站:https://www.bilibili.com/video/BV11341127iK?spm_id_from=333.999.0.0 23 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | from torch.utils.data import Dataset 6 | from utils import * 7 | from torchvision import transforms 8 | 9 | transform = transforms.Compose([ 10 | transforms.ToTensor() 11 | ]) 12 | 13 | 14 | class MyDataset(Dataset): 15 | def __init__(self, path): 16 | self.path = path 17 | self.name = os.listdir(os.path.join(path, 'SegmentationClass')) 18 | 19 | def __len__(self): 20 | return len(self.name) 21 | 22 | def __getitem__(self, index): 23 | segment_name = self.name[index] # xx.png 24 | segment_path = os.path.join(self.path, 'SegmentationClass', segment_name) 25 | image_path = os.path.join(self.path, 'JPEGImages', segment_name) 26 | segment_image = keep_image_size_open(segment_path) 27 | image = keep_image_size_open_rgb(image_path) 28 | return transform(image), torch.Tensor(np.array(segment_image)) 29 | 30 | 31 | if __name__ == '__main__': 32 | from torch.nn.functional import one_hot 33 | data = MyDataset('data') 34 | print(data[0][0].shape) 35 | print(data[0][1].shape) 36 | out=one_hot(data[0][1].long()) 37 | print(out.shape) 38 | -------------------------------------------------------------------------------- /data/JPEGImages/000799.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaofengsheng/pytorch-UNet/d4d649eb357fc25896c3fd15080e507e13f17571/data/JPEGImages/000799.png -------------------------------------------------------------------------------- /data/SegmentationClass/000799.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaofengsheng/pytorch-UNet/d4d649eb357fc25896c3fd15080e507e13f17571/data/SegmentationClass/000799.png -------------------------------------------------------------------------------- /data/image/000799.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaofengsheng/pytorch-UNet/d4d649eb357fc25896c3fd15080e507e13f17571/data/image/000799.png -------------------------------------------------------------------------------- /data/make_mask_data.py: -------------------------------------------------------------------------------- 1 | ''' 2 | ==================板块功能描述==================== 3 | @Time :2022/4/9 15:34 4 | @Author : qiaofengsheng 5 | @File :make_mask_data.py 6 | @Software :PyCharm 7 | @description: 8 | ================================================ 9 | ''' 10 | import os 11 | 12 | import cv2 13 | import numpy as np 14 | from PIL import Image, ImageDraw 15 | import json 16 | 17 | CLASS_NAMES = ['horse', 'person'] 18 | 19 | 20 | def make_mask(image_dir, save_dir): 21 | data = os.listdir(image_dir) 22 | temp_data = [] 23 | for i in data: 24 | if i.split('.')[1] == 'json': 25 | temp_data.append(i) 26 | else: 27 | continue 28 | for js in temp_data: 29 | json_data = json.load(open(os.path.join(image_dir, js), 'r')) 30 | shapes_ = json_data['shapes'] 31 | mask = Image.new('P', Image.open(os.path.join(image_dir, js.replace('json', 'png'))).size) 32 | for shape_ in shapes_: 33 | label = shape_['label'] 34 | points = shape_['points'] 35 | points = tuple(tuple(i) for i in points) 36 | mask_draw = ImageDraw.Draw(mask) 37 | mask_draw.polygon(points, fill=CLASS_NAMES.index(label) + 1) 38 | mask.save(os.path.join(save_dir, js.replace('json', 'png'))) 39 | 40 | 41 | def vis_label(img): 42 | img=Image.open(img) 43 | img=np.array(img) 44 | print(set(img.reshape(-1).tolist())) 45 | 46 | 47 | 48 | if __name__ == '__main__': 49 | # make_mask('image', 'SegmentationClass') 50 | vis_label('SegmentationClass/000799.png') 51 | # img=Image.open('SegmentationClass/000019.png') 52 | # print(np.array(img).shape) 53 | # out=np.array(img).reshape(-1) 54 | # print(set(out.tolist())) 55 | -------------------------------------------------------------------------------- /evaluation/get_evaluation.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | from sklearn.metrics import confusion_matrix 4 | import cv2 5 | 6 | def keep_image_size_open_label(path, size=(256, 256)): 7 | img = Image.open(path) 8 | temp = max(img.size) 9 | mask = Image.new('P', (temp, temp)) 10 | mask.paste(img, (0, 0)) 11 | mask = mask.resize(size) 12 | mask = np.array(mask) 13 | mask[mask!=255]=0 14 | mask[mask==255]=1 15 | mask = Image.fromarray(mask) 16 | return mask 17 | 18 | def keep_image_size_open_predict(path, size=(256, 256)): 19 | img = Image.open(path) 20 | temp = max(img.size) 21 | mask = Image.new('P', (temp, temp)) 22 | mask.paste(img, (0, 0)) 23 | mask = mask.resize(size) 24 | mask = np.array(mask) 25 | mask = Image.fromarray(mask) 26 | return mask 27 | 28 | def compute_iou(seg_pred, seg_gt, num_classes): 29 | ious = [] 30 | for c in range(num_classes): 31 | pred_inds = seg_pred == c 32 | target_inds = seg_gt == c 33 | intersection = np.logical_and(pred_inds, target_inds).sum() 34 | union = np.logical_or(pred_inds, target_inds).sum() 35 | if union == 0: 36 | ious.append(float('nan')) 37 | else: 38 | ious.append(float(intersection) / float(union)) 39 | return ious 40 | 41 | def compute_miou(seg_preds, seg_gts, num_classes): 42 | ious = [] 43 | for i in range(len(seg_preds)): 44 | ious.append(compute_iou(seg_preds[i], seg_gts[i], num_classes)) 45 | ious = np.array(ious, dtype=np.float32) 46 | miou = np.nanmean(ious, axis=0) 47 | return miou 48 | 49 | if __name__ == '__main__': 50 | from PIL import Image 51 | import os 52 | 53 | label_path = "data/val/SegmentationClass" # 标签的文件夹位置 54 | 55 | predict_path = "data/val/predict" # 预测结果的文件夹位置 56 | 57 | res_miou = [] 58 | for pred_im in os.listdir(predict_path): 59 | label = keep_image_size_open_label(os.path.join(label_path,pred_im)) 60 | pred = keep_image_size_open_predict(os.path.join(predict_path,pred_im)) 61 | l, p = np.array(label).astype(int), np.array(pred).astype(int) 62 | print(set(l.reshape(-1).tolist()),set(p.reshape(-1).tolist())) 63 | miou = compute_miou(p,l,2) 64 | res_miou.append(miou) 65 | print(np.array(res_miou).mean(axis=0)) 66 | 67 | 68 | -------------------------------------------------------------------------------- /net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | class Conv_Block(nn.Module): 6 | def __init__(self,in_channel,out_channel): 7 | super(Conv_Block, self).__init__() 8 | self.layer=nn.Sequential( 9 | nn.Conv2d(in_channel,out_channel,3,1,1,padding_mode='reflect',bias=False), 10 | nn.BatchNorm2d(out_channel), 11 | nn.Dropout2d(0.3), 12 | nn.LeakyReLU(), 13 | nn.Conv2d(out_channel, out_channel, 3, 1, 1, padding_mode='reflect', bias=False), 14 | nn.BatchNorm2d(out_channel), 15 | nn.Dropout2d(0.3), 16 | nn.LeakyReLU() 17 | ) 18 | def forward(self,x): 19 | return self.layer(x) 20 | 21 | 22 | class DownSample(nn.Module): 23 | def __init__(self,channel): 24 | super(DownSample, self).__init__() 25 | self.layer=nn.Sequential( 26 | nn.Conv2d(channel,channel,3,2,1,padding_mode='reflect',bias=False), 27 | nn.BatchNorm2d(channel), 28 | nn.LeakyReLU() 29 | ) 30 | def forward(self,x): 31 | return self.layer(x) 32 | 33 | 34 | class UpSample(nn.Module): 35 | def __init__(self,channel): 36 | super(UpSample, self).__init__() 37 | self.layer=nn.Conv2d(channel,channel//2,1,1) 38 | def forward(self,x,feature_map): 39 | up=F.interpolate(x,scale_factor=2,mode='nearest') 40 | out=self.layer(up) 41 | return torch.cat((out,feature_map),dim=1) 42 | 43 | 44 | class UNet(nn.Module): 45 | def __init__(self,num_classes): 46 | super(UNet, self).__init__() 47 | self.c1=Conv_Block(3,64) 48 | self.d1=DownSample(64) 49 | self.c2=Conv_Block(64,128) 50 | self.d2=DownSample(128) 51 | self.c3=Conv_Block(128,256) 52 | self.d3=DownSample(256) 53 | self.c4=Conv_Block(256,512) 54 | self.d4=DownSample(512) 55 | self.c5=Conv_Block(512,1024) 56 | self.u1=UpSample(1024) 57 | self.c6=Conv_Block(1024,512) 58 | self.u2 = UpSample(512) 59 | self.c7 = Conv_Block(512, 256) 60 | self.u3 = UpSample(256) 61 | self.c8 = Conv_Block(256, 128) 62 | self.u4 = UpSample(128) 63 | self.c9 = Conv_Block(128, 64) 64 | self.out=nn.Conv2d(64,num_classes,3,1,1) 65 | 66 | def forward(self,x): 67 | R1=self.c1(x) 68 | R2=self.c2(self.d1(R1)) 69 | R3 = self.c3(self.d2(R2)) 70 | R4 = self.c4(self.d3(R3)) 71 | R5 = self.c5(self.d4(R4)) 72 | O1=self.c6(self.u1(R5,R4)) 73 | O2 = self.c7(self.u2(O1, R3)) 74 | O3 = self.c8(self.u3(O2, R2)) 75 | O4 = self.c9(self.u4(O3, R1)) 76 | 77 | return self.out(O4) 78 | 79 | if __name__ == '__main__': 80 | x=torch.randn(2,3,256,256) 81 | net=UNet() 82 | print(net(x).shape) -------------------------------------------------------------------------------- /params/README.md: -------------------------------------------------------------------------------- 1 | 存放权重文件 -------------------------------------------------------------------------------- /result/README.md: -------------------------------------------------------------------------------- 1 | 存放结果文件 2 | -------------------------------------------------------------------------------- /result/result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaofengsheng/pytorch-UNet/d4d649eb357fc25896c3fd15080e507e13f17571/result/result.png -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import cv2 4 | import numpy as np 5 | import torch 6 | 7 | from net import * 8 | from utils import * 9 | from data import * 10 | from torchvision.utils import save_image 11 | from PIL import Image 12 | net=UNet(3).cuda() 13 | 14 | weights='params/unet.pth' 15 | if os.path.exists(weights): 16 | net.load_state_dict(torch.load(weights)) 17 | print('successfully') 18 | else: 19 | print('no loading') 20 | 21 | _input=input('please input JPEGImages path:') 22 | 23 | img=keep_image_size_open_rgb(_input) 24 | img_data=transform(img).cuda() 25 | img_data=torch.unsqueeze(img_data,dim=0) 26 | net.eval() 27 | out=net(img_data) 28 | out=torch.argmax(out,dim=1) 29 | out=torch.squeeze(out,dim=0) 30 | out=out.unsqueeze(dim=0) 31 | print(set((out).reshape(-1).tolist())) 32 | out=(out).permute((1,2,0)).cpu().detach().numpy() 33 | cv2.imwrite('result/result.png',out) 34 | cv2.imshow('out',out*255.0) 35 | cv2.waitKey(0) 36 | 37 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import tqdm 4 | from torch import nn, optim 5 | import torch 6 | from torch.utils.data import DataLoader 7 | from data import * 8 | from net import * 9 | from torchvision.utils import save_image 10 | 11 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 12 | weight_path = 'params/unet.pth' 13 | data_path = r'data' 14 | save_path = 'train_image' 15 | if __name__ == '__main__': 16 | num_classes = 2 + 1 # +1是背景也为一类 17 | data_loader = DataLoader(MyDataset(data_path), batch_size=1, shuffle=True) 18 | net = UNet(num_classes).to(device) 19 | if os.path.exists(weight_path): 20 | net.load_state_dict(torch.load(weight_path)) 21 | print('successful load weight!') 22 | else: 23 | print('not successful load weight') 24 | 25 | opt = optim.Adam(net.parameters()) 26 | loss_fun = nn.CrossEntropyLoss() 27 | 28 | epoch = 1 29 | while epoch < 200: 30 | for i, (image, segment_image) in enumerate(tqdm.tqdm(data_loader)): 31 | image, segment_image = image.to(device), segment_image.to(device) 32 | out_image = net(image) 33 | train_loss = loss_fun(out_image, segment_image.long()) 34 | opt.zero_grad() 35 | train_loss.backward() 36 | opt.step() 37 | 38 | if i % 1 == 0: 39 | print(f'{epoch}-{i}-train_loss===>>{train_loss.item()}') 40 | 41 | _image = image[0] 42 | _segment_image = torch.unsqueeze(segment_image[0], 0) * 255 43 | _out_image = torch.argmax(out_image[0], dim=0).unsqueeze(0) * 255 44 | 45 | img = torch.stack([_segment_image, _out_image], dim=0) 46 | save_image(img, f'{save_path}/{i}.png') 47 | if epoch % 20 == 0: 48 | torch.save(net.state_dict(), weight_path) 49 | print('save successfully!') 50 | epoch += 1 51 | -------------------------------------------------------------------------------- /train_image/0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaofengsheng/pytorch-UNet/d4d649eb357fc25896c3fd15080e507e13f17571/train_image/0.png -------------------------------------------------------------------------------- /train_image/README.md: -------------------------------------------------------------------------------- 1 | 训练时效果图 -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | 3 | 4 | def keep_image_size_open(path, size=(256, 256)): 5 | img = Image.open(path) 6 | temp = max(img.size) 7 | mask = Image.new('P', (temp, temp)) 8 | mask.paste(img, (0, 0)) 9 | mask = mask.resize(size) 10 | return mask 11 | def keep_image_size_open_rgb(path, size=(256, 256)): 12 | img = Image.open(path) 13 | temp = max(img.size) 14 | mask = Image.new('RGB', (temp, temp)) 15 | mask.paste(img, (0, 0)) 16 | mask = mask.resize(size) 17 | return mask 18 | --------------------------------------------------------------------------------