├── README.md ├── data_loader.py ├── data_set └── Urben_pre │ ├── annotation │ ├── 0.png │ ├── 1.png │ ├── 10.png │ ├── 29.png │ └── 31.png │ ├── test │ ├── 29.png │ └── 31.png │ ├── train │ ├── 0.png │ └── 1.png │ └── valid │ ├── 1.png │ └── 10.png ├── dataset.py ├── enhance_image.py ├── evaluation.py ├── img ├── AttU-Net.png └── U-Net.png ├── main.py ├── models └── train.pth ├── network.py ├── result └── res.csv ├── solver.py ├── test ├── test_data.py ├── test_one_data.py └── test_train.py └── util └── data_preprocess.py /README.md: -------------------------------------------------------------------------------- 1 | # Water_extraction 2 | 高分辨率城市遥感图像的水体提取 3 | #数据来源 4 | 使用的是由武汉大学的王俊觉、卓峥等研究员所创建的土地覆盖领域自适应语义分割(LoveDA)数据集。这个数据集创建的目的是为了探索深度迁移学习方法将如何应用于促进城市或国家级土地覆盖制图,这个数据集可以适应土地覆盖语义分割和无监督领域自适应(UDA)任务。 5 | ![image](https://user-images.githubusercontent.com/48434805/178443289-ade0da93-b4d1-4c86-a527-90eeaba017c2.png) 6 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from random import shuffle 4 | import numpy as np 5 | import torch 6 | from torch.utils import data 7 | from torchvision import transforms as T 8 | from torchvision.transforms import functional as F 9 | from PIL import Image 10 | 11 | class ImageFolder(data.Dataset): 12 | def __init__(self, root, image_size=224, mode='train', augmentation_prob=0): 13 | """Initializes image paths and preprocessing module.""" 14 | self.root = root 15 | 16 | # GT : Ground Truth 17 | if root[-6:] == "train/": 18 | self.GT_paths = root[:-6]+'annotation/' 19 | elif root[-6:] == "valid/": 20 | self.GT_paths = root[:-6] + 'annotation/' 21 | elif root[-5:] == "test/": 22 | self.GT_paths = root[:-5] + 'annotation/' 23 | self.image_paths = list(map(lambda x: os.path.join(root, x), os.listdir(root))) 24 | self.image_size = image_size 25 | self.mode = mode 26 | self.RotationDegree = [0,90,180,270] 27 | self.augmentation_prob = augmentation_prob 28 | print("image count in {} path :{}".format(self.mode, len(self.image_paths))) 29 | 30 | def __getitem__(self, index): 31 | """Reads an image from a file and preprocesses it and returns.""" 32 | image_path = self.image_paths[index] 33 | if self.mode == 'train' or self.mode == 'valid': 34 | filename = image_path.split('/')[-1][:-len(".png")] 35 | 36 | GT_path = self.GT_paths + filename + '.png' 37 | 38 | image = Image.open(image_path) 39 | GT = Image.open(GT_path) 40 | 41 | Transform = [] 42 | 43 | Transform.append(T.ToTensor()) 44 | Transform = T.Compose(Transform) 45 | 46 | image = Transform(image) 47 | GT = Transform(GT) 48 | 49 | Norm_ = T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 50 | image = Norm_(image) 51 | return image, GT 52 | elif self.mode == 'test': 53 | image = Image.open(image_path) 54 | Transform = [] 55 | Transform.append(T.ToTensor()) 56 | Transform = T.Compose(Transform) 57 | 58 | image = Transform(image) 59 | 60 | Norm_ = T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 61 | image = Norm_(image) 62 | return [image_path, image] 63 | def __len__(self): 64 | """Returns the total number of font files.""" 65 | return len(self.image_paths) 66 | 67 | def get_loader(image_path, image_size, batch_size, num_workers=2, mode='train',augmentation_prob=0): 68 | """Builds and returns Dataloader.""" 69 | 70 | dataset = ImageFolder(root=image_path, image_size=image_size, mode=mode, augmentation_prob=augmentation_prob) 71 | if mode == 'train' or mode == 'valid': 72 | data_loader = data.DataLoader(dataset=dataset, 73 | batch_size=batch_size, 74 | shuffle=True, 75 | num_workers=num_workers) 76 | # for i, (images, GT) in enumerate(data_loader): 77 | # print(images.shape, GT.shape) 78 | return data_loader 79 | elif mode == 'test': 80 | data_loader = data.DataLoader(dataset=dataset, 81 | batch_size=batch_size, 82 | num_workers=num_workers) 83 | # for i, (images_path, images) in enumerate(data_loader): 84 | # print(images_path, images.shape) 85 | return data_loader -------------------------------------------------------------------------------- /data_set/Urben_pre/annotation/0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QTazimi/Water_extraction/bf366527263d123c7ae6b2c13a72f3e56e25ff18/data_set/Urben_pre/annotation/0.png -------------------------------------------------------------------------------- /data_set/Urben_pre/annotation/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QTazimi/Water_extraction/bf366527263d123c7ae6b2c13a72f3e56e25ff18/data_set/Urben_pre/annotation/1.png -------------------------------------------------------------------------------- /data_set/Urben_pre/annotation/10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QTazimi/Water_extraction/bf366527263d123c7ae6b2c13a72f3e56e25ff18/data_set/Urben_pre/annotation/10.png -------------------------------------------------------------------------------- /data_set/Urben_pre/annotation/29.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QTazimi/Water_extraction/bf366527263d123c7ae6b2c13a72f3e56e25ff18/data_set/Urben_pre/annotation/29.png -------------------------------------------------------------------------------- /data_set/Urben_pre/annotation/31.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QTazimi/Water_extraction/bf366527263d123c7ae6b2c13a72f3e56e25ff18/data_set/Urben_pre/annotation/31.png -------------------------------------------------------------------------------- /data_set/Urben_pre/test/29.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QTazimi/Water_extraction/bf366527263d123c7ae6b2c13a72f3e56e25ff18/data_set/Urben_pre/test/29.png -------------------------------------------------------------------------------- /data_set/Urben_pre/test/31.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QTazimi/Water_extraction/bf366527263d123c7ae6b2c13a72f3e56e25ff18/data_set/Urben_pre/test/31.png -------------------------------------------------------------------------------- /data_set/Urben_pre/train/0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QTazimi/Water_extraction/bf366527263d123c7ae6b2c13a72f3e56e25ff18/data_set/Urben_pre/train/0.png -------------------------------------------------------------------------------- /data_set/Urben_pre/train/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QTazimi/Water_extraction/bf366527263d123c7ae6b2c13a72f3e56e25ff18/data_set/Urben_pre/train/1.png -------------------------------------------------------------------------------- /data_set/Urben_pre/valid/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QTazimi/Water_extraction/bf366527263d123c7ae6b2c13a72f3e56e25ff18/data_set/Urben_pre/valid/1.png -------------------------------------------------------------------------------- /data_set/Urben_pre/valid/10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QTazimi/Water_extraction/bf366527263d123c7ae6b2c13a72f3e56e25ff18/data_set/Urben_pre/valid/10.png -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | import os 3 | from PIL import Image 4 | 5 | def split_train_test(data, train_ratio1, valid_ratio2): 6 | # 设置随机数种子,保证每次生成的结果都是一样的 7 | rng = random.Random(12345) 8 | rng.shuffle(data) 9 | train_set_size = int(len(data) * train_ratio1) 10 | valid_set_size = int(len(data) * valid_ratio2) 11 | # test_set_size = len(data) - train_set_size - valid_set_size 12 | train_datas = data[:train_set_size] 13 | valid_datas = data[train_set_size:train_set_size + valid_set_size] 14 | test_datas = data[train_set_size + valid_set_size:] 15 | return train_datas, valid_datas, test_datas 16 | 17 | 18 | def get_dataset(original_path, save_path): 19 | lists = os.listdir(original_path) 20 | print(len(lists)) 21 | train_datas, valid_datas, test_datas = split_train_test(lists, 0.7, 0.2) 22 | print(len(train_datas), len(valid_datas), len(test_datas)) 23 | for item in train_datas: 24 | train_path = original_path + item 25 | save_train_path = save_path + 'train/' + item 26 | image = Image.open(train_path) 27 | image.save(save_train_path) 28 | # image.show() 29 | for item in valid_datas: 30 | valid_path = original_path + item 31 | save_valid_path = save_path + 'valid/' + item 32 | image = Image.open(valid_path) 33 | image.save(save_valid_path) 34 | for item in valid_datas: 35 | test_path = original_path + item 36 | save_test_path = save_path + 'test/' + item 37 | image = Image.open(test_path) 38 | image.save(save_test_path) 39 | 40 | print("end!") 41 | 42 | 43 | 44 | 45 | # get_dataset("E:/高分辨率/many/", "E:/高分辨率/new_splict/") -------------------------------------------------------------------------------- /enhance_image.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | from PIL import Image 4 | import matplotlib.pyplot as plt 5 | 6 | #进行样本的数据增强,负样本数量太少的时候可以对负样本进行数据增强。 7 | def get_enhance_image(path, rotation_degree, reduction_ratio): 8 | #旋转放缩数据但会造成一定的黑边。 9 | #path原始图片的路径,rotation_degree旋转角度,reduction_ratio放缩比率 10 | img = cv2.imread(path) 11 | 12 | h, w = img.shape[:2] 13 | center = (w // 2, h // 2) 14 | 15 | # 旋转中心坐标,逆时针旋转:45°,缩放因子:0.5 16 | M_1 = cv2.getRotationMatrix2D(center, rotation_degree, reduction_ratio) 17 | rotated_1 = cv2.warpAffine(img, M_1, (w, h)) 18 | 19 | plt.imshow(rotated_1) 20 | plt.show() 21 | 22 | # cv2.imshow('rotated_45.jpg', rotated_1) 23 | temp_path = path.split('.')[0] 24 | rear = path.split('.')[1] 25 | new_path = temp_path + '_' + str(rotation_degree) + '_' + str(reduction_ratio).replace('.', 'd') + '.' + rear 26 | print(new_path) 27 | # cv2.imwrite(new_path, rotated_1) 28 | 29 | # get_enhance_image('D:/Workspace/Pycharm/Unet/data_set/Urben_original/train\images_png/1379.png', 120, 0.8) 30 | 31 | def get_rotation_image(path, ro): 32 | #只旋转3种旋转角度:90、180、270 33 | img = Image.open(path) 34 | temp_path = path.split('.')[0] 35 | rear = path.split('.')[1] 36 | # new_path = temp_path + '_' + ro + '.' + rear 37 | # print(new_path) 38 | if ro == '90': 39 | new_img = img.transpose(Image.ROTATE_90) # 将图片旋转90度 40 | plt.imshow(new_img) 41 | plt.show() 42 | # new_img.show("img/rotateImg.png") 43 | # new_img.save(new_path) 44 | elif ro == '180': 45 | new_img = img.transpose(Image.ROTATE_180) # 将图片旋转180度 46 | # new_img.show("img/rotateImg.png") 47 | # new_img.save(new_path) 48 | elif ro == '270': 49 | new_img = img.transpose(Image.ROTATE_270) # 将图片旋转270度 50 | # new_img.show("img/rotateImg.png") 51 | # new_img.save(new_path) 52 | 53 | get_rotation_image('D:/Workspace/Pycharm/Unet/data_set/Urben_original/train\images_png/1379.png', "90") 54 | def get_all_enhance(path): 55 | lists = os.listdir(path) 56 | for item in lists: 57 | temp_path = path + item 58 | print(temp_path) 59 | get_rotation_image(temp_path, '180') 60 | 61 | # get_all_enhance("D:/HAAR/data_set/negative_enhance_dataset/") 62 | 63 | 64 | # path = "D:/HAAR/test/1/106.jpg" 65 | # get_rotation_image(path, "90") 66 | # get_enhance_image(path, 45, 0.5) -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class DiceLoss(nn.Module): 6 | def __init__(self): 7 | super(DiceLoss, self).__init__() 8 | self.epsilon = 1e-5 9 | 10 | def forward(self, predict, target): 11 | assert predict.size() == target.size(), "the size of predict and target must be equal." 12 | num = predict.size(0) 13 | 14 | pre = torch.sigmoid(predict).view(num, -1) 15 | tar = target.view(num, -1) 16 | intersection = (pre * tar).sum(-1).sum() # 利用预测值与标签相乘当作交集 17 | union = (pre + tar).sum(-1).sum() 18 | 19 | score = 1 - 2 * (intersection + self.epsilon) / (union + self.epsilon) 20 | 21 | return score 22 | 23 | 24 | # loss = DiceLoss() 25 | # predict = torch.tensor([[1, 0, 1], [1, 1, 0]]) 26 | # target = torch.tensor([[1, 0, 0], [0, 1, 1]]) 27 | # score = loss(predict, target) 28 | # print(score) 29 | 30 | # SR : Segmentation Result 31 | # GT : Ground Truth 32 | 33 | def get_accuracy(SR,GT,threshold=0.5): 34 | SR = SR > threshold 35 | GT = GT == torch.max(GT) 36 | corr = torch.sum(SR==GT) 37 | tensor_size = SR.size(0)*SR.size(1)*SR.size(2)*SR.size(3) 38 | acc = float(corr)/float(tensor_size) 39 | 40 | return acc 41 | 42 | def get_sensitivity(SR,GT,threshold=0.5): 43 | # Sensitivity == Recall 44 | SR = SR > threshold 45 | GT = GT == torch.max(GT) 46 | 47 | # TP : True Positive 48 | # FN : False Negative 49 | Inter = SR * GT 50 | TP = torch.sum(Inter) 51 | TP_FN = torch.sum(GT) 52 | 53 | SE = float(TP)/(float(TP_FN) + 1e-6) 54 | 55 | return SE 56 | 57 | def get_specificity(SR,GT,threshold=0.5): 58 | SR = SR > threshold 59 | GT = GT == torch.max(GT) 60 | 61 | # TN : True Negative 62 | # FP : False Positive 63 | Inter = SR * GT 64 | TP = torch.sum(Inter) 65 | FP = torch.sum(SR) - TP 66 | 67 | FN = torch.sum(GT) - TP 68 | TN = SR.size(0)*SR.size(1)*SR.size(2)*SR.size(3) - FN - TP - FP 69 | 70 | 71 | SP = float(TN)/(float(TN + FP) + 1e-6) 72 | 73 | return SP 74 | 75 | def get_precision(SR,GT,threshold=0.5): 76 | SR = SR > threshold 77 | GT = GT == torch.max(GT) 78 | 79 | # TP : True Positive 80 | # FP : False Positive 81 | Inter = SR * GT 82 | TP = torch.sum(Inter) 83 | TP_FP = torch.sum(SR) 84 | 85 | PC = float(TP)/(float(TP_FP) + 1e-6) 86 | 87 | return PC 88 | 89 | def get_F1(SR,GT,threshold=0.5): 90 | # Sensitivity == Recall 91 | SE = get_sensitivity(SR,GT,threshold=threshold) 92 | PC = get_precision(SR,GT,threshold=threshold) 93 | 94 | F1 = 2*SE*PC/(SE+PC + 1e-6) 95 | 96 | return F1 97 | 98 | def get_JS(SR,GT,threshold=0.5): 99 | # JS : Jaccard similarity 100 | SR = SR > threshold 101 | GT = GT == torch.max(GT) 102 | Inter = SR * GT 103 | Inter = torch.sum(Inter) 104 | Union = torch.sum(SR)+torch.sum(GT)-Inter 105 | JS = float(Inter)/(float(Union) + 1e-6) 106 | 107 | return JS 108 | 109 | def get_DC(SR,GT,threshold=0.5): 110 | # DC : Dice Coefficient 111 | SR = SR > threshold 112 | GT = GT == torch.max(GT) 113 | Inter = SR * GT 114 | Inter = torch.sum(Inter) 115 | DC = float(2*Inter)/(float(torch.sum(SR)+torch.sum(GT)) + 1e-6) 116 | 117 | return DC 118 | 119 | 120 | def get_iou(SR, GT, threshold=0.5): 121 | SR = SR > threshold 122 | GT = GT == torch.max(GT) 123 | 124 | # TN : True Negative 125 | # FP : False Positive 126 | Inter = SR * GT 127 | TP = torch.sum(Inter) 128 | FP = torch.sum(SR) - TP 129 | 130 | FN = torch.sum(GT) - TP 131 | 132 | iou = float(TP) / (float(TP + FN + FP) + 1e-6) 133 | 134 | return iou 135 | 136 | 137 | def get_FWiou(SR, GT, threshold=0.5): 138 | SR = SR > threshold 139 | GT = GT == torch.max(GT) 140 | 141 | # TN : True Negative 142 | # FP : False Positive 143 | Inter = SR * GT 144 | TP = torch.sum(Inter) 145 | FP = torch.sum(SR) - TP 146 | 147 | FN = torch.sum(GT) - TP 148 | TN = SR.size(0) * SR.size(1) * SR.size(2) * SR.size(3) - FN - TP - FP 149 | 150 | a = float(TP + FN) / float(TP + FP + FN + TN + 1e-6) 151 | b = float(TP) / float(TP + FP + FN + 1e-6) 152 | fwiou = a * b 153 | 154 | return fwiou 155 | 156 | -------------------------------------------------------------------------------- /img/AttU-Net.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QTazimi/Water_extraction/bf366527263d123c7ae6b2c13a72f3e56e25ff18/img/AttU-Net.png -------------------------------------------------------------------------------- /img/U-Net.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QTazimi/Water_extraction/bf366527263d123c7ae6b2c13a72f3e56e25ff18/img/U-Net.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from solver import Solver 4 | from data_loader import get_loader 5 | from torch.backends import cudnn 6 | import random 7 | 8 | def main(config): 9 | cudnn.benchmark = True 10 | if config.model_type not in ['U_Net','R2U_Net','AttU_Net','R2AttU_Net']: 11 | print('ERROR!! model_type should be selected in U_Net/R2U_Net/AttU_Net/R2AttU_Net') 12 | print('Your input for model_type was %s'%config.model_type) 13 | return 14 | 15 | # Create directories if not exist 16 | if not os.path.exists(config.model_path): 17 | os.makedirs(config.model_path) 18 | if not os.path.exists(config.result_path): 19 | os.makedirs(config.result_path) 20 | config.result_path = os.path.join(config.result_path,config.model_type) 21 | if not os.path.exists(config.result_path): 22 | os.makedirs(config.result_path) 23 | 24 | print(config) 25 | # Train and sample the images 26 | if config.mode == 'train': 27 | train_loader = get_loader(image_path=config.train_path, 28 | image_size=config.image_size, 29 | batch_size=config.batch_size, 30 | num_workers=config.num_workers, 31 | mode='train', 32 | augmentation_prob=config.augmentation_prob) 33 | valid_loader = get_loader(image_path=config.valid_path, 34 | image_size=config.image_size, 35 | batch_size=config.batch_size, 36 | num_workers=config.num_workers, 37 | mode='valid', 38 | augmentation_prob=0.) 39 | solver = Solver(config, train_loader=train_loader, valid_loader=valid_loader) 40 | solver.train() 41 | elif config.mode == 'test': 42 | test_loader = get_loader(image_path=config.test_path, 43 | image_size=config.image_size, 44 | batch_size=config.batch_size, 45 | num_workers=config.num_workers, 46 | mode='test', 47 | augmentation_prob=0.) 48 | solver = Solver(config, test_loader=test_loader) 49 | solver.test() 50 | 51 | 52 | if __name__ == '__main__': 53 | parser = argparse.ArgumentParser() 54 | 55 | 56 | # model hyper-parameters 57 | parser.add_argument('--image_size', type=int, default=256) 58 | parser.add_argument('--t', type=int, default=3, help='t for Recurrent step of R2U_Net or R2AttU_Net') 59 | 60 | # training hyper-parameters 61 | parser.add_argument('--img_ch', type=int, default=3) 62 | parser.add_argument('--output_ch', type=int, default=1) 63 | parser.add_argument('--num_epochs', type=int, default=5)#100 64 | parser.add_argument('--num_epochs_decay', type=int, default=70)#70 65 | parser.add_argument('--batch_size', type=int, default=8) 66 | parser.add_argument('--num_workers', type=int, default=8) 67 | parser.add_argument('--lr', type=float, default=0.0002) #0.0002 68 | parser.add_argument('--beta1', type=float, default=0.5) # momentum1 in Adam 69 | parser.add_argument('--beta2', type=float, default=0.999) # momentum2 in Adam 70 | parser.add_argument('--augmentation_prob', type=float, default=0) #0.4 71 | 72 | parser.add_argument('--log_step', type=int, default=2) 73 | parser.add_argument('--val_step', type=int, default=2) 74 | 75 | # misc 76 | parser.add_argument('--mode', type=str, default='test') 77 | parser.add_argument('--model_type', type=str, default='U_Net', help='U_Net/AttU_Net/')#U_Net 78 | parser.add_argument('--model_path', type=str, default='./models') 79 | parser.add_argument('--train_path', type=str, default='./data_set/Urben_pre/train/') 80 | parser.add_argument('--valid_path', type=str, default='./data_set/Urben_pre/valid/') 81 | parser.add_argument('--test_path', type=str, default='/home/program/Unet/data_set/test_image/test/') 82 | parser.add_argument('--result_path', type=str, default='./result/') 83 | 84 | parser.add_argument('--cuda_idx', type=int, default=1) 85 | 86 | config = parser.parse_args() 87 | main(config) 88 | -------------------------------------------------------------------------------- /models/train.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QTazimi/Water_extraction/bf366527263d123c7ae6b2c13a72f3e56e25ff18/models/train.pth -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import init 5 | 6 | def init_weights(net, init_type='normal', gain=0.02): 7 | def init_func(m): 8 | classname = m.__class__.__name__ 9 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 10 | if init_type == 'normal': 11 | init.normal_(m.weight.data, 0.0, gain) 12 | elif init_type == 'xavier': 13 | init.xavier_normal_(m.weight.data, gain=gain) 14 | elif init_type == 'kaiming': 15 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 16 | elif init_type == 'orthogonal': 17 | init.orthogonal_(m.weight.data, gain=gain) 18 | else: 19 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 20 | if hasattr(m, 'bias') and m.bias is not None: 21 | init.constant_(m.bias.data, 0.0) 22 | elif classname.find('BatchNorm2d') != -1: 23 | init.normal_(m.weight.data, 1.0, gain) 24 | init.constant_(m.bias.data, 0.0) 25 | 26 | print('initialize network with %s' % init_type) 27 | net.apply(init_func) 28 | 29 | class conv_block(nn.Module): 30 | def __init__(self,ch_in,ch_out): 31 | super(conv_block,self).__init__() 32 | self.conv = nn.Sequential( 33 | nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True), 34 | nn.BatchNorm2d(ch_out), 35 | nn.ReLU(inplace=True), 36 | nn.Conv2d(ch_out, ch_out, kernel_size=3,stride=1,padding=1,bias=True), 37 | nn.BatchNorm2d(ch_out), 38 | nn.ReLU(inplace=True) 39 | ) 40 | 41 | 42 | def forward(self,x): 43 | x = self.conv(x) 44 | return x 45 | 46 | class up_conv(nn.Module): 47 | def __init__(self,ch_in,ch_out): 48 | super(up_conv,self).__init__() 49 | self.up = nn.Sequential( 50 | nn.Upsample(scale_factor=2), 51 | nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=1,padding=1,bias=True), 52 | nn.BatchNorm2d(ch_out), 53 | nn.ReLU(inplace=True) 54 | ) 55 | 56 | def forward(self,x): 57 | x = self.up(x) 58 | return x 59 | 60 | class Attention_block(nn.Module): 61 | def __init__(self,F_g,F_l,F_int): 62 | super(Attention_block,self).__init__() 63 | self.W_g = nn.Sequential( 64 | nn.Conv2d(F_g, F_int, kernel_size=1,stride=1,padding=0,bias=True), 65 | nn.BatchNorm2d(F_int) 66 | ) 67 | 68 | self.W_x = nn.Sequential( 69 | nn.Conv2d(F_l, F_int, kernel_size=1,stride=1,padding=0,bias=True), 70 | nn.BatchNorm2d(F_int) 71 | ) 72 | 73 | self.psi = nn.Sequential( 74 | nn.Conv2d(F_int, 1, kernel_size=1,stride=1,padding=0,bias=True), 75 | nn.BatchNorm2d(1), 76 | nn.Sigmoid() 77 | ) 78 | 79 | self.relu = nn.ReLU(inplace=True) 80 | 81 | def forward(self,g,x): 82 | g1 = self.W_g(g) 83 | x1 = self.W_x(x) 84 | psi = self.relu(g1+x1) 85 | psi = self.psi(psi) 86 | 87 | return x*psi 88 | 89 | class U_Net(nn.Module): 90 | def __init__(self,img_ch=3,output_ch=1): 91 | super(U_Net,self).__init__() 92 | 93 | self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2) 94 | 95 | self.Conv1 = conv_block(ch_in=img_ch,ch_out=64) 96 | self.Conv2 = conv_block(ch_in=64,ch_out=128) 97 | self.Conv3 = conv_block(ch_in=128,ch_out=256) 98 | self.Conv4 = conv_block(ch_in=256,ch_out=512) 99 | self.Conv5 = conv_block(ch_in=512,ch_out=1024) 100 | 101 | self.Up5 = up_conv(ch_in=1024,ch_out=512) 102 | self.Up_conv5 = conv_block(ch_in=1024, ch_out=512) 103 | 104 | self.Up4 = up_conv(ch_in=512,ch_out=256) 105 | self.Up_conv4 = conv_block(ch_in=512, ch_out=256) 106 | 107 | self.Up3 = up_conv(ch_in=256,ch_out=128) 108 | self.Up_conv3 = conv_block(ch_in=256, ch_out=128) 109 | 110 | self.Up2 = up_conv(ch_in=128,ch_out=64) 111 | self.Up_conv2 = conv_block(ch_in=128, ch_out=64) 112 | 113 | self.Conv_1x1 = nn.Conv2d(64,output_ch,kernel_size=1,stride=1,padding=0) 114 | 115 | 116 | def forward(self,x): 117 | # encoding path 118 | x1 = self.Conv1(x) 119 | 120 | x2 = self.Maxpool(x1) 121 | x2 = self.Conv2(x2) 122 | 123 | x3 = self.Maxpool(x2) 124 | x3 = self.Conv3(x3) 125 | 126 | x4 = self.Maxpool(x3) 127 | x4 = self.Conv4(x4) 128 | 129 | x5 = self.Maxpool(x4) 130 | x5 = self.Conv5(x5) 131 | 132 | # decoding + concat path 133 | d5 = self.Up5(x5) 134 | d5 = torch.cat((x4,d5),dim=1) 135 | 136 | d5 = self.Up_conv5(d5) 137 | 138 | d4 = self.Up4(d5) 139 | d4 = torch.cat((x3,d4),dim=1) 140 | d4 = self.Up_conv4(d4) 141 | 142 | d3 = self.Up3(d4) 143 | d3 = torch.cat((x2,d3),dim=1) 144 | d3 = self.Up_conv3(d3) 145 | 146 | d2 = self.Up2(d3) 147 | d2 = torch.cat((x1,d2),dim=1) 148 | d2 = self.Up_conv2(d2) 149 | 150 | d1 = self.Conv_1x1(d2) 151 | 152 | return d1 153 | 154 | class AttU_Net(nn.Module): 155 | def __init__(self,img_ch=3,output_ch=1): 156 | super(AttU_Net,self).__init__() 157 | 158 | self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2) 159 | 160 | self.Conv1 = conv_block(ch_in=img_ch,ch_out=64) 161 | self.Conv2 = conv_block(ch_in=64,ch_out=128) 162 | self.Conv3 = conv_block(ch_in=128,ch_out=256) 163 | self.Conv4 = conv_block(ch_in=256,ch_out=512) 164 | self.Conv5 = conv_block(ch_in=512,ch_out=1024) 165 | 166 | self.Up5 = up_conv(ch_in=1024,ch_out=512) 167 | self.Att5 = Attention_block(F_g=512,F_l=512,F_int=256) 168 | self.Up_conv5 = conv_block(ch_in=1024, ch_out=512) 169 | 170 | self.Up4 = up_conv(ch_in=512,ch_out=256) 171 | self.Att4 = Attention_block(F_g=256,F_l=256,F_int=128) 172 | self.Up_conv4 = conv_block(ch_in=512, ch_out=256) 173 | 174 | self.Up3 = up_conv(ch_in=256,ch_out=128) 175 | self.Att3 = Attention_block(F_g=128,F_l=128,F_int=64) 176 | self.Up_conv3 = conv_block(ch_in=256, ch_out=128) 177 | 178 | self.Up2 = up_conv(ch_in=128,ch_out=64) 179 | self.Att2 = Attention_block(F_g=64,F_l=64,F_int=32) 180 | self.Up_conv2 = conv_block(ch_in=128, ch_out=64) 181 | 182 | self.Conv_1x1 = nn.Conv2d(64,output_ch,kernel_size=1,stride=1,padding=0) 183 | 184 | 185 | def forward(self,x): 186 | # encoding path 187 | x1 = self.Conv1(x) 188 | 189 | x2 = self.Maxpool(x1) 190 | x2 = self.Conv2(x2) 191 | 192 | x3 = self.Maxpool(x2) 193 | x3 = self.Conv3(x3) 194 | 195 | x4 = self.Maxpool(x3) 196 | x4 = self.Conv4(x4) 197 | 198 | x5 = self.Maxpool(x4) 199 | x5 = self.Conv5(x5) 200 | 201 | # decoding + concat path 202 | d5 = self.Up5(x5) 203 | x4 = self.Att5(g=d5,x=x4) 204 | d5 = torch.cat((x4,d5),dim=1) 205 | d5 = self.Up_conv5(d5) 206 | 207 | d4 = self.Up4(d5) 208 | x3 = self.Att4(g=d4,x=x3) 209 | d4 = torch.cat((x3,d4),dim=1) 210 | d4 = self.Up_conv4(d4) 211 | 212 | d3 = self.Up3(d4) 213 | x2 = self.Att3(g=d3,x=x2) 214 | d3 = torch.cat((x2,d3),dim=1) 215 | d3 = self.Up_conv3(d3) 216 | 217 | d2 = self.Up2(d3) 218 | x1 = self.Att2(g=d2,x=x1) 219 | d2 = torch.cat((x1,d2),dim=1) 220 | d2 = self.Up_conv2(d2) 221 | 222 | d1 = self.Conv_1x1(d2) 223 | 224 | return d1 225 | 226 | 227 | -------------------------------------------------------------------------------- /result/res.csv: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QTazimi/Water_extraction/bf366527263d123c7ae6b2c13a72f3e56e25ff18/result/res.csv -------------------------------------------------------------------------------- /solver.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import time 4 | import datetime 5 | import torch 6 | import torchvision 7 | from torch.autograd import Variable 8 | import torch.nn.functional as F 9 | from evaluation import * 10 | from network import U_Net, AttU_Net 11 | from torch import optim, nn 12 | import cv2 13 | import matplotlib.pyplot as plt 14 | import csv 15 | 16 | 17 | class Solver(object): 18 | def __init__(self, config, train_loader=None, valid_loader=None, test_loader=None): 19 | 20 | # Data loader 21 | self.train_loader = train_loader 22 | self.valid_loader = valid_loader 23 | self.test_loader = test_loader 24 | 25 | # Models 26 | self.unet = None 27 | self.optimizer = None 28 | self.img_ch = config.img_ch 29 | self.output_ch = config.output_ch 30 | self.criterion = DiceLoss() #nn.BCEWithLogitsLoss() #torch.nn.BCELoss() 31 | self.augmentation_prob = config.augmentation_prob 32 | 33 | # Hyper-parameters 34 | self.lr = config.lr 35 | self.beta1 = config.beta1 36 | self.beta2 = config.beta2 37 | 38 | # Training settings 39 | self.num_epochs = config.num_epochs 40 | self.num_epochs_decay = config.num_epochs_decay 41 | self.batch_size = config.batch_size 42 | 43 | # Step size 44 | self.log_step = config.log_step 45 | self.val_step = config.val_step 46 | 47 | # Path 48 | self.model_path = config.model_path 49 | self.result_path = config.result_path 50 | self.mode = config.mode 51 | 52 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')#:1 53 | self.model_type = config.model_type 54 | self.t = config.t 55 | self.build_model() 56 | 57 | def build_model(self): 58 | """Build generator and discriminator.""" 59 | if self.model_type =='U_Net': 60 | self.unet = U_Net(img_ch=3,output_ch=1) 61 | elif self.model_type =='AttU_Net': 62 | self.unet = AttU_Net(img_ch=3,output_ch=1) 63 | 64 | 65 | self.optimizer = optim.Adam(list(self.unet.parameters()), 66 | self.lr, [self.beta1, self.beta2]) 67 | self.unet.to(self.device) 68 | self.unet = nn.DataParallel(self.unet) 69 | # self.print_network(self.unet, self.model_type) 70 | 71 | def print_network(self, model, name): 72 | """Print out the network information.""" 73 | num_params = 0 74 | for p in model.parameters(): 75 | num_params += p.numel() 76 | print(model) 77 | print(name) 78 | print("The number of parameters: {}".format(num_params)) 79 | 80 | 81 | # def to_data(self, x): 82 | # """Convert variable to tensor.""" 83 | # if torch.cuda.is_available(): 84 | # x = x.cpu() 85 | # return x.data 86 | 87 | def reset_grad(self): 88 | """Zero the gradient buffers.""" 89 | self.unet.zero_grad() 90 | 91 | 92 | def train(self): 93 | """Train encoder, generator and discriminator.""" 94 | 95 | #====================================== Training ===========================================# 96 | #===========================================================================================# 97 | 98 | unet_path = os.path.join(self.model_path, '%s-%d-%.4f-%d-%.4f.pth' %(self.model_type,self.num_epochs,self.lr,self.num_epochs_decay,self.augmentation_prob)) 99 | 100 | # U-Net Train 101 | if os.path.isfile(unet_path): 102 | # Load the pretrained Encoder 103 | self.unet.load_state_dict(torch.load(unet_path)) 104 | print('%s is Successfully Loaded from %s'%(self.model_type,unet_path)) 105 | else: 106 | f = open(os.path.join(self.result_path, 'train.csv'), 'a', encoding='utf-8', newline='') 107 | wr = csv.writer(f) 108 | wr.writerow(["model_type", "Epoch", "Total_Epoch", "loss", "acc", "SE", "SP", "PC", "F1", "JS", "DC", "IOU"]) 109 | f.close() 110 | f = open(os.path.join(self.result_path, 'valid.csv'), 'a', encoding='utf-8', newline='') 111 | wr = csv.writer(f) 112 | wr.writerow(["model_type", "Epoch", "Total_Epoch", "acc", "SE", "SP", "PC", "F1", "JS", "DC", "IOU"]) 113 | f.close() 114 | # Train for Encoder 115 | lr = self.lr 116 | best_unet_score = 0. 117 | for epoch in range(self.num_epochs): 118 | 119 | self.unet.train(True) 120 | epoch_loss = 0 121 | 122 | acc = 0. # Accuracy 123 | SE = 0. # Sensitivity (Recall) 124 | SP = 0. # Specificity 125 | PC = 0. # Precision 126 | F1 = 0. # F1 Score 127 | JS = 0. # Jaccard Similarity 128 | DC = 0. # Dice Coefficient 129 | IOU = 0. # Intersection-over-Union, IoU 130 | # FWIOU = 0. #Frequency Weighted Intersection-over-Union, FWIoU 131 | length = 0 132 | 133 | for i, (images, GT) in enumerate(self.train_loader): 134 | # GT : Ground Truth 135 | 136 | images = images.to(self.device, dtype=torch.float32) 137 | GT = GT.to(self.device, dtype=torch.float32) 138 | # SR : Segmentation Result 139 | SR = self.unet(images) 140 | 141 | SR_flat = SR.view(SR.size(0), -1)#sigmoid 142 | 143 | GT_flat = GT.view(GT.size(0), -1) 144 | 145 | # GT_flat = GT_flat.float() 146 | loss = self.criterion(SR_flat, GT_flat) 147 | epoch_loss += loss.item() 148 | 149 | # Backprop + optimize 150 | self.reset_grad() 151 | loss.backward() 152 | self.optimizer.step() 153 | 154 | SR_probs = F.sigmoid(SR) 155 | acc += get_accuracy(SR_probs, GT) 156 | SE += get_sensitivity(SR_probs, GT) 157 | SP += get_specificity(SR_probs, GT) 158 | PC += get_precision(SR_probs, GT) 159 | F1 += get_F1(SR_probs, GT) 160 | JS += get_JS(SR_probs, GT) 161 | DC += get_DC(SR_probs, GT) 162 | IOU += get_iou(SR_probs, GT) 163 | # FWIOU += get_FWiou(SR_probs, GT) 164 | length += 1#images.size(0) 165 | 166 | acc = acc/length 167 | SE = SE/length 168 | SP = SP/length 169 | PC = PC/length 170 | F1 = F1/length 171 | JS = JS/length 172 | DC = DC/length 173 | IOU = IOU/length 174 | # FWIOU = FWIOU/length 175 | # Print the log info 176 | print('Epoch [%d/%d], Loss: %.4f, \n[Training] Acc: %.4f, SE: %.4f, SP: %.4f, PC: %.4f, F1: %.4f, JS: %.4f, DC: %.4f, IOU: %.4f' % ( 177 | epoch+1, self.num_epochs, \ 178 | epoch_loss / length,\ 179 | acc,SE,SP,PC,F1,JS,DC,IOU)) 180 | 181 | f = open(os.path.join(self.result_path, 'train.csv'), 'a', encoding='utf-8', newline='') 182 | wr = csv.writer(f) 183 | wr.writerow([self.model_type, epoch+1, self.num_epochs, 184 | epoch_loss / length, acc, SE, SP, PC, F1, JS, DC,IOU]) 185 | f.close() 186 | 187 | # decay_rate = 0.7 188 | # lr = self.lr * np.power(decay_rate, epoch) 189 | # print(lr) 190 | # Decay learning rate 191 | if (epoch+1) > (self.num_epochs - self.num_epochs_decay): 192 | lr -= (self.lr / float(self.num_epochs_decay)) 193 | # decay_rate = 0.7 194 | # lr = self.lr * np.power(decay_rate, (epoch + 1 - self.num_epochs_decay) / 20) 195 | for param_group in self.optimizer.param_groups: 196 | param_group['lr'] = lr 197 | print('Decay learning rate to lr: {}.'.format(lr)) 198 | 199 | 200 | #===================================== Validation ====================================# 201 | self.unet.train(False) 202 | self.unet.eval() 203 | 204 | acc = 0. # Accuracy 205 | SE = 0. # Sensitivity (Recall) 206 | SP = 0. # Specificity 207 | PC = 0. # Precision 208 | F1 = 0. # F1 Score 209 | JS = 0. # Jaccard Similarity 210 | DC = 0. # Dice Coefficient 211 | IOU = 0. # Intersection-over-Union, IoU 212 | # FWIOU = 0. #Frequency Weighted Intersection-over-Union, FWIoU 213 | length = 0 214 | for i, (images, GT) in enumerate(self.valid_loader): 215 | 216 | images = images.to(self.device) 217 | GT = GT.to(self.device) 218 | SR = F.sigmoid(self.unet(images)) 219 | acc += get_accuracy(SR,GT) 220 | SE += get_sensitivity(SR,GT) 221 | SP += get_specificity(SR,GT) 222 | PC += get_precision(SR,GT) 223 | F1 += get_F1(SR,GT) 224 | JS += get_JS(SR,GT) 225 | DC += get_DC(SR,GT) 226 | IOU += get_iou(SR,GT) 227 | # FWIOU += get_FWiou(SR,GT) 228 | 229 | 230 | length += 1 #images.size(0) 231 | acc = acc/length 232 | SE = SE/length 233 | SP = SP/length 234 | PC = PC/length 235 | F1 = F1/length 236 | JS = JS/length 237 | DC = DC/length 238 | IOU = IOU / length 239 | # FWIOU = FWIOU/length 240 | unet_score = JS + DC 241 | 242 | print('[Validation] Acc: %.4f, SE: %.4f, SP: %.4f, PC: %.4f, F1: %.4f, JS: %.4f, DC: %.4f, IOU: %.4f'%(acc,SE,SP,PC,F1,JS,DC,IOU)) 243 | f = open(os.path.join(self.result_path, 'valid.csv'), 'a', encoding='utf-8', newline='') 244 | wr = csv.writer(f) 245 | wr.writerow([self.model_type, epoch + 1, self.num_epochs, 246 | acc, SE, SP, PC, F1, JS, DC, IOU]) 247 | f.close() 248 | ''' 249 | torchvision.utils.save_image(images.data.cpu(), 250 | os.path.join(self.result_path, 251 | '%s_valid_%d_image.png'%(self.model_type,epoch+1))) 252 | torchvision.utils.save_image(SR.data.cpu(), 253 | os.path.join(self.result_path, 254 | '%s_valid_%d_SR.png'%(self.model_type,epoch+1))) 255 | torchvision.utils.save_image(GT.data.cpu(), 256 | os.path.join(self.result_path, 257 | '%s_valid_%d_GT.png'%(self.model_type,epoch+1))) 258 | ''' 259 | 260 | # print('qsave:', unet_score, best_unet_score) 261 | # Save Best U-Net model 262 | if unet_score > best_unet_score: 263 | best_unet_score = unet_score 264 | best_epoch = epoch 265 | best_unet = self.unet.state_dict() 266 | print('Best %s model score : %.4f'%(self.model_type,best_unet_score)) 267 | torch.save(best_unet, unet_path) 268 | 269 | 270 | def test(self): 271 | # #===================================== Test ====================================# 272 | unet_path = "/home/program/Unet/models/U_Net-350-0.0002-70-0.0000.pth" 273 | save_path = '' 274 | self.build_model() 275 | self.unet.load_state_dict(torch.load(unet_path)) 276 | self.unet.eval() 277 | # acc = 0. # Accuracy 278 | # SE = 0. # Sensitivity (Recall) 279 | # SP = 0. # Specificity 280 | # PC = 0. # Precision 281 | # F1 = 0. # F1 Score 282 | # JS = 0. # Jaccard Similarity 283 | # DC = 0. # Dice Coefficient 284 | # length = 0 285 | 286 | for i, (image_path, image) in enumerate(self.test_loader): 287 | image = image.to(device=self.device, dtype=torch.float32) 288 | pred = F.sigmoid(self.unet(image)) 289 | pred = np.array(pred.data.cpu()) 290 | pred[pred >= 0.5] = 255 291 | pred[pred < 0.5] = 0 292 | for image_path_item, pred_item in zip(image_path, pred): 293 | #找轮廓 294 | image_path_item = image_path_item 295 | pred_item = np.array(pred_item, np.uint8) 296 | pred_item = pred_item.reshape(pred_item.shape[1], pred_item.shape[2]) 297 | contours, _ = cv2.findContours(pred_item, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 298 | img = cv2.imread(image_path_item, 1) 299 | cv2.drawContours(img, contours, -1, (0, 0, 255), 1) 300 | 301 | img = img[:, :, ::-1] 302 | img[..., 2] = np.where(pred_item == 255, 200, img[..., 2]) 303 | 304 | plt.imshow(img) 305 | plt.show() 306 | 307 | 308 | # print(image_path) 309 | # filename = image_path.split('/')[-1][:-len(".png")] 310 | # save_path = save_path + filename + '.png' 311 | # cv2.imwrite(save_path, pred) 312 | 313 | #指标计算 314 | # acc += get_accuracy(SR, GT) 315 | # SE += get_sensitivity(SR, GT) 316 | # SP += get_specificity(SR, GT) 317 | # PC += get_precision(SR, GT) 318 | # F1 += get_F1(SR, GT) 319 | # JS += get_JS(SR, GT) 320 | # DC += get_DC(SR, GT) 321 | # 322 | # length += 1 # images.size(0) 323 | 324 | # acc = acc / length 325 | # SE = SE / length 326 | # SP = SP / length 327 | # PC = PC / length 328 | # F1 = F1 / length 329 | # JS = JS / length 330 | # DC = DC / length 331 | # unet_score = JS + DC 332 | # 333 | # f = open(os.path.join(self.result_path, 'result.csv'), 'a', encoding='utf-8', newline='') 334 | # wr = csv.writer(f) 335 | # wr.writerow([self.model_type, acc, SE, SP, PC, F1, JS, DC]) 336 | # f.close() 337 | 338 | 339 | 340 | -------------------------------------------------------------------------------- /test/test_data.py: -------------------------------------------------------------------------------- 1 | from data_loader import get_loader 2 | from evaluation import * 3 | from network import U_Net 4 | import numpy as np 5 | import torch.nn.functional as F 6 | import matplotlib.pyplot as plt 7 | from PIL import Image 8 | import cv2 9 | import csv 10 | 11 | 12 | def test_net(test_loader, save_path=""): 13 | net = U_Net(img_ch=3, output_ch=1) 14 | # 选择设备,有cuda用cuda,没有就用cpu 15 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 16 | # 将网络拷贝到deivce中 17 | net = nn.DataParallel(net) 18 | net.to(device=device) 19 | # 加载模型参数 20 | # net.load_state_dict(torch.load("/home/program/Unet/models/U_Net-350-0.0002-70-0.0000.pth")) 21 | net.load_state_dict(torch.load("best_model.pth", map_location=device)) 22 | # 测试模式 23 | net = net.module.to(device) 24 | net.eval() 25 | print(len(train_loader)) 26 | for i, (image_path, image) in enumerate(test_loader): 27 | # 将数据拷贝到device中 28 | image = image.to(device=device, dtype=torch.float32) 29 | # # 预测 30 | pred = F.sigmoid(net(image)) 31 | # print(pred) 32 | # 处理结果 33 | pred = np.array(pred.data.cpu()[0])[0] 34 | pred[pred >= 0.5] = 255 35 | pred[pred < 0.5] = 0 36 | 37 | 38 | image_path = image_path[0] 39 | # 找轮廓 40 | pred = np.array(pred, np.uint8) 41 | contours,_ = cv2.findContours(pred, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 42 | img = cv2.imread(image_path, 1) 43 | cv2.drawContours(img, contours, -1, (0, 0, 255), 1) 44 | 45 | img = img[:, :, ::-1] 46 | img[..., 2] = np.where(pred == 255, 200, img[..., 2]) 47 | 48 | plt.imshow(img) 49 | plt.show() 50 | 51 | filename = image_path.split('/')[-1][:-len(".png")] 52 | cv2.imwrite(save_path + filename + ".png", pred) 53 | 54 | 55 | if __name__ == "__main__": 56 | train_loader = get_loader(image_path="../data_set/test_image/test/", 57 | image_size=256, 58 | batch_size=1, 59 | num_workers=8, 60 | mode='test', 61 | augmentation_prob=0)#单张图片作为一个batch_size 62 | test_net(train_loader, save_path="../data_set/test_image/res/") -------------------------------------------------------------------------------- /test/test_one_data.py: -------------------------------------------------------------------------------- 1 | from data_loader import get_loader 2 | from evaluation import * 3 | from network import U_Net 4 | import numpy as np 5 | import torch.nn.functional as F 6 | from torchvision import transforms as T 7 | import matplotlib.pyplot as plt 8 | from PIL import Image 9 | import cv2 10 | import csv 11 | 12 | def get_one_image(image_path): 13 | image = Image.open(image_path) 14 | Transform = [] 15 | Transform.append(T.ToTensor()) 16 | Transform = T.Compose(Transform) 17 | image = Transform(image) 18 | Norm_ = T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 19 | image = Norm_(image) 20 | return [image_path, image] 21 | 22 | 23 | def test_one_image(test_loader, save_path=""): 24 | net = U_Net(img_ch=3, output_ch=1) 25 | # 选择设备,有cuda用cuda,没有就用cpu 26 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 27 | # 将网络拷贝到deivce中 28 | net = nn.DataParallel(net) 29 | net.to(device=device) 30 | # 加载模型参数 31 | # net.load_state_dict(torch.load("/home/program/Unet/models/U_Net-350-0.0002-70-0.0000.pth")) 32 | net.load_state_dict(torch.load("best_model.pth", map_location=device)) 33 | # 测试模式 34 | net = net.module.to(device) 35 | net.eval() 36 | 37 | ########################################## 38 | image_path, image = test_loader[0], test_loader[1] 39 | image = torch.unsqueeze(image, dim=0) 40 | # 将数据拷贝到device中 41 | image = image.to(device=device, dtype=torch.float32) 42 | # # 预测 43 | pred = F.sigmoid(net(image)) 44 | # print(pred) 45 | # 处理结果 46 | pred = np.array(pred.data.cpu()[0])[0] 47 | pred[pred >= 0.5] = 255 48 | pred[pred < 0.5] = 0 49 | 50 | 51 | # 找轮廓 52 | pred = np.array(pred, np.uint8) 53 | contours,_ = cv2.findContours(pred, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 54 | img = cv2.imread(image_path, 1) 55 | cv2.drawContours(img, contours, -1, (0, 0, 255), 1) 56 | 57 | img = img[:, :, ::-1] 58 | img[..., 2] = np.where(pred == 255, 200, img[..., 2]) 59 | 60 | plt.imshow(img) 61 | plt.show() 62 | 63 | filename = image_path.split('/')[-1][:-len(".png")] 64 | save_path = save_path + filename + "_res.png" 65 | # cv2.imwrite(save_path, pred) 66 | img = img[:, :, ::-1] 67 | cv2.imwrite(save_path, img) 68 | 69 | 70 | if __name__ == "__main__": 71 | test_loader = get_one_image("1049.png") 72 | 73 | path = test_one_image(test_loader, save_path="../data_set/test_image/res/")#返回值是具体的存储路径../data_set/test_image/res/1049.png -------------------------------------------------------------------------------- /test/test_train.py: -------------------------------------------------------------------------------- 1 | from data_loader import get_loader 2 | from torch import optim, nn 3 | from torch.autograd import Variable 4 | import torch.nn.functional as F 5 | from evaluation import * 6 | from network import U_Net 7 | import csv 8 | import os 9 | 10 | def train_net(net, device, train_loader, epochs=350, lr=0.00001): 11 | f = open('train.csv', 'a', encoding='utf-8', newline='') 12 | wr = csv.writer(f) 13 | wr.writerow(["Epoch", "Total_Epoch", "loss"]) 14 | f.close() 15 | # 定义RMSprop算法 16 | optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9) 17 | # 定义Loss算法 18 | criterion = nn.BCEWithLogitsLoss() 19 | # best_loss统计,初始化为正无穷 20 | best_loss = float('inf') 21 | # 训练epochs次 22 | for epoch in range(epochs): 23 | # 训练模式 24 | net.train(True) 25 | # 按照batch_size开始训练 26 | print('epoch:', epoch) 27 | average_loss = 0 28 | for i, (image, label) in enumerate(train_loader): 29 | # for image, label in train_loader: 30 | optimizer.zero_grad() 31 | # 将数据拷贝到device中 32 | image = image.to(device=device, dtype=torch.float32) 33 | label = label.to(device=device, dtype=torch.float32) 34 | # 使用网络参数,输出预测结果 35 | pred = net(image) 36 | # SR_probs = F.sigmoid(pred) 37 | # print(SR_probs) 38 | # print(label) 39 | # SR_flat = SR_probs.view(SR_probs.size(0), -1) 40 | # 计算loss 41 | loss = criterion(pred, label) 42 | # print('loss:', loss) 43 | average_loss += loss.item() 44 | print('Loss/train', loss.item()) 45 | # 保存loss值最小的网络参数 46 | if loss < best_loss: 47 | best_loss = loss 48 | torch.save(net.state_dict(), 'best_model.pth') 49 | # 更新参数 50 | loss.backward() 51 | optimizer.step() 52 | average_loss = average_loss / (i + 1) 53 | f = open('train.csv', 'a', encoding='utf-8', newline='') 54 | wr = csv.writer(f) 55 | wr.writerow([epoch + 1, epochs, average_loss]) 56 | f.close() 57 | if __name__ == "__main__": 58 | # 选择设备,有cuda用cuda,没有就用cpu 59 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 60 | # 加载网络,图片单通道3,分类为1。 61 | net = U_Net(img_ch=3, output_ch=1) 62 | # 将网络拷贝到deivce中 63 | net.to(device=device) 64 | net = nn.DataParallel(net) 65 | # 指定训练集地址,开始训练 66 | # data_path = "data/train/" 67 | train_loader = get_loader(image_path="../data_set/Urben_pre/train/", 68 | image_size=256, 69 | batch_size=8, 70 | num_workers=8, 71 | mode='train', 72 | augmentation_prob=0) 73 | train_net(net, device, train_loader) -------------------------------------------------------------------------------- /util/data_preprocess.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | import os 4 | 5 | 6 | def find_river_image(masks_path, images_path, masks_save_path, images_save_path): 7 | #Look up remote sensing images containing water areas and modify multi-category labels. 8 | lists = os.listdir(masks_path) 9 | for item in lists: 10 | original_path = masks_path + item 11 | original_image = Image.open(original_path) 12 | original_img = np.array(original_image) 13 | if 4 in original_img: 14 | print(item) 15 | original_img[original_img != 4] = 0 16 | original_img[original_img == 4] = 255 17 | save_path1 = masks_save_path + item 18 | img1 = Image.fromarray(original_img) 19 | img1.save(save_path1) 20 | 21 | original_path2 = images_path + item 22 | save_path2 = images_save_path + item 23 | original_image2 = Image.open(original_path2) 24 | original_img2 = np.array(original_image2) 25 | img2 = Image.fromarray(original_img2) 26 | img2.save(save_path2) 27 | 28 | return True 29 | 30 | def split_train_image(masks_path, images_path, masks_save_path, images_save_path): 31 | #1024*1024 remote sensing images were cut into 16 256*256 images 32 | lists = os.listdir(masks_path) 33 | print(len(lists)) 34 | count = 0#0 35 | for item in lists: 36 | temp_path1 = masks_path + item 37 | im1 = Image.open(temp_path1) 38 | temp_path2 = images_path + item 39 | im2 = Image.open(temp_path2) 40 | # 准备将图片切割成16张小图片 41 | size = im1.size 42 | weight = int(size[0] // 4) 43 | height = int(size[1] // 4) 44 | for j in range(4): 45 | for i in range(4): 46 | box = (weight * i, height * j, weight * (i + 1), height * (j + 1)) 47 | region1 = im1.crop(box) 48 | region2 = im2.crop(box) 49 | img_numpy = np.array(region1) 50 | if 255 in img_numpy: 51 | print(count) 52 | path1 = masks_save_path + str(count) + ".png" 53 | path2 = images_save_path + str(count) + ".png" 54 | count += 1 55 | region1.save(path1) 56 | region2.save(path2) 57 | return True 58 | 59 | def split_test_image(images_path, images_save_path): 60 | # 1024*1024 remote sensing images were cut into 16 256*256 images 61 | lists = os.listdir(images_path) 62 | print(len(lists)) 63 | for item in lists: 64 | temp_path = images_path + item 65 | im = Image.open(temp_path) 66 | # 准备将图片切割成16张小图片 67 | size = im.size 68 | weight = int(size[0] // 4) 69 | height = int(size[1] // 4) 70 | for j in range(4): 71 | for i in range(4): 72 | box = (weight * i, height * j, weight * (i + 1), height * (j + 1)) 73 | region = im.crop(box) 74 | # img_numpy = np.array(region) 75 | path = images_save_path + item + "_" + str(j) + str(i) + ".png" 76 | region.save(path) 77 | return True 78 | 79 | def combination_test_image(images_path, images_save_path): 80 | pass 81 | # find_river_image("/home/program/Unet/data_set/Urben_original/val/masks_png/", 82 | # "/home/program/Unet/data_set/Urben_original/val/images_png/", 83 | # "/home/program/Unet/data_set/Urben_original/valid/masks_png/", 84 | # "/home/program/Unet/data_set/Urben_original/valid/images_png/") 85 | 86 | # split_train_image("/home/program/Unet/data_set/Urben_original/valid/masks_png/", 87 | # "/home/program/Unet/data_set/Urben_original/valid/images_png/", 88 | # "/home/program/Unet/data_set/Urben_pre/annotation/", 89 | # "/home/program/Unet/data_set/Urben_pre/train/") 90 | 91 | # lists = os.listdir("/home/program/Unet/data_set/Urben_pre/train/") 92 | # print(len(lists)) 93 | 94 | 95 | 96 | --------------------------------------------------------------------------------