├── README.md
├── dataset.py
├── imgs
├── BsiNet.png
├── comparison_results.png
└── results.png
├── losses.py
├── models.py
├── preprocess.py
├── test.py
├── train.py
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 | # BsiNet
2 |
3 | Official Pytorch Code base for "Delineation of agricultural fields using multi-task BsiNet from high-resolution satellite images"
4 |
5 | [Project](https://github.com/long123524/BsiNet-torch)
6 |
7 | ## Introduction
8 |
9 | This paper presents a new multi-task neural network BsiNet to delineate agricultural fields from remote sensing images. BsiNet learns three tasks, i.e., a core task for agricultural field identification and two auxiliary tasks for field boundary prediction and distance estimation, corresponding to mask, boundary, and distance tasks, respectively.
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 | ## Using the code:
25 |
26 | The code is stable while using Python 3.7.0, CUDA >=11.0
27 |
28 | - Clone this repository:
29 | ```bash
30 | git clone https://github.com/long123524/BsiNet-torch
31 | cd BsiNet-torch
32 | ```
33 |
34 | To install all the dependencies using conda or pip:
35 |
36 | ```
37 | PyTorch
38 | TensorboardX
39 | OpenCV
40 | numpy
41 | tqdm
42 | ```
43 |
44 | ## Preprocessing
45 | Using the code preprocess.py to obtain contour and distance maps.
46 |
47 | ## Data Format
48 |
49 | Make sure to put the files as the following structure:
50 |
51 | ```
52 | inputs
53 | └──
54 | ├── image
55 | | ├── 001.tif
56 | │ ├── 002.tif
57 | │ ├── 003.tif
58 | │ ├── ...
59 | |
60 | └── mask
61 | | ├── 001.tif
62 | | ├── 002.tif
63 | | ├── 003.tif
64 | | ├── ...
65 | └── contour
66 | | ├── 001.tif
67 | | ├── 002.tif
68 | | ├── 003.tif
69 | | ├── ...
70 | └── dist_contour
71 | | ├── 001.tif
72 | | ├── 002.tif
73 | | ├── 003.tif
74 | └── ├── ...
75 | ```
76 |
77 | For test and validation datasets, the same structure as the above.
78 |
79 | ## Training and testing
80 |
81 | 1. Train the model.
82 | ```
83 | python train.py --train_path ./fields/image --save_path ./model --model_type 'bsinet' --distance_type 'dist_contour'
84 | ```
85 | 2. Evaluate.
86 | ```
87 | python test.py --model_file ./model/150.pt --save_path ./save --model_type 'bsinet' --distance_type 'dist_contour' --val_path ./test_image
88 | ```
89 |
90 | If you have any questions, you can contact us: Jiang long, hnzzyxlj@163.com and Mengmeng Li, mli@fzu.edu.cn.
91 |
92 | ## GF dataset
93 | A GF2 image (1m) is provided for scientific use: https://pan.baidu.com/s/1isg9jD9AlE9EeTqa3Fqrrg, password:bzfd
94 | Google drive:https://drive.google.com/file/d/1JZtRSxX5PaT3JCzvCLq2Jrt0CBXqZj7c/view?usp=drive_link
95 | A corresponding partial field label is provided for scientific study: https://drive.google.com/file/d/19OrVPkb0MkoaUvaax_9uvnJgSr_dcSSW/view?usp=sharing
96 |
97 | ## A pretrained weight
98 | A pretrained weight on a Xinjiang GF-2 image is provided: https://pan.baidu.com/s/1asAMj4_ZrIQeJiewP2LpqA password:rz8k
99 | Google drive: https://drive.google.com/drive/folders/121T8FjiyEsIbfyLUbrBXYCg75PIzCzRX?usp=sharing
100 |
101 | ### Acknowledgements:
102 |
103 | This code-base uses certain code-blocks and helper functions from Psi-Net
104 |
105 | ### Citation:
106 | If you find this work useful or interesting, please consider citing the following references.
107 | ```
108 | Citation 1:
109 | {Authors: Long Jiang (龙江), Li Mengmeng* (李蒙蒙), Wang Xiaoqin (汪小钦), et al;
110 | Institute: The Academy of Digital China (Fujian), Fuzhou University,
111 | Article Title: Delineation of agricultural fields using multi-task BsiNet from high-resolution satellite images,
112 | Publication: International Journal of Applied Earth Observation and Geoinformation,
113 | Year: 2022,
114 | Volume:112
115 | Page: 102871,
116 | DOI: 10.1016/j.jag.2022.102871
117 | }
118 | Citation 2:
119 | {Authors: Li Mengmeng* (李蒙蒙), Long Jiang (龙江), et al;
120 | Institute: The Academy of Digital China (Fujian), Fuzhou University,
121 | Article Title: Using a semantic edge-aware multi-task neural network to delineate agricultural parcels from remote sensing images,
122 | Publication: ISPRS Journal of Photogrammetry and Remote Sensing,
123 | Year: 2023,
124 | Volume:200
125 | Page: 24-40,
126 | DOI: 10.1016/j.isprsjprs.2023.04.019
127 | }
128 | Citation 3:
129 | {Authors: Long jiang (龙江), Zhao hang (赵航), Li Mengmeng* (李蒙蒙), et al;
130 | Institute: The Academy of Digital China (Fujian), Fuzhou University; Chinese Academy of Sciences
131 | Article Title: Integrating Segment Anything Model derived boundary prior and high-level semantics for cropland extraction from high-resolution remote sensing images,
132 | Publication: IEEE Geoscience and Remote Sensing Letters,
133 | Year: 2024,
134 | Volume:21,
135 | Page: 1-5,
136 | DOI: 10.1109/LGRS.2024.3454263
137 | }
138 | ...
139 | ```
140 | ### A large cropland dataset collected from VHR images:
141 | Will be accessible at https://github.com/NanNanmei/HBGNet, more details can be found at a recent collaborative paper "A large-scale VHR parcel dataset and a novel hierarchical semantic boundary-guided network for agricultural parcel delineation (https://www.sciencedirect.com/science/article/pii/S0924271625000395)"
142 | ### A parcel vectorization model:
143 | More details can be found at a recent collaborative paper "Extracting vectorized agricultural parcels from high-resolution satellite images using a Point-Line-Region interactive multitask model" published in the journal of Computers and Electronics in Agriculture. Code is available at https://github.com/mengmengli01/PLR-Net-demo/tree/main.
144 |
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | """
2 | The role of this file completes the data reading
3 | "dist_mask" is obtained by using Euclidean distance transformation on the mask
4 | "dist_contour" is obtained by using quasi-Euclidean distance transformation on the mask
5 | """
6 |
7 | import torch
8 | import numpy as np
9 | import cv2
10 | from PIL import Image, ImageFile
11 |
12 | from skimage import io
13 | import imageio
14 | from torch.utils.data import Dataset
15 | from torchvision import transforms
16 | from scipy import io
17 | import os
18 | from osgeo import gdal
19 |
20 | ### Reading and saving of remote sensing images (Keep coordinate information)
21 | def readTif(fileName, xoff = 0, yoff = 0, data_width = 0, data_height = 0):
22 | dataset = gdal.Open(fileName)
23 | if dataset == None:
24 | print(fileName + "文件无法打开")
25 | # 栅格矩阵的列数
26 | width = dataset.RasterXSize
27 | # 栅格矩阵的行数
28 | height = dataset.RasterYSize
29 | # 波段数
30 | bands = dataset.RasterCount
31 | # 获取数据
32 | if(data_width == 0 and data_height == 0):
33 | data_width = width
34 | data_height = height
35 | data = dataset.ReadAsArray(xoff, yoff, data_width, data_height)
36 | # 获取仿射矩阵信息
37 | geotrans = dataset.GetGeoTransform()
38 | # 获取投影信息
39 | proj = dataset.GetProjection()
40 | return width, height, bands, data, geotrans, proj
41 |
42 |
43 | #保存遥感影像
44 | def writeTiff(im_data, im_geotrans, im_proj, path):
45 | if 'int8' in im_data.dtype.name:
46 | datatype = gdal.GDT_Byte
47 | elif 'int16' in im_data.dtype.name:
48 | datatype = gdal.GDT_UInt16
49 | else:
50 | datatype = gdal.GDT_Float32
51 | if len(im_data.shape) == 3:
52 | im_bands, im_height, im_width = im_data.shape
53 | else:
54 | im_bands, (im_height, im_width) = 1, im_data.shape
55 | # 创建文件
56 | driver = gdal.GetDriverByName("GTiff")
57 | dataset = driver.Create(path, int(im_width), int(im_height), int(im_bands), datatype)
58 | if (dataset != None):
59 | dataset.SetGeoTransform(im_geotrans) # 写入仿射变换参数
60 | dataset.SetProjection(im_proj) # 写入投影
61 | if im_bands == 1:
62 | dataset.GetRasterBand(1).WriteArray(im_data)
63 | else:
64 | for i in range(im_bands):
65 | dataset.GetRasterBand(i + 1).WriteArray(im_data[i])
66 | del dataset
67 |
68 |
69 |
70 | class DatasetImageMaskContourDist(Dataset):
71 |
72 | def __init__(self, dir, file_names, distance_type):
73 |
74 | self.file_names = file_names
75 | self.distance_type = distance_type
76 | self.dir = dir
77 |
78 | def __len__(self):
79 |
80 | return len(self.file_names)
81 |
82 | def __getitem__(self, idx):
83 |
84 | img_file_name = self.file_names[idx]
85 | image = load_image(os.path.join(self.dir,img_file_name+'.tif'))
86 | mask = load_mask(os.path.join(self.dir,img_file_name+'.tif'))
87 | contour = load_contour(os.path.join(self.dir,img_file_name+'.tif'))
88 | dist = load_distance(os.path.join(self.dir,img_file_name+'.tif'), self.distance_type)
89 |
90 | return img_file_name, image, mask, contour, dist
91 |
92 |
93 | def load_image(path):
94 |
95 | img = Image.open(path)
96 | data_transforms = transforms.Compose(
97 | [
98 | # transforms.Resize(256),
99 | transforms.ToTensor(),
100 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
101 |
102 | ]
103 | )
104 | img = data_transforms(img)
105 |
106 | return img
107 |
108 |
109 | def load_mask(path):
110 | mask = cv2.imread(path.replace("image", "mask").replace("tif", "tif"), 0)
111 | # im_width, im_height, im_bands, mask, im_geotrans, im_proj = readTif(path.replace("image", "mask").replace("tif", "tif"))
112 | ###mask = mask/225.
113 | mask[mask == 255] = 1
114 | mask[mask == 0] = 0
115 |
116 | return torch.from_numpy(np.expand_dims(mask, 0)).long()
117 |
118 |
119 | def load_contour(path):
120 |
121 | contour = cv2.imread(path.replace("image", "contour").replace("tif", "tif"), 0)
122 | ###contour = contour/255.
123 | contour[contour ==255] = 1
124 | contour[contour == 0] = 0
125 |
126 |
127 | return torch.from_numpy(np.expand_dims(contour, 0)).long()
128 |
129 |
130 | def load_distance(path, distance_type):
131 |
132 | if distance_type == "dist_mask":
133 | path = path.replace("image", "dist_mask").replace("tif", "mat")
134 |
135 | dist = io.loadmat(path)["D2"]
136 |
137 | if distance_type == "dist_contour":
138 | path = path.replace("image", "dist_contour").replace("tif", "mat")
139 | dist = io.loadmat(path)["D2"]
140 |
141 | if distance_type == "dist_contour_tif":
142 | dist = cv2.imread(path.replace("image", "dist_contour_tif").replace("tif", "tif"), 0)
143 | dist = dist/255.
144 |
145 | return torch.from_numpy(np.expand_dims(dist, 0)).float()
146 |
--------------------------------------------------------------------------------
/imgs/BsiNet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/long123524/BsiNet-torch/dfe2b98d6ab04c2afe446769787e3476030a9b58/imgs/BsiNet.png
--------------------------------------------------------------------------------
/imgs/comparison_results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/long123524/BsiNet-torch/dfe2b98d6ab04c2afe446769787e3476030a9b58/imgs/comparison_results.png
--------------------------------------------------------------------------------
/imgs/results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/long123524/BsiNet-torch/dfe2b98d6ab04c2afe446769787e3476030a9b58/imgs/results.png
--------------------------------------------------------------------------------
/losses.py:
--------------------------------------------------------------------------------
1 | """Calculating the loss
2 | You can build the loss function of BsiNet by combining multiple losses
3 | """
4 |
5 | import torch
6 | import numpy as np
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 |
11 | def dice_loss(prediction, target):
12 | """Calculating the dice loss
13 | Args:
14 | prediction = predicted image
15 | target = Targeted image
16 | Output:
17 | dice_loss"""
18 |
19 | smooth = 1.0
20 |
21 | i_flat = prediction.view(-1)
22 | t_flat = target.view(-1)
23 |
24 | intersection = (i_flat * t_flat).sum()
25 |
26 | return 1 - ((2. * intersection + smooth) / (i_flat.sum() + t_flat.sum() + smooth))
27 |
28 |
29 | def calc_loss(prediction, target, bce_weight=0.5):
30 | """Calculating the loss and metrics
31 | Args:
32 | prediction = predicted image
33 | target = Targeted image
34 | metrics = Metrics printed
35 | bce_weight = 0.5 (default)
36 | Output:
37 | loss : dice loss of the epoch """
38 | bce = F.binary_cross_entropy_with_logits(prediction, target)
39 | prediction = torch.sigmoid(prediction)
40 | dice = dice_loss(prediction, target)
41 |
42 | loss = bce * bce_weight + dice * (1 - bce_weight)
43 |
44 | return loss
45 |
46 |
47 |
48 |
49 |
50 | class log_cosh_dice_loss(nn.Module):
51 | def __init__(self, num_classes=1, smooth=1, alpha=0.7):
52 | super(log_cosh_dice_loss, self).__init__()
53 | self.smooth = smooth
54 | self.alpha = alpha
55 | self.num_classes = num_classes
56 |
57 | def forward(self, outputs, targets):
58 | x = self.dice_loss(outputs, targets)
59 | return torch.log((torch.exp(x) + torch.exp(-x)) / 2.0)
60 |
61 | def dice_loss(self, y_pred, y_true):
62 | """[function to compute dice loss]
63 | Args:
64 | y_true ([float32]): [ground truth image]
65 | y_pred ([float32]): [predicted image]
66 | Returns:
67 | [float32]: [loss value]
68 | """
69 | smooth = 1.
70 | y_true = torch.flatten(y_true)
71 | y_pred = torch.flatten(y_pred)
72 | intersection = torch.sum((y_true * y_pred))
73 | coeff = (2. * intersection + smooth) / (torch.sum(y_true) + torch.sum(y_pred) + smooth)
74 | return (1. - coeff)
75 |
76 |
77 | def focal_loss(predict, label, alpha=0.6, beta=2):
78 | probs = torch.sigmoid(predict)
79 | # 交叉熵Loss
80 | ce_loss = nn.BCELoss()
81 | ce_loss = ce_loss(probs,label)
82 | alpha_ = torch.ones_like(predict) * alpha
83 | # 正label 为alpha, 负label为1-alpha
84 | alpha_ = torch.where(label > 0, alpha_, 1.0 - alpha_)
85 | probs_ = torch.where(label > 0, probs, 1.0 - probs)
86 | # loss weight matrix
87 | loss_matrix = alpha_ * torch.pow((1.0 - probs_), beta)
88 | # 最终loss 矩阵,为对应的权重与loss值相乘,控制预测越不准的产生更大的loss
89 | loss = loss_matrix * ce_loss
90 | loss = torch.sum(loss)
91 | return loss
92 |
93 |
94 |
95 | class Loss:
96 | def __init__(self, dice_weight=0.0, class_weights=None, num_classes=1, device=None):
97 | self.device = device
98 | if class_weights is not None:
99 | nll_weight = torch.from_numpy(class_weights.astype(np.float32)).to(
100 | self.device
101 | )
102 | else:
103 | nll_weight = None
104 | self.nll_loss = nn.NLLLoss2d(weight=nll_weight)
105 | self.dice_weight = dice_weight
106 | self.num_classes = num_classes
107 |
108 | def __call__(self, outputs, targets):
109 | loss = self.nll_loss(outputs, targets)
110 | if self.dice_weight:
111 | eps = 1e-7
112 | cls_weight = self.dice_weight / self.num_classes
113 | for cls in range(self.num_classes):
114 | dice_target = (targets == cls).float()
115 | dice_output = outputs[:, cls].exp()
116 | intersection = (dice_output * dice_target).sum()
117 | # union without intersection
118 | uwi = dice_output.sum() + dice_target.sum() + eps
119 | loss += (1 - intersection / uwi) * cls_weight
120 | loss /= (1 + self.dice_weight)
121 | return loss
122 |
123 |
124 | class LossMulti:
125 | def __init__(
126 | self, jaccard_weight=0.0, class_weights=None, num_classes=1, device=None
127 | ):
128 | self.device = device
129 | if class_weights is not None:
130 | nll_weight = torch.from_numpy(class_weights.astype(np.float32)).to(
131 | self.device
132 | )
133 | else:
134 | nll_weight = None
135 |
136 | self.nll_loss = nn.NLLLoss(weight=nll_weight)
137 | self.jaccard_weight = jaccard_weight
138 | self.num_classes = num_classes
139 |
140 | def __call__(self, outputs, targets):
141 |
142 | targets = targets.squeeze(1)
143 |
144 | loss = (1 - self.jaccard_weight) * self.nll_loss(outputs, targets)
145 |
146 | if self.jaccard_weight:
147 | eps = 1e-7 # 原先是1e-7
148 | for cls in range(self.num_classes):
149 | jaccard_target = (targets == cls).float()
150 | jaccard_output = outputs[:, cls].exp()
151 | intersection = (jaccard_output * jaccard_target).sum()
152 |
153 | union = jaccard_output.sum() + jaccard_target.sum()
154 | loss -= (
155 | torch.log((intersection + eps) / (union - intersection + eps))
156 | * self.jaccard_weight
157 | )
158 | return loss
159 |
160 |
161 | class LossBsiNet:
162 | def __init__(self, weights=[1, 1, 1]):
163 | self.criterion1 = LossMulti(num_classes=2) #mask_loss
164 | self.criterion2 = LossMulti(num_classes=2) #contour_loss
165 | self.criterion3 = nn.MSELoss() ##distance_loss
166 | self.weights = weights
167 |
168 | def __call__(self, outputs1, outputs2, outputs3, targets1, targets2, targets3):
169 | #
170 | criterion = (
171 | self.weights[0] * self.criterion1(outputs1, targets1)
172 | + self.weights[1] * self.criterion2(outputs2, targets2)
173 | + self.weights[2] * self.criterion3(outputs3, targets3)
174 | )
175 |
176 | return criterion
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | """Model construction
2 | 1. We offer two versions of BsiNet, one concise and the other clear
3 | 2. The clear version is designed for user understanding and modification
4 | 3. You can use these attention mechanism we provide to bulid a new multi-task model, and you can also
5 | 4. You can also add your own module or change the location of the attention mechanism to build a better model
6 | """
7 |
8 |
9 | from torch import nn
10 | import torch
11 | from torch.nn import functional as F
12 | from torch.nn.parameter import Parameter
13 |
14 |
15 | def conv3x3(in_, out):
16 | return nn.Conv2d(in_, out, 3, padding=1)
17 |
18 |
19 | class Conv3BN(nn.Module):
20 | def __init__(self, in_: int, out: int, bn=False):
21 | super().__init__()
22 | self.conv = conv3x3(in_, out)
23 | self.bn = nn.BatchNorm2d(out) if bn else None
24 | self.activation = nn.ReLU(inplace=True)
25 |
26 | def forward(self, x):
27 | x = self.conv(x)
28 | if self.bn is not None:
29 | x = self.bn(x)
30 | x = self.activation(x)
31 | return x
32 |
33 |
34 | class NetModule(nn.Module):
35 | def __init__(self, in_: int, out: int):
36 | super().__init__()
37 | self.l1 = Conv3BN(in_, out)
38 | self.l2 = Conv3BN(out, out)
39 |
40 | def forward(self, x):
41 | x = self.l1(x)
42 | x = self.l2(x)
43 | return x
44 |
45 |
46 | #SE注意力机制
47 | class SELayer(nn.Module):
48 | def __init__(self, channel, reduction=16):
49 | super(SELayer, self).__init__()
50 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
51 | self.fc = nn.Sequential(
52 | nn.Linear(channel, channel // reduction, bias=False),
53 | nn.ReLU(inplace=True),
54 | nn.Linear(channel // reduction, channel, bias=False),
55 | nn.Sigmoid()
56 | )
57 |
58 | def forward(self, x):
59 | b, c, _, _ = x.size()
60 | y = self.avg_pool(x).view(b, c)
61 | y = self.fc(y).view(b, c, 1, 1)
62 | return x * y.expand_as(x)
63 |
64 |
65 |
66 | class SpatialGroupEnhance(nn.Module):
67 | def __init__(self, groups = 64):
68 | super(SpatialGroupEnhance, self).__init__()
69 | self.groups = groups
70 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
71 | self.weight = Parameter(torch.zeros(1, groups, 1, 1))
72 | self.bias = Parameter(torch.ones(1, groups, 1, 1))
73 | self.sig = nn.Sigmoid()
74 |
75 | def forward(self, x): # (b, c, h, w)
76 | b, c, h, w = x.size()
77 | x = x.view(b * self.groups, -1, h, w)
78 | xn = x * self.avg_pool(x)
79 | xn = xn.sum(dim=1, keepdim=True)
80 | t = xn.view(b * self.groups, -1)
81 | t = t - t.mean(dim=1, keepdim=True)
82 | std = t.std(dim=1, keepdim=True) + 1e-5
83 | t = t / std
84 | t = t.view(b, self.groups, h, w)
85 | t = t * self.weight + self.bias
86 | t = t.view(b * self.groups, 1, h, w)
87 | x = x * self.sig(t)
88 | x = x.view(b, c, h, w)
89 | return x
90 |
91 | ######CBAM注意力
92 | class ChannelAttention(nn.Module):
93 | def __init__(self, in_planes, ratio=16):
94 | super(ChannelAttention, self).__init__()
95 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
96 | self.max_pool = nn.AdaptiveMaxPool2d(1)
97 |
98 | self.fc1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
99 | self.relu1 = nn.ReLU()
100 | self.fc2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
101 |
102 | self.sigmoid = nn.Sigmoid()
103 |
104 | def forward(self, x):
105 | avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
106 | max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
107 | out = avg_out + max_out
108 | return self.sigmoid(out)
109 |
110 | class SpatialAttention(nn.Module):
111 | def __init__(self, kernel_size=7):
112 | super(SpatialAttention, self).__init__()
113 |
114 | assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
115 | padding = 3 if kernel_size == 7 else 1
116 |
117 | self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
118 | self.sigmoid = nn.Sigmoid()
119 |
120 | def forward(self, x):
121 | avg_out = torch.mean(x, dim=1, keepdim=True)
122 | max_out, _ = torch.max(x, dim=1, keepdim=True)
123 | x = torch.cat([avg_out, max_out], dim=1)
124 | x = self.conv1(x)
125 | return self.sigmoid(x)
126 |
127 |
128 |
129 | #scce注意力模块
130 | class cSE(nn.Module): # noqa: N801
131 | """
132 | The channel-wise SE (Squeeze and Excitation) block from the
133 | `Squeeze-and-Excitation Networks`__ paper.
134 | Adapted from
135 | https://www.kaggle.com/c/tgs-salt-identification-challenge/discussion/65939
136 | and
137 | https://www.kaggle.com/c/tgs-salt-identification-challenge/discussion/66178
138 | Shape:
139 | - Input: (batch, channels, height, width)
140 | - Output: (batch, channels, height, width) (same shape as input)
141 | __ https://arxiv.org/abs/1709.01507
142 | """
143 |
144 | def __init__(self, in_channels: int, r: int = 16):
145 | """
146 | Args:
147 | in_channels: The number of channels
148 | in the feature map of the input.
149 | r: The reduction ratio of the intermediate channels.
150 | Default: 16.
151 | """
152 | super().__init__()
153 | self.linear1 = nn.Linear(in_channels, in_channels // r)
154 | self.linear2 = nn.Linear(in_channels // r, in_channels)
155 |
156 | def forward(self, x: torch.Tensor):
157 | """Forward call."""
158 | input_x = x
159 |
160 | x = x.view(*(x.shape[:-2]), -1).mean(-1)
161 | x = F.relu(self.linear1(x), inplace=True)
162 | x = self.linear2(x)
163 | x = x.unsqueeze(-1).unsqueeze(-1)
164 | x = torch.sigmoid(x)
165 |
166 | x = torch.mul(input_x, x)
167 | return x
168 |
169 |
170 | class sSE(nn.Module): # noqa: N801
171 | """
172 | The sSE (Channel Squeeze and Spatial Excitation) block from the
173 | `Concurrent Spatial and Channel ‘Squeeze & Excitation’
174 | in Fully Convolutional Networks`__ paper.
175 | Adapted from
176 | https://www.kaggle.com/c/tgs-salt-identification-challenge/discussion/66178
177 | Shape:
178 | - Input: (batch, channels, height, width)
179 | - Output: (batch, channels, height, width) (same shape as input)
180 | __ https://arxiv.org/abs/1803.02579
181 | """
182 |
183 | def __init__(self, in_channels: int):
184 | """
185 | Args:
186 | in_channels: The number of channels
187 | in the feature map of the input.
188 | """
189 | super().__init__()
190 | self.conv = nn.Conv2d(in_channels, 1, kernel_size=1, stride=1)
191 |
192 | def forward(self, x: torch.Tensor):
193 | """Forward call."""
194 | input_x = x
195 |
196 | x = self.conv(x)
197 | x = torch.sigmoid(x)
198 |
199 | x = torch.mul(input_x, x)
200 | return x
201 |
202 |
203 | class scSE(nn.Module): # noqa: N801
204 | """
205 | The scSE (Concurrent Spatial and Channel Squeeze and Channel Excitation)
206 | block from the `Concurrent Spatial and Channel ‘Squeeze & Excitation’
207 | in Fully Convolutional Networks`__ paper.
208 | Adapted from
209 | https://www.kaggle.com/c/tgs-salt-identification-challenge/discussion/66178
210 | Shape:
211 | - Input: (batch, channels, height, width)
212 | - Output: (batch, channels, height, width) (same shape as input)
213 | __ https://arxiv.org/abs/1803.02579
214 | """
215 |
216 | def __init__(self, in_channels: int, r: int = 16):
217 | """
218 | Args:
219 | in_channels: The number of channels
220 | in the feature map of the input.
221 | r: The reduction ratio of the intermediate channels.
222 | Default: 16.
223 | """
224 | super().__init__()
225 | self.cse_block = cSE(in_channels, r)
226 | self.sse_block = sSE(in_channels)
227 |
228 | def forward(self, x: torch.Tensor):
229 | """Forward call."""
230 | cse = self.cse_block(x)
231 | sse = self.sse_block(x)
232 | x = torch.add(cse, sse)
233 | return x
234 |
235 |
236 | ##This is a concise version of the BsiNet whose modules are better packaged
237 |
238 | class BsiNet(nn.Module):
239 |
240 | output_downscaled = 1
241 | module = NetModule
242 |
243 | def __init__(
244 | self,
245 | input_channels: int = 3,
246 | filters_base: int = 32,
247 | down_filter_factors=(1, 2, 4, 8, 16),
248 | up_filter_factors=(1, 2, 4, 8, 16),
249 | bottom_s=4,
250 | num_classes=1,
251 | add_output=True,
252 | ):
253 | super().__init__()
254 | self.num_classes = num_classes
255 | assert len(down_filter_factors) == len(up_filter_factors)
256 | assert down_filter_factors[-1] == up_filter_factors[-1]
257 | down_filter_sizes = [filters_base * s for s in down_filter_factors]
258 | up_filter_sizes = [filters_base * s for s in up_filter_factors]
259 | self.down, self.up = nn.ModuleList(), nn.ModuleList()
260 | self.down.append(self.module(input_channels, down_filter_sizes[0]))
261 | for prev_i, nf in enumerate(down_filter_sizes[1:]):
262 | self.down.append(self.module(down_filter_sizes[prev_i], nf))
263 | for prev_i, nf in enumerate(up_filter_sizes[1:]):
264 | self.up.append(
265 | self.module(down_filter_sizes[prev_i] + nf, up_filter_sizes[prev_i])
266 | )
267 |
268 | pool = nn.MaxPool2d(2, 2)
269 | pool_bottom = nn.MaxPool2d(bottom_s, bottom_s)
270 | upsample = nn.Upsample(scale_factor=2)
271 | upsample_bottom = nn.Upsample(scale_factor=bottom_s)
272 | self.downsamplers = [None] + [pool] * (len(self.down) - 1)
273 | self.downsamplers[-1] = pool_bottom
274 | self.upsamplers = [upsample] * len(self.up)
275 | self.upsamplers[-1] = upsample_bottom
276 | self.add_output = add_output
277 | self.sge = SpatialGroupEnhance(32)
278 |
279 | if add_output:
280 | self.conv_final1 = nn.Conv2d(up_filter_sizes[0], num_classes, 1)
281 | self.conv_final2 = nn.Conv2d(up_filter_sizes[0], num_classes, 1)
282 | self.conv_final3 = nn.Conv2d(up_filter_sizes[0], 1, 1)
283 |
284 | def forward(self, x):
285 | xs = []
286 | for downsample, down in zip(self.downsamplers, self.down):
287 | x_in = x if downsample is None else downsample(xs[-1])
288 | x_out = down(x_in)
289 | xs.append(x_out)
290 |
291 | for x_skip, upsample, up in reversed(
292 | list(zip(xs[:-1], self.upsamplers, self.up))
293 | ):
294 |
295 | x_out2 = upsample(x_out)
296 | x_out= (torch.cat([x_out2, x_skip], 1))
297 | x_out = up(x_out)
298 |
299 | if self.add_output:
300 |
301 | x_out = self.sge(x_out)
302 |
303 | x_out1 = self.conv_final1(x_out)
304 | x_out2 = self.conv_final2(x_out)
305 | x_out3 = self.conv_final3(x_out)
306 | if self.num_classes > 1:
307 | x_out1 = F.log_softmax(x_out1,dim=1)
308 | x_out2 = F.log_softmax(x_out2,dim=1)
309 | x_out3 = torch.sigmoid(x_out3)
310 |
311 | return [x_out1, x_out2, x_out3]
312 |
313 |
314 |
315 |
316 | ##This is a clearer BsiNet which shows a clearer building process
317 |
318 | class BsiNet_2(nn.Module):
319 | def __init__(
320 | self,
321 | input_channels: int = 3,
322 | filters_base: int = 32,
323 | num_classes=1,
324 | add_output=True,
325 | ):
326 | super().__init__()
327 | self.num_classes = num_classes
328 | self.add_output = add_output
329 | self.conv1 = NetModule(input_channels, 32)
330 | self.conv2 = NetModule(32, 64)
331 | self.conv3 = NetModule(64, 128)
332 | self.conv4 = NetModule(128, 256)
333 | self.conv5 = NetModule(256, 512)
334 |
335 | self.conv6 = NetModule(768, 256)
336 | self.conv7 = NetModule(384, 128)
337 | self.conv8 = NetModule(192, 64)
338 | self.conv9 = NetModule(96, 32)
339 |
340 | self.pool1 = nn.MaxPool2d(2, 2)
341 | self.pool2 = nn.MaxPool2d(4, 4)
342 | self.upsample1 = nn.Upsample(scale_factor=2)
343 | self.upsample2 = nn.Upsample(scale_factor=4)
344 | self.sge = SpatialGroupEnhance(32)
345 | if add_output:
346 | self.conv_final1 = nn.Conv2d(filters_base, num_classes, 1)
347 | self.conv_final2 = nn.Conv2d(filters_base, num_classes, 1)
348 | self.conv_final3 = nn.Conv2d(filters_base, 1, 1)
349 |
350 | def forward(self, x):
351 | x1 = self.conv1(x)
352 |
353 | x2 = self.conv2(x1)
354 | x2 = self.pool1(x2)
355 |
356 | x3 = self.conv3(x2)
357 | x3 = self.pool1(x3)
358 |
359 | x4 = self.conv4(x3)
360 | x4 = self.pool1(x4)
361 |
362 | x5 = self.conv5(x4)
363 | x5 = self.pool2(x5)
364 |
365 | x_6 = self.upsample2(x5)
366 | x6 = self.conv6(torch.cat([x_6, x4], 1))
367 | x6 = self.upsample1(x6)
368 |
369 | x7 = self.conv7(torch.cat([x6, x3], 1))
370 | x7 = self.upsample1(x7)
371 |
372 | x8 = self.conv8(torch.cat([x7, x2], 1))
373 | x8 = self.upsample1(x8)
374 |
375 | x9 = self.conv9(torch.cat([x8, x1], 1))
376 | x_out = self.sge(x9)
377 |
378 | if self.add_output:
379 |
380 | x_out1 = self.conv_final1(x_out)
381 | x_out2 = self.conv_final2(x_out)
382 | x_out3 = self.conv_final3(x_out)
383 | if self.num_classes > 1:
384 | x_out1 = F.log_softmax(x_out1, dim=1)
385 | x_out2 = F.log_softmax(x_out2, dim=1)
386 | x_out3 = torch.sigmoid(x_out3)
387 |
388 | return [x_out1, x_out2, x_out3]
389 |
390 |
391 |
392 |
393 |
--------------------------------------------------------------------------------
/preprocess.py:
--------------------------------------------------------------------------------
1 | ## Example: A simple example to obtain distsance map and boundary map
2 | import numpy as np
3 | import os
4 | import cv2
5 | from osgeo import gdal
6 | import scipy.ndimage as sn
7 |
8 | def read_img(filename):
9 | dataset=gdal.Open(filename)
10 |
11 | im_width = dataset.RasterXSize
12 | im_height = dataset.RasterYSize
13 |
14 | im_geotrans = dataset.GetGeoTransform()
15 | im_proj = dataset.GetProjection()
16 | im_data = dataset.ReadAsArray(0,0,im_width,im_height)
17 |
18 | del dataset
19 | return im_proj, im_geotrans, im_width, im_height, im_data
20 |
21 |
22 | def write_img(filename, im_proj, im_geotrans, im_data):
23 | if 'int8' in im_data.dtype.name:
24 | datatype = gdal.GDT_Byte
25 | elif 'int16' in im_data.dtype.name:
26 | datatype = gdal.GDT_UInt16
27 | else:
28 | datatype = gdal.GDT_Float32
29 |
30 | if len(im_data.shape) == 3:
31 | im_bands, im_height, im_width = im_data.shape
32 | else:
33 | im_bands, (im_height, im_width) = 1,im_data.shape
34 |
35 | driver = gdal.GetDriverByName("GTiff")
36 | dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)
37 |
38 | dataset.SetGeoTransform(im_geotrans)
39 | dataset.SetProjection(im_proj)
40 |
41 | if im_bands == 1:
42 | dataset.GetRasterBand(1).WriteArray(im_data)
43 | else:
44 | for i in range(im_bands):
45 | dataset.GetRasterBand(i+1).WriteArray(im_data[i])
46 |
47 | del dataset
48 |
49 |
50 |
51 | maskRoot = r"C:\Users\hnzzy\Desktop\mask"
52 | distRoot = r"C:\Users\hnzzy\Desktop\dist"
53 | boundaryRoot = r"C:\Users\hnzzy\Desktop\boundary"
54 |
55 | for imgPath in os.listdir(maskRoot):
56 | input_path = os.path.join(maskRoot, imgPath)
57 | boundaryOutPath = os.path.join(boundaryRoot, imgPath)
58 | distOutPath = os.path.join(distRoot, imgPath)
59 | im_proj, im_geotrans, im_width, im_height, im_data = read_img(input_path)
60 | result = cv2.distanceTransform(src=im_data, distanceType=cv2.DIST_L2, maskSize=3)
61 | min_value = np.min(result)
62 | max_value = np.max(result)
63 | scaled_image = ((result - min_value) / (max_value - min_value)) * 255
64 | result = scaled_image.astype(np.uint8)
65 | # result = result.astype(np.uint8)
66 | write_img(distOutPath, im_proj, im_geotrans, result)
67 | ##distance map(you can also use bwdist function in Matlab to obtain distance map)
68 | ###boundary(you can also use bwperim function in Matlab to obtain boundary map)
69 | boundary = cv2.Canny(im_data, 100, 200)
70 | ## dilation
71 | # kernel = np.ones((3, 3), np.uint8)
72 | # boundary = cv2.dilate(boundary, kernel, iterations=1)
73 | write_img(boundaryOutPath, im_proj, im_geotrans, boundary)
74 |
75 |
76 |
77 |
78 |
79 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | from torch.utils.data import DataLoader
4 | from dataset import DatasetImageMaskContourDist
5 | from models import BsiNet
6 | from tqdm import tqdm
7 | import numpy as np
8 | import cv2
9 | from utils import create_validation_arg_parser
10 | from torch import nn
11 |
12 | def build_model(model_type):
13 |
14 | if model_type == "bsinet":
15 | model = BsiNet(num_classes=2)
16 |
17 | return model
18 |
19 |
20 | if __name__ == "__main__":
21 |
22 | args = create_validation_arg_parser().parse_args()
23 |
24 | args.model_file = './bsi/150.pt'
25 | args.save_path = './save'
26 | args.model_type = 'bsinet'
27 | args.distance_type = 'dist_contour'
28 | args.test_path = './test'
29 |
30 |
31 | test_path = args.test_path + '/' + 'image'
32 | model_file = args.model_file
33 | save_path = args.save_path
34 | model_type = args.model_type
35 |
36 | cuda_no = args.cuda_no
37 | CUDA_SELECT = "cuda:{}".format(cuda_no)
38 | device = torch.device(CUDA_SELECT if torch.cuda.is_available() else "cpu")
39 |
40 | img_name = []
41 | for img_file in os.listdir(test_path):
42 | img_name.append(img_file[:-4])
43 | valLoader = DataLoader(DatasetImageMaskContourDist(test_path, img_name,args.distance_type))
44 |
45 | if not os.path.exists(save_path):
46 | os.mkdir(save_path)
47 |
48 | model = build_model(model_type)
49 | model = nn.DataParallel(model)
50 | model = model.to(device)
51 | model.load_state_dict(torch.load(model_file))
52 | model.eval()
53 |
54 | for i, (img_file_name, inputs, targets1, targets2, targets3) in enumerate(
55 | tqdm(valLoader)
56 | ):
57 |
58 | inputs = inputs.to(device)
59 | outputs1, outputs2, outputs3 = model(inputs)
60 |
61 | ## TTA
62 | # outputs4, outputs5, outputs6 = model(torch.flip(inputs, [-1]))
63 | # predict_2 = torch.flip(outputs4, [-1])
64 | # outputs7, outputs8, outputs9 = model(torch.flip(inputs, [-2]))
65 | # predict_3 = torch.flip(outputs7, [-2])
66 | # outputs10, outputs11, outputs12 = model(torch.flip(inputs, [-1, -2]))
67 | # predict_4 = torch.flip(outputs10, [-1, -2])
68 | # predict_list = outputs1 + predict_2 + predict_3 + predict_4
69 | # pred1 = predict_list/4.0
70 |
71 | outputs1 = outputs1.detach().cpu().numpy().squeeze()
72 |
73 | res = np.zeros((256, 256))
74 | indices = np.argmax(outputs1, axis=0)
75 | res[indices == 1] = 255
76 | res[indices == 0] = 0
77 | res = np.array(res, dtype='uint8') # 转变为8字节型
78 | output_path = os.path.join(
79 | save_path, img_file_name[0]+'.tif'
80 | )
81 | cv2.imwrite(output_path, res)
82 |
83 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import logging
3 | import os
4 | import random
5 | import torch
6 | from dataset import DatasetImageMaskContourDist
7 | from losses import LossBsiNet
8 | from models import BsiNet
9 | from tensorboardX import SummaryWriter
10 | from torch import nn
11 | from torch.utils.data import DataLoader
12 | from tqdm import tqdm
13 | from utils import visualize, create_train_arg_parser,evaluate
14 | # from torchsummary import summary
15 | from sklearn.model_selection import train_test_split
16 |
17 | def define_loss(loss_type, weights=[1, 1, 1]):
18 |
19 | if loss_type == "bsinet":
20 | criterion = LossBsiNet(weights)
21 |
22 | return criterion
23 |
24 |
25 | def build_model(model_type):
26 |
27 | if model_type == "bsinet":
28 | model = BsiNet(num_classes=2)
29 |
30 | return model
31 |
32 |
33 | def train_model(model, targets, model_type, criterion, optimizer):
34 |
35 | if model_type == "bsinet":
36 |
37 | optimizer.zero_grad()
38 |
39 | with torch.set_grad_enabled(True):
40 | outputs = model(inputs)
41 | loss = criterion(
42 | outputs[0], outputs[1], outputs[2], targets[0], targets[1], targets[2]
43 | )
44 | loss.backward()
45 | optimizer.step()
46 |
47 | return loss
48 |
49 |
50 | if __name__ == "__main__":
51 |
52 | args = create_train_arg_parser().parse_args()
53 |
54 | args.distance_type = 'dist_contour'
55 | # args.pretrained_model_path = './best_merge_model_article/85.pt'
56 |
57 | args.train_path = './train/image/'
58 | # args.val_path = './XJ_goole/test/image/'
59 | args.model_type = 'bsinet'
60 | args.save_path = './model'
61 |
62 | CUDA_SELECT = "cuda:{}".format(args.cuda_no)
63 | log_path = args.save_path + "/summary"
64 | writer = SummaryWriter(log_dir=log_path)
65 |
66 | logging.basicConfig(
67 | filename="".format(args.object_type),
68 | filemode="a",
69 | format="%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s",
70 | datefmt="%Y-%m-%d %H:%M",
71 | level=logging.INFO,
72 | )
73 | logging.info("")
74 |
75 | # train_file_names = glob.glob(os.path.join(args.train_path, "*.tif"))
76 | # random.shuffle(train_file_names)
77 | # val_file_names = glob.glob(os.path.join(args.val_path, "*.tif"))
78 |
79 | train_file_names = glob.glob(os.path.join(args.train_path, "*.tif"))
80 | random.shuffle(train_file_names)
81 |
82 | img_ids = [os.path.splitext(os.path.basename(p))[0] for p in train_file_names]
83 | train_file, val_file = train_test_split(img_ids, test_size=0.2, random_state=41)
84 |
85 | device = torch.device(CUDA_SELECT if torch.cuda.is_available() else "cpu")
86 | print(device)
87 | model = build_model(args.model_type)
88 |
89 | if torch.cuda.device_count() > 0: #本来是0
90 | print("Let's use", torch.cuda.device_count(), "GPUs!")
91 | model = nn.DataParallel(model)
92 |
93 | model = model.to(device)
94 | # summary(model, input_size=(3, 256, 256))
95 |
96 | epoch_start = "0"
97 | if args.use_pretrained:
98 | print("Loading Model {}".format(os.path.basename(args.pretrained_model_path)))
99 | model.load_state_dict(torch.load(args.pretrained_model_path)) #加了False
100 | epoch_start = os.path.basename(args.pretrained_model_path).split(".")[0]
101 | print(epoch_start)
102 | print('train',args.use_pretrained)
103 | trainLoader = DataLoader(
104 | DatasetImageMaskContourDist(args.train_path,train_file, args.distance_type),
105 | batch_size=args.batch_size,drop_last=False, shuffle=True
106 | )
107 | devLoader = DataLoader(
108 | DatasetImageMaskContourDist(args.train_path,val_file, args.distance_type),drop_last=True,
109 | )
110 | displayLoader = DataLoader(
111 | DatasetImageMaskContourDist(args.train_path,val_file, args.distance_type),
112 | batch_size=args.val_batch_size,drop_last=True, shuffle=True
113 | )
114 |
115 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
116 | # optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
117 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, int(1e10), eta_min=1e-5)
118 | # scheduler = optim.lr_scheduler.StepLR(optimizer, 50, 0.1) #新加的
119 | criterion = define_loss(args.model_type)
120 |
121 |
122 | for epoch in tqdm(
123 | range(int(epoch_start) + 1, int(epoch_start) + 1 + args.num_epochs)
124 | ):
125 |
126 | global_step = epoch * len(trainLoader)
127 | running_loss = 0.0
128 |
129 | for i, (img_file_name, inputs, targets1, targets2,targets3) in enumerate(
130 | tqdm(trainLoader)
131 | ):
132 |
133 | model.train()
134 |
135 | inputs = inputs.to(device)
136 | targets1 = targets1.to(device)
137 | targets2 = targets2.to(device)
138 | targets3 = targets3.to(device)
139 |
140 | targets = [targets1, targets2,targets3]
141 |
142 |
143 | loss = train_model(model, targets, args.model_type, criterion, optimizer)
144 |
145 | writer.add_scalar("loss", loss.item(), epoch)
146 |
147 | running_loss += loss.item() * inputs.size(0)
148 | scheduler.step()
149 |
150 | epoch_loss = running_loss / len(train_file_names)
151 | print(epoch_loss)
152 |
153 | if epoch % 1 == 0:
154 |
155 | dev_loss, dev_time = evaluate(device, epoch, model, devLoader, writer)
156 | writer.add_scalar("loss_valid", dev_loss, epoch)
157 | visualize(device, epoch, model, displayLoader, writer, args.val_batch_size)
158 | print("Global Loss:{} Val Loss:{}".format(epoch_loss, dev_loss))
159 | else:
160 | print("Global Loss:{} ".format(epoch_loss))
161 |
162 | logging.info("epoch:{} train_loss:{} ".format(epoch, epoch_loss))
163 | if epoch % 5 == 0:
164 | torch.save(
165 | model.state_dict(), os.path.join(args.save_path, str(epoch) + ".pt")
166 | )
167 |
168 |
169 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from tqdm import tqdm
3 | import numpy as np
4 | import torchvision
5 | from torch.nn import functional as F
6 | import time
7 | import argparse
8 |
9 |
10 | def evaluate(device, epoch, model, data_loader, writer):
11 | model.eval()
12 | losses = []
13 | start = time.perf_counter()
14 | with torch.no_grad():
15 |
16 | for iter, data in enumerate(tqdm(data_loader)):
17 |
18 | _, inputs, targets, _,_ = data
19 | inputs = inputs.to(device)
20 | targets = targets.to(device)
21 | outputs = model(inputs)
22 | loss = F.nll_loss(outputs[0], targets.squeeze(1))
23 | losses.append(loss.item())
24 |
25 | writer.add_scalar("Dev_Loss", np.mean(losses), epoch)
26 |
27 | return np.mean(losses), time.perf_counter() - start
28 |
29 |
30 | def visualize(device, epoch, model, data_loader, writer, val_batch_size, train=True):
31 | def save_image(image, tag, val_batch_size):
32 | image -= image.min()
33 | image /= image.max()
34 | grid = torchvision.utils.make_grid(
35 | image, nrow=int(np.sqrt(val_batch_size)), pad_value=0, padding=25
36 | )
37 | writer.add_image(tag, grid, epoch)
38 |
39 | model.eval()
40 | with torch.no_grad():
41 | for iter, data in enumerate(tqdm(data_loader)):
42 | _, inputs, targets, _,_ = data
43 |
44 | inputs = inputs.to(device)
45 |
46 | targets = targets.to(device)
47 | outputs = model(inputs)
48 |
49 | output_mask = outputs[0].detach().cpu().numpy()
50 | output_final = np.argmax(output_mask, axis=1).astype(float)
51 | output_final = torch.from_numpy(output_final).unsqueeze(1)
52 |
53 | if train == "True":
54 | save_image(targets.float(), "Target_train",val_batch_size)
55 | save_image(output_final, "Prediction_train",val_batch_size)
56 | else:
57 | save_image(targets.float(), "Target", val_batch_size)
58 | save_image(output_final, "Prediction", val_batch_size)
59 |
60 | break
61 |
62 |
63 | def create_train_arg_parser():
64 |
65 | parser = argparse.ArgumentParser(description="train setup for segmentation")
66 | parser.add_argument("--train_path", type=str, help="path to img tif files")
67 | parser.add_argument("--val_path", type=str, help="path to img tif files")
68 | parser.add_argument(
69 | "--model_type",
70 | type=str,
71 | help="select model type: bsinet",
72 | )
73 | parser.add_argument("--object_type", type=str, help="Dataset.")
74 | parser.add_argument(
75 | "--distance_type",
76 | type=str,
77 | default="dist_contour",
78 | help="select distance transform type - dist_mask,dist_contour,dist_contour_tif",
79 | )
80 | parser.add_argument("--batch_size", type=int, default=4, help="train batch size")
81 | parser.add_argument(
82 | "--val_batch_size", type=int, default=4, help="validation batch size"
83 | )
84 | parser.add_argument("--num_epochs", type=int, default=150, help="number of epochs")
85 | parser.add_argument("--cuda_no", type=int, default=0, help="cuda number")
86 | parser.add_argument(
87 | "--use_pretrained", type=bool, default=False, help="Load pretrained checkpoint."
88 | )
89 | parser.add_argument(
90 | "--pretrained_model_path",
91 | type=str,
92 | default=None,
93 | help="If use_pretrained is true, provide checkpoint.",
94 | )
95 | parser.add_argument("--save_path", type=str, help="Model save path.")
96 |
97 | return parser
98 |
99 |
100 | def create_validation_arg_parser():
101 |
102 | parser = argparse.ArgumentParser(description="train setup for segmentation")
103 | parser.add_argument(
104 | "--model_type",
105 | type=str,
106 | help="select model type: bsinet",
107 | )
108 | parser.add_argument("--test_path", type=str, help="path to img tif files")
109 | parser.add_argument("--model_file", type=str, help="model_file")
110 | parser.add_argument("--save_path", type=str, help="results save path.")
111 | parser.add_argument("--cuda_no", type=int, default=0, help="cuda number")
112 |
113 | return parser
114 |
115 |
116 |
--------------------------------------------------------------------------------