├── README.md ├── dataset.py ├── imgs ├── BsiNet.png ├── comparison_results.png └── results.png ├── losses.py ├── models.py ├── preprocess.py ├── test.py ├── train.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # BsiNet 2 | 3 | Official Pytorch Code base for "Delineation of agricultural fields using multi-task BsiNet from high-resolution satellite images" 4 | 5 | [Project](https://github.com/long123524/BsiNet-torch) 6 | 7 | ## Introduction 8 | 9 | This paper presents a new multi-task neural network BsiNet to delineate agricultural fields from remote sensing images. BsiNet learns three tasks, i.e., a core task for agricultural field identification and two auxiliary tasks for field boundary prediction and distance estimation, corresponding to mask, boundary, and distance tasks, respectively. 10 | 11 |

12 | 13 |

14 | 15 |

16 | 17 |

18 | 19 |

20 | 21 |

22 | 23 | 24 | ## Using the code: 25 | 26 | The code is stable while using Python 3.7.0, CUDA >=11.0 27 | 28 | - Clone this repository: 29 | ```bash 30 | git clone https://github.com/long123524/BsiNet-torch 31 | cd BsiNet-torch 32 | ``` 33 | 34 | To install all the dependencies using conda or pip: 35 | 36 | ``` 37 | PyTorch 38 | TensorboardX 39 | OpenCV 40 | numpy 41 | tqdm 42 | ``` 43 | 44 | ## Preprocessing 45 | Using the code preprocess.py to obtain contour and distance maps. 46 | 47 | ## Data Format 48 | 49 | Make sure to put the files as the following structure: 50 | 51 | ``` 52 | inputs 53 | └── 54 | ├── image 55 | | ├── 001.tif 56 | │ ├── 002.tif 57 | │ ├── 003.tif 58 | │ ├── ... 59 | | 60 | └── mask 61 | | ├── 001.tif 62 | | ├── 002.tif 63 | | ├── 003.tif 64 | | ├── ... 65 | └── contour 66 | | ├── 001.tif 67 | | ├── 002.tif 68 | | ├── 003.tif 69 | | ├── ... 70 | └── dist_contour 71 | | ├── 001.tif 72 | | ├── 002.tif 73 | | ├── 003.tif 74 | └── ├── ... 75 | ``` 76 | 77 | For test and validation datasets, the same structure as the above. 78 | 79 | ## Training and testing 80 | 81 | 1. Train the model. 82 | ``` 83 | python train.py --train_path ./fields/image --save_path ./model --model_type 'bsinet' --distance_type 'dist_contour' 84 | ``` 85 | 2. Evaluate. 86 | ``` 87 | python test.py --model_file ./model/150.pt --save_path ./save --model_type 'bsinet' --distance_type 'dist_contour' --val_path ./test_image 88 | ``` 89 | 90 | If you have any questions, you can contact us: Jiang long, hnzzyxlj@163.com and Mengmeng Li, mli@fzu.edu.cn. 91 | 92 | ## GF dataset 93 | A GF2 image (1m) is provided for scientific use: https://pan.baidu.com/s/1isg9jD9AlE9EeTqa3Fqrrg, password:bzfd 94 | Google drive:https://drive.google.com/file/d/1JZtRSxX5PaT3JCzvCLq2Jrt0CBXqZj7c/view?usp=drive_link 95 | A corresponding partial field label is provided for scientific study: https://drive.google.com/file/d/19OrVPkb0MkoaUvaax_9uvnJgSr_dcSSW/view?usp=sharing 96 | 97 | ## A pretrained weight 98 | A pretrained weight on a Xinjiang GF-2 image is provided: https://pan.baidu.com/s/1asAMj4_ZrIQeJiewP2LpqA password:rz8k 99 | Google drive: https://drive.google.com/drive/folders/121T8FjiyEsIbfyLUbrBXYCg75PIzCzRX?usp=sharing 100 | 101 | ### Acknowledgements: 102 | 103 | This code-base uses certain code-blocks and helper functions from Psi-Net 104 | 105 | ### Citation: 106 | If you find this work useful or interesting, please consider citing the following references. 107 | ``` 108 | Citation 1: 109 | {Authors: Long Jiang (龙江), Li Mengmeng* (李蒙蒙), Wang Xiaoqin (汪小钦), et al; 110 | Institute: The Academy of Digital China (Fujian), Fuzhou University, 111 | Article Title: Delineation of agricultural fields using multi-task BsiNet from high-resolution satellite images, 112 | Publication: International Journal of Applied Earth Observation and Geoinformation, 113 | Year: 2022, 114 | Volume:112 115 | Page: 102871, 116 | DOI: 10.1016/j.jag.2022.102871 117 | } 118 | Citation 2: 119 | {Authors: Li Mengmeng* (李蒙蒙), Long Jiang (龙江), et al; 120 | Institute: The Academy of Digital China (Fujian), Fuzhou University, 121 | Article Title: Using a semantic edge-aware multi-task neural network to delineate agricultural parcels from remote sensing images, 122 | Publication: ISPRS Journal of Photogrammetry and Remote Sensing, 123 | Year: 2023, 124 | Volume:200 125 | Page: 24-40, 126 | DOI: 10.1016/j.isprsjprs.2023.04.019 127 | } 128 | Citation 3: 129 | {Authors: Long jiang (龙江), Zhao hang (赵航), Li Mengmeng* (李蒙蒙), et al; 130 | Institute: The Academy of Digital China (Fujian), Fuzhou University; Chinese Academy of Sciences 131 | Article Title: Integrating Segment Anything Model derived boundary prior and high-level semantics for cropland extraction from high-resolution remote sensing images, 132 | Publication: IEEE Geoscience and Remote Sensing Letters, 133 | Year: 2024, 134 | Volume:21, 135 | Page: 1-5, 136 | DOI: 10.1109/LGRS.2024.3454263 137 | } 138 | ... 139 | ``` 140 | ### A large cropland dataset collected from VHR images: 141 | Will be accessible at https://github.com/NanNanmei/HBGNet, more details can be found at a recent collaborative paper "A large-scale VHR parcel dataset and a novel hierarchical semantic boundary-guided network for agricultural parcel delineation (https://www.sciencedirect.com/science/article/pii/S0924271625000395)" 142 | ### A parcel vectorization model: 143 | More details can be found at a recent collaborative paper "Extracting vectorized agricultural parcels from high-resolution satellite images using a Point-Line-Region interactive multitask model" published in the journal of Computers and Electronics in Agriculture. Code is available at https://github.com/mengmengli01/PLR-Net-demo/tree/main. 144 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | The role of this file completes the data reading 3 | "dist_mask" is obtained by using Euclidean distance transformation on the mask 4 | "dist_contour" is obtained by using quasi-Euclidean distance transformation on the mask 5 | """ 6 | 7 | import torch 8 | import numpy as np 9 | import cv2 10 | from PIL import Image, ImageFile 11 | 12 | from skimage import io 13 | import imageio 14 | from torch.utils.data import Dataset 15 | from torchvision import transforms 16 | from scipy import io 17 | import os 18 | from osgeo import gdal 19 | 20 | ### Reading and saving of remote sensing images (Keep coordinate information) 21 | def readTif(fileName, xoff = 0, yoff = 0, data_width = 0, data_height = 0): 22 | dataset = gdal.Open(fileName) 23 | if dataset == None: 24 | print(fileName + "文件无法打开") 25 | # 栅格矩阵的列数 26 | width = dataset.RasterXSize 27 | # 栅格矩阵的行数 28 | height = dataset.RasterYSize 29 | # 波段数 30 | bands = dataset.RasterCount 31 | # 获取数据 32 | if(data_width == 0 and data_height == 0): 33 | data_width = width 34 | data_height = height 35 | data = dataset.ReadAsArray(xoff, yoff, data_width, data_height) 36 | # 获取仿射矩阵信息 37 | geotrans = dataset.GetGeoTransform() 38 | # 获取投影信息 39 | proj = dataset.GetProjection() 40 | return width, height, bands, data, geotrans, proj 41 | 42 | 43 | #保存遥感影像 44 | def writeTiff(im_data, im_geotrans, im_proj, path): 45 | if 'int8' in im_data.dtype.name: 46 | datatype = gdal.GDT_Byte 47 | elif 'int16' in im_data.dtype.name: 48 | datatype = gdal.GDT_UInt16 49 | else: 50 | datatype = gdal.GDT_Float32 51 | if len(im_data.shape) == 3: 52 | im_bands, im_height, im_width = im_data.shape 53 | else: 54 | im_bands, (im_height, im_width) = 1, im_data.shape 55 | # 创建文件 56 | driver = gdal.GetDriverByName("GTiff") 57 | dataset = driver.Create(path, int(im_width), int(im_height), int(im_bands), datatype) 58 | if (dataset != None): 59 | dataset.SetGeoTransform(im_geotrans) # 写入仿射变换参数 60 | dataset.SetProjection(im_proj) # 写入投影 61 | if im_bands == 1: 62 | dataset.GetRasterBand(1).WriteArray(im_data) 63 | else: 64 | for i in range(im_bands): 65 | dataset.GetRasterBand(i + 1).WriteArray(im_data[i]) 66 | del dataset 67 | 68 | 69 | 70 | class DatasetImageMaskContourDist(Dataset): 71 | 72 | def __init__(self, dir, file_names, distance_type): 73 | 74 | self.file_names = file_names 75 | self.distance_type = distance_type 76 | self.dir = dir 77 | 78 | def __len__(self): 79 | 80 | return len(self.file_names) 81 | 82 | def __getitem__(self, idx): 83 | 84 | img_file_name = self.file_names[idx] 85 | image = load_image(os.path.join(self.dir,img_file_name+'.tif')) 86 | mask = load_mask(os.path.join(self.dir,img_file_name+'.tif')) 87 | contour = load_contour(os.path.join(self.dir,img_file_name+'.tif')) 88 | dist = load_distance(os.path.join(self.dir,img_file_name+'.tif'), self.distance_type) 89 | 90 | return img_file_name, image, mask, contour, dist 91 | 92 | 93 | def load_image(path): 94 | 95 | img = Image.open(path) 96 | data_transforms = transforms.Compose( 97 | [ 98 | # transforms.Resize(256), 99 | transforms.ToTensor(), 100 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 101 | 102 | ] 103 | ) 104 | img = data_transforms(img) 105 | 106 | return img 107 | 108 | 109 | def load_mask(path): 110 | mask = cv2.imread(path.replace("image", "mask").replace("tif", "tif"), 0) 111 | # im_width, im_height, im_bands, mask, im_geotrans, im_proj = readTif(path.replace("image", "mask").replace("tif", "tif")) 112 | ###mask = mask/225. 113 | mask[mask == 255] = 1 114 | mask[mask == 0] = 0 115 | 116 | return torch.from_numpy(np.expand_dims(mask, 0)).long() 117 | 118 | 119 | def load_contour(path): 120 | 121 | contour = cv2.imread(path.replace("image", "contour").replace("tif", "tif"), 0) 122 | ###contour = contour/255. 123 | contour[contour ==255] = 1 124 | contour[contour == 0] = 0 125 | 126 | 127 | return torch.from_numpy(np.expand_dims(contour, 0)).long() 128 | 129 | 130 | def load_distance(path, distance_type): 131 | 132 | if distance_type == "dist_mask": 133 | path = path.replace("image", "dist_mask").replace("tif", "mat") 134 | 135 | dist = io.loadmat(path)["D2"] 136 | 137 | if distance_type == "dist_contour": 138 | path = path.replace("image", "dist_contour").replace("tif", "mat") 139 | dist = io.loadmat(path)["D2"] 140 | 141 | if distance_type == "dist_contour_tif": 142 | dist = cv2.imread(path.replace("image", "dist_contour_tif").replace("tif", "tif"), 0) 143 | dist = dist/255. 144 | 145 | return torch.from_numpy(np.expand_dims(dist, 0)).float() 146 | -------------------------------------------------------------------------------- /imgs/BsiNet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/long123524/BsiNet-torch/dfe2b98d6ab04c2afe446769787e3476030a9b58/imgs/BsiNet.png -------------------------------------------------------------------------------- /imgs/comparison_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/long123524/BsiNet-torch/dfe2b98d6ab04c2afe446769787e3476030a9b58/imgs/comparison_results.png -------------------------------------------------------------------------------- /imgs/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/long123524/BsiNet-torch/dfe2b98d6ab04c2afe446769787e3476030a9b58/imgs/results.png -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | """Calculating the loss 2 | You can build the loss function of BsiNet by combining multiple losses 3 | """ 4 | 5 | import torch 6 | import numpy as np 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | def dice_loss(prediction, target): 12 | """Calculating the dice loss 13 | Args: 14 | prediction = predicted image 15 | target = Targeted image 16 | Output: 17 | dice_loss""" 18 | 19 | smooth = 1.0 20 | 21 | i_flat = prediction.view(-1) 22 | t_flat = target.view(-1) 23 | 24 | intersection = (i_flat * t_flat).sum() 25 | 26 | return 1 - ((2. * intersection + smooth) / (i_flat.sum() + t_flat.sum() + smooth)) 27 | 28 | 29 | def calc_loss(prediction, target, bce_weight=0.5): 30 | """Calculating the loss and metrics 31 | Args: 32 | prediction = predicted image 33 | target = Targeted image 34 | metrics = Metrics printed 35 | bce_weight = 0.5 (default) 36 | Output: 37 | loss : dice loss of the epoch """ 38 | bce = F.binary_cross_entropy_with_logits(prediction, target) 39 | prediction = torch.sigmoid(prediction) 40 | dice = dice_loss(prediction, target) 41 | 42 | loss = bce * bce_weight + dice * (1 - bce_weight) 43 | 44 | return loss 45 | 46 | 47 | 48 | 49 | 50 | class log_cosh_dice_loss(nn.Module): 51 | def __init__(self, num_classes=1, smooth=1, alpha=0.7): 52 | super(log_cosh_dice_loss, self).__init__() 53 | self.smooth = smooth 54 | self.alpha = alpha 55 | self.num_classes = num_classes 56 | 57 | def forward(self, outputs, targets): 58 | x = self.dice_loss(outputs, targets) 59 | return torch.log((torch.exp(x) + torch.exp(-x)) / 2.0) 60 | 61 | def dice_loss(self, y_pred, y_true): 62 | """[function to compute dice loss] 63 | Args: 64 | y_true ([float32]): [ground truth image] 65 | y_pred ([float32]): [predicted image] 66 | Returns: 67 | [float32]: [loss value] 68 | """ 69 | smooth = 1. 70 | y_true = torch.flatten(y_true) 71 | y_pred = torch.flatten(y_pred) 72 | intersection = torch.sum((y_true * y_pred)) 73 | coeff = (2. * intersection + smooth) / (torch.sum(y_true) + torch.sum(y_pred) + smooth) 74 | return (1. - coeff) 75 | 76 | 77 | def focal_loss(predict, label, alpha=0.6, beta=2): 78 | probs = torch.sigmoid(predict) 79 | # 交叉熵Loss 80 | ce_loss = nn.BCELoss() 81 | ce_loss = ce_loss(probs,label) 82 | alpha_ = torch.ones_like(predict) * alpha 83 | # 正label 为alpha, 负label为1-alpha 84 | alpha_ = torch.where(label > 0, alpha_, 1.0 - alpha_) 85 | probs_ = torch.where(label > 0, probs, 1.0 - probs) 86 | # loss weight matrix 87 | loss_matrix = alpha_ * torch.pow((1.0 - probs_), beta) 88 | # 最终loss 矩阵,为对应的权重与loss值相乘,控制预测越不准的产生更大的loss 89 | loss = loss_matrix * ce_loss 90 | loss = torch.sum(loss) 91 | return loss 92 | 93 | 94 | 95 | class Loss: 96 | def __init__(self, dice_weight=0.0, class_weights=None, num_classes=1, device=None): 97 | self.device = device 98 | if class_weights is not None: 99 | nll_weight = torch.from_numpy(class_weights.astype(np.float32)).to( 100 | self.device 101 | ) 102 | else: 103 | nll_weight = None 104 | self.nll_loss = nn.NLLLoss2d(weight=nll_weight) 105 | self.dice_weight = dice_weight 106 | self.num_classes = num_classes 107 | 108 | def __call__(self, outputs, targets): 109 | loss = self.nll_loss(outputs, targets) 110 | if self.dice_weight: 111 | eps = 1e-7 112 | cls_weight = self.dice_weight / self.num_classes 113 | for cls in range(self.num_classes): 114 | dice_target = (targets == cls).float() 115 | dice_output = outputs[:, cls].exp() 116 | intersection = (dice_output * dice_target).sum() 117 | # union without intersection 118 | uwi = dice_output.sum() + dice_target.sum() + eps 119 | loss += (1 - intersection / uwi) * cls_weight 120 | loss /= (1 + self.dice_weight) 121 | return loss 122 | 123 | 124 | class LossMulti: 125 | def __init__( 126 | self, jaccard_weight=0.0, class_weights=None, num_classes=1, device=None 127 | ): 128 | self.device = device 129 | if class_weights is not None: 130 | nll_weight = torch.from_numpy(class_weights.astype(np.float32)).to( 131 | self.device 132 | ) 133 | else: 134 | nll_weight = None 135 | 136 | self.nll_loss = nn.NLLLoss(weight=nll_weight) 137 | self.jaccard_weight = jaccard_weight 138 | self.num_classes = num_classes 139 | 140 | def __call__(self, outputs, targets): 141 | 142 | targets = targets.squeeze(1) 143 | 144 | loss = (1 - self.jaccard_weight) * self.nll_loss(outputs, targets) 145 | 146 | if self.jaccard_weight: 147 | eps = 1e-7 # 原先是1e-7 148 | for cls in range(self.num_classes): 149 | jaccard_target = (targets == cls).float() 150 | jaccard_output = outputs[:, cls].exp() 151 | intersection = (jaccard_output * jaccard_target).sum() 152 | 153 | union = jaccard_output.sum() + jaccard_target.sum() 154 | loss -= ( 155 | torch.log((intersection + eps) / (union - intersection + eps)) 156 | * self.jaccard_weight 157 | ) 158 | return loss 159 | 160 | 161 | class LossBsiNet: 162 | def __init__(self, weights=[1, 1, 1]): 163 | self.criterion1 = LossMulti(num_classes=2) #mask_loss 164 | self.criterion2 = LossMulti(num_classes=2) #contour_loss 165 | self.criterion3 = nn.MSELoss() ##distance_loss 166 | self.weights = weights 167 | 168 | def __call__(self, outputs1, outputs2, outputs3, targets1, targets2, targets3): 169 | # 170 | criterion = ( 171 | self.weights[0] * self.criterion1(outputs1, targets1) 172 | + self.weights[1] * self.criterion2(outputs2, targets2) 173 | + self.weights[2] * self.criterion3(outputs3, targets3) 174 | ) 175 | 176 | return criterion 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | """Model construction 2 | 1. We offer two versions of BsiNet, one concise and the other clear 3 | 2. The clear version is designed for user understanding and modification 4 | 3. You can use these attention mechanism we provide to bulid a new multi-task model, and you can also 5 | 4. You can also add your own module or change the location of the attention mechanism to build a better model 6 | """ 7 | 8 | 9 | from torch import nn 10 | import torch 11 | from torch.nn import functional as F 12 | from torch.nn.parameter import Parameter 13 | 14 | 15 | def conv3x3(in_, out): 16 | return nn.Conv2d(in_, out, 3, padding=1) 17 | 18 | 19 | class Conv3BN(nn.Module): 20 | def __init__(self, in_: int, out: int, bn=False): 21 | super().__init__() 22 | self.conv = conv3x3(in_, out) 23 | self.bn = nn.BatchNorm2d(out) if bn else None 24 | self.activation = nn.ReLU(inplace=True) 25 | 26 | def forward(self, x): 27 | x = self.conv(x) 28 | if self.bn is not None: 29 | x = self.bn(x) 30 | x = self.activation(x) 31 | return x 32 | 33 | 34 | class NetModule(nn.Module): 35 | def __init__(self, in_: int, out: int): 36 | super().__init__() 37 | self.l1 = Conv3BN(in_, out) 38 | self.l2 = Conv3BN(out, out) 39 | 40 | def forward(self, x): 41 | x = self.l1(x) 42 | x = self.l2(x) 43 | return x 44 | 45 | 46 | #SE注意力机制 47 | class SELayer(nn.Module): 48 | def __init__(self, channel, reduction=16): 49 | super(SELayer, self).__init__() 50 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 51 | self.fc = nn.Sequential( 52 | nn.Linear(channel, channel // reduction, bias=False), 53 | nn.ReLU(inplace=True), 54 | nn.Linear(channel // reduction, channel, bias=False), 55 | nn.Sigmoid() 56 | ) 57 | 58 | def forward(self, x): 59 | b, c, _, _ = x.size() 60 | y = self.avg_pool(x).view(b, c) 61 | y = self.fc(y).view(b, c, 1, 1) 62 | return x * y.expand_as(x) 63 | 64 | 65 | 66 | class SpatialGroupEnhance(nn.Module): 67 | def __init__(self, groups = 64): 68 | super(SpatialGroupEnhance, self).__init__() 69 | self.groups = groups 70 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 71 | self.weight = Parameter(torch.zeros(1, groups, 1, 1)) 72 | self.bias = Parameter(torch.ones(1, groups, 1, 1)) 73 | self.sig = nn.Sigmoid() 74 | 75 | def forward(self, x): # (b, c, h, w) 76 | b, c, h, w = x.size() 77 | x = x.view(b * self.groups, -1, h, w) 78 | xn = x * self.avg_pool(x) 79 | xn = xn.sum(dim=1, keepdim=True) 80 | t = xn.view(b * self.groups, -1) 81 | t = t - t.mean(dim=1, keepdim=True) 82 | std = t.std(dim=1, keepdim=True) + 1e-5 83 | t = t / std 84 | t = t.view(b, self.groups, h, w) 85 | t = t * self.weight + self.bias 86 | t = t.view(b * self.groups, 1, h, w) 87 | x = x * self.sig(t) 88 | x = x.view(b, c, h, w) 89 | return x 90 | 91 | ######CBAM注意力 92 | class ChannelAttention(nn.Module): 93 | def __init__(self, in_planes, ratio=16): 94 | super(ChannelAttention, self).__init__() 95 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 96 | self.max_pool = nn.AdaptiveMaxPool2d(1) 97 | 98 | self.fc1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False) 99 | self.relu1 = nn.ReLU() 100 | self.fc2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False) 101 | 102 | self.sigmoid = nn.Sigmoid() 103 | 104 | def forward(self, x): 105 | avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x)))) 106 | max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x)))) 107 | out = avg_out + max_out 108 | return self.sigmoid(out) 109 | 110 | class SpatialAttention(nn.Module): 111 | def __init__(self, kernel_size=7): 112 | super(SpatialAttention, self).__init__() 113 | 114 | assert kernel_size in (3, 7), 'kernel size must be 3 or 7' 115 | padding = 3 if kernel_size == 7 else 1 116 | 117 | self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False) 118 | self.sigmoid = nn.Sigmoid() 119 | 120 | def forward(self, x): 121 | avg_out = torch.mean(x, dim=1, keepdim=True) 122 | max_out, _ = torch.max(x, dim=1, keepdim=True) 123 | x = torch.cat([avg_out, max_out], dim=1) 124 | x = self.conv1(x) 125 | return self.sigmoid(x) 126 | 127 | 128 | 129 | #scce注意力模块 130 | class cSE(nn.Module): # noqa: N801 131 | """ 132 | The channel-wise SE (Squeeze and Excitation) block from the 133 | `Squeeze-and-Excitation Networks`__ paper. 134 | Adapted from 135 | https://www.kaggle.com/c/tgs-salt-identification-challenge/discussion/65939 136 | and 137 | https://www.kaggle.com/c/tgs-salt-identification-challenge/discussion/66178 138 | Shape: 139 | - Input: (batch, channels, height, width) 140 | - Output: (batch, channels, height, width) (same shape as input) 141 | __ https://arxiv.org/abs/1709.01507 142 | """ 143 | 144 | def __init__(self, in_channels: int, r: int = 16): 145 | """ 146 | Args: 147 | in_channels: The number of channels 148 | in the feature map of the input. 149 | r: The reduction ratio of the intermediate channels. 150 | Default: 16. 151 | """ 152 | super().__init__() 153 | self.linear1 = nn.Linear(in_channels, in_channels // r) 154 | self.linear2 = nn.Linear(in_channels // r, in_channels) 155 | 156 | def forward(self, x: torch.Tensor): 157 | """Forward call.""" 158 | input_x = x 159 | 160 | x = x.view(*(x.shape[:-2]), -1).mean(-1) 161 | x = F.relu(self.linear1(x), inplace=True) 162 | x = self.linear2(x) 163 | x = x.unsqueeze(-1).unsqueeze(-1) 164 | x = torch.sigmoid(x) 165 | 166 | x = torch.mul(input_x, x) 167 | return x 168 | 169 | 170 | class sSE(nn.Module): # noqa: N801 171 | """ 172 | The sSE (Channel Squeeze and Spatial Excitation) block from the 173 | `Concurrent Spatial and Channel ‘Squeeze & Excitation’ 174 | in Fully Convolutional Networks`__ paper. 175 | Adapted from 176 | https://www.kaggle.com/c/tgs-salt-identification-challenge/discussion/66178 177 | Shape: 178 | - Input: (batch, channels, height, width) 179 | - Output: (batch, channels, height, width) (same shape as input) 180 | __ https://arxiv.org/abs/1803.02579 181 | """ 182 | 183 | def __init__(self, in_channels: int): 184 | """ 185 | Args: 186 | in_channels: The number of channels 187 | in the feature map of the input. 188 | """ 189 | super().__init__() 190 | self.conv = nn.Conv2d(in_channels, 1, kernel_size=1, stride=1) 191 | 192 | def forward(self, x: torch.Tensor): 193 | """Forward call.""" 194 | input_x = x 195 | 196 | x = self.conv(x) 197 | x = torch.sigmoid(x) 198 | 199 | x = torch.mul(input_x, x) 200 | return x 201 | 202 | 203 | class scSE(nn.Module): # noqa: N801 204 | """ 205 | The scSE (Concurrent Spatial and Channel Squeeze and Channel Excitation) 206 | block from the `Concurrent Spatial and Channel ‘Squeeze & Excitation’ 207 | in Fully Convolutional Networks`__ paper. 208 | Adapted from 209 | https://www.kaggle.com/c/tgs-salt-identification-challenge/discussion/66178 210 | Shape: 211 | - Input: (batch, channels, height, width) 212 | - Output: (batch, channels, height, width) (same shape as input) 213 | __ https://arxiv.org/abs/1803.02579 214 | """ 215 | 216 | def __init__(self, in_channels: int, r: int = 16): 217 | """ 218 | Args: 219 | in_channels: The number of channels 220 | in the feature map of the input. 221 | r: The reduction ratio of the intermediate channels. 222 | Default: 16. 223 | """ 224 | super().__init__() 225 | self.cse_block = cSE(in_channels, r) 226 | self.sse_block = sSE(in_channels) 227 | 228 | def forward(self, x: torch.Tensor): 229 | """Forward call.""" 230 | cse = self.cse_block(x) 231 | sse = self.sse_block(x) 232 | x = torch.add(cse, sse) 233 | return x 234 | 235 | 236 | ##This is a concise version of the BsiNet whose modules are better packaged 237 | 238 | class BsiNet(nn.Module): 239 | 240 | output_downscaled = 1 241 | module = NetModule 242 | 243 | def __init__( 244 | self, 245 | input_channels: int = 3, 246 | filters_base: int = 32, 247 | down_filter_factors=(1, 2, 4, 8, 16), 248 | up_filter_factors=(1, 2, 4, 8, 16), 249 | bottom_s=4, 250 | num_classes=1, 251 | add_output=True, 252 | ): 253 | super().__init__() 254 | self.num_classes = num_classes 255 | assert len(down_filter_factors) == len(up_filter_factors) 256 | assert down_filter_factors[-1] == up_filter_factors[-1] 257 | down_filter_sizes = [filters_base * s for s in down_filter_factors] 258 | up_filter_sizes = [filters_base * s for s in up_filter_factors] 259 | self.down, self.up = nn.ModuleList(), nn.ModuleList() 260 | self.down.append(self.module(input_channels, down_filter_sizes[0])) 261 | for prev_i, nf in enumerate(down_filter_sizes[1:]): 262 | self.down.append(self.module(down_filter_sizes[prev_i], nf)) 263 | for prev_i, nf in enumerate(up_filter_sizes[1:]): 264 | self.up.append( 265 | self.module(down_filter_sizes[prev_i] + nf, up_filter_sizes[prev_i]) 266 | ) 267 | 268 | pool = nn.MaxPool2d(2, 2) 269 | pool_bottom = nn.MaxPool2d(bottom_s, bottom_s) 270 | upsample = nn.Upsample(scale_factor=2) 271 | upsample_bottom = nn.Upsample(scale_factor=bottom_s) 272 | self.downsamplers = [None] + [pool] * (len(self.down) - 1) 273 | self.downsamplers[-1] = pool_bottom 274 | self.upsamplers = [upsample] * len(self.up) 275 | self.upsamplers[-1] = upsample_bottom 276 | self.add_output = add_output 277 | self.sge = SpatialGroupEnhance(32) 278 | 279 | if add_output: 280 | self.conv_final1 = nn.Conv2d(up_filter_sizes[0], num_classes, 1) 281 | self.conv_final2 = nn.Conv2d(up_filter_sizes[0], num_classes, 1) 282 | self.conv_final3 = nn.Conv2d(up_filter_sizes[0], 1, 1) 283 | 284 | def forward(self, x): 285 | xs = [] 286 | for downsample, down in zip(self.downsamplers, self.down): 287 | x_in = x if downsample is None else downsample(xs[-1]) 288 | x_out = down(x_in) 289 | xs.append(x_out) 290 | 291 | for x_skip, upsample, up in reversed( 292 | list(zip(xs[:-1], self.upsamplers, self.up)) 293 | ): 294 | 295 | x_out2 = upsample(x_out) 296 | x_out= (torch.cat([x_out2, x_skip], 1)) 297 | x_out = up(x_out) 298 | 299 | if self.add_output: 300 | 301 | x_out = self.sge(x_out) 302 | 303 | x_out1 = self.conv_final1(x_out) 304 | x_out2 = self.conv_final2(x_out) 305 | x_out3 = self.conv_final3(x_out) 306 | if self.num_classes > 1: 307 | x_out1 = F.log_softmax(x_out1,dim=1) 308 | x_out2 = F.log_softmax(x_out2,dim=1) 309 | x_out3 = torch.sigmoid(x_out3) 310 | 311 | return [x_out1, x_out2, x_out3] 312 | 313 | 314 | 315 | 316 | ##This is a clearer BsiNet which shows a clearer building process 317 | 318 | class BsiNet_2(nn.Module): 319 | def __init__( 320 | self, 321 | input_channels: int = 3, 322 | filters_base: int = 32, 323 | num_classes=1, 324 | add_output=True, 325 | ): 326 | super().__init__() 327 | self.num_classes = num_classes 328 | self.add_output = add_output 329 | self.conv1 = NetModule(input_channels, 32) 330 | self.conv2 = NetModule(32, 64) 331 | self.conv3 = NetModule(64, 128) 332 | self.conv4 = NetModule(128, 256) 333 | self.conv5 = NetModule(256, 512) 334 | 335 | self.conv6 = NetModule(768, 256) 336 | self.conv7 = NetModule(384, 128) 337 | self.conv8 = NetModule(192, 64) 338 | self.conv9 = NetModule(96, 32) 339 | 340 | self.pool1 = nn.MaxPool2d(2, 2) 341 | self.pool2 = nn.MaxPool2d(4, 4) 342 | self.upsample1 = nn.Upsample(scale_factor=2) 343 | self.upsample2 = nn.Upsample(scale_factor=4) 344 | self.sge = SpatialGroupEnhance(32) 345 | if add_output: 346 | self.conv_final1 = nn.Conv2d(filters_base, num_classes, 1) 347 | self.conv_final2 = nn.Conv2d(filters_base, num_classes, 1) 348 | self.conv_final3 = nn.Conv2d(filters_base, 1, 1) 349 | 350 | def forward(self, x): 351 | x1 = self.conv1(x) 352 | 353 | x2 = self.conv2(x1) 354 | x2 = self.pool1(x2) 355 | 356 | x3 = self.conv3(x2) 357 | x3 = self.pool1(x3) 358 | 359 | x4 = self.conv4(x3) 360 | x4 = self.pool1(x4) 361 | 362 | x5 = self.conv5(x4) 363 | x5 = self.pool2(x5) 364 | 365 | x_6 = self.upsample2(x5) 366 | x6 = self.conv6(torch.cat([x_6, x4], 1)) 367 | x6 = self.upsample1(x6) 368 | 369 | x7 = self.conv7(torch.cat([x6, x3], 1)) 370 | x7 = self.upsample1(x7) 371 | 372 | x8 = self.conv8(torch.cat([x7, x2], 1)) 373 | x8 = self.upsample1(x8) 374 | 375 | x9 = self.conv9(torch.cat([x8, x1], 1)) 376 | x_out = self.sge(x9) 377 | 378 | if self.add_output: 379 | 380 | x_out1 = self.conv_final1(x_out) 381 | x_out2 = self.conv_final2(x_out) 382 | x_out3 = self.conv_final3(x_out) 383 | if self.num_classes > 1: 384 | x_out1 = F.log_softmax(x_out1, dim=1) 385 | x_out2 = F.log_softmax(x_out2, dim=1) 386 | x_out3 = torch.sigmoid(x_out3) 387 | 388 | return [x_out1, x_out2, x_out3] 389 | 390 | 391 | 392 | 393 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | ## Example: A simple example to obtain distsance map and boundary map 2 | import numpy as np 3 | import os 4 | import cv2 5 | from osgeo import gdal 6 | import scipy.ndimage as sn 7 | 8 | def read_img(filename): 9 | dataset=gdal.Open(filename) 10 | 11 | im_width = dataset.RasterXSize 12 | im_height = dataset.RasterYSize 13 | 14 | im_geotrans = dataset.GetGeoTransform() 15 | im_proj = dataset.GetProjection() 16 | im_data = dataset.ReadAsArray(0,0,im_width,im_height) 17 | 18 | del dataset 19 | return im_proj, im_geotrans, im_width, im_height, im_data 20 | 21 | 22 | def write_img(filename, im_proj, im_geotrans, im_data): 23 | if 'int8' in im_data.dtype.name: 24 | datatype = gdal.GDT_Byte 25 | elif 'int16' in im_data.dtype.name: 26 | datatype = gdal.GDT_UInt16 27 | else: 28 | datatype = gdal.GDT_Float32 29 | 30 | if len(im_data.shape) == 3: 31 | im_bands, im_height, im_width = im_data.shape 32 | else: 33 | im_bands, (im_height, im_width) = 1,im_data.shape 34 | 35 | driver = gdal.GetDriverByName("GTiff") 36 | dataset = driver.Create(filename, im_width, im_height, im_bands, datatype) 37 | 38 | dataset.SetGeoTransform(im_geotrans) 39 | dataset.SetProjection(im_proj) 40 | 41 | if im_bands == 1: 42 | dataset.GetRasterBand(1).WriteArray(im_data) 43 | else: 44 | for i in range(im_bands): 45 | dataset.GetRasterBand(i+1).WriteArray(im_data[i]) 46 | 47 | del dataset 48 | 49 | 50 | 51 | maskRoot = r"C:\Users\hnzzy\Desktop\mask" 52 | distRoot = r"C:\Users\hnzzy\Desktop\dist" 53 | boundaryRoot = r"C:\Users\hnzzy\Desktop\boundary" 54 | 55 | for imgPath in os.listdir(maskRoot): 56 | input_path = os.path.join(maskRoot, imgPath) 57 | boundaryOutPath = os.path.join(boundaryRoot, imgPath) 58 | distOutPath = os.path.join(distRoot, imgPath) 59 | im_proj, im_geotrans, im_width, im_height, im_data = read_img(input_path) 60 | result = cv2.distanceTransform(src=im_data, distanceType=cv2.DIST_L2, maskSize=3) 61 | min_value = np.min(result) 62 | max_value = np.max(result) 63 | scaled_image = ((result - min_value) / (max_value - min_value)) * 255 64 | result = scaled_image.astype(np.uint8) 65 | # result = result.astype(np.uint8) 66 | write_img(distOutPath, im_proj, im_geotrans, result) 67 | ##distance map(you can also use bwdist function in Matlab to obtain distance map) 68 | ###boundary(you can also use bwperim function in Matlab to obtain boundary map) 69 | boundary = cv2.Canny(im_data, 100, 200) 70 | ## dilation 71 | # kernel = np.ones((3, 3), np.uint8) 72 | # boundary = cv2.dilate(boundary, kernel, iterations=1) 73 | write_img(boundaryOutPath, im_proj, im_geotrans, boundary) 74 | 75 | 76 | 77 | 78 | 79 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from torch.utils.data import DataLoader 4 | from dataset import DatasetImageMaskContourDist 5 | from models import BsiNet 6 | from tqdm import tqdm 7 | import numpy as np 8 | import cv2 9 | from utils import create_validation_arg_parser 10 | from torch import nn 11 | 12 | def build_model(model_type): 13 | 14 | if model_type == "bsinet": 15 | model = BsiNet(num_classes=2) 16 | 17 | return model 18 | 19 | 20 | if __name__ == "__main__": 21 | 22 | args = create_validation_arg_parser().parse_args() 23 | 24 | args.model_file = './bsi/150.pt' 25 | args.save_path = './save' 26 | args.model_type = 'bsinet' 27 | args.distance_type = 'dist_contour' 28 | args.test_path = './test' 29 | 30 | 31 | test_path = args.test_path + '/' + 'image' 32 | model_file = args.model_file 33 | save_path = args.save_path 34 | model_type = args.model_type 35 | 36 | cuda_no = args.cuda_no 37 | CUDA_SELECT = "cuda:{}".format(cuda_no) 38 | device = torch.device(CUDA_SELECT if torch.cuda.is_available() else "cpu") 39 | 40 | img_name = [] 41 | for img_file in os.listdir(test_path): 42 | img_name.append(img_file[:-4]) 43 | valLoader = DataLoader(DatasetImageMaskContourDist(test_path, img_name,args.distance_type)) 44 | 45 | if not os.path.exists(save_path): 46 | os.mkdir(save_path) 47 | 48 | model = build_model(model_type) 49 | model = nn.DataParallel(model) 50 | model = model.to(device) 51 | model.load_state_dict(torch.load(model_file)) 52 | model.eval() 53 | 54 | for i, (img_file_name, inputs, targets1, targets2, targets3) in enumerate( 55 | tqdm(valLoader) 56 | ): 57 | 58 | inputs = inputs.to(device) 59 | outputs1, outputs2, outputs3 = model(inputs) 60 | 61 | ## TTA 62 | # outputs4, outputs5, outputs6 = model(torch.flip(inputs, [-1])) 63 | # predict_2 = torch.flip(outputs4, [-1]) 64 | # outputs7, outputs8, outputs9 = model(torch.flip(inputs, [-2])) 65 | # predict_3 = torch.flip(outputs7, [-2]) 66 | # outputs10, outputs11, outputs12 = model(torch.flip(inputs, [-1, -2])) 67 | # predict_4 = torch.flip(outputs10, [-1, -2]) 68 | # predict_list = outputs1 + predict_2 + predict_3 + predict_4 69 | # pred1 = predict_list/4.0 70 | 71 | outputs1 = outputs1.detach().cpu().numpy().squeeze() 72 | 73 | res = np.zeros((256, 256)) 74 | indices = np.argmax(outputs1, axis=0) 75 | res[indices == 1] = 255 76 | res[indices == 0] = 0 77 | res = np.array(res, dtype='uint8') # 转变为8字节型 78 | output_path = os.path.join( 79 | save_path, img_file_name[0]+'.tif' 80 | ) 81 | cv2.imwrite(output_path, res) 82 | 83 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import logging 3 | import os 4 | import random 5 | import torch 6 | from dataset import DatasetImageMaskContourDist 7 | from losses import LossBsiNet 8 | from models import BsiNet 9 | from tensorboardX import SummaryWriter 10 | from torch import nn 11 | from torch.utils.data import DataLoader 12 | from tqdm import tqdm 13 | from utils import visualize, create_train_arg_parser,evaluate 14 | # from torchsummary import summary 15 | from sklearn.model_selection import train_test_split 16 | 17 | def define_loss(loss_type, weights=[1, 1, 1]): 18 | 19 | if loss_type == "bsinet": 20 | criterion = LossBsiNet(weights) 21 | 22 | return criterion 23 | 24 | 25 | def build_model(model_type): 26 | 27 | if model_type == "bsinet": 28 | model = BsiNet(num_classes=2) 29 | 30 | return model 31 | 32 | 33 | def train_model(model, targets, model_type, criterion, optimizer): 34 | 35 | if model_type == "bsinet": 36 | 37 | optimizer.zero_grad() 38 | 39 | with torch.set_grad_enabled(True): 40 | outputs = model(inputs) 41 | loss = criterion( 42 | outputs[0], outputs[1], outputs[2], targets[0], targets[1], targets[2] 43 | ) 44 | loss.backward() 45 | optimizer.step() 46 | 47 | return loss 48 | 49 | 50 | if __name__ == "__main__": 51 | 52 | args = create_train_arg_parser().parse_args() 53 | 54 | args.distance_type = 'dist_contour' 55 | # args.pretrained_model_path = './best_merge_model_article/85.pt' 56 | 57 | args.train_path = './train/image/' 58 | # args.val_path = './XJ_goole/test/image/' 59 | args.model_type = 'bsinet' 60 | args.save_path = './model' 61 | 62 | CUDA_SELECT = "cuda:{}".format(args.cuda_no) 63 | log_path = args.save_path + "/summary" 64 | writer = SummaryWriter(log_dir=log_path) 65 | 66 | logging.basicConfig( 67 | filename="".format(args.object_type), 68 | filemode="a", 69 | format="%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s", 70 | datefmt="%Y-%m-%d %H:%M", 71 | level=logging.INFO, 72 | ) 73 | logging.info("") 74 | 75 | # train_file_names = glob.glob(os.path.join(args.train_path, "*.tif")) 76 | # random.shuffle(train_file_names) 77 | # val_file_names = glob.glob(os.path.join(args.val_path, "*.tif")) 78 | 79 | train_file_names = glob.glob(os.path.join(args.train_path, "*.tif")) 80 | random.shuffle(train_file_names) 81 | 82 | img_ids = [os.path.splitext(os.path.basename(p))[0] for p in train_file_names] 83 | train_file, val_file = train_test_split(img_ids, test_size=0.2, random_state=41) 84 | 85 | device = torch.device(CUDA_SELECT if torch.cuda.is_available() else "cpu") 86 | print(device) 87 | model = build_model(args.model_type) 88 | 89 | if torch.cuda.device_count() > 0: #本来是0 90 | print("Let's use", torch.cuda.device_count(), "GPUs!") 91 | model = nn.DataParallel(model) 92 | 93 | model = model.to(device) 94 | # summary(model, input_size=(3, 256, 256)) 95 | 96 | epoch_start = "0" 97 | if args.use_pretrained: 98 | print("Loading Model {}".format(os.path.basename(args.pretrained_model_path))) 99 | model.load_state_dict(torch.load(args.pretrained_model_path)) #加了False 100 | epoch_start = os.path.basename(args.pretrained_model_path).split(".")[0] 101 | print(epoch_start) 102 | print('train',args.use_pretrained) 103 | trainLoader = DataLoader( 104 | DatasetImageMaskContourDist(args.train_path,train_file, args.distance_type), 105 | batch_size=args.batch_size,drop_last=False, shuffle=True 106 | ) 107 | devLoader = DataLoader( 108 | DatasetImageMaskContourDist(args.train_path,val_file, args.distance_type),drop_last=True, 109 | ) 110 | displayLoader = DataLoader( 111 | DatasetImageMaskContourDist(args.train_path,val_file, args.distance_type), 112 | batch_size=args.val_batch_size,drop_last=True, shuffle=True 113 | ) 114 | 115 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) 116 | # optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) 117 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, int(1e10), eta_min=1e-5) 118 | # scheduler = optim.lr_scheduler.StepLR(optimizer, 50, 0.1) #新加的 119 | criterion = define_loss(args.model_type) 120 | 121 | 122 | for epoch in tqdm( 123 | range(int(epoch_start) + 1, int(epoch_start) + 1 + args.num_epochs) 124 | ): 125 | 126 | global_step = epoch * len(trainLoader) 127 | running_loss = 0.0 128 | 129 | for i, (img_file_name, inputs, targets1, targets2,targets3) in enumerate( 130 | tqdm(trainLoader) 131 | ): 132 | 133 | model.train() 134 | 135 | inputs = inputs.to(device) 136 | targets1 = targets1.to(device) 137 | targets2 = targets2.to(device) 138 | targets3 = targets3.to(device) 139 | 140 | targets = [targets1, targets2,targets3] 141 | 142 | 143 | loss = train_model(model, targets, args.model_type, criterion, optimizer) 144 | 145 | writer.add_scalar("loss", loss.item(), epoch) 146 | 147 | running_loss += loss.item() * inputs.size(0) 148 | scheduler.step() 149 | 150 | epoch_loss = running_loss / len(train_file_names) 151 | print(epoch_loss) 152 | 153 | if epoch % 1 == 0: 154 | 155 | dev_loss, dev_time = evaluate(device, epoch, model, devLoader, writer) 156 | writer.add_scalar("loss_valid", dev_loss, epoch) 157 | visualize(device, epoch, model, displayLoader, writer, args.val_batch_size) 158 | print("Global Loss:{} Val Loss:{}".format(epoch_loss, dev_loss)) 159 | else: 160 | print("Global Loss:{} ".format(epoch_loss)) 161 | 162 | logging.info("epoch:{} train_loss:{} ".format(epoch, epoch_loss)) 163 | if epoch % 5 == 0: 164 | torch.save( 165 | model.state_dict(), os.path.join(args.save_path, str(epoch) + ".pt") 166 | ) 167 | 168 | 169 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | import numpy as np 4 | import torchvision 5 | from torch.nn import functional as F 6 | import time 7 | import argparse 8 | 9 | 10 | def evaluate(device, epoch, model, data_loader, writer): 11 | model.eval() 12 | losses = [] 13 | start = time.perf_counter() 14 | with torch.no_grad(): 15 | 16 | for iter, data in enumerate(tqdm(data_loader)): 17 | 18 | _, inputs, targets, _,_ = data 19 | inputs = inputs.to(device) 20 | targets = targets.to(device) 21 | outputs = model(inputs) 22 | loss = F.nll_loss(outputs[0], targets.squeeze(1)) 23 | losses.append(loss.item()) 24 | 25 | writer.add_scalar("Dev_Loss", np.mean(losses), epoch) 26 | 27 | return np.mean(losses), time.perf_counter() - start 28 | 29 | 30 | def visualize(device, epoch, model, data_loader, writer, val_batch_size, train=True): 31 | def save_image(image, tag, val_batch_size): 32 | image -= image.min() 33 | image /= image.max() 34 | grid = torchvision.utils.make_grid( 35 | image, nrow=int(np.sqrt(val_batch_size)), pad_value=0, padding=25 36 | ) 37 | writer.add_image(tag, grid, epoch) 38 | 39 | model.eval() 40 | with torch.no_grad(): 41 | for iter, data in enumerate(tqdm(data_loader)): 42 | _, inputs, targets, _,_ = data 43 | 44 | inputs = inputs.to(device) 45 | 46 | targets = targets.to(device) 47 | outputs = model(inputs) 48 | 49 | output_mask = outputs[0].detach().cpu().numpy() 50 | output_final = np.argmax(output_mask, axis=1).astype(float) 51 | output_final = torch.from_numpy(output_final).unsqueeze(1) 52 | 53 | if train == "True": 54 | save_image(targets.float(), "Target_train",val_batch_size) 55 | save_image(output_final, "Prediction_train",val_batch_size) 56 | else: 57 | save_image(targets.float(), "Target", val_batch_size) 58 | save_image(output_final, "Prediction", val_batch_size) 59 | 60 | break 61 | 62 | 63 | def create_train_arg_parser(): 64 | 65 | parser = argparse.ArgumentParser(description="train setup for segmentation") 66 | parser.add_argument("--train_path", type=str, help="path to img tif files") 67 | parser.add_argument("--val_path", type=str, help="path to img tif files") 68 | parser.add_argument( 69 | "--model_type", 70 | type=str, 71 | help="select model type: bsinet", 72 | ) 73 | parser.add_argument("--object_type", type=str, help="Dataset.") 74 | parser.add_argument( 75 | "--distance_type", 76 | type=str, 77 | default="dist_contour", 78 | help="select distance transform type - dist_mask,dist_contour,dist_contour_tif", 79 | ) 80 | parser.add_argument("--batch_size", type=int, default=4, help="train batch size") 81 | parser.add_argument( 82 | "--val_batch_size", type=int, default=4, help="validation batch size" 83 | ) 84 | parser.add_argument("--num_epochs", type=int, default=150, help="number of epochs") 85 | parser.add_argument("--cuda_no", type=int, default=0, help="cuda number") 86 | parser.add_argument( 87 | "--use_pretrained", type=bool, default=False, help="Load pretrained checkpoint." 88 | ) 89 | parser.add_argument( 90 | "--pretrained_model_path", 91 | type=str, 92 | default=None, 93 | help="If use_pretrained is true, provide checkpoint.", 94 | ) 95 | parser.add_argument("--save_path", type=str, help="Model save path.") 96 | 97 | return parser 98 | 99 | 100 | def create_validation_arg_parser(): 101 | 102 | parser = argparse.ArgumentParser(description="train setup for segmentation") 103 | parser.add_argument( 104 | "--model_type", 105 | type=str, 106 | help="select model type: bsinet", 107 | ) 108 | parser.add_argument("--test_path", type=str, help="path to img tif files") 109 | parser.add_argument("--model_file", type=str, help="model_file") 110 | parser.add_argument("--save_path", type=str, help="results save path.") 111 | parser.add_argument("--cuda_no", type=int, default=0, help="cuda number") 112 | 113 | return parser 114 | 115 | 116 | --------------------------------------------------------------------------------