├── Data_Enhance.py
├── README.md
├── data_remove.py
├── dataset.py
├── gdalUtil.py
├── imgs
├── FHAPD.jpg
└── HBGNet.jpg
├── lib
└── pvtv2.py
├── losses.py
├── models.py
├── test.py
├── train.py
└── utils.py
/Data_Enhance.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import os
3 | import glob
4 | from PIL import ImageEnhance
5 | import sys
6 |
7 | img_path = r'D:\LJ2\SBA2\SD_train\926_new_train\2' #输入和输出影像所在文件夹
8 |
9 | def get_image_paths(folder):
10 | return glob.glob(os.path.join(folder, '*.tif'))
11 |
12 |
13 | def create_read_img(filename):
14 | # 读取图像
15 | im = Image.open(filename)
16 |
17 | out_h = im.transpose(Image.FLIP_LEFT_RIGHT) #水平翻转
18 | out_w = im.transpose(Image.FLIP_TOP_BOTTOM) #垂直翻转
19 | out_90 = im.transpose(Image.ROTATE_90) #顺时针选择90度
20 | # out_180 = im.transpose(Image.ROTATE_180)
21 | # out_270 = im.transpose(Image.ROTATE_270)
22 |
23 | # 亮度增强
24 | enh_bri = ImageEnhance.Brightness(im)
25 | brightness = 1.5
26 | image_brightened = enh_bri.enhance(brightness)
27 | image_brightened.save(filename[:-4] + '_brighter.tif')
28 | #
29 | # # 色度增强
30 | # enh_col = ImageEnhance.Color(im)
31 | # color = 1.5
32 | # image_colored = enh_col.enhance(color)
33 | # image_colored.save(filename[:-4] + '_color.tif')
34 | #
35 | # # 对比度增强
36 | enh_con = ImageEnhance.Contrast(im)
37 | contrast = 1.5
38 | image_contrasted = enh_con.enhance(contrast)
39 | image_contrasted.save(filename[:-4] + '_contrast.tif')
40 |
41 | # 锐度增强
42 | # enh_sha = ImageEnhance.Sharpness(im)
43 | # sharpness = 3.0
44 | # image_sharped = enh_sha.enhance(sharpness)
45 | # image_sharped.save(filename[:-4] + '_sharp.tif')
46 |
47 | #
48 | out_h.save(filename[:-4] + '_h.tif')
49 | out_w.save(filename[:-4] + '_w.tif')
50 | out_90.save(filename[:-4] + '_90.tif')
51 | # out_180.save(filename[:-4] + '_180.tif')
52 | # out_270.save(filename[:-4] + '_270.tif')
53 |
54 | #print(filename)
55 | imgs = get_image_paths(img_path)
56 | for i in imgs:
57 | create_read_img(i)
58 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # HBGNet
2 |
3 | A large-scale VHR parcel dataset and a novel hierarchical semantic boundary-guided network for agricultural parcel delineation ((https://www.sciencedirect.com/science/article/pii/S0924271625000395)
4 |
5 | [Project](https://github.com/NanNanmei/HBGNet)
6 |
7 | ## Introduction
8 |
9 | We develop a hierarchical semantic boundary-guided network (HBGNet) to fully leverage boundary semantics, thereby improving AP delineation. It integrates two branches, a core branch of AP feature extraction and an auxiliary branch related to boundary feature mining. Specifically, the boundary extract branch employes a module based on Laplace convolution operator to enhance the model’s awareness of parcel boundary.
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 | ## Using the code:
20 |
21 | The code is stable while using Python 3.9.0, CUDA >=11.4
22 |
23 | - Clone this repository:
24 | ```bash
25 | git clone https://github.com/NanNanmei/HBGNet.git
26 | cd HBGNet
27 | ```
28 |
29 | To install all the dependencies using conda or pip:
30 |
31 | ```
32 | PyTorch
33 | OpenCV
34 | numpy
35 | tqdm
36 | timm
37 | ...
38 | ```
39 |
40 | ## Preprocessing
41 | You can use the https://github.com/long123524/BsiNet-torch/blob/main/preprocess.py to obtain contour and distance maps.
42 |
43 | ## Data Format
44 |
45 | Make sure to put the files as the following structure:
46 |
47 | ```
48 | inputs
49 | └──
50 | ├── train_image
51 | | ├── 001.tif
52 | │ ├── 002.tif
53 | │ ├── 003.tif
54 | │ ├── ...
55 | |
56 | └── train_mask
57 | | ├── 001.tif
58 | | ├── 002.tif
59 | | ├── 003.tif
60 | | ├── ...
61 | └── train_boundary
62 | | ├── 001.tif
63 | | ├── 002.tif
64 | | ├── 003.tif
65 | | ├── ...
66 | └── train_dist
67 | | ├── 001.tif
68 | | ├── 002.tif
69 | | ├── 003.tif
70 | └── ├── ...
71 | ```
72 |
73 | For test datasets, the same structure as the above.
74 |
75 | ## Pretrained weight
76 |
77 | The weight of PVT-V2 pretrained on ImageNet dataset can be downloaded from: https://drive.google.com/file/d/1uzeVfA4gEQ772vzLntnkqvWePSw84F6y/view?usp=sharing
78 |
79 | ### A large-scale VHR parcel dataset
80 | The beautiful vision for FHAPD is that it will be a continuously updated dataset for different agricultural landscapes in China. At present, I have carried out many AP delineation works in different regions of China using HBGNet, such as Shanghai, Inner Mongolia, Guangdong, Shandong and Gansu, and will update it immediately after the official publication of this article.
81 | I also built an APs dataset for the African and tried to add it to the FHAPD as a personality region.
82 |
83 | Link: https://pan.baidu.com/s/1OS7G0H27zGexxRfqTcKizw?pwd=8die code: 8die
84 |
85 | If you have any problem, please email to zhaohang201@mails.ucas.ac.cn.
86 |
87 | ## Training and testing
88 |
89 | 1. Train the model.
90 | ```
91 | python train.py
92 | ```
93 | 2. Test the model.
94 | ```
95 | python test.py
96 | ```
97 | ### Citation:
98 | If you find this work useful or interesting, please consider citing the following references.
99 | ```
100 | Citation 1:
101 | @article{zhao2025,
102 | title={A large-scale VHR parcel dataset and a novel hierarchical semantic boundary-guided network for agricultural parcel delineation},
103 | author={Zhao, Hang and Wu, Bingfang and Zhang, Miao and Long, Jiang and Tian, Fuyou and Xie, Yan and Zeng, Hongwei and Zheng, Zhaoju and Ma, Zonghan and Wang, Mingxing and others},
104 | journal={ISPRS Journal of Photogrammetry and Remote Sensing},
105 | volume={221},
106 | pages={1--19},
107 | year={2025},
108 | publisher={Elsevier}
109 | }
110 | Citation 2:
111 | @article{zhao2024,
112 | title={Irregular agricultural field delineation using a dual-branch architecture from high-resolution remote sensing images},
113 | author={Zhao, Hang and Long, Jiang and Zhang, Miao and Wu, Bingfang and Xu, Chenxi and Tian, Fuyou and Ma, Zonghan},
114 | journal={IEEE Geoscience and Remote Sensing Letters},
115 | volume={21},
116 | pages={1--5},
117 | year={2024},
118 | publisher={IEEE}
119 | }
120 | Citation 3:
121 | @article{long2024,
122 | title={Integrating Segment Anything Model derived boundary prior and high-level semantics for cropland extraction from high-resolution remote sensing images},
123 | author={Long, Jiang and Zhao, Hang and Li, Mengmeng and Wang, Xiaoqin and Lu, Chengwen},
124 | journal={IEEE Geoscience and Remote Sensing Letters},
125 | volume={21},
126 | pages={1--5},
127 | year={2024},
128 | publisher={IEEE}
129 | }
130 | ```
131 |
132 | ### Acknowledgement
133 | We are very grateful for these excellent works [BsiNet](https://github.com/long123524/BsiNet-torch), [SEANet](https://github.com/long123524/SEANet_torch), and [HGINet](https://github.com/long123524/HGINet-torch), which have provided the basis for our framework.
134 |
--------------------------------------------------------------------------------
/data_remove.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import os
3 | import numpy as np
4 | import sys
5 | import glob
6 |
7 | imaPath = r'E:\Change_detection\HRSCD\images_2012\2012\D14' #输入需要处理的影像图片所在的文件夹D35
8 | labelPath = r'E:\Change_detection\HRSCD\labels_land_cover_2006\2006\caijian' #输入需要处理的标签图片所在的文件夹D35
9 | s_g = 0.3 #比例因子o-1,小于这个比例就删掉
10 |
11 | # image1 = []
12 | # imageList1 = glob.glob(os.path.join(imaPath, '*.tif'))
13 | # for item in imageList1:
14 | # image1.append(os.path.basename(item))
15 | # image2 = []
16 | # imageList2 = glob.glob(os.path.join(labelPath, '*.tif'))
17 |
18 | #
19 | # for item in imageList2:
20 | # image2.append(os.path.basename(item))
21 |
22 | lablist= os.listdir(labelPath)
23 | # imaList = os.listdir(imaPath)
24 | for labels in lablist:
25 |
26 | label_path = os.path.join(labelPath,labels)
27 |
28 | img = cv2.imread(label_path, 0)
29 | s=img.size
30 |
31 | # 先利用二值化去除图片噪声
32 | ret, img = cv2.threshold(img, 0.5, 255, cv2.THRESH_BINARY)
33 | area = 0
34 | height, width = img.shape
35 | yuzhi = s_g #比例大小
36 | yuzhi = s*yuzhi
37 | yuzhi = np.array(yuzhi, dtype='uint64') # 转变为8字节型
38 | for i in range(height):
39 | for j in range(width):
40 | if img[i, j] == 255:
41 | area += 1
42 | if area <= yuzhi:
43 | os.remove(labelPath + r"\\" + labels)
44 | else:
45 | print('')
46 |
47 | image1 = []
48 | imageList1 = glob.glob(os.path.join(labelPath, '*.tif'))
49 | for item in imageList1:
50 | image1.append(os.path.basename(item))
51 |
52 | image2 = []
53 | imageList2 = glob.glob(os.path.join(imaPath, '*.tif'))
54 | for item in imageList2:
55 | image2.append(os.path.basename(item))
56 |
57 | a = list(set(image2).difference(set(image1)))
58 | print(a)
59 | for x in a:
60 | os.remove(imaPath + r"\\" + x)
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | """
2 | The role of this file completes the data reading
3 | "dist_mask" is obtained by using Euclidean distance transformation on the mask
4 | "dist_contour" is obtained by using quasi-Euclidean distance transformation on the mask
5 | """
6 |
7 | import torch
8 | import numpy as np
9 | import cv2
10 | from PIL import Image, ImageFile
11 |
12 | from skimage import io
13 | import imageio
14 | from torch.utils.data import Dataset
15 | from torchvision import transforms
16 | from scipy import io
17 | import os
18 | from osgeo import gdal
19 | import tifffile as tiff
20 | ### Reading and saving of remote sensing images (Keep coordinate information)
21 | def readTif(fileName, xoff = 0, yoff = 0, data_width = 0, data_height = 0):
22 | dataset = gdal.Open(fileName)
23 | if dataset == None:
24 | print(fileName + "文件无法打开")
25 | # 栅格矩阵的列数
26 | width = dataset.RasterXSize
27 | # 栅格矩阵的行数
28 | height = dataset.RasterYSize
29 | # 波段数
30 | bands = dataset.RasterCount
31 | # 获取数据
32 | if(data_width == 0 and data_height == 0):
33 | data_width = width
34 | data_height = height
35 | data = dataset.ReadAsArray(xoff, yoff, data_width, data_height)
36 | # 获取仿射矩阵信息
37 | geotrans = dataset.GetGeoTransform()
38 | # 获取投影信息
39 | proj = dataset.GetProjection()
40 | return width, height, bands, data, geotrans, proj
41 |
42 |
43 | #保存遥感影像
44 | def writeTiff(im_data, im_geotrans, im_proj, path):
45 | if 'int8' in im_data.dtype.name:
46 | datatype = gdal.GDT_Byte
47 | elif 'int16' in im_data.dtype.name:
48 | datatype = gdal.GDT_UInt16
49 | else:
50 | datatype = gdal.GDT_Float32
51 | if len(im_data.shape) == 3:
52 | im_bands, im_height, im_width = im_data.shape
53 | else:
54 | im_bands, (im_height, im_width) = 1, im_data.shape
55 | # 创建文件
56 | driver = gdal.GetDriverByName("GTiff")
57 | dataset = driver.Create(path, int(im_width), int(im_height), int(im_bands), datatype)
58 | if (dataset != None):
59 | dataset.SetGeoTransform(im_geotrans) # 写入仿射变换参数
60 | dataset.SetProjection(im_proj) # 写入投影
61 | if im_bands == 1:
62 | dataset.GetRasterBand(1).WriteArray(im_data)
63 | else:
64 | for i in range(im_bands):
65 | dataset.GetRasterBand(i + 1).WriteArray(im_data[i])
66 | del dataset
67 |
68 |
69 | #######
70 | # class DatasetImageMaskContourDist(Dataset):
71 | #
72 | # def __init__(self, file_names):
73 | #
74 | # self.file_names = file_names
75 | # # self.distance_type = distance_type
76 | # # self.dir = dir
77 | #
78 | # def __len__(self):
79 | #
80 | # return len(self.file_names)
81 | #
82 | # def __getitem__(self, idx):
83 | #
84 | # img_file_name = self.file_names[idx]
85 | # image = load_image(img_file_name)
86 | # mask = load_mask(img_file_name)
87 | # contour = load_contour(img_file_name)
88 | # # dist = load_distance(os.path.join(self.dir,img_file_name+'.tif'), self.distance_type)
89 | #
90 | # return img_file_name, image, mask, contour
91 |
92 | ###训练的时候用这个
93 | class DatasetImageMaskContourDist(Dataset):
94 |
95 | def __init__(self, dir, file_names):
96 |
97 | self.file_names = file_names
98 | # self.distance_type = distance_type
99 | self.dir = dir
100 |
101 | def __len__(self):
102 |
103 | return len(self.file_names)
104 |
105 | def __getitem__(self, idx):
106 |
107 | img_file_name = self.file_names[idx]
108 | image = load_image(os.path.join(self.dir,img_file_name+'.tif'))
109 | mask = load_mask(os.path.join(self.dir,img_file_name+'.tif'))
110 | contour = load_contour(os.path.join(self.dir,img_file_name+'.tif'))
111 | # dist = load_distance(os.path.join(self.dir,img_file_name+'.tif'), self.distance_type)
112 | dist = load_distance(os.path.join(self.dir, img_file_name+'.tif'))
113 |
114 | return img_file_name, image, mask, contour, dist
115 |
116 |
117 |
118 |
119 | def load_image(path):
120 |
121 | img = Image.open(path)
122 | data_transforms = transforms.Compose(
123 | [
124 | # transforms.Resize(256),
125 | transforms.ToTensor(),
126 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
127 |
128 | ]
129 | )
130 | img = data_transforms(img)
131 |
132 | return img
133 | #
134 | #
135 | # def load_mask(path):
136 | # # mask = cv2.imread(path.replace("image", "mask").replace("tif", "tif"), 0)
137 | # mask = cv2.imread(path.replace("train_image", "train_mask").replace("tif", "tif"), 0)
138 | # # im_width, im_height, im_bands, mask, im_geotrans, im_proj = readTif(path.replace("image", "mask").replace("tif", "tif"))
139 | # mask = mask/255.
140 | # # mask[mask == 255] = 1
141 | # # mask[mask == 0] = 0
142 | #
143 | # return torch.from_numpy(np.expand_dims(mask, 0)).float()
144 |
145 | # def load_image(path):
146 | # img = tiff.imread(path).transpose([2, 0, 1]) # array([3, 512, 512])
147 | # img = img.astype('uint8')
148 | # img = (img - img.min())/(img.max()-img.min())
149 | #
150 | # return torch.from_numpy(img).float()
151 |
152 |
153 |
154 |
155 | def load_mask(path):
156 | # mask = cv2.imread(path.replace("train_images", "train_labels").replace("tif", "tif"), 0)
157 | # im_width, im_height, im_bands, mask, im_geotrans, im_proj = readTif(path.replace("train_image", "train_mask").replace("image", "label"))
158 | im_width, im_height, im_bands, mask, im_geotrans, im_proj = readTif(
159 | path.replace("train_image", "train_mask"))
160 | mask = mask.astype('uint8')
161 | mask = mask/255.
162 | # mask = np.reshape(mask, mask.shape + (1,))
163 | ###mask = mask/225.
164 | # mask[mask == 1] = 1
165 | # mask[mask == 0] = 0
166 | # print(mask.shape())
167 | # return torch.from_numpy(mask).long()
168 |
169 | return torch.from_numpy(np.expand_dims(mask, 0)).float()
170 |
171 |
172 | def load_contour(path):
173 |
174 | # contour = cv2.imread(path.replace("train_image", "train_boundary").replace("tif", "tif"), 0)
175 | # contour = contour.astype('uint8')
176 | # im_width, im_height, im_bands, contour, im_geotrans, im_proj = readTif(path.replace("train_image", "train_boundary").replace("image", "label"))
177 | im_width, im_height, im_bands, contour, im_geotrans, im_proj = readTif(
178 | path.replace("train_image", "train_boundary"))
179 | contour = contour.astype('uint8')
180 | contour = contour/255.
181 | # contour[contour ==255] = 1
182 | # contour[contour == 0] = 0
183 |
184 |
185 | return torch.from_numpy(np.expand_dims(contour, 0)).long()
186 |
187 | def load_distance(path):
188 | # im_width, im_height, im_bands, dist, im_geotrans, im_proj = readTif(path.replace("train_image", "train_dist").replace("image", "label"))
189 | im_width, im_height, im_bands, dist, im_geotrans, im_proj = readTif(
190 | path.replace("train_image", "train_dist"))
191 | dist = dist.astype('uint8')
192 | dist = dist/255.
193 | # dist = np.reshape(dist, dist.shape+(1,))
194 | return torch.from_numpy(np.expand_dims(dist, 0)).float()
195 |
196 |
197 |
--------------------------------------------------------------------------------
/gdalUtil.py:
--------------------------------------------------------------------------------
1 | import os, sys, time
2 | import numpy as np
3 | from osgeo import ogr, gdal, gdalconst
4 | from osgeo import gdal_array as ga
5 |
6 | def del_file(path):
7 | for i in os.listdir(path):
8 | path_file = os.path.join(path, i)
9 | if os.path.isfile(path_file):
10 | os.remove(path_file)
11 | else:
12 | del_file(path_file)
13 |
14 |
15 | def stretch_n(bands, img_min, img_max, lower_percent=0, higher_percent=100):
16 | """
17 | :param bands: 目标数据,numpy格式
18 | :param img_min: 目标位深的最小值,以8bit为例,最大值为255, 最小值为0
19 | :param img_max: 目标位深的最大值
20 | :return:
21 | """
22 | out = np.zeros_like(bands).astype(np.float32)
23 | a = img_min
24 | b = img_max
25 | c = np.percentile(bands[:, :], lower_percent)
26 | d = np.percentile(bands[:, :], higher_percent)
27 | t = a + (bands[:, :] - c) * (b - a) / (d - c)
28 | t[t < a] = a
29 | t[t > b] = b
30 | out[:, :] = t
31 | return out
32 |
33 |
34 | def read_img(filename):
35 | dataset=gdal.Open(filename)
36 |
37 | im_width = dataset.RasterXSize
38 | im_height = dataset.RasterYSize
39 |
40 | im_geotrans = dataset.GetGeoTransform()
41 | im_proj = dataset.GetProjection()
42 | im_data = dataset.ReadAsArray(0,0,im_width,im_height)
43 |
44 | del dataset
45 | return im_proj, im_geotrans, im_width, im_height, im_data
46 |
47 |
48 | def write_img(filename, im_proj, im_geotrans, im_data):
49 | if 'int8' in im_data.dtype.name:
50 | datatype = gdal.GDT_Byte
51 | elif 'int16' in im_data.dtype.name:
52 | datatype = gdal.GDT_UInt16
53 | else:
54 | datatype = gdal.GDT_Float32
55 |
56 | if len(im_data.shape) == 3:
57 | im_bands, im_height, im_width = im_data.shape
58 | else:
59 | im_bands, (im_height, im_width) = 1,im_data.shape
60 |
61 | driver = gdal.GetDriverByName("GTiff")
62 | dataset = driver.Create(filename, im_width, im_height, im_bands, datatype, options=['COMPRESS=LZW'])
63 |
64 | dataset.SetGeoTransform(im_geotrans)
65 | dataset.SetProjection(im_proj)
66 |
67 | if im_bands == 1:
68 | dataset.GetRasterBand(1).WriteArray(im_data)
69 | else:
70 | for i in range(im_bands):
71 | dataset.GetRasterBand(i+1).WriteArray(im_data[i])
72 |
73 | del dataset
74 |
75 |
76 | def image_resampling(source_file, target_file, scale=5.):
77 | """
78 | image resampling
79 | :param source_file: the path of source file
80 | :param target_file: the path of target file
81 | :param scale: pixel scaling
82 | :return: None
83 | """
84 | dataset = gdal.Open(source_file, gdalconst.GA_ReadOnly)
85 | band_count = dataset.RasterCount # 波段数
86 |
87 | if band_count == 0 or not scale > 0:
88 | print("参数异常")
89 | return
90 |
91 | cols = dataset.RasterXSize # 列数
92 | rows = dataset.RasterYSize # 行数
93 | cols = int(cols * scale) # 计算新的行列数
94 | rows = int(rows * scale)
95 |
96 | geotrans = list(dataset.GetGeoTransform())
97 | print(dataset.GetGeoTransform())
98 | print(geotrans)
99 | geotrans[1] = geotrans[1] / scale # 像元宽度变为原来的scale倍
100 | geotrans[5] = geotrans[5] / scale # 像元高度变为原来的scale倍
101 | print(geotrans)
102 |
103 | if os.path.exists(target_file) and os.path.isfile(target_file): # 如果已存在同名影像
104 | os.remove(target_file) # 则删除之
105 |
106 | band1 = dataset.GetRasterBand(1)
107 | data_type = band1.DataType
108 | target = dataset.GetDriver().Create(target_file, xsize=cols, ysize=rows, bands=band_count,
109 | eType=data_type)
110 | target.SetProjection(dataset.GetProjection()) # 设置投影坐标
111 | target.SetGeoTransform(geotrans) # 设置地理变换参数
112 | total = band_count + 1
113 | for index in range(1, total):
114 | # 读取波段数据
115 | print("正在写入" + str(index) + "波段")
116 | data = dataset.GetRasterBand(index).ReadAsArray(buf_xsize=cols, buf_ysize=rows)
117 | out_band = target.GetRasterBand(index)
118 | # out_band.SetNoDataValue(dataset.GetRasterBand(index).GetNoDataValue())
119 | out_band.WriteArray(data) # 写入数据到新影像中
120 | out_band.FlushCache()
121 | out_band.ComputeBandStats(False) # 计算统计信息
122 | print("正在写入完成")
123 | del dataset
124 |
125 |
126 | def sample_clip(shp, tif, outputdir, sampletype, size, fieldName='cls', n=None):
127 | """
128 | according to sampling point, generating image slices
129 | :param shp: the path of shape file
130 | :param tif: the path of image
131 | :param outputdir: the directory of output
132 | :param sampletype: line or polygon
133 | :param size: the size of images slices
134 | :param fieldName: the name of field
135 | :param n: the start number
136 | :return:
137 | """
138 | time1 = time.clock()
139 |
140 | gdal.AllRegister()
141 | lc = gdal.Open(tif)
142 | im_width = lc.RasterXSize
143 | im_height = lc.RasterYSize
144 | im_geotrans = lc.GetGeoTransform()
145 | bandscount = lc.RasterCount
146 | im_proj = lc.GetProjection()
147 | print(im_width, im_height)
148 | gdal.AllRegister()
149 | gdal.SetConfigOption("gdal_FILENAME_IS_UTF8", "YES")
150 |
151 | driver = ogr.GetDriverByName('ESRI Shapefile')
152 | dsshp = driver.Open(shp, 0)
153 | if dsshp is None:
154 | print('Could not open ' + 'sites.shp')
155 | sys.exit(1)
156 | layer = dsshp.GetLayer()
157 | xValues = []
158 | yValues = []
159 | m = layer.GetFeatureCount()
160 | feature = layer.GetNextFeature()
161 | print("tif_bands:{0},samples_nums:{1},sample_type:{2},sample_size:{3}*{3}".format(bandscount, m, sampletype,
162 | int(size)))
163 |
164 | if n is not None:
165 | pass
166 | else:
167 | n = 1
168 | while feature:
169 | if n < 10:
170 | dirname = "0000000" + str(n)
171 | elif n >= 10 and n < 100:
172 | dirname = "000000" + str(n)
173 | elif n >= 100 and n > 1000:
174 | dirname = "00000" + str(n)
175 | else:
176 | dirname = "0000" + str(n)
177 |
178 | # print dirname
179 | dirpath = os.path.join(outputdir, dirname + "_V1")
180 | if not os.path.exists(dirpath):
181 | os.mkdir(dirpath)
182 | tifname = dirname + ".tif"
183 | if "poly" in sampletype or "POLY" in sampletype:
184 | shpname = dirname + "_V1_POLY.shp"
185 | if "line" in sampletype or "LINE" in sampletype:
186 | shpname = dirname + "_V1_LINE.shp"
187 | geometry = feature.GetGeometryRef()
188 | x = geometry.GetX()
189 | y = geometry.GetY()
190 | print(x, y)
191 | print(im_geotrans)
192 | xValues.append(x)
193 | yValues.append(y)
194 | newform = []
195 | newform = list(im_geotrans)
196 | # print newform
197 | newform[0] = x - im_geotrans[1] * int(size) / 2.0
198 | newform[3] = y - im_geotrans[5] * int(size) / 2.0
199 | print(newform[0], newform[3])
200 | newformtuple = tuple(newform)
201 | x1 = x - int(size) / 2 * im_geotrans[1]
202 | y1 = y - int(size) / 2 * im_geotrans[5]
203 | x2 = x + int(size) / 2 * im_geotrans[1]
204 | y2 = y - int(size) / 2 * im_geotrans[5]
205 | x3 = x - int(size) / 2 * im_geotrans[1]
206 | y3 = y + int(size) / 2 * im_geotrans[5]
207 | x4 = x + int(size) / 2 * im_geotrans[1]
208 | y4 = y + int(size) / 2 * im_geotrans[5]
209 | Xpix = (x1 - im_geotrans[0]) / im_geotrans[1]
210 | # Xpix=(newform[0]-im_geotrans[0])
211 |
212 | Ypix = (newform[3] - im_geotrans[3]) / im_geotrans[5]
213 | # Ypix=abs(newform[3]-im_geotrans[3])
214 | print("#################")
215 | print(Xpix, Ypix)
216 |
217 | # **************create tif**********************
218 | # print"start creating {0}".format(tifname)
219 | pBuf = None
220 | pBuf = lc.ReadAsArray(int(Xpix), int(Ypix), int(size), int(size))
221 | # print pBuf.dtype.name
222 | driver = gdal.GetDriverByName("GTiff")
223 | create_option = []
224 | if 'int8' in pBuf.dtype.name:
225 | datatype = gdal.GDT_Byte
226 | elif 'int16' in pBuf.dtype.name:
227 | datatype = gdal.GDT_UInt16
228 | else:
229 | datatype = gdal.GDT_Float32
230 | outtif = os.path.join(dirpath, tifname)
231 | ds = driver.Create(outtif, int(size), int(size), int(bandscount), datatype, options=create_option)
232 | if ds == None:
233 | print("2222")
234 | ds.SetProjection(im_proj)
235 | ds.SetGeoTransform(newformtuple)
236 | ds.FlushCache()
237 | if bandscount > 1:
238 | for i in range(int(bandscount)):
239 | outBand = ds.GetRasterBand(i + 1)
240 | outBand.WriteArray(pBuf[i])
241 | else:
242 | outBand = ds.GetRasterBand(1)
243 | outBand.WriteArray(pBuf)
244 | ds.FlushCache()
245 | # print "creating {0} successfully".format(tifname)
246 | # **************create shp**********************
247 | # print"start creating shps"
248 | gdal.SetConfigOption("GDAL_FILENAME_IS_UTF8", "NO")
249 | gdal.SetConfigOption("SHAPE_ENCODING", "")
250 | strVectorFile = os.path.join(dirpath, shpname)
251 | ogr.RegisterAll()
252 | driver = ogr.GetDriverByName('ESRI Shapefile')
253 | ds = driver.Open(shp)
254 | layer0 = ds.GetLayerByIndex(0)
255 | prosrs = layer0.GetSpatialRef()
256 | # geosrs = osr.SpatialReference()
257 |
258 | oDriver = ogr.GetDriverByName("ESRI Shapefile")
259 | if oDriver == None:
260 | print("1")
261 | return
262 |
263 | oDS = oDriver.CreateDataSource(strVectorFile)
264 | if oDS == None:
265 | print("2")
266 | return
267 |
268 | papszLCO = []
269 | if "line" in sampletype or "LINE" in sampletype:
270 | oLayer = oDS.CreateLayer("TestPolygon", prosrs, ogr.wkbLineString, papszLCO)
271 | if "poly" in sampletype or "POLY" in sampletype:
272 | oLayer = oDS.CreateLayer("TestPolygon", prosrs, ogr.wkbPolygon, papszLCO)
273 | if oLayer == None:
274 | print("3")
275 | return
276 |
277 | oFieldName = ogr.FieldDefn(fieldName, ogr.OFTString)
278 | oFieldName.SetWidth(50)
279 | oLayer.CreateField(oFieldName, 1)
280 | oDefn = oLayer.GetLayerDefn()
281 | oFeatureRectangle = ogr.Feature(oDefn)
282 |
283 | geomRectangle = ogr.CreateGeometryFromWkt(
284 | "POLYGON (({0} {1},{2} {3},{4} {5},{6} {7},{0} {1}))".format(x1, y1, x2, y2, x4, y4, x3, y3))
285 | oFeatureRectangle.SetGeometry(geomRectangle)
286 | oLayer.CreateFeature(oFeatureRectangle)
287 | print("{0} ok".format(dirname))
288 | n = n + 1
289 | feature = layer.GetNextFeature()
290 | time2 = time.clock()
291 | print('Process Running time: %s min' % ((time2 - time1) / 60))
292 |
293 | return n
--------------------------------------------------------------------------------
/imgs/FHAPD.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NanNanmei/HBGNet/b3480a4b95967d62a1a969f1e27b24fdf0a53126/imgs/FHAPD.jpg
--------------------------------------------------------------------------------
/imgs/HBGNet.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NanNanmei/HBGNet/b3480a4b95967d62a1a969f1e27b24fdf0a53126/imgs/HBGNet.jpg
--------------------------------------------------------------------------------
/lib/pvtv2.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from functools import partial
5 |
6 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_
7 | from timm.models.registry import register_model
8 | from timm.models.vision_transformer import _cfg
9 | from timm.models.registry import register_model
10 |
11 | import math
12 |
13 |
14 | class Mlp(nn.Module):
15 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
16 | super().__init__()
17 | out_features = out_features or in_features
18 | hidden_features = hidden_features or in_features
19 | self.fc1 = nn.Linear(in_features, hidden_features)
20 | self.dwconv = DWConv(hidden_features)
21 | self.act = act_layer()
22 | self.fc2 = nn.Linear(hidden_features, out_features)
23 | self.drop = nn.Dropout(drop)
24 |
25 | self.apply(self._init_weights)
26 |
27 | def _init_weights(self, m):
28 | if isinstance(m, nn.Linear):
29 | trunc_normal_(m.weight, std=.02)
30 | if isinstance(m, nn.Linear) and m.bias is not None:
31 | nn.init.constant_(m.bias, 0)
32 | elif isinstance(m, nn.LayerNorm):
33 | nn.init.constant_(m.bias, 0)
34 | nn.init.constant_(m.weight, 1.0)
35 | elif isinstance(m, nn.Conv2d):
36 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
37 | fan_out //= m.groups
38 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
39 | if m.bias is not None:
40 | m.bias.data.zero_()
41 |
42 | def forward(self, x, H, W):
43 | x = self.fc1(x)
44 | x = self.dwconv(x, H, W)
45 | x = self.act(x)
46 | x = self.drop(x)
47 | x = self.fc2(x)
48 | x = self.drop(x)
49 | return x
50 |
51 |
52 | class Attention(nn.Module):
53 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1):
54 | super().__init__()
55 | assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
56 |
57 | self.dim = dim
58 | self.num_heads = num_heads
59 | head_dim = dim // num_heads
60 | self.scale = qk_scale or head_dim ** -0.5
61 |
62 | self.q = nn.Linear(dim, dim, bias=qkv_bias)
63 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
64 | self.attn_drop = nn.Dropout(attn_drop)
65 | self.proj = nn.Linear(dim, dim)
66 | self.proj_drop = nn.Dropout(proj_drop)
67 |
68 | self.sr_ratio = sr_ratio
69 | if sr_ratio > 1:
70 | self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
71 | self.norm = nn.LayerNorm(dim)
72 |
73 | self.apply(self._init_weights)
74 |
75 | def _init_weights(self, m):
76 | if isinstance(m, nn.Linear):
77 | trunc_normal_(m.weight, std=.02)
78 | if isinstance(m, nn.Linear) and m.bias is not None:
79 | nn.init.constant_(m.bias, 0)
80 | elif isinstance(m, nn.LayerNorm):
81 | nn.init.constant_(m.bias, 0)
82 | nn.init.constant_(m.weight, 1.0)
83 | elif isinstance(m, nn.Conv2d):
84 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
85 | fan_out //= m.groups
86 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
87 | if m.bias is not None:
88 | m.bias.data.zero_()
89 |
90 | def forward(self, x, H, W):
91 | B, N, C = x.shape
92 | q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
93 |
94 | if self.sr_ratio > 1:
95 | x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
96 | x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
97 | x_ = self.norm(x_)
98 | kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
99 | else:
100 | kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
101 | k, v = kv[0], kv[1]
102 |
103 | attn = (q @ k.transpose(-2, -1)) * self.scale
104 | attn = attn.softmax(dim=-1)
105 | attn = self.attn_drop(attn)
106 |
107 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
108 | x = self.proj(x)
109 | x = self.proj_drop(x)
110 |
111 | return x
112 |
113 |
114 | class Block(nn.Module):
115 |
116 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
117 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1):
118 | super().__init__()
119 | self.norm1 = norm_layer(dim)
120 | self.attn = Attention(
121 | dim,
122 | num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
123 | attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio)
124 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
125 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
126 | self.norm2 = norm_layer(dim)
127 | mlp_hidden_dim = int(dim * mlp_ratio)
128 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
129 |
130 | self.apply(self._init_weights)
131 |
132 | def _init_weights(self, m):
133 | if isinstance(m, nn.Linear):
134 | trunc_normal_(m.weight, std=.02)
135 | if isinstance(m, nn.Linear) and m.bias is not None:
136 | nn.init.constant_(m.bias, 0)
137 | elif isinstance(m, nn.LayerNorm):
138 | nn.init.constant_(m.bias, 0)
139 | nn.init.constant_(m.weight, 1.0)
140 | elif isinstance(m, nn.Conv2d):
141 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
142 | fan_out //= m.groups
143 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
144 | if m.bias is not None:
145 | m.bias.data.zero_()
146 |
147 | def forward(self, x, H, W):
148 | x = x + self.drop_path(self.attn(self.norm1(x), H, W))
149 | x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
150 |
151 | return x
152 |
153 |
154 | class OverlapPatchEmbed(nn.Module):
155 | """ Image to Patch Embedding
156 | """
157 |
158 | def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
159 | super().__init__()
160 | img_size = to_2tuple(img_size)
161 | patch_size = to_2tuple(patch_size)
162 |
163 | self.img_size = img_size
164 | self.patch_size = patch_size
165 | self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
166 | self.num_patches = self.H * self.W
167 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
168 | padding=(patch_size[0] // 2, patch_size[1] // 2))
169 | self.norm = nn.LayerNorm(embed_dim)
170 |
171 | self.apply(self._init_weights)
172 |
173 | def _init_weights(self, m):
174 | if isinstance(m, nn.Linear):
175 | trunc_normal_(m.weight, std=.02)
176 | if isinstance(m, nn.Linear) and m.bias is not None:
177 | nn.init.constant_(m.bias, 0)
178 | elif isinstance(m, nn.LayerNorm):
179 | nn.init.constant_(m.bias, 0)
180 | nn.init.constant_(m.weight, 1.0)
181 | elif isinstance(m, nn.Conv2d):
182 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
183 | fan_out //= m.groups
184 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
185 | if m.bias is not None:
186 | m.bias.data.zero_()
187 |
188 | def forward(self, x):
189 | x = self.proj(x)
190 | _, _, H, W = x.shape
191 | x = x.flatten(2).transpose(1, 2)
192 | x = self.norm(x)
193 |
194 | return x, H, W
195 |
196 |
197 | class PyramidVisionTransformerImpr(nn.Module):
198 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512],
199 | num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0.,
200 | attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
201 | depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1]):
202 | super().__init__()
203 | self.num_classes = num_classes
204 | self.depths = depths
205 |
206 | # patch_embed
207 | self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_chans=in_chans,
208 | embed_dim=embed_dims[0])
209 | self.patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0],
210 | embed_dim=embed_dims[1])
211 | self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1],
212 | embed_dim=embed_dims[2])
213 | self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2],
214 | embed_dim=embed_dims[3])
215 |
216 | # transformer encoder
217 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
218 | cur = 0
219 | self.block1 = nn.ModuleList([Block(
220 | dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale,
221 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
222 | sr_ratio=sr_ratios[0])
223 | for i in range(depths[0])])
224 | self.norm1 = norm_layer(embed_dims[0])
225 |
226 | cur += depths[0]
227 | self.block2 = nn.ModuleList([Block(
228 | dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale,
229 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
230 | sr_ratio=sr_ratios[1])
231 | for i in range(depths[1])])
232 | self.norm2 = norm_layer(embed_dims[1])
233 |
234 | cur += depths[1]
235 | self.block3 = nn.ModuleList([Block(
236 | dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale,
237 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
238 | sr_ratio=sr_ratios[2])
239 | for i in range(depths[2])])
240 | self.norm3 = norm_layer(embed_dims[2])
241 |
242 | cur += depths[2]
243 | self.block4 = nn.ModuleList([Block(
244 | dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale,
245 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
246 | sr_ratio=sr_ratios[3])
247 | for i in range(depths[3])])
248 | self.norm4 = norm_layer(embed_dims[3])
249 |
250 | # classification head
251 | # self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity()
252 |
253 | self.apply(self._init_weights)
254 |
255 | def _init_weights(self, m):
256 | if isinstance(m, nn.Linear):
257 | trunc_normal_(m.weight, std=.02)
258 | if isinstance(m, nn.Linear) and m.bias is not None:
259 | nn.init.constant_(m.bias, 0)
260 | elif isinstance(m, nn.LayerNorm):
261 | nn.init.constant_(m.bias, 0)
262 | nn.init.constant_(m.weight, 1.0)
263 | elif isinstance(m, nn.Conv2d):
264 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
265 | fan_out //= m.groups
266 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
267 | if m.bias is not None:
268 | m.bias.data.zero_()
269 |
270 | def init_weights(self, pretrained=None):
271 | if isinstance(pretrained, str):
272 | logger = 1
273 | #load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger)
274 |
275 | def reset_drop_path(self, drop_path_rate):
276 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))]
277 | cur = 0
278 | for i in range(self.depths[0]):
279 | self.block1[i].drop_path.drop_prob = dpr[cur + i]
280 |
281 | cur += self.depths[0]
282 | for i in range(self.depths[1]):
283 | self.block2[i].drop_path.drop_prob = dpr[cur + i]
284 |
285 | cur += self.depths[1]
286 | for i in range(self.depths[2]):
287 | self.block3[i].drop_path.drop_prob = dpr[cur + i]
288 |
289 | cur += self.depths[2]
290 | for i in range(self.depths[3]):
291 | self.block4[i].drop_path.drop_prob = dpr[cur + i]
292 |
293 | def freeze_patch_emb(self):
294 | self.patch_embed1.requires_grad = False
295 |
296 | @torch.jit.ignore
297 | def no_weight_decay(self):
298 | return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'} # has pos_embed may be better
299 |
300 | def get_classifier(self):
301 | return self.head
302 |
303 | def reset_classifier(self, num_classes, global_pool=''):
304 | self.num_classes = num_classes
305 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
306 |
307 | # def _get_pos_embed(self, pos_embed, patch_embed, H, W):
308 | # if H * W == self.patch_embed1.num_patches:
309 | # return pos_embed
310 | # else:
311 | # return F.interpolate(
312 | # pos_embed.reshape(1, patch_embed.H, patch_embed.W, -1).permute(0, 3, 1, 2),
313 | # size=(H, W), mode="bilinear").reshape(1, -1, H * W).permute(0, 2, 1)
314 |
315 | def forward_features(self, x):
316 | B = x.shape[0]
317 | outs = []
318 |
319 | # stage 1
320 | x, H, W = self.patch_embed1(x)
321 | for i, blk in enumerate(self.block1):
322 | x = blk(x, H, W)
323 | x = self.norm1(x)
324 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
325 | outs.append(x)
326 |
327 | # stage 2
328 | x, H, W = self.patch_embed2(x)
329 | for i, blk in enumerate(self.block2):
330 | x = blk(x, H, W)
331 | x = self.norm2(x)
332 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
333 | outs.append(x)
334 |
335 | # stage 3
336 | x, H, W = self.patch_embed3(x)
337 | for i, blk in enumerate(self.block3):
338 | x = blk(x, H, W)
339 | x = self.norm3(x)
340 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
341 | outs.append(x)
342 |
343 | # stage 4
344 | x, H, W = self.patch_embed4(x)
345 | for i, blk in enumerate(self.block4):
346 | x = blk(x, H, W)
347 | x = self.norm4(x)
348 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
349 | outs.append(x)
350 |
351 | return outs
352 |
353 | # return x.mean(dim=1)
354 |
355 | def forward(self, x):
356 | x = self.forward_features(x)
357 | # x = self.head(x)
358 |
359 | return x
360 |
361 |
362 | class DWConv(nn.Module):
363 | def __init__(self, dim=768):
364 | super(DWConv, self).__init__()
365 | self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
366 |
367 | def forward(self, x, H, W):
368 | B, N, C = x.shape
369 | x = x.transpose(1, 2).view(B, C, H, W)
370 | x = self.dwconv(x)
371 | x = x.flatten(2).transpose(1, 2)
372 |
373 | return x
374 |
375 |
376 | def _conv_filter(state_dict, patch_size=16):
377 | """ convert patch embedding weight from manual patchify + linear proj to conv"""
378 | out_dict = {}
379 | for k, v in state_dict.items():
380 | if 'patch_embed.proj.weight' in k:
381 | v = v.reshape((v.shape[0], 3, patch_size, patch_size))
382 | out_dict[k] = v
383 |
384 | return out_dict
385 |
386 |
387 | @register_model
388 | class pvt_v2_b0(PyramidVisionTransformerImpr):
389 | def __init__(self, **kwargs):
390 | super(pvt_v2_b0, self).__init__(
391 | patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4],
392 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
393 | drop_rate=0.0, drop_path_rate=0.1)
394 |
395 |
396 |
397 | @register_model
398 | class pvt_v2_b1(PyramidVisionTransformerImpr):
399 | def __init__(self, **kwargs):
400 | super(pvt_v2_b1, self).__init__(
401 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4],
402 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
403 | drop_rate=0.0, drop_path_rate=0.1)
404 |
405 | @register_model
406 | class pvt_v2_b2(PyramidVisionTransformerImpr):
407 | def __init__(self, **kwargs):
408 | super(pvt_v2_b2, self).__init__(
409 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4],
410 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1],
411 | drop_rate=0.0, drop_path_rate=0.1)
412 |
413 | @register_model
414 | class pvt_v2_b3(PyramidVisionTransformerImpr):
415 | def __init__(self, **kwargs):
416 | super(pvt_v2_b3, self).__init__(
417 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4],
418 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1],
419 | drop_rate=0.0, drop_path_rate=0.1)
420 |
421 | @register_model
422 | class pvt_v2_b4(PyramidVisionTransformerImpr):
423 | def __init__(self, **kwargs):
424 | super(pvt_v2_b4, self).__init__(
425 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4],
426 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1],
427 | drop_rate=0.0, drop_path_rate=0.1)
428 |
429 |
430 | @register_model
431 | class pvt_v2_b5(PyramidVisionTransformerImpr):
432 | def __init__(self, **kwargs):
433 | super(pvt_v2_b5, self).__init__(
434 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
435 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1],
436 | drop_rate=0.0, drop_path_rate=0.1)
--------------------------------------------------------------------------------
/losses.py:
--------------------------------------------------------------------------------
1 | """Calculating the loss
2 | You can build the loss function of BsiNet by combining multiple losses
3 | """
4 |
5 | import torch
6 | import numpy as np
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 |
11 | def dice_loss(prediction, target):
12 | """Calculating the dice loss
13 | Args:
14 | prediction = predicted image
15 | target = Targeted image
16 | Output:
17 | dice_loss"""
18 |
19 | smooth = 1.0
20 |
21 | i_flat = prediction.view(-1)
22 | t_flat = target.view(-1)
23 |
24 | intersection = (i_flat * t_flat).sum()
25 |
26 | return 1 - ((2. * intersection + smooth) / (i_flat.sum() + t_flat.sum() + smooth))
27 |
28 |
29 | def calc_loss(prediction, target, bce_weight=0.5):
30 | """Calculating the loss and metrics
31 | Args:
32 | prediction = predicted image
33 | target = Targeted image
34 | metrics = Metrics printed
35 | bce_weight = 0.5 (default)
36 | Output:
37 | loss : dice loss of the epoch """
38 | bce = F.binary_cross_entropy_with_logits(prediction, target)
39 | prediction = torch.sigmoid(prediction)
40 | dice = dice_loss(prediction, target)
41 |
42 | loss = bce * bce_weight + dice * (1 - bce_weight)
43 |
44 | return loss
45 |
46 |
47 |
48 | class log_cosh_dice_loss(nn.Module):
49 | def __init__(self, num_classes=1, smooth=1, alpha=0.7):
50 | super(log_cosh_dice_loss, self).__init__()
51 | self.smooth = smooth
52 | self.alpha = alpha
53 | self.num_classes = num_classes
54 |
55 | def forward(self, outputs, targets):
56 | x = self.dice_loss(outputs, targets)
57 | return torch.log((torch.exp(x) + torch.exp(-x)) / 2.0)
58 |
59 | def dice_loss(self, y_pred, y_true):
60 | """[function to compute dice loss]
61 | Args:
62 | y_true ([float32]): [ground truth image]
63 | y_pred ([float32]): [predicted image]
64 | Returns:
65 | [float32]: [loss value]
66 | """
67 | smooth = 1.
68 | y_true = torch.flatten(y_true)
69 | y_pred = torch.flatten(y_pred)
70 | intersection = torch.sum((y_true * y_pred))
71 | coeff = (2. * intersection + smooth) / (torch.sum(y_true) + torch.sum(y_pred) + smooth)
72 | return (1. - coeff)
73 |
74 |
75 | def focal_loss(predict, label, alpha=0.6, beta=2):
76 | probs = torch.sigmoid(predict)
77 | # 交叉熵Loss
78 | ce_loss = nn.BCELoss()
79 | ce_loss = ce_loss(probs,label)
80 | alpha_ = torch.ones_like(predict) * alpha
81 | # 正label 为alpha, 负label为1-alpha
82 | alpha_ = torch.where(label > 0, alpha_, 1.0 - alpha_)
83 | probs_ = torch.where(label > 0, probs, 1.0 - probs)
84 | # loss weight matrix
85 | loss_matrix = alpha_ * torch.pow((1.0 - probs_), beta)
86 | # 最终loss 矩阵,为对应的权重与loss值相乘,控制预测越不准的产生更大的loss
87 | loss = loss_matrix * ce_loss
88 | loss = torch.sum(loss)
89 | return loss
90 |
91 |
92 |
93 | class Loss:
94 | def __init__(self, dice_weight=0.0, class_weights=None, num_classes=1, device=None):
95 | self.device = device
96 | if class_weights is not None:
97 | nll_weight = torch.from_numpy(class_weights.astype(np.float32)).to(
98 | self.device
99 | )
100 | else:
101 | nll_weight = None
102 | self.nll_loss = nn.NLLLoss2d(weight=nll_weight)
103 | self.dice_weight = dice_weight
104 | self.num_classes = num_classes
105 |
106 | def __call__(self, outputs, targets):
107 | loss = self.nll_loss(outputs, targets)
108 | if self.dice_weight:
109 | eps = 1e-7
110 | cls_weight = self.dice_weight / self.num_classes
111 | for cls in range(self.num_classes):
112 | dice_target = (targets == cls).float()
113 | dice_output = outputs[:, cls].exp()
114 | intersection = (dice_output * dice_target).sum()
115 | # union without intersection
116 | uwi = dice_output.sum() + dice_target.sum() + eps
117 | loss += (1 - intersection / uwi) * cls_weight
118 | loss /= (1 + self.dice_weight)
119 | return loss
120 |
121 |
122 | class LossMulti:
123 | def __init__(
124 | self, jaccard_weight=0.0, class_weights=None, num_classes=1, device=None
125 | ):
126 | self.device = device
127 | if class_weights is not None:
128 | nll_weight = torch.from_numpy(class_weights.astype(np.float32)).to(
129 | self.device
130 | )
131 | else:
132 | nll_weight = None
133 |
134 | self.nll_loss = nn.NLLLoss(weight=nll_weight)
135 | self.jaccard_weight = jaccard_weight
136 | self.num_classes = num_classes
137 |
138 | def __call__(self, outputs, targets):
139 |
140 | targets = targets.squeeze(1)
141 |
142 | loss = (1 - self.jaccard_weight) * self.nll_loss(outputs, targets)
143 |
144 | if self.jaccard_weight:
145 | eps = 1e-7 # 原先是1e-7
146 | for cls in range(self.num_classes):
147 | jaccard_target = (targets == cls).float()
148 | jaccard_output = outputs[:, cls].exp()
149 | intersection = (jaccard_output * jaccard_target).sum()
150 |
151 | union = jaccard_output.sum() + jaccard_target.sum()
152 | loss -= (
153 | torch.log((intersection + eps) / (union - intersection + eps))
154 | * self.jaccard_weight
155 | )
156 | return loss
157 |
158 |
159 |
160 | class BCEDiceLoss(nn.Module):
161 | def __init__(self):
162 | super().__init__()
163 |
164 | def forward(self, input, target):
165 | bce = F.binary_cross_entropy_with_logits(input, target)
166 | smooth = 1e-5
167 | input = torch.sigmoid(input)
168 | num = target.size(0)
169 | input = input.view(num, -1)
170 | target = target.view(num, -1)
171 | intersection = (input * target)
172 | dice = (2. * intersection.sum(1) + smooth) / (input.sum(1) + target.sum(1) + smooth)
173 | dice = 1 - dice.sum() / num
174 | return 0.5 * bce + dice
175 |
176 |
177 | class LossF:
178 | def __init__(self, weights=[1, 1, 1]):
179 | self.criterion1 = BCEDiceLoss() #mask_loss BCE loss 参考SEANet
180 | self.criterion2 = LossMulti(num_classes=2) #contour_loss NLL 参考bsinet
181 | self.criterion3 = nn.MSELoss() #distance_loss MSE 参考bsinet
182 | self.weights = weights
183 |
184 | def __call__(self, outputs1, outputs2,outputs3, targets1, targets2, targets3):
185 | #
186 | criterion = (
187 | self.weights[0] * self.criterion1(outputs1, targets1)
188 | + self.weights[1] * self.criterion2(outputs2, targets2)
189 | + self.weights[2] * self.criterion3(outputs3, targets3)
190 | )
191 |
192 | return criterion
193 |
194 | class LossF_noEdgeTask:
195 | def __init__(self, weights=[1, 1]):
196 | self.criterion1 = BCEDiceLoss() #mask_loss BCE loss 参考SEANet
197 | # self.criterion2 = LossMulti(num_classes=2) #contour_loss NLL 参考bsinet
198 | self.criterion3 = nn.MSELoss() #distance_loss MSE 参考bsinet
199 | self.weights = weights
200 |
201 | def __call__(self, outputs1,outputs3, targets1, targets3):
202 | #
203 | criterion = (
204 | self.weights[0] * self.criterion1(outputs1, targets1)
205 | + self.weights[1] * self.criterion3(outputs3, targets3)
206 | )
207 |
208 | return criterion
209 |
210 |
211 |
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | """Model construction
2 | 1. We offer two versions of BsiNet, one concise and the other clear
3 | 2. The clear version is designed for user understanding and modification
4 | 3. You can use these attention mechanism we provide to bulid a new multi-task model, and you can also
5 | 4. You can also add your own module or change the location of the attention mechanism to build a better model
6 | 5........................................................
7 | ..... Baseline + GLCA + BGM + MSF
8 | """
9 |
10 | from torch.nn.parameter import Parameter
11 | # from timm.models.registry import register_model
12 | # import math
13 | import torch
14 | import torch.nn as nn
15 | import torch.nn.functional as F
16 | # from functools import partial
17 | # from timm.models.layers import DropPath, to_2tuple, trunc_normal_
18 | from lib.pvtv2 import pvt_v2_b2
19 | # from tensorboardX import SummaryWriter
20 |
21 | def conv3x3(in_, out):
22 | return nn.Conv2d(in_, out, 3, padding=1)
23 |
24 |
25 | class Conv3BN(nn.Module):
26 | def __init__(self, in_: int, out: int, bn=False):
27 | super().__init__()
28 | self.conv = conv3x3(in_, out)
29 | self.bn = nn.BatchNorm2d(out) if bn else None
30 | self.activation = nn.ReLU(inplace=True)
31 |
32 | def forward(self, x):
33 | x = self.conv(x)
34 | if self.bn is not None:
35 | x = self.bn(x)
36 | x = self.activation(x)
37 | return x
38 |
39 |
40 | class NetModule(nn.Module):
41 | def __init__(self, in_: int, out: int):
42 | super().__init__()
43 | self.l1 = Conv3BN(in_, out)
44 | self.l2 = Conv3BN(out, out)
45 |
46 | def forward(self, x):
47 | x = self.l1(x)
48 | x = self.l2(x)
49 | return x
50 |
51 |
52 | #SE注意力机制
53 | class SELayer(nn.Module):
54 | def __init__(self, channel, reduction=16):
55 | super(SELayer, self).__init__()
56 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
57 | self.fc = nn.Sequential(
58 | nn.Linear(channel, channel // reduction, bias=False),
59 | nn.ReLU(inplace=True),
60 | nn.Linear(channel // reduction, channel, bias=False),
61 | nn.Sigmoid()
62 | )
63 |
64 | def forward(self, x):
65 | b, c, _, _ = x.size()
66 | y = self.avg_pool(x).view(b, c)
67 | y = self.fc(y).view(b, c, 1, 1)
68 | return x * y.expand_as(x)
69 |
70 |
71 |
72 | class SpatialGroupEnhance(nn.Module):
73 | def __init__(self, groups = 64):
74 | super(SpatialGroupEnhance, self).__init__()
75 | self.groups = groups
76 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
77 | self.weight = Parameter(torch.zeros(1, groups, 1, 1))
78 | self.bias = Parameter(torch.ones(1, groups, 1, 1))
79 | self.sig = nn.Sigmoid()
80 |
81 | def forward(self, x): # (b, c, h, w)
82 | b, c, h, w = x.size()
83 | x = x.view(b * self.groups, -1, h, w)
84 | xn = x * self.avg_pool(x)
85 | xn = xn.sum(dim=1, keepdim=True)
86 | t = xn.view(b * self.groups, -1)
87 | t = t - t.mean(dim=1, keepdim=True)
88 | std = t.std(dim=1, keepdim=True) + 1e-5
89 | t = t / std
90 | t = t.view(b, self.groups, h, w)
91 | t = t * self.weight + self.bias
92 | t = t.view(b * self.groups, 1, h, w)
93 | x = x * self.sig(t)
94 | x = x.view(b, c, h, w)
95 | return x
96 |
97 |
98 |
99 | #scce注意力模块
100 | class cSE(nn.Module): # noqa: N801
101 | """
102 | The channel-wise SE (Squeeze and Excitation) block from the
103 | `Squeeze-and-Excitation Networks`__ paper.
104 | Adapted from
105 | https://www.kaggle.com/c/tgs-salt-identification-challenge/discussion/65939
106 | and
107 | https://www.kaggle.com/c/tgs-salt-identification-challenge/discussion/66178
108 | Shape:
109 | - Input: (batch, channels, height, width)
110 | - Output: (batch, channels, height, width) (same shape as input)
111 | __ https://arxiv.org/abs/1709.01507
112 | """
113 |
114 | def __init__(self, in_channels: int, r: int = 16):
115 | """
116 | Args:
117 | in_channels: The number of channels
118 | in the feature map of the input.
119 | r: The reduction ratio of the intermediate channels.
120 | Default: 16.
121 | """
122 | super().__init__()
123 | self.linear1 = nn.Linear(in_channels, in_channels // r)
124 | self.linear2 = nn.Linear(in_channels // r, in_channels)
125 |
126 | def forward(self, x: torch.Tensor):
127 | """Forward call."""
128 | input_x = x
129 |
130 | x = x.view(*(x.shape[:-2]), -1).mean(-1)
131 | x = F.relu(self.linear1(x), inplace=True)
132 | x = self.linear2(x)
133 | x = x.unsqueeze(-1).unsqueeze(-1)
134 | x = torch.sigmoid(x)
135 |
136 | x = torch.mul(input_x, x)
137 | return x
138 |
139 |
140 | class sSE(nn.Module): # noqa: N801
141 | """
142 | The sSE (Channel Squeeze and Spatial Excitation) block from the
143 | `Concurrent Spatial and Channel ‘Squeeze & Excitation’
144 | in Fully Convolutional Networks`__ paper.
145 | Adapted from
146 | https://www.kaggle.com/c/tgs-salt-identification-challenge/discussion/66178
147 | Shape:
148 | - Input: (batch, channels, height, width)
149 | - Output: (batch, channels, height, width) (same shape as input)
150 | __ https://arxiv.org/abs/1803.02579
151 | """
152 |
153 | def __init__(self, in_channels: int):
154 | """
155 | Args:
156 | in_channels: The number of channels
157 | in the feature map of the input.
158 | """
159 | super().__init__()
160 | self.conv = nn.Conv2d(in_channels, 1, kernel_size=1, stride=1)
161 |
162 | def forward(self, x: torch.Tensor):
163 | """Forward call."""
164 | input_x = x
165 |
166 | x = self.conv(x)
167 | x = torch.sigmoid(x)
168 |
169 | x = torch.mul(input_x, x)
170 | return x
171 |
172 |
173 | class scSE(nn.Module): # noqa: N801
174 | """
175 | The scSE (Concurrent Spatial and Channel Squeeze and Channel Excitation)
176 | block from the `Concurrent Spatial and Channel ‘Squeeze & Excitation’
177 | in Fully Convolutional Networks`__ paper.
178 | Adapted from
179 | https://www.kaggle.com/c/tgs-salt-identification-challenge/discussion/66178
180 | Shape:
181 | - Input: (batch, channels, height, width)
182 | - Output: (batch, channels, height, width) (same shape as input)
183 | __ https://arxiv.org/abs/1803.02579
184 | """
185 |
186 | def __init__(self, in_channels: int, r: int = 16):
187 | """
188 | Args:
189 | in_channels: The number of channels
190 | in the feature map of the input.
191 | r: The reduction ratio of the intermediate channels.
192 | Default: 16.
193 | """
194 | super().__init__()
195 | self.cse_block = cSE(in_channels, r)
196 | self.sse_block = sSE(in_channels)
197 |
198 | def forward(self, x: torch.Tensor):
199 | """Forward call."""
200 | cse = self.cse_block(x)
201 | sse = self.sse_block(x)
202 | x = torch.add(cse, sse)
203 | return x
204 |
205 |
206 | ##This is a concise version of the BsiNet whose modules are better packaged
207 |
208 | class BsiNet(nn.Module):
209 |
210 | output_downscaled = 1
211 | module = NetModule
212 |
213 | def __init__(
214 | self,
215 | input_channels: int = 3,
216 | filters_base: int = 32,
217 | down_filter_factors=(1, 2, 4, 8, 16),
218 | up_filter_factors=(1, 2, 4, 8, 16),
219 | bottom_s=4,
220 | num_classes=1,
221 | add_output=True,
222 | ):
223 | super().__init__()
224 | self.num_classes = num_classes
225 | assert len(down_filter_factors) == len(up_filter_factors)
226 | assert down_filter_factors[-1] == up_filter_factors[-1]
227 | down_filter_sizes = [filters_base * s for s in down_filter_factors]
228 | up_filter_sizes = [filters_base * s for s in up_filter_factors]
229 | self.down, self.up = nn.ModuleList(), nn.ModuleList()
230 | self.down.append(self.module(input_channels, down_filter_sizes[0]))
231 | for prev_i, nf in enumerate(down_filter_sizes[1:]):
232 | self.down.append(self.module(down_filter_sizes[prev_i], nf))
233 | for prev_i, nf in enumerate(up_filter_sizes[1:]):
234 | self.up.append(
235 | self.module(down_filter_sizes[prev_i] + nf, up_filter_sizes[prev_i])
236 | )
237 |
238 | pool = nn.MaxPool2d(2, 2)
239 | pool_bottom = nn.MaxPool2d(bottom_s, bottom_s)
240 | upsample = nn.Upsample(scale_factor=2)
241 | upsample_bottom = nn.Upsample(scale_factor=bottom_s)
242 | self.downsamplers = [None] + [pool] * (len(self.down) - 1)
243 | self.downsamplers[-1] = pool_bottom
244 | self.upsamplers = [upsample] * len(self.up)
245 | self.upsamplers[-1] = upsample_bottom
246 | self.add_output = add_output
247 | self.sge = SpatialGroupEnhance(32)
248 |
249 | # self.ca1 = ChannelAttention()
250 | # self.sa1 = SpatialAttention()
251 |
252 | if add_output:
253 | self.conv_final1 = nn.Conv2d(up_filter_sizes[0], num_classes, 1)
254 | self.conv_final2 = nn.Conv2d(up_filter_sizes[0], num_classes, 1)
255 | self.conv_final3 = nn.Conv2d(up_filter_sizes[0], 1, 1)
256 |
257 | def forward(self, x):
258 | xs = []
259 | for downsample, down in zip(self.downsamplers, self.down):
260 | x_in = x if downsample is None else downsample(xs[-1])
261 | x_out = down(x_in)
262 | xs.append(x_out)
263 |
264 | for x_skip, upsample, up in reversed(
265 | list(zip(xs[:-1], self.upsamplers, self.up))
266 | ):
267 |
268 | x_out2 = upsample(x_out)
269 | x_out= (torch.cat([x_out2, x_skip], 1))
270 | x_out = up(x_out)
271 |
272 | if self.add_output:
273 |
274 | x_out = self.sge(x_out)
275 |
276 | x_out1 = self.conv_final1(x_out)
277 | x_out2 = self.conv_final2(x_out)
278 | x_out3 = self.conv_final3(x_out)
279 | if self.num_classes > 1:
280 | x_out1 = F.log_softmax(x_out1,dim=1)
281 | x_out2 = F.log_softmax(x_out2,dim=1)
282 | x_out3 = torch.sigmoid(x_out3)
283 |
284 | return [x_out1, x_out2, x_out3]
285 |
286 |
287 |
288 |
289 | class LaplaceConv2d(nn.Module):
290 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, bias=True):
291 | super(LaplaceConv2d, self).__init__()
292 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias)
293 |
294 | # Generate Laplace kernel
295 | laplace_kernel = torch.tensor([[1, 1, 1], [1, -8, 1], [1, 1, 1]], dtype=torch.float32) ##8领域
296 | laplace_kernel = laplace_kernel.unsqueeze(0).unsqueeze(0)
297 | laplace_kernel = laplace_kernel.repeat((out_channels, in_channels, 1, 1))
298 | self.conv.weight = nn.Parameter(laplace_kernel)
299 | self.conv.bias.data.fill_(0)
300 | self.bn = nn.BatchNorm2d(out_channels)
301 | self.relu = nn.ReLU(inplace=True)
302 |
303 | def forward(self, x):
304 | x1 = self.conv(x)
305 | x1 = self.relu(self.bn(x1))
306 |
307 | return x1
308 |
309 |
310 | ######CBAM注意力
311 | class CBAM(nn.Module):
312 | def __init__(self, channel, reduction=16, spatial_kernel=7):
313 | super(CBAM, self).__init__()
314 | # channel attention 压缩H,W为1
315 | self.max_pool = nn.AdaptiveMaxPool2d(1)
316 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
317 | # shared MLP
318 | self.mlp = nn.Sequential(
319 | nn.Conv2d(channel, channel // reduction, 1, bias=False),
320 | nn.ReLU(inplace=True),
321 | nn.Conv2d(channel // reduction, channel, 1, bias=False)
322 | )
323 | # spatial attention
324 | self.conv = nn.Conv2d(2, 1, kernel_size=spatial_kernel,
325 | padding=spatial_kernel // 2, bias=False)
326 | self.sigmoid = nn.Sigmoid()
327 | def forward(self, x):
328 | max_out = self.mlp(self.max_pool(x))
329 | avg_out = self.mlp(self.avg_pool(x))
330 | channel_out = self.sigmoid(max_out + avg_out)
331 | x = channel_out * x
332 | max_out, _ = torch.max(x, dim=1, keepdim=True)
333 | avg_out = torch.mean(x, dim=1, keepdim=True)
334 | spatial_out = self.sigmoid(self.conv(torch.cat([max_out, avg_out], dim=1)))
335 | x = spatial_out * x
336 | return x
337 |
338 |
339 |
340 | ###边界增强模块/边界引导模块(boundary guided module,BGM)
341 | class Boundary_guided_module(nn.Module):
342 | def __init__(self, in_channel1,in_channel2,out_channel):
343 | super(Boundary_guided_module, self).__init__()
344 | self.sigmoid = nn.Sigmoid()
345 | self.conv1 = nn.Conv2d(in_channel1, out_channel, 1) ##1x1卷积用来降低通道数
346 | self.conv2 = nn.Conv2d(in_channel2, out_channel, 1) ##1x1卷积用来降低通道数
347 | self.max_pool = nn.AdaptiveMaxPool2d(1)
348 |
349 | def forward(self,edge,semantic):
350 | x = self.conv1(edge)
351 | x, _ = torch.max(x, dim=1, keepdim=True)
352 | x = self.sigmoid(x)
353 | x = x*self.conv2(semantic)
354 | x = x + self.conv2(semantic)
355 | return x
356 |
357 |
358 | class Long_distance(nn.Module):
359 | '''Spatial reasoning module'''
360 |
361 | # codes from DANet 'Dual attention network for scene segmentation'
362 | def __init__(self, in_dim):
363 | super(Long_distance, self).__init__()
364 | self.chanel_in = in_dim
365 |
366 | self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
367 | self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
368 | self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
369 | self.gamma = nn.Parameter(torch.zeros(1)) # nn.Parameter() 将这个张量转化为一个可以在模型训练过程中进行梯度更新的参数。在神经网络中,这样的参数通常用于权重或偏置项。
370 |
371 | self.softmax = nn.Softmax(dim=-1)
372 |
373 | def forward(self, x):
374 | ''' inputs :
375 | x : input feature maps( B X C X H X W)
376 | returns :
377 | out : attention value + input feature
378 | attention: B X (HxW) X (HxW) '''
379 | m_batchsize, C, height, width = x.size()
380 | proj_query = self.query_conv(x).view(m_batchsize, -1, width * height).permute(0, 2, 1)
381 | proj_key = self.key_conv(x).view(m_batchsize, -1, width * height)
382 | energy = torch.bmm(proj_query, proj_key)
383 | attention = self.softmax(energy)
384 | proj_value = self.value_conv(x).view(m_batchsize, -1, width * height)
385 |
386 | out = torch.bmm(proj_value, attention.permute(0, 2, 1))
387 | out = out.view(m_batchsize, C, height, width)
388 | out = x + self.gamma * out
389 |
390 | return out
391 |
392 | ##局部空间和全局空间上下文(或近和远距离上下文提取)
393 | class near_and_long(nn.Module):
394 | def __init__(self, in_channel,out_channel):
395 | super(near_and_long, self).__init__()
396 | self.long = Long_distance(in_channel)
397 | self.near = Residual(in_channel,out_channel)
398 | self.conv1 = nn.Conv2d(in_channel+out_channel,out_channel,1)
399 |
400 | def forward(self,x):
401 | x1 = self.long(x)
402 | x2 = self.near(x)
403 | fuse = torch.cat([x1,x2], 1)
404 | fuse = self.conv1(fuse)
405 |
406 | return fuse
407 |
408 |
409 | class multi_scale_fuseion(nn.Module):
410 | def __init__(self, in_channel,out_channel):
411 | super(multi_scale_fuseion, self).__init__()
412 | self.c1 = Conv(in_channel,out_channel, kernel_size=1, padding=0)
413 | self.c2 = Conv(in_channel,out_channel, kernel_size=3, padding=1)
414 | self.c3 = Conv(in_channel,out_channel, kernel_size=7, padding=3)
415 | self.c4 = Conv(in_channel,out_channel, kernel_size=11, padding=5)
416 | self.s1 = Conv(out_channel*4,out_channel, kernel_size=1, padding=0)
417 | self.attention = CBAM(out_channel)
418 |
419 | def forward(self,x):
420 | x1 = self.c1(x)
421 | x2 = self.c2(x)
422 | x3 = self.c3(x)
423 | x4 = self.c4(x)
424 | x5 = torch.cat([x1,x2,x3,x4], 1)
425 | x5 = self.s1(x5)
426 | x6 = self.attention(x5)
427 |
428 | return x6
429 |
430 |
431 | class Residual(nn.Module):
432 | def __init__(self, input_dim, output_dim, stride=1, padding=1):
433 | super(Residual, self).__init__()
434 |
435 | self.conv_block = nn.Sequential(
436 | nn.BatchNorm2d(input_dim),
437 | nn.ReLU(),
438 | nn.Conv2d(
439 | input_dim, output_dim, kernel_size=3, stride=stride, padding=padding
440 | ),
441 | nn.BatchNorm2d(output_dim),
442 | nn.ReLU(),
443 | nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1),
444 | )
445 | self.conv_skip = nn.Sequential(
446 | nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride, padding=1),
447 | nn.BatchNorm2d(output_dim),
448 | )
449 |
450 | def forward(self, x):
451 |
452 | return self.conv_block(x) + self.conv_skip(x)
453 |
454 |
455 | class Conv(nn.Module):
456 | def __init__(self, inp_dim, out_dim, kernel_size=3, stride=1, bn=True, padding=1,relu=True, bias=True):
457 | super(Conv, self).__init__()
458 | self.inp_dim = inp_dim
459 | self.conv = nn.Conv2d(inp_dim, out_dim, kernel_size, stride, padding, bias=bias)
460 | self.relu = relu
461 | self.bn = bn
462 | if relu:
463 | self.relu = nn.ReLU(inplace=True)
464 | if bn:
465 | self.bn = nn.BatchNorm2d(out_dim)
466 |
467 | def forward(self, x):
468 | assert x.size()[1] == self.inp_dim, "{} {}".format(x.size()[1], self.inp_dim)
469 |
470 | x = self.conv(x)
471 | if self.bn is not None:
472 | x = self.bn(x)
473 | if self.relu is not None:
474 | x = self.relu(x)
475 | return x
476 |
477 |
478 |
479 |
480 | class Field(nn.Module):
481 | def __init__(self, channel=32,num_classes=2,drop_rate=0.4):
482 | super(Field, self).__init__()
483 |
484 | self.drop = nn.Dropout2d(drop_rate)
485 | self.backbone = pvt_v2_b2() # [64, 128, 320, 512]
486 | path = r"E:\zhaohang\DeepL\code\model_New\preweight\pvt_v2_b2.pth"
487 | save_model = torch.load(path)
488 | model_dict = self.backbone.state_dict()
489 | state_dict = {k: v for k, v in save_model.items() if k in model_dict.keys()}
490 | model_dict.update(state_dict)
491 | self.backbone.load_state_dict(model_dict)
492 | self.channel = channel
493 | self.num_classes = num_classes
494 |
495 | self.edge_lap = LaplaceConv2d(in_channels=3,out_channels=1)
496 | self.conv1 = Residual(1,32)
497 | self.conv2 = nn.Conv2d(64, 16, 1)
498 | self.attention1 = CBAM(32)
499 | self.attention2 = CBAM(64)
500 | self.fuse1 = near_and_long(512,256)
501 | self.fuse2 = near_and_long(320, 128)
502 | self.fuse3 = near_and_long(128, 64)
503 | self.fuse4 = near_and_long(64, 32)
504 | ###
505 | self.up1 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)
506 | self.up2 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True)
507 | self.up3 = nn.Upsample(scale_factor=16, mode='bilinear', align_corners=True)
508 | self.up4 = nn.Upsample(scale_factor=32, mode='bilinear', align_corners=True)
509 |
510 | ##
511 | self.boundary1 = Boundary_guided_module(64,64,32)
512 | self.boundary2 = Boundary_guided_module(64,64, 16)
513 | self.boundary3 = Boundary_guided_module(64,128, 16)
514 | self.boundary4 = Boundary_guided_module(64, 256, 16)
515 | ##
516 | self.multi_fusion = multi_scale_fuseion(64,64)
517 | self.sigmoid = nn.Sigmoid()
518 | self.out_feature = nn.Conv2d(64,1,1)
519 | self.edge_feature = nn.Conv2d(64, num_classes, 1)
520 | #输出距离图
521 | self.dis_feature = nn.Conv2d(64,1,1)
522 |
523 | def forward(self, x):
524 | #---------------------------------------------------------------------#
525 | #
526 |
527 | edge = self.edge_lap(x)
528 | edge = self.conv1(edge)
529 |
530 | pvt = self.backbone(x)
531 | x1 = pvt[0] ##(b,64,64,64)
532 | x1 = self.drop(x1)
533 | x2 = pvt[1] ##(b,128,32,32)
534 | x2 = self.drop(x2)
535 | x3 = pvt[2] ##(b,320,16,16)
536 | x3 = self.drop(x3)
537 | x4 = pvt[3] ##(b,512,32,32)
538 | x4 = self.drop(x4)
539 | ##global and local context aggregation
540 | x1 = self.fuse4(x1) #32
541 | x2 = self.fuse3(x2) #64
542 | x3 = self.fuse2(x3) #128
543 | x4 = self.fuse1(x4) #256
544 | ### boundary guided module
545 | x1 = self.up1(x1)
546 | edge = torch.cat([edge,x1],1) #64
547 | edge = self.attention2(edge)
548 | edge1 = self.conv2(edge)
549 | # bs1 = self.boundary1(edge,x1)
550 | x2 = self.up2(x2)
551 | bs2 = self.boundary2(edge, x2)
552 | x3 = self.up3(x3)
553 | bs3 = self.boundary3(edge, x3)
554 | x4 = self.up4(x4)
555 | bs4 = self.boundary4(edge, x4)
556 | ###multi-scale feature fusion module
557 | ms = torch.cat([edge1,bs2,bs3,bs4],1)
558 | out = self.multi_fusion(ms)
559 |
560 | edge_out = self.edge_feature(edge)
561 | edge_out = F.log_softmax(edge_out, dim=1)
562 | mask_out = self.out_feature(out)
563 | # mask_out = F.log_softmax(mask_out,dim=1)
564 | #输出dist_out
565 | dis_out = self.dis_feature(out)
566 |
567 |
568 | return [mask_out,edge_out,dis_out]
569 |
570 | if __name__ == "__main__":
571 | tensor = torch.randn((8, 3, 512, 512))
572 | net = Field()
573 | # 打印模型每一层的名字
574 | # for name, module in net.named_modules():
575 | # print(f'Layer Name: {name}')
576 | outputs = net(tensor)
577 | for output in outputs:
578 | print(output.size())
579 |
580 |
581 | # with SummaryWriter(logdir="network") as w:
582 | # w.add_graph(net, tensor)
583 | #
584 | # w.close()
585 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | from torch.utils.data import DataLoader
4 | from dataset import DatasetImageMaskContourDist
5 | import glob
6 | from models import Field
7 | from tqdm import tqdm
8 | import numpy as np
9 | import cv2
10 | from utils import create_validation_arg_parser
11 | from torch import nn
12 |
13 | def build_model(model_type):
14 |
15 | if model_type == "field":
16 | model = Field(num_classes=2)
17 |
18 | return model
19 |
20 |
21 | if __name__ == "__main__":
22 | args = create_validation_arg_parser().parse_args()
23 | args.model_file = r"E:\zhaohang\DeepL\模型训练pt文件_lmx机子\JS_zh\80.pt"
24 | args.save_path = r"D:\DeepL\data\paper\JS\test\pre_mask1"
25 | args.model_type = 'field'
26 | args.test_path = r"D:\DeepL\data\paper\JS\test\train_image"
27 |
28 | # args = create_validation_arg_parser().parse_args()
29 | # args.model_file = r"D:\DeepL\data\paper\JS\ablation_study\model_F1\model_pt\50.pt"
30 | # args.save_path = r"D:\DeepL\data\paper\JS\ablation_study\model_F1\pre_mask1"
31 | # args.model_type = 'field'
32 | # args.test_path = r"D:\DeepL\data\paper\JS\ablation_study\model_F1\train_image"
33 |
34 |
35 | test_path = os.path.join(args.test_path, "*.tif")
36 | model_file = args.model_file
37 | save_path = args.save_path
38 | model_type = args.model_type
39 |
40 | cuda_no = args.cuda_no
41 | CUDA_SELECT = "cuda:{}".format(cuda_no)
42 | device = torch.device(CUDA_SELECT if torch.cuda.is_available() else "cpu")
43 |
44 | test_file_names = glob.glob(test_path)
45 | # print(test_file_names)
46 | # valLoader = DataLoader(DatasetImageMaskContourDist(test_file_names))
47 | test_file_names = [filePath.split('.')[0] for filePath in test_file_names]
48 | valLoader = DataLoader(DatasetImageMaskContourDist(args.test_path, test_file_names))
49 | # valLoader = DataLoader(DatasetImageMaskContourDist(test_file_names))
50 |
51 | if not os.path.exists(save_path):
52 | os.mkdir(save_path)
53 |
54 | model = build_model(model_type)
55 | model = nn.DataParallel(model) # 自己加的
56 | model = model.to(device)
57 | model.load_state_dict(torch.load(model_file))
58 | model.eval()
59 |
60 | for i, (img_file_name, inputs, targets1, targets2, targets3) in enumerate(
61 | tqdm(valLoader)
62 | ):
63 |
64 | inputs = inputs.to(device)
65 | outputs1, outputs2 ,outputs3= model(inputs)
66 |
67 | ## TTA
68 | # outputs4, outputs5, outputs6 = model(torch.flip(inputs, [-1]))
69 | # predict_2 = torch.flip(outputs4, [-1])
70 | # outputs7, outputs8, outputs9 = model(torch.flip(inputs, [-2]))
71 | # predict_3 = torch.flip(outputs7, [-2])
72 | # outputs10, outputs11, outputs12 = model(torch.flip(inputs, [-1, -2]))
73 | # predict_4 = torch.flip(outputs10, [-1, -2])
74 | # predict_list = outputs1 + predict_2 + predict_3 + predict_4
75 | # pred1 = predict_list/4.0
76 |
77 | outputs1 = outputs1.detach().cpu().numpy().squeeze()
78 |
79 |
80 |
81 | res = np.zeros((256, 256))
82 | res[outputs1>0.5] = 255
83 | res[outputs1<=0.5] = 0
84 |
85 | res = np.array(res, dtype='uint8')
86 | output_path = os.path.join(
87 | # save_path, os.path.basename(img_file_name[0])
88 | save_path, os.path.basename(img_file_name[0] + ".tif")
89 | )
90 | cv2.imwrite(output_path, res)
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import logging
3 | import os
4 | import random
5 | import torch
6 | from dataset import DatasetImageMaskContourDist
7 | from losses import LossF
8 | from models import Field
9 |
10 | from tensorboardX import SummaryWriter
11 | from torch import nn
12 | from torch.utils.data import DataLoader
13 | from tqdm import tqdm
14 | from utils import visualize, create_train_arg_parser,evaluate
15 | # from torchsummary import summary
16 | from sklearn.model_selection import train_test_split
17 |
18 | #
19 | def define_loss(loss_type, weights=[1, 1, 1]):
20 |
21 | if loss_type == "field":
22 | criterion = LossF(weights)
23 |
24 | return criterion
25 |
26 |
27 | def build_model(model_type):
28 |
29 | if model_type == "field":
30 | model = Field(num_classes=2)
31 |
32 | return model
33 |
34 |
35 | def train_model(model, targets, model_type, criterion, optimizer):
36 |
37 | if model_type == "field":
38 |
39 | optimizer.zero_grad()
40 |
41 | with torch.set_grad_enabled(True):
42 | outputs = model(inputs)
43 | loss = criterion(
44 | outputs[0], outputs[1], outputs[2], targets[0], targets[1], targets[2]
45 | )
46 | loss.backward()
47 | optimizer.step()
48 |
49 | return loss
50 |
51 |
52 | if __name__ == "__main__":
53 |
54 | args = create_train_arg_parser().parse_args()
55 | # args.pretrained_model_path = r"E:\zhaohang\data\onesoil_2m_pt\56.pt"
56 | # args.pretrained_model_path = r"E:\zhaohang\DeepL\模型训练pt文件_lmx机子\dikuai_all_ihave_ZH\60.pt"
57 |
58 | args.train_path = r"F:\PingAn\NX\Ningxia\output\train_image"
59 | args.model_type = 'field'
60 | args.save_path = r"F:\PingAn\NX\Ningxia\output\model_pt"
61 |
62 | CUDA_SELECT = "cuda:{}".format(args.cuda_no)
63 | log_path = args.save_path + "/summary"
64 | writer = SummaryWriter(log_dir=log_path)
65 |
66 | logging.basicConfig(
67 | filename="".format(args.object_type),
68 | filemode="a",
69 | format="%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s",
70 | datefmt="%Y-%m-%d %H:%M",
71 | level=logging.INFO,
72 | )
73 | logging.info("")
74 |
75 | train_file_names = glob.glob(os.path.join(args.train_path, "*.tif"))
76 | random.shuffle(train_file_names)
77 |
78 | img_ids = [os.path.splitext(os.path.basename(p))[0] for p in train_file_names]
79 | train_file, val_file = train_test_split(img_ids, test_size=0.2, random_state=41) #
80 |
81 | device = torch.device(CUDA_SELECT if torch.cuda.is_available() else "cpu")
82 | print(device)
83 | model = build_model(args.model_type)
84 |
85 | if torch.cuda.device_count() > 2: #本来是0
86 | print("Let's use", torch.cuda.device_count(), "GPUs!")
87 | model = nn.DataParallel(model)
88 |
89 | model = model.to(device)
90 | # summary(model, input_size=(3, 256, 256))
91 |
92 | epoch_start = "0"
93 | if args.use_pretrained:
94 | print("Loading Model {}".format(os.path.basename(args.pretrained_model_path)))
95 | model.load_state_dict(torch.load(args.pretrained_model_path))
96 | epoch_start = os.path.basename(args.pretrained_model_path).split(".")[0]
97 | print(epoch_start)
98 | print('train',args.use_pretrained)
99 |
100 | trainLoader = DataLoader(
101 | DatasetImageMaskContourDist(args.train_path,train_file),
102 | batch_size=args.batch_size,drop_last=False, shuffle=True
103 | )
104 | devLoader = DataLoader(
105 | DatasetImageMaskContourDist(args.train_path,val_file),drop_last=False,
106 | )
107 | displayLoader = DataLoader(
108 | DatasetImageMaskContourDist(args.train_path,val_file),
109 | batch_size=args.val_batch_size,drop_last=False, shuffle=True
110 | )
111 |
112 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
113 | # optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
114 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, int(1e10), eta_min=1e-5)
115 | # scheduler = optim.lr_scheduler.StepLR(optimizer, 50, 0.1)
116 | criterion = define_loss(args.model_type)
117 |
118 |
119 | for epoch in tqdm(
120 | range(int(epoch_start) + 1, int(epoch_start) + 1 + args.num_epochs)
121 | ):
122 |
123 | global_step = epoch * len(trainLoader)
124 | running_loss = 0.0
125 |
126 | for i, (img_file_name, inputs, targets1, targets2, targets3) in enumerate(
127 | tqdm(trainLoader)
128 | ):
129 |
130 | model.train()
131 |
132 | inputs = inputs.to(device)
133 | targets1 = targets1.to(device)
134 | targets2 = targets2.to(device)
135 | targets3 = targets3.to(device)
136 |
137 | targets = [targets1, targets2, targets3]
138 |
139 |
140 | loss = train_model(model, targets, args.model_type, criterion, optimizer)
141 |
142 | writer.add_scalar("loss", loss.item(), epoch)
143 |
144 | running_loss += loss.item() * inputs.size(0)
145 | scheduler.step()
146 |
147 | epoch_loss = running_loss / len(train_file_names)
148 | print(epoch_loss)
149 |
150 | if epoch % 1 == 0:
151 |
152 | dev_loss, dev_time = evaluate(device, epoch, model, devLoader, writer)
153 | writer.add_scalar("loss_valid", dev_loss, epoch)
154 | visualize(device, epoch, model, displayLoader, writer, args.val_batch_size)
155 | print("Global Loss:{} Val Loss:{}".format(epoch_loss, dev_loss))
156 | else:
157 | print("Global Loss:{} ".format(epoch_loss))
158 |
159 |
160 | logging.info("epoch:{} train_loss:{} ".format(epoch, epoch_loss))
161 | if epoch % 5 == 0:
162 | torch.save(
163 | model.state_dict(), os.path.join(args.save_path, str(epoch) + ".pt")
164 | )
165 |
166 |
167 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from tqdm import tqdm
3 | import numpy as np
4 | import torchvision
5 | from torch.nn import functional as F
6 | import time
7 | import argparse
8 | from losses import *
9 |
10 | def evaluate(device, epoch, model, data_loader, writer):
11 | model.eval()
12 | losses = []
13 | start = time.perf_counter()
14 | with torch.no_grad():
15 |
16 | for iter, data in enumerate(tqdm(data_loader)):
17 |
18 | _, inputs, targets, _, _ = data
19 | # _, inputs, targets, _ = data
20 | inputs = inputs.to(device)
21 | targets = targets.to(device)
22 | outputs = model(inputs)
23 | crition = BCEDiceLoss()
24 | loss = crition(outputs[0],targets)
25 | # loss = F.nll_loss(outputs[0], targets.squeeze(1))
26 | losses.append(loss.item())
27 |
28 | writer.add_scalar("Dev_Loss", np.mean(losses), epoch)
29 |
30 | return np.mean(losses), time.perf_counter() - start
31 |
32 |
33 | def visualize(device, epoch, model, data_loader, writer, val_batch_size, train=True):
34 | def save_image(image, tag, val_batch_size):
35 | image -= image.min()
36 | image /= image.max()
37 | grid = torchvision.utils.make_grid(
38 | image, nrow=int(np.sqrt(val_batch_size)), pad_value=0, padding=25
39 | )
40 | writer.add_image(tag, grid, epoch)
41 |
42 | model.eval()
43 | with torch.no_grad():
44 | for iter, data in enumerate(tqdm(data_loader)):
45 | _, inputs, targets, _, _ = data
46 | # _, inputs, targets, _ = data
47 |
48 | inputs = inputs.to(device)
49 |
50 | targets = targets.to(device)
51 | outputs = model(inputs)
52 |
53 | output_mask = outputs[0].detach().cpu().numpy()
54 | output_mask[output_mask>0.5]= 255
55 | output_mask[output_mask <=0.5] = 0
56 | # output_final = np.argmax(output_mask, axis=1).astype(float)
57 | output_final = torch.from_numpy(output_mask)
58 |
59 | if train == "True":
60 | save_image(targets.float(), "Target_train",val_batch_size)
61 | save_image(output_final, "Prediction_train",val_batch_size)
62 | else:
63 | save_image(targets.float(), "Target", val_batch_size)
64 | save_image(output_final, "Prediction", val_batch_size)
65 |
66 | break
67 |
68 |
69 | def create_train_arg_parser():
70 |
71 | parser = argparse.ArgumentParser(description="train setup for segmentation")
72 | parser.add_argument("--train_path", type=str, help="path to img tif files")
73 | parser.add_argument("--val_path", type=str, help="path to img tif files")
74 | parser.add_argument(
75 | "--model_type",
76 | type=str,
77 | help="select model type: bsinet",
78 | )
79 | parser.add_argument("--object_type", type=str, help="Dataset.")
80 | parser.add_argument(
81 | "--distance_type",
82 | type=str,
83 | default="dist_contour",
84 | help="select distance transform type - dist_mask,dist_contour,dist_contour_tif",
85 | )
86 | parser.add_argument("--batch_size", type=int, default=8, help="train batch size")
87 | parser.add_argument(
88 | "--val_batch_size", type=int, default=8, help="validation batch size"
89 | )
90 | parser.add_argument("--num_epochs", type=int, default=100, help="number of epochs")
91 | parser.add_argument("--cuda_no", type=int, default=0, help="cuda number")
92 | parser.add_argument(
93 | "--use_pretrained", type=bool, default=False, help="Load pretrained checkpoint."
94 | )
95 | parser.add_argument(
96 | "--pretrained_model_path",
97 | type=str,
98 | default=None,
99 | help="If use_pretrained is true, provide checkpoint.",
100 | )
101 | parser.add_argument("--save_path", type=str, help="Model save path.")
102 |
103 | return parser
104 |
105 |
106 | def create_validation_arg_parser():
107 |
108 | parser = argparse.ArgumentParser(description="train setup for segmentation")
109 | parser.add_argument(
110 | "--model_type",
111 | type=str,
112 | help="select model type: bsinet",
113 | )
114 | parser.add_argument("--test_path", type=str, help="path to img tif files")
115 | parser.add_argument("--model_file", type=str, help="model_file")
116 | parser.add_argument("--save_path", type=str, help="results save path.")
117 | parser.add_argument("--cuda_no", type=int, default=0, help="cuda number")
118 |
119 | return parser
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
--------------------------------------------------------------------------------