├── Data_Enhance.py ├── README.md ├── data_remove.py ├── dataset.py ├── gdalUtil.py ├── imgs ├── FHAPD.jpg └── HBGNet.jpg ├── lib └── pvtv2.py ├── losses.py ├── models.py ├── test.py ├── train.py └── utils.py /Data_Enhance.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import os 3 | import glob 4 | from PIL import ImageEnhance 5 | import sys 6 | 7 | img_path = r'D:\LJ2\SBA2\SD_train\926_new_train\2' #输入和输出影像所在文件夹 8 | 9 | def get_image_paths(folder): 10 | return glob.glob(os.path.join(folder, '*.tif')) 11 | 12 | 13 | def create_read_img(filename): 14 | # 读取图像 15 | im = Image.open(filename) 16 | 17 | out_h = im.transpose(Image.FLIP_LEFT_RIGHT) #水平翻转 18 | out_w = im.transpose(Image.FLIP_TOP_BOTTOM) #垂直翻转 19 | out_90 = im.transpose(Image.ROTATE_90) #顺时针选择90度 20 | # out_180 = im.transpose(Image.ROTATE_180) 21 | # out_270 = im.transpose(Image.ROTATE_270) 22 | 23 | # 亮度增强 24 | enh_bri = ImageEnhance.Brightness(im) 25 | brightness = 1.5 26 | image_brightened = enh_bri.enhance(brightness) 27 | image_brightened.save(filename[:-4] + '_brighter.tif') 28 | # 29 | # # 色度增强 30 | # enh_col = ImageEnhance.Color(im) 31 | # color = 1.5 32 | # image_colored = enh_col.enhance(color) 33 | # image_colored.save(filename[:-4] + '_color.tif') 34 | # 35 | # # 对比度增强 36 | enh_con = ImageEnhance.Contrast(im) 37 | contrast = 1.5 38 | image_contrasted = enh_con.enhance(contrast) 39 | image_contrasted.save(filename[:-4] + '_contrast.tif') 40 | 41 | # 锐度增强 42 | # enh_sha = ImageEnhance.Sharpness(im) 43 | # sharpness = 3.0 44 | # image_sharped = enh_sha.enhance(sharpness) 45 | # image_sharped.save(filename[:-4] + '_sharp.tif') 46 | 47 | # 48 | out_h.save(filename[:-4] + '_h.tif') 49 | out_w.save(filename[:-4] + '_w.tif') 50 | out_90.save(filename[:-4] + '_90.tif') 51 | # out_180.save(filename[:-4] + '_180.tif') 52 | # out_270.save(filename[:-4] + '_270.tif') 53 | 54 | #print(filename) 55 | imgs = get_image_paths(img_path) 56 | for i in imgs: 57 | create_read_img(i) 58 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # HBGNet 2 | 3 | A large-scale VHR parcel dataset and a novel hierarchical semantic boundary-guided network for agricultural parcel delineation ((https://www.sciencedirect.com/science/article/pii/S0924271625000395) 4 | 5 | [Project](https://github.com/NanNanmei/HBGNet) 6 | 7 | ## Introduction 8 | 9 | We develop a hierarchical semantic boundary-guided network (HBGNet) to fully leverage boundary semantics, thereby improving AP delineation. It integrates two branches, a core branch of AP feature extraction and an auxiliary branch related to boundary feature mining. Specifically, the boundary extract branch employes a module based on Laplace convolution operator to enhance the model’s awareness of parcel boundary. 10 | 11 |

12 | 13 |

14 | 15 |

16 | 17 |

18 | 19 | ## Using the code: 20 | 21 | The code is stable while using Python 3.9.0, CUDA >=11.4 22 | 23 | - Clone this repository: 24 | ```bash 25 | git clone https://github.com/NanNanmei/HBGNet.git 26 | cd HBGNet 27 | ``` 28 | 29 | To install all the dependencies using conda or pip: 30 | 31 | ``` 32 | PyTorch 33 | OpenCV 34 | numpy 35 | tqdm 36 | timm 37 | ... 38 | ``` 39 | 40 | ## Preprocessing 41 | You can use the https://github.com/long123524/BsiNet-torch/blob/main/preprocess.py to obtain contour and distance maps. 42 | 43 | ## Data Format 44 | 45 | Make sure to put the files as the following structure: 46 | 47 | ``` 48 | inputs 49 | └── 50 | ├── train_image 51 | | ├── 001.tif 52 | │ ├── 002.tif 53 | │ ├── 003.tif 54 | │ ├── ... 55 | | 56 | └── train_mask 57 | | ├── 001.tif 58 | | ├── 002.tif 59 | | ├── 003.tif 60 | | ├── ... 61 | └── train_boundary 62 | | ├── 001.tif 63 | | ├── 002.tif 64 | | ├── 003.tif 65 | | ├── ... 66 | └── train_dist 67 | | ├── 001.tif 68 | | ├── 002.tif 69 | | ├── 003.tif 70 | └── ├── ... 71 | ``` 72 | 73 | For test datasets, the same structure as the above. 74 | 75 | ## Pretrained weight 76 | 77 | The weight of PVT-V2 pretrained on ImageNet dataset can be downloaded from: https://drive.google.com/file/d/1uzeVfA4gEQ772vzLntnkqvWePSw84F6y/view?usp=sharing 78 | 79 | ### A large-scale VHR parcel dataset 80 | The beautiful vision for FHAPD is that it will be a continuously updated dataset for different agricultural landscapes in China. At present, I have carried out many AP delineation works in different regions of China using HBGNet, such as Shanghai, Inner Mongolia, Guangdong, Shandong and Gansu, and will update it immediately after the official publication of this article. 81 | I also built an APs dataset for the African and tried to add it to the FHAPD as a personality region. 82 | 83 | Link: https://pan.baidu.com/s/1OS7G0H27zGexxRfqTcKizw?pwd=8die code: 8die 84 | 85 | If you have any problem, please email to zhaohang201@mails.ucas.ac.cn. 86 | 87 | ## Training and testing 88 | 89 | 1. Train the model. 90 | ``` 91 | python train.py 92 | ``` 93 | 2. Test the model. 94 | ``` 95 | python test.py 96 | ``` 97 | ### Citation: 98 | If you find this work useful or interesting, please consider citing the following references. 99 | ``` 100 | Citation 1: 101 | @article{zhao2025, 102 | title={A large-scale VHR parcel dataset and a novel hierarchical semantic boundary-guided network for agricultural parcel delineation}, 103 | author={Zhao, Hang and Wu, Bingfang and Zhang, Miao and Long, Jiang and Tian, Fuyou and Xie, Yan and Zeng, Hongwei and Zheng, Zhaoju and Ma, Zonghan and Wang, Mingxing and others}, 104 | journal={ISPRS Journal of Photogrammetry and Remote Sensing}, 105 | volume={221}, 106 | pages={1--19}, 107 | year={2025}, 108 | publisher={Elsevier} 109 | } 110 | Citation 2: 111 | @article{zhao2024, 112 | title={Irregular agricultural field delineation using a dual-branch architecture from high-resolution remote sensing images}, 113 | author={Zhao, Hang and Long, Jiang and Zhang, Miao and Wu, Bingfang and Xu, Chenxi and Tian, Fuyou and Ma, Zonghan}, 114 | journal={IEEE Geoscience and Remote Sensing Letters}, 115 | volume={21}, 116 | pages={1--5}, 117 | year={2024}, 118 | publisher={IEEE} 119 | } 120 | Citation 3: 121 | @article{long2024, 122 | title={Integrating Segment Anything Model derived boundary prior and high-level semantics for cropland extraction from high-resolution remote sensing images}, 123 | author={Long, Jiang and Zhao, Hang and Li, Mengmeng and Wang, Xiaoqin and Lu, Chengwen}, 124 | journal={IEEE Geoscience and Remote Sensing Letters}, 125 | volume={21}, 126 | pages={1--5}, 127 | year={2024}, 128 | publisher={IEEE} 129 | } 130 | ``` 131 | 132 | ### Acknowledgement 133 | We are very grateful for these excellent works [BsiNet](https://github.com/long123524/BsiNet-torch), [SEANet](https://github.com/long123524/SEANet_torch), and [HGINet](https://github.com/long123524/HGINet-torch), which have provided the basis for our framework. 134 | -------------------------------------------------------------------------------- /data_remove.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | import numpy as np 4 | import sys 5 | import glob 6 | 7 | imaPath = r'E:\Change_detection\HRSCD\images_2012\2012\D14' #输入需要处理的影像图片所在的文件夹D35 8 | labelPath = r'E:\Change_detection\HRSCD\labels_land_cover_2006\2006\caijian' #输入需要处理的标签图片所在的文件夹D35 9 | s_g = 0.3 #比例因子o-1,小于这个比例就删掉 10 | 11 | # image1 = [] 12 | # imageList1 = glob.glob(os.path.join(imaPath, '*.tif')) 13 | # for item in imageList1: 14 | # image1.append(os.path.basename(item)) 15 | # image2 = [] 16 | # imageList2 = glob.glob(os.path.join(labelPath, '*.tif')) 17 | 18 | # 19 | # for item in imageList2: 20 | # image2.append(os.path.basename(item)) 21 | 22 | lablist= os.listdir(labelPath) 23 | # imaList = os.listdir(imaPath) 24 | for labels in lablist: 25 | 26 | label_path = os.path.join(labelPath,labels) 27 | 28 | img = cv2.imread(label_path, 0) 29 | s=img.size 30 | 31 | # 先利用二值化去除图片噪声 32 | ret, img = cv2.threshold(img, 0.5, 255, cv2.THRESH_BINARY) 33 | area = 0 34 | height, width = img.shape 35 | yuzhi = s_g #比例大小 36 | yuzhi = s*yuzhi 37 | yuzhi = np.array(yuzhi, dtype='uint64') # 转变为8字节型 38 | for i in range(height): 39 | for j in range(width): 40 | if img[i, j] == 255: 41 | area += 1 42 | if area <= yuzhi: 43 | os.remove(labelPath + r"\\" + labels) 44 | else: 45 | print('') 46 | 47 | image1 = [] 48 | imageList1 = glob.glob(os.path.join(labelPath, '*.tif')) 49 | for item in imageList1: 50 | image1.append(os.path.basename(item)) 51 | 52 | image2 = [] 53 | imageList2 = glob.glob(os.path.join(imaPath, '*.tif')) 54 | for item in imageList2: 55 | image2.append(os.path.basename(item)) 56 | 57 | a = list(set(image2).difference(set(image1))) 58 | print(a) 59 | for x in a: 60 | os.remove(imaPath + r"\\" + x) -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | The role of this file completes the data reading 3 | "dist_mask" is obtained by using Euclidean distance transformation on the mask 4 | "dist_contour" is obtained by using quasi-Euclidean distance transformation on the mask 5 | """ 6 | 7 | import torch 8 | import numpy as np 9 | import cv2 10 | from PIL import Image, ImageFile 11 | 12 | from skimage import io 13 | import imageio 14 | from torch.utils.data import Dataset 15 | from torchvision import transforms 16 | from scipy import io 17 | import os 18 | from osgeo import gdal 19 | import tifffile as tiff 20 | ### Reading and saving of remote sensing images (Keep coordinate information) 21 | def readTif(fileName, xoff = 0, yoff = 0, data_width = 0, data_height = 0): 22 | dataset = gdal.Open(fileName) 23 | if dataset == None: 24 | print(fileName + "文件无法打开") 25 | # 栅格矩阵的列数 26 | width = dataset.RasterXSize 27 | # 栅格矩阵的行数 28 | height = dataset.RasterYSize 29 | # 波段数 30 | bands = dataset.RasterCount 31 | # 获取数据 32 | if(data_width == 0 and data_height == 0): 33 | data_width = width 34 | data_height = height 35 | data = dataset.ReadAsArray(xoff, yoff, data_width, data_height) 36 | # 获取仿射矩阵信息 37 | geotrans = dataset.GetGeoTransform() 38 | # 获取投影信息 39 | proj = dataset.GetProjection() 40 | return width, height, bands, data, geotrans, proj 41 | 42 | 43 | #保存遥感影像 44 | def writeTiff(im_data, im_geotrans, im_proj, path): 45 | if 'int8' in im_data.dtype.name: 46 | datatype = gdal.GDT_Byte 47 | elif 'int16' in im_data.dtype.name: 48 | datatype = gdal.GDT_UInt16 49 | else: 50 | datatype = gdal.GDT_Float32 51 | if len(im_data.shape) == 3: 52 | im_bands, im_height, im_width = im_data.shape 53 | else: 54 | im_bands, (im_height, im_width) = 1, im_data.shape 55 | # 创建文件 56 | driver = gdal.GetDriverByName("GTiff") 57 | dataset = driver.Create(path, int(im_width), int(im_height), int(im_bands), datatype) 58 | if (dataset != None): 59 | dataset.SetGeoTransform(im_geotrans) # 写入仿射变换参数 60 | dataset.SetProjection(im_proj) # 写入投影 61 | if im_bands == 1: 62 | dataset.GetRasterBand(1).WriteArray(im_data) 63 | else: 64 | for i in range(im_bands): 65 | dataset.GetRasterBand(i + 1).WriteArray(im_data[i]) 66 | del dataset 67 | 68 | 69 | ####### 70 | # class DatasetImageMaskContourDist(Dataset): 71 | # 72 | # def __init__(self, file_names): 73 | # 74 | # self.file_names = file_names 75 | # # self.distance_type = distance_type 76 | # # self.dir = dir 77 | # 78 | # def __len__(self): 79 | # 80 | # return len(self.file_names) 81 | # 82 | # def __getitem__(self, idx): 83 | # 84 | # img_file_name = self.file_names[idx] 85 | # image = load_image(img_file_name) 86 | # mask = load_mask(img_file_name) 87 | # contour = load_contour(img_file_name) 88 | # # dist = load_distance(os.path.join(self.dir,img_file_name+'.tif'), self.distance_type) 89 | # 90 | # return img_file_name, image, mask, contour 91 | 92 | ###训练的时候用这个 93 | class DatasetImageMaskContourDist(Dataset): 94 | 95 | def __init__(self, dir, file_names): 96 | 97 | self.file_names = file_names 98 | # self.distance_type = distance_type 99 | self.dir = dir 100 | 101 | def __len__(self): 102 | 103 | return len(self.file_names) 104 | 105 | def __getitem__(self, idx): 106 | 107 | img_file_name = self.file_names[idx] 108 | image = load_image(os.path.join(self.dir,img_file_name+'.tif')) 109 | mask = load_mask(os.path.join(self.dir,img_file_name+'.tif')) 110 | contour = load_contour(os.path.join(self.dir,img_file_name+'.tif')) 111 | # dist = load_distance(os.path.join(self.dir,img_file_name+'.tif'), self.distance_type) 112 | dist = load_distance(os.path.join(self.dir, img_file_name+'.tif')) 113 | 114 | return img_file_name, image, mask, contour, dist 115 | 116 | 117 | 118 | 119 | def load_image(path): 120 | 121 | img = Image.open(path) 122 | data_transforms = transforms.Compose( 123 | [ 124 | # transforms.Resize(256), 125 | transforms.ToTensor(), 126 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 127 | 128 | ] 129 | ) 130 | img = data_transforms(img) 131 | 132 | return img 133 | # 134 | # 135 | # def load_mask(path): 136 | # # mask = cv2.imread(path.replace("image", "mask").replace("tif", "tif"), 0) 137 | # mask = cv2.imread(path.replace("train_image", "train_mask").replace("tif", "tif"), 0) 138 | # # im_width, im_height, im_bands, mask, im_geotrans, im_proj = readTif(path.replace("image", "mask").replace("tif", "tif")) 139 | # mask = mask/255. 140 | # # mask[mask == 255] = 1 141 | # # mask[mask == 0] = 0 142 | # 143 | # return torch.from_numpy(np.expand_dims(mask, 0)).float() 144 | 145 | # def load_image(path): 146 | # img = tiff.imread(path).transpose([2, 0, 1]) # array([3, 512, 512]) 147 | # img = img.astype('uint8') 148 | # img = (img - img.min())/(img.max()-img.min()) 149 | # 150 | # return torch.from_numpy(img).float() 151 | 152 | 153 | 154 | 155 | def load_mask(path): 156 | # mask = cv2.imread(path.replace("train_images", "train_labels").replace("tif", "tif"), 0) 157 | # im_width, im_height, im_bands, mask, im_geotrans, im_proj = readTif(path.replace("train_image", "train_mask").replace("image", "label")) 158 | im_width, im_height, im_bands, mask, im_geotrans, im_proj = readTif( 159 | path.replace("train_image", "train_mask")) 160 | mask = mask.astype('uint8') 161 | mask = mask/255. 162 | # mask = np.reshape(mask, mask.shape + (1,)) 163 | ###mask = mask/225. 164 | # mask[mask == 1] = 1 165 | # mask[mask == 0] = 0 166 | # print(mask.shape()) 167 | # return torch.from_numpy(mask).long() 168 | 169 | return torch.from_numpy(np.expand_dims(mask, 0)).float() 170 | 171 | 172 | def load_contour(path): 173 | 174 | # contour = cv2.imread(path.replace("train_image", "train_boundary").replace("tif", "tif"), 0) 175 | # contour = contour.astype('uint8') 176 | # im_width, im_height, im_bands, contour, im_geotrans, im_proj = readTif(path.replace("train_image", "train_boundary").replace("image", "label")) 177 | im_width, im_height, im_bands, contour, im_geotrans, im_proj = readTif( 178 | path.replace("train_image", "train_boundary")) 179 | contour = contour.astype('uint8') 180 | contour = contour/255. 181 | # contour[contour ==255] = 1 182 | # contour[contour == 0] = 0 183 | 184 | 185 | return torch.from_numpy(np.expand_dims(contour, 0)).long() 186 | 187 | def load_distance(path): 188 | # im_width, im_height, im_bands, dist, im_geotrans, im_proj = readTif(path.replace("train_image", "train_dist").replace("image", "label")) 189 | im_width, im_height, im_bands, dist, im_geotrans, im_proj = readTif( 190 | path.replace("train_image", "train_dist")) 191 | dist = dist.astype('uint8') 192 | dist = dist/255. 193 | # dist = np.reshape(dist, dist.shape+(1,)) 194 | return torch.from_numpy(np.expand_dims(dist, 0)).float() 195 | 196 | 197 | -------------------------------------------------------------------------------- /gdalUtil.py: -------------------------------------------------------------------------------- 1 | import os, sys, time 2 | import numpy as np 3 | from osgeo import ogr, gdal, gdalconst 4 | from osgeo import gdal_array as ga 5 | 6 | def del_file(path): 7 | for i in os.listdir(path): 8 | path_file = os.path.join(path, i) 9 | if os.path.isfile(path_file): 10 | os.remove(path_file) 11 | else: 12 | del_file(path_file) 13 | 14 | 15 | def stretch_n(bands, img_min, img_max, lower_percent=0, higher_percent=100): 16 | """ 17 | :param bands: 目标数据,numpy格式 18 | :param img_min: 目标位深的最小值,以8bit为例,最大值为255, 最小值为0 19 | :param img_max: 目标位深的最大值 20 | :return: 21 | """ 22 | out = np.zeros_like(bands).astype(np.float32) 23 | a = img_min 24 | b = img_max 25 | c = np.percentile(bands[:, :], lower_percent) 26 | d = np.percentile(bands[:, :], higher_percent) 27 | t = a + (bands[:, :] - c) * (b - a) / (d - c) 28 | t[t < a] = a 29 | t[t > b] = b 30 | out[:, :] = t 31 | return out 32 | 33 | 34 | def read_img(filename): 35 | dataset=gdal.Open(filename) 36 | 37 | im_width = dataset.RasterXSize 38 | im_height = dataset.RasterYSize 39 | 40 | im_geotrans = dataset.GetGeoTransform() 41 | im_proj = dataset.GetProjection() 42 | im_data = dataset.ReadAsArray(0,0,im_width,im_height) 43 | 44 | del dataset 45 | return im_proj, im_geotrans, im_width, im_height, im_data 46 | 47 | 48 | def write_img(filename, im_proj, im_geotrans, im_data): 49 | if 'int8' in im_data.dtype.name: 50 | datatype = gdal.GDT_Byte 51 | elif 'int16' in im_data.dtype.name: 52 | datatype = gdal.GDT_UInt16 53 | else: 54 | datatype = gdal.GDT_Float32 55 | 56 | if len(im_data.shape) == 3: 57 | im_bands, im_height, im_width = im_data.shape 58 | else: 59 | im_bands, (im_height, im_width) = 1,im_data.shape 60 | 61 | driver = gdal.GetDriverByName("GTiff") 62 | dataset = driver.Create(filename, im_width, im_height, im_bands, datatype, options=['COMPRESS=LZW']) 63 | 64 | dataset.SetGeoTransform(im_geotrans) 65 | dataset.SetProjection(im_proj) 66 | 67 | if im_bands == 1: 68 | dataset.GetRasterBand(1).WriteArray(im_data) 69 | else: 70 | for i in range(im_bands): 71 | dataset.GetRasterBand(i+1).WriteArray(im_data[i]) 72 | 73 | del dataset 74 | 75 | 76 | def image_resampling(source_file, target_file, scale=5.): 77 | """ 78 | image resampling 79 | :param source_file: the path of source file 80 | :param target_file: the path of target file 81 | :param scale: pixel scaling 82 | :return: None 83 | """ 84 | dataset = gdal.Open(source_file, gdalconst.GA_ReadOnly) 85 | band_count = dataset.RasterCount # 波段数 86 | 87 | if band_count == 0 or not scale > 0: 88 | print("参数异常") 89 | return 90 | 91 | cols = dataset.RasterXSize # 列数 92 | rows = dataset.RasterYSize # 行数 93 | cols = int(cols * scale) # 计算新的行列数 94 | rows = int(rows * scale) 95 | 96 | geotrans = list(dataset.GetGeoTransform()) 97 | print(dataset.GetGeoTransform()) 98 | print(geotrans) 99 | geotrans[1] = geotrans[1] / scale # 像元宽度变为原来的scale倍 100 | geotrans[5] = geotrans[5] / scale # 像元高度变为原来的scale倍 101 | print(geotrans) 102 | 103 | if os.path.exists(target_file) and os.path.isfile(target_file): # 如果已存在同名影像 104 | os.remove(target_file) # 则删除之 105 | 106 | band1 = dataset.GetRasterBand(1) 107 | data_type = band1.DataType 108 | target = dataset.GetDriver().Create(target_file, xsize=cols, ysize=rows, bands=band_count, 109 | eType=data_type) 110 | target.SetProjection(dataset.GetProjection()) # 设置投影坐标 111 | target.SetGeoTransform(geotrans) # 设置地理变换参数 112 | total = band_count + 1 113 | for index in range(1, total): 114 | # 读取波段数据 115 | print("正在写入" + str(index) + "波段") 116 | data = dataset.GetRasterBand(index).ReadAsArray(buf_xsize=cols, buf_ysize=rows) 117 | out_band = target.GetRasterBand(index) 118 | # out_band.SetNoDataValue(dataset.GetRasterBand(index).GetNoDataValue()) 119 | out_band.WriteArray(data) # 写入数据到新影像中 120 | out_band.FlushCache() 121 | out_band.ComputeBandStats(False) # 计算统计信息 122 | print("正在写入完成") 123 | del dataset 124 | 125 | 126 | def sample_clip(shp, tif, outputdir, sampletype, size, fieldName='cls', n=None): 127 | """ 128 | according to sampling point, generating image slices 129 | :param shp: the path of shape file 130 | :param tif: the path of image 131 | :param outputdir: the directory of output 132 | :param sampletype: line or polygon 133 | :param size: the size of images slices 134 | :param fieldName: the name of field 135 | :param n: the start number 136 | :return: 137 | """ 138 | time1 = time.clock() 139 | 140 | gdal.AllRegister() 141 | lc = gdal.Open(tif) 142 | im_width = lc.RasterXSize 143 | im_height = lc.RasterYSize 144 | im_geotrans = lc.GetGeoTransform() 145 | bandscount = lc.RasterCount 146 | im_proj = lc.GetProjection() 147 | print(im_width, im_height) 148 | gdal.AllRegister() 149 | gdal.SetConfigOption("gdal_FILENAME_IS_UTF8", "YES") 150 | 151 | driver = ogr.GetDriverByName('ESRI Shapefile') 152 | dsshp = driver.Open(shp, 0) 153 | if dsshp is None: 154 | print('Could not open ' + 'sites.shp') 155 | sys.exit(1) 156 | layer = dsshp.GetLayer() 157 | xValues = [] 158 | yValues = [] 159 | m = layer.GetFeatureCount() 160 | feature = layer.GetNextFeature() 161 | print("tif_bands:{0},samples_nums:{1},sample_type:{2},sample_size:{3}*{3}".format(bandscount, m, sampletype, 162 | int(size))) 163 | 164 | if n is not None: 165 | pass 166 | else: 167 | n = 1 168 | while feature: 169 | if n < 10: 170 | dirname = "0000000" + str(n) 171 | elif n >= 10 and n < 100: 172 | dirname = "000000" + str(n) 173 | elif n >= 100 and n > 1000: 174 | dirname = "00000" + str(n) 175 | else: 176 | dirname = "0000" + str(n) 177 | 178 | # print dirname 179 | dirpath = os.path.join(outputdir, dirname + "_V1") 180 | if not os.path.exists(dirpath): 181 | os.mkdir(dirpath) 182 | tifname = dirname + ".tif" 183 | if "poly" in sampletype or "POLY" in sampletype: 184 | shpname = dirname + "_V1_POLY.shp" 185 | if "line" in sampletype or "LINE" in sampletype: 186 | shpname = dirname + "_V1_LINE.shp" 187 | geometry = feature.GetGeometryRef() 188 | x = geometry.GetX() 189 | y = geometry.GetY() 190 | print(x, y) 191 | print(im_geotrans) 192 | xValues.append(x) 193 | yValues.append(y) 194 | newform = [] 195 | newform = list(im_geotrans) 196 | # print newform 197 | newform[0] = x - im_geotrans[1] * int(size) / 2.0 198 | newform[3] = y - im_geotrans[5] * int(size) / 2.0 199 | print(newform[0], newform[3]) 200 | newformtuple = tuple(newform) 201 | x1 = x - int(size) / 2 * im_geotrans[1] 202 | y1 = y - int(size) / 2 * im_geotrans[5] 203 | x2 = x + int(size) / 2 * im_geotrans[1] 204 | y2 = y - int(size) / 2 * im_geotrans[5] 205 | x3 = x - int(size) / 2 * im_geotrans[1] 206 | y3 = y + int(size) / 2 * im_geotrans[5] 207 | x4 = x + int(size) / 2 * im_geotrans[1] 208 | y4 = y + int(size) / 2 * im_geotrans[5] 209 | Xpix = (x1 - im_geotrans[0]) / im_geotrans[1] 210 | # Xpix=(newform[0]-im_geotrans[0]) 211 | 212 | Ypix = (newform[3] - im_geotrans[3]) / im_geotrans[5] 213 | # Ypix=abs(newform[3]-im_geotrans[3]) 214 | print("#################") 215 | print(Xpix, Ypix) 216 | 217 | # **************create tif********************** 218 | # print"start creating {0}".format(tifname) 219 | pBuf = None 220 | pBuf = lc.ReadAsArray(int(Xpix), int(Ypix), int(size), int(size)) 221 | # print pBuf.dtype.name 222 | driver = gdal.GetDriverByName("GTiff") 223 | create_option = [] 224 | if 'int8' in pBuf.dtype.name: 225 | datatype = gdal.GDT_Byte 226 | elif 'int16' in pBuf.dtype.name: 227 | datatype = gdal.GDT_UInt16 228 | else: 229 | datatype = gdal.GDT_Float32 230 | outtif = os.path.join(dirpath, tifname) 231 | ds = driver.Create(outtif, int(size), int(size), int(bandscount), datatype, options=create_option) 232 | if ds == None: 233 | print("2222") 234 | ds.SetProjection(im_proj) 235 | ds.SetGeoTransform(newformtuple) 236 | ds.FlushCache() 237 | if bandscount > 1: 238 | for i in range(int(bandscount)): 239 | outBand = ds.GetRasterBand(i + 1) 240 | outBand.WriteArray(pBuf[i]) 241 | else: 242 | outBand = ds.GetRasterBand(1) 243 | outBand.WriteArray(pBuf) 244 | ds.FlushCache() 245 | # print "creating {0} successfully".format(tifname) 246 | # **************create shp********************** 247 | # print"start creating shps" 248 | gdal.SetConfigOption("GDAL_FILENAME_IS_UTF8", "NO") 249 | gdal.SetConfigOption("SHAPE_ENCODING", "") 250 | strVectorFile = os.path.join(dirpath, shpname) 251 | ogr.RegisterAll() 252 | driver = ogr.GetDriverByName('ESRI Shapefile') 253 | ds = driver.Open(shp) 254 | layer0 = ds.GetLayerByIndex(0) 255 | prosrs = layer0.GetSpatialRef() 256 | # geosrs = osr.SpatialReference() 257 | 258 | oDriver = ogr.GetDriverByName("ESRI Shapefile") 259 | if oDriver == None: 260 | print("1") 261 | return 262 | 263 | oDS = oDriver.CreateDataSource(strVectorFile) 264 | if oDS == None: 265 | print("2") 266 | return 267 | 268 | papszLCO = [] 269 | if "line" in sampletype or "LINE" in sampletype: 270 | oLayer = oDS.CreateLayer("TestPolygon", prosrs, ogr.wkbLineString, papszLCO) 271 | if "poly" in sampletype or "POLY" in sampletype: 272 | oLayer = oDS.CreateLayer("TestPolygon", prosrs, ogr.wkbPolygon, papszLCO) 273 | if oLayer == None: 274 | print("3") 275 | return 276 | 277 | oFieldName = ogr.FieldDefn(fieldName, ogr.OFTString) 278 | oFieldName.SetWidth(50) 279 | oLayer.CreateField(oFieldName, 1) 280 | oDefn = oLayer.GetLayerDefn() 281 | oFeatureRectangle = ogr.Feature(oDefn) 282 | 283 | geomRectangle = ogr.CreateGeometryFromWkt( 284 | "POLYGON (({0} {1},{2} {3},{4} {5},{6} {7},{0} {1}))".format(x1, y1, x2, y2, x4, y4, x3, y3)) 285 | oFeatureRectangle.SetGeometry(geomRectangle) 286 | oLayer.CreateFeature(oFeatureRectangle) 287 | print("{0} ok".format(dirname)) 288 | n = n + 1 289 | feature = layer.GetNextFeature() 290 | time2 = time.clock() 291 | print('Process Running time: %s min' % ((time2 - time1) / 60)) 292 | 293 | return n -------------------------------------------------------------------------------- /imgs/FHAPD.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NanNanmei/HBGNet/b3480a4b95967d62a1a969f1e27b24fdf0a53126/imgs/FHAPD.jpg -------------------------------------------------------------------------------- /imgs/HBGNet.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NanNanmei/HBGNet/b3480a4b95967d62a1a969f1e27b24fdf0a53126/imgs/HBGNet.jpg -------------------------------------------------------------------------------- /lib/pvtv2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from functools import partial 5 | 6 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 7 | from timm.models.registry import register_model 8 | from timm.models.vision_transformer import _cfg 9 | from timm.models.registry import register_model 10 | 11 | import math 12 | 13 | 14 | class Mlp(nn.Module): 15 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 16 | super().__init__() 17 | out_features = out_features or in_features 18 | hidden_features = hidden_features or in_features 19 | self.fc1 = nn.Linear(in_features, hidden_features) 20 | self.dwconv = DWConv(hidden_features) 21 | self.act = act_layer() 22 | self.fc2 = nn.Linear(hidden_features, out_features) 23 | self.drop = nn.Dropout(drop) 24 | 25 | self.apply(self._init_weights) 26 | 27 | def _init_weights(self, m): 28 | if isinstance(m, nn.Linear): 29 | trunc_normal_(m.weight, std=.02) 30 | if isinstance(m, nn.Linear) and m.bias is not None: 31 | nn.init.constant_(m.bias, 0) 32 | elif isinstance(m, nn.LayerNorm): 33 | nn.init.constant_(m.bias, 0) 34 | nn.init.constant_(m.weight, 1.0) 35 | elif isinstance(m, nn.Conv2d): 36 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 37 | fan_out //= m.groups 38 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 39 | if m.bias is not None: 40 | m.bias.data.zero_() 41 | 42 | def forward(self, x, H, W): 43 | x = self.fc1(x) 44 | x = self.dwconv(x, H, W) 45 | x = self.act(x) 46 | x = self.drop(x) 47 | x = self.fc2(x) 48 | x = self.drop(x) 49 | return x 50 | 51 | 52 | class Attention(nn.Module): 53 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): 54 | super().__init__() 55 | assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." 56 | 57 | self.dim = dim 58 | self.num_heads = num_heads 59 | head_dim = dim // num_heads 60 | self.scale = qk_scale or head_dim ** -0.5 61 | 62 | self.q = nn.Linear(dim, dim, bias=qkv_bias) 63 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) 64 | self.attn_drop = nn.Dropout(attn_drop) 65 | self.proj = nn.Linear(dim, dim) 66 | self.proj_drop = nn.Dropout(proj_drop) 67 | 68 | self.sr_ratio = sr_ratio 69 | if sr_ratio > 1: 70 | self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) 71 | self.norm = nn.LayerNorm(dim) 72 | 73 | self.apply(self._init_weights) 74 | 75 | def _init_weights(self, m): 76 | if isinstance(m, nn.Linear): 77 | trunc_normal_(m.weight, std=.02) 78 | if isinstance(m, nn.Linear) and m.bias is not None: 79 | nn.init.constant_(m.bias, 0) 80 | elif isinstance(m, nn.LayerNorm): 81 | nn.init.constant_(m.bias, 0) 82 | nn.init.constant_(m.weight, 1.0) 83 | elif isinstance(m, nn.Conv2d): 84 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 85 | fan_out //= m.groups 86 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 87 | if m.bias is not None: 88 | m.bias.data.zero_() 89 | 90 | def forward(self, x, H, W): 91 | B, N, C = x.shape 92 | q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 93 | 94 | if self.sr_ratio > 1: 95 | x_ = x.permute(0, 2, 1).reshape(B, C, H, W) 96 | x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) 97 | x_ = self.norm(x_) 98 | kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 99 | else: 100 | kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 101 | k, v = kv[0], kv[1] 102 | 103 | attn = (q @ k.transpose(-2, -1)) * self.scale 104 | attn = attn.softmax(dim=-1) 105 | attn = self.attn_drop(attn) 106 | 107 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 108 | x = self.proj(x) 109 | x = self.proj_drop(x) 110 | 111 | return x 112 | 113 | 114 | class Block(nn.Module): 115 | 116 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 117 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1): 118 | super().__init__() 119 | self.norm1 = norm_layer(dim) 120 | self.attn = Attention( 121 | dim, 122 | num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, 123 | attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio) 124 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 125 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 126 | self.norm2 = norm_layer(dim) 127 | mlp_hidden_dim = int(dim * mlp_ratio) 128 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 129 | 130 | self.apply(self._init_weights) 131 | 132 | def _init_weights(self, m): 133 | if isinstance(m, nn.Linear): 134 | trunc_normal_(m.weight, std=.02) 135 | if isinstance(m, nn.Linear) and m.bias is not None: 136 | nn.init.constant_(m.bias, 0) 137 | elif isinstance(m, nn.LayerNorm): 138 | nn.init.constant_(m.bias, 0) 139 | nn.init.constant_(m.weight, 1.0) 140 | elif isinstance(m, nn.Conv2d): 141 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 142 | fan_out //= m.groups 143 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 144 | if m.bias is not None: 145 | m.bias.data.zero_() 146 | 147 | def forward(self, x, H, W): 148 | x = x + self.drop_path(self.attn(self.norm1(x), H, W)) 149 | x = x + self.drop_path(self.mlp(self.norm2(x), H, W)) 150 | 151 | return x 152 | 153 | 154 | class OverlapPatchEmbed(nn.Module): 155 | """ Image to Patch Embedding 156 | """ 157 | 158 | def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768): 159 | super().__init__() 160 | img_size = to_2tuple(img_size) 161 | patch_size = to_2tuple(patch_size) 162 | 163 | self.img_size = img_size 164 | self.patch_size = patch_size 165 | self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] 166 | self.num_patches = self.H * self.W 167 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, 168 | padding=(patch_size[0] // 2, patch_size[1] // 2)) 169 | self.norm = nn.LayerNorm(embed_dim) 170 | 171 | self.apply(self._init_weights) 172 | 173 | def _init_weights(self, m): 174 | if isinstance(m, nn.Linear): 175 | trunc_normal_(m.weight, std=.02) 176 | if isinstance(m, nn.Linear) and m.bias is not None: 177 | nn.init.constant_(m.bias, 0) 178 | elif isinstance(m, nn.LayerNorm): 179 | nn.init.constant_(m.bias, 0) 180 | nn.init.constant_(m.weight, 1.0) 181 | elif isinstance(m, nn.Conv2d): 182 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 183 | fan_out //= m.groups 184 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 185 | if m.bias is not None: 186 | m.bias.data.zero_() 187 | 188 | def forward(self, x): 189 | x = self.proj(x) 190 | _, _, H, W = x.shape 191 | x = x.flatten(2).transpose(1, 2) 192 | x = self.norm(x) 193 | 194 | return x, H, W 195 | 196 | 197 | class PyramidVisionTransformerImpr(nn.Module): 198 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512], 199 | num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0., 200 | attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, 201 | depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1]): 202 | super().__init__() 203 | self.num_classes = num_classes 204 | self.depths = depths 205 | 206 | # patch_embed 207 | self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_chans=in_chans, 208 | embed_dim=embed_dims[0]) 209 | self.patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0], 210 | embed_dim=embed_dims[1]) 211 | self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1], 212 | embed_dim=embed_dims[2]) 213 | self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2], 214 | embed_dim=embed_dims[3]) 215 | 216 | # transformer encoder 217 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 218 | cur = 0 219 | self.block1 = nn.ModuleList([Block( 220 | dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale, 221 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, 222 | sr_ratio=sr_ratios[0]) 223 | for i in range(depths[0])]) 224 | self.norm1 = norm_layer(embed_dims[0]) 225 | 226 | cur += depths[0] 227 | self.block2 = nn.ModuleList([Block( 228 | dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale, 229 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, 230 | sr_ratio=sr_ratios[1]) 231 | for i in range(depths[1])]) 232 | self.norm2 = norm_layer(embed_dims[1]) 233 | 234 | cur += depths[1] 235 | self.block3 = nn.ModuleList([Block( 236 | dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale, 237 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, 238 | sr_ratio=sr_ratios[2]) 239 | for i in range(depths[2])]) 240 | self.norm3 = norm_layer(embed_dims[2]) 241 | 242 | cur += depths[2] 243 | self.block4 = nn.ModuleList([Block( 244 | dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale, 245 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, 246 | sr_ratio=sr_ratios[3]) 247 | for i in range(depths[3])]) 248 | self.norm4 = norm_layer(embed_dims[3]) 249 | 250 | # classification head 251 | # self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity() 252 | 253 | self.apply(self._init_weights) 254 | 255 | def _init_weights(self, m): 256 | if isinstance(m, nn.Linear): 257 | trunc_normal_(m.weight, std=.02) 258 | if isinstance(m, nn.Linear) and m.bias is not None: 259 | nn.init.constant_(m.bias, 0) 260 | elif isinstance(m, nn.LayerNorm): 261 | nn.init.constant_(m.bias, 0) 262 | nn.init.constant_(m.weight, 1.0) 263 | elif isinstance(m, nn.Conv2d): 264 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 265 | fan_out //= m.groups 266 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 267 | if m.bias is not None: 268 | m.bias.data.zero_() 269 | 270 | def init_weights(self, pretrained=None): 271 | if isinstance(pretrained, str): 272 | logger = 1 273 | #load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger) 274 | 275 | def reset_drop_path(self, drop_path_rate): 276 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))] 277 | cur = 0 278 | for i in range(self.depths[0]): 279 | self.block1[i].drop_path.drop_prob = dpr[cur + i] 280 | 281 | cur += self.depths[0] 282 | for i in range(self.depths[1]): 283 | self.block2[i].drop_path.drop_prob = dpr[cur + i] 284 | 285 | cur += self.depths[1] 286 | for i in range(self.depths[2]): 287 | self.block3[i].drop_path.drop_prob = dpr[cur + i] 288 | 289 | cur += self.depths[2] 290 | for i in range(self.depths[3]): 291 | self.block4[i].drop_path.drop_prob = dpr[cur + i] 292 | 293 | def freeze_patch_emb(self): 294 | self.patch_embed1.requires_grad = False 295 | 296 | @torch.jit.ignore 297 | def no_weight_decay(self): 298 | return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'} # has pos_embed may be better 299 | 300 | def get_classifier(self): 301 | return self.head 302 | 303 | def reset_classifier(self, num_classes, global_pool=''): 304 | self.num_classes = num_classes 305 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 306 | 307 | # def _get_pos_embed(self, pos_embed, patch_embed, H, W): 308 | # if H * W == self.patch_embed1.num_patches: 309 | # return pos_embed 310 | # else: 311 | # return F.interpolate( 312 | # pos_embed.reshape(1, patch_embed.H, patch_embed.W, -1).permute(0, 3, 1, 2), 313 | # size=(H, W), mode="bilinear").reshape(1, -1, H * W).permute(0, 2, 1) 314 | 315 | def forward_features(self, x): 316 | B = x.shape[0] 317 | outs = [] 318 | 319 | # stage 1 320 | x, H, W = self.patch_embed1(x) 321 | for i, blk in enumerate(self.block1): 322 | x = blk(x, H, W) 323 | x = self.norm1(x) 324 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 325 | outs.append(x) 326 | 327 | # stage 2 328 | x, H, W = self.patch_embed2(x) 329 | for i, blk in enumerate(self.block2): 330 | x = blk(x, H, W) 331 | x = self.norm2(x) 332 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 333 | outs.append(x) 334 | 335 | # stage 3 336 | x, H, W = self.patch_embed3(x) 337 | for i, blk in enumerate(self.block3): 338 | x = blk(x, H, W) 339 | x = self.norm3(x) 340 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 341 | outs.append(x) 342 | 343 | # stage 4 344 | x, H, W = self.patch_embed4(x) 345 | for i, blk in enumerate(self.block4): 346 | x = blk(x, H, W) 347 | x = self.norm4(x) 348 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 349 | outs.append(x) 350 | 351 | return outs 352 | 353 | # return x.mean(dim=1) 354 | 355 | def forward(self, x): 356 | x = self.forward_features(x) 357 | # x = self.head(x) 358 | 359 | return x 360 | 361 | 362 | class DWConv(nn.Module): 363 | def __init__(self, dim=768): 364 | super(DWConv, self).__init__() 365 | self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) 366 | 367 | def forward(self, x, H, W): 368 | B, N, C = x.shape 369 | x = x.transpose(1, 2).view(B, C, H, W) 370 | x = self.dwconv(x) 371 | x = x.flatten(2).transpose(1, 2) 372 | 373 | return x 374 | 375 | 376 | def _conv_filter(state_dict, patch_size=16): 377 | """ convert patch embedding weight from manual patchify + linear proj to conv""" 378 | out_dict = {} 379 | for k, v in state_dict.items(): 380 | if 'patch_embed.proj.weight' in k: 381 | v = v.reshape((v.shape[0], 3, patch_size, patch_size)) 382 | out_dict[k] = v 383 | 384 | return out_dict 385 | 386 | 387 | @register_model 388 | class pvt_v2_b0(PyramidVisionTransformerImpr): 389 | def __init__(self, **kwargs): 390 | super(pvt_v2_b0, self).__init__( 391 | patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], 392 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], 393 | drop_rate=0.0, drop_path_rate=0.1) 394 | 395 | 396 | 397 | @register_model 398 | class pvt_v2_b1(PyramidVisionTransformerImpr): 399 | def __init__(self, **kwargs): 400 | super(pvt_v2_b1, self).__init__( 401 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], 402 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], 403 | drop_rate=0.0, drop_path_rate=0.1) 404 | 405 | @register_model 406 | class pvt_v2_b2(PyramidVisionTransformerImpr): 407 | def __init__(self, **kwargs): 408 | super(pvt_v2_b2, self).__init__( 409 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], 410 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], 411 | drop_rate=0.0, drop_path_rate=0.1) 412 | 413 | @register_model 414 | class pvt_v2_b3(PyramidVisionTransformerImpr): 415 | def __init__(self, **kwargs): 416 | super(pvt_v2_b3, self).__init__( 417 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], 418 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1], 419 | drop_rate=0.0, drop_path_rate=0.1) 420 | 421 | @register_model 422 | class pvt_v2_b4(PyramidVisionTransformerImpr): 423 | def __init__(self, **kwargs): 424 | super(pvt_v2_b4, self).__init__( 425 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], 426 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1], 427 | drop_rate=0.0, drop_path_rate=0.1) 428 | 429 | 430 | @register_model 431 | class pvt_v2_b5(PyramidVisionTransformerImpr): 432 | def __init__(self, **kwargs): 433 | super(pvt_v2_b5, self).__init__( 434 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], 435 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1], 436 | drop_rate=0.0, drop_path_rate=0.1) -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | """Calculating the loss 2 | You can build the loss function of BsiNet by combining multiple losses 3 | """ 4 | 5 | import torch 6 | import numpy as np 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | def dice_loss(prediction, target): 12 | """Calculating the dice loss 13 | Args: 14 | prediction = predicted image 15 | target = Targeted image 16 | Output: 17 | dice_loss""" 18 | 19 | smooth = 1.0 20 | 21 | i_flat = prediction.view(-1) 22 | t_flat = target.view(-1) 23 | 24 | intersection = (i_flat * t_flat).sum() 25 | 26 | return 1 - ((2. * intersection + smooth) / (i_flat.sum() + t_flat.sum() + smooth)) 27 | 28 | 29 | def calc_loss(prediction, target, bce_weight=0.5): 30 | """Calculating the loss and metrics 31 | Args: 32 | prediction = predicted image 33 | target = Targeted image 34 | metrics = Metrics printed 35 | bce_weight = 0.5 (default) 36 | Output: 37 | loss : dice loss of the epoch """ 38 | bce = F.binary_cross_entropy_with_logits(prediction, target) 39 | prediction = torch.sigmoid(prediction) 40 | dice = dice_loss(prediction, target) 41 | 42 | loss = bce * bce_weight + dice * (1 - bce_weight) 43 | 44 | return loss 45 | 46 | 47 | 48 | class log_cosh_dice_loss(nn.Module): 49 | def __init__(self, num_classes=1, smooth=1, alpha=0.7): 50 | super(log_cosh_dice_loss, self).__init__() 51 | self.smooth = smooth 52 | self.alpha = alpha 53 | self.num_classes = num_classes 54 | 55 | def forward(self, outputs, targets): 56 | x = self.dice_loss(outputs, targets) 57 | return torch.log((torch.exp(x) + torch.exp(-x)) / 2.0) 58 | 59 | def dice_loss(self, y_pred, y_true): 60 | """[function to compute dice loss] 61 | Args: 62 | y_true ([float32]): [ground truth image] 63 | y_pred ([float32]): [predicted image] 64 | Returns: 65 | [float32]: [loss value] 66 | """ 67 | smooth = 1. 68 | y_true = torch.flatten(y_true) 69 | y_pred = torch.flatten(y_pred) 70 | intersection = torch.sum((y_true * y_pred)) 71 | coeff = (2. * intersection + smooth) / (torch.sum(y_true) + torch.sum(y_pred) + smooth) 72 | return (1. - coeff) 73 | 74 | 75 | def focal_loss(predict, label, alpha=0.6, beta=2): 76 | probs = torch.sigmoid(predict) 77 | # 交叉熵Loss 78 | ce_loss = nn.BCELoss() 79 | ce_loss = ce_loss(probs,label) 80 | alpha_ = torch.ones_like(predict) * alpha 81 | # 正label 为alpha, 负label为1-alpha 82 | alpha_ = torch.where(label > 0, alpha_, 1.0 - alpha_) 83 | probs_ = torch.where(label > 0, probs, 1.0 - probs) 84 | # loss weight matrix 85 | loss_matrix = alpha_ * torch.pow((1.0 - probs_), beta) 86 | # 最终loss 矩阵,为对应的权重与loss值相乘,控制预测越不准的产生更大的loss 87 | loss = loss_matrix * ce_loss 88 | loss = torch.sum(loss) 89 | return loss 90 | 91 | 92 | 93 | class Loss: 94 | def __init__(self, dice_weight=0.0, class_weights=None, num_classes=1, device=None): 95 | self.device = device 96 | if class_weights is not None: 97 | nll_weight = torch.from_numpy(class_weights.astype(np.float32)).to( 98 | self.device 99 | ) 100 | else: 101 | nll_weight = None 102 | self.nll_loss = nn.NLLLoss2d(weight=nll_weight) 103 | self.dice_weight = dice_weight 104 | self.num_classes = num_classes 105 | 106 | def __call__(self, outputs, targets): 107 | loss = self.nll_loss(outputs, targets) 108 | if self.dice_weight: 109 | eps = 1e-7 110 | cls_weight = self.dice_weight / self.num_classes 111 | for cls in range(self.num_classes): 112 | dice_target = (targets == cls).float() 113 | dice_output = outputs[:, cls].exp() 114 | intersection = (dice_output * dice_target).sum() 115 | # union without intersection 116 | uwi = dice_output.sum() + dice_target.sum() + eps 117 | loss += (1 - intersection / uwi) * cls_weight 118 | loss /= (1 + self.dice_weight) 119 | return loss 120 | 121 | 122 | class LossMulti: 123 | def __init__( 124 | self, jaccard_weight=0.0, class_weights=None, num_classes=1, device=None 125 | ): 126 | self.device = device 127 | if class_weights is not None: 128 | nll_weight = torch.from_numpy(class_weights.astype(np.float32)).to( 129 | self.device 130 | ) 131 | else: 132 | nll_weight = None 133 | 134 | self.nll_loss = nn.NLLLoss(weight=nll_weight) 135 | self.jaccard_weight = jaccard_weight 136 | self.num_classes = num_classes 137 | 138 | def __call__(self, outputs, targets): 139 | 140 | targets = targets.squeeze(1) 141 | 142 | loss = (1 - self.jaccard_weight) * self.nll_loss(outputs, targets) 143 | 144 | if self.jaccard_weight: 145 | eps = 1e-7 # 原先是1e-7 146 | for cls in range(self.num_classes): 147 | jaccard_target = (targets == cls).float() 148 | jaccard_output = outputs[:, cls].exp() 149 | intersection = (jaccard_output * jaccard_target).sum() 150 | 151 | union = jaccard_output.sum() + jaccard_target.sum() 152 | loss -= ( 153 | torch.log((intersection + eps) / (union - intersection + eps)) 154 | * self.jaccard_weight 155 | ) 156 | return loss 157 | 158 | 159 | 160 | class BCEDiceLoss(nn.Module): 161 | def __init__(self): 162 | super().__init__() 163 | 164 | def forward(self, input, target): 165 | bce = F.binary_cross_entropy_with_logits(input, target) 166 | smooth = 1e-5 167 | input = torch.sigmoid(input) 168 | num = target.size(0) 169 | input = input.view(num, -1) 170 | target = target.view(num, -1) 171 | intersection = (input * target) 172 | dice = (2. * intersection.sum(1) + smooth) / (input.sum(1) + target.sum(1) + smooth) 173 | dice = 1 - dice.sum() / num 174 | return 0.5 * bce + dice 175 | 176 | 177 | class LossF: 178 | def __init__(self, weights=[1, 1, 1]): 179 | self.criterion1 = BCEDiceLoss() #mask_loss BCE loss 参考SEANet 180 | self.criterion2 = LossMulti(num_classes=2) #contour_loss NLL 参考bsinet 181 | self.criterion3 = nn.MSELoss() #distance_loss MSE 参考bsinet 182 | self.weights = weights 183 | 184 | def __call__(self, outputs1, outputs2,outputs3, targets1, targets2, targets3): 185 | # 186 | criterion = ( 187 | self.weights[0] * self.criterion1(outputs1, targets1) 188 | + self.weights[1] * self.criterion2(outputs2, targets2) 189 | + self.weights[2] * self.criterion3(outputs3, targets3) 190 | ) 191 | 192 | return criterion 193 | 194 | class LossF_noEdgeTask: 195 | def __init__(self, weights=[1, 1]): 196 | self.criterion1 = BCEDiceLoss() #mask_loss BCE loss 参考SEANet 197 | # self.criterion2 = LossMulti(num_classes=2) #contour_loss NLL 参考bsinet 198 | self.criterion3 = nn.MSELoss() #distance_loss MSE 参考bsinet 199 | self.weights = weights 200 | 201 | def __call__(self, outputs1,outputs3, targets1, targets3): 202 | # 203 | criterion = ( 204 | self.weights[0] * self.criterion1(outputs1, targets1) 205 | + self.weights[1] * self.criterion3(outputs3, targets3) 206 | ) 207 | 208 | return criterion 209 | 210 | 211 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | """Model construction 2 | 1. We offer two versions of BsiNet, one concise and the other clear 3 | 2. The clear version is designed for user understanding and modification 4 | 3. You can use these attention mechanism we provide to bulid a new multi-task model, and you can also 5 | 4. You can also add your own module or change the location of the attention mechanism to build a better model 6 | 5........................................................ 7 | ..... Baseline + GLCA + BGM + MSF 8 | """ 9 | 10 | from torch.nn.parameter import Parameter 11 | # from timm.models.registry import register_model 12 | # import math 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | # from functools import partial 17 | # from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 18 | from lib.pvtv2 import pvt_v2_b2 19 | # from tensorboardX import SummaryWriter 20 | 21 | def conv3x3(in_, out): 22 | return nn.Conv2d(in_, out, 3, padding=1) 23 | 24 | 25 | class Conv3BN(nn.Module): 26 | def __init__(self, in_: int, out: int, bn=False): 27 | super().__init__() 28 | self.conv = conv3x3(in_, out) 29 | self.bn = nn.BatchNorm2d(out) if bn else None 30 | self.activation = nn.ReLU(inplace=True) 31 | 32 | def forward(self, x): 33 | x = self.conv(x) 34 | if self.bn is not None: 35 | x = self.bn(x) 36 | x = self.activation(x) 37 | return x 38 | 39 | 40 | class NetModule(nn.Module): 41 | def __init__(self, in_: int, out: int): 42 | super().__init__() 43 | self.l1 = Conv3BN(in_, out) 44 | self.l2 = Conv3BN(out, out) 45 | 46 | def forward(self, x): 47 | x = self.l1(x) 48 | x = self.l2(x) 49 | return x 50 | 51 | 52 | #SE注意力机制 53 | class SELayer(nn.Module): 54 | def __init__(self, channel, reduction=16): 55 | super(SELayer, self).__init__() 56 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 57 | self.fc = nn.Sequential( 58 | nn.Linear(channel, channel // reduction, bias=False), 59 | nn.ReLU(inplace=True), 60 | nn.Linear(channel // reduction, channel, bias=False), 61 | nn.Sigmoid() 62 | ) 63 | 64 | def forward(self, x): 65 | b, c, _, _ = x.size() 66 | y = self.avg_pool(x).view(b, c) 67 | y = self.fc(y).view(b, c, 1, 1) 68 | return x * y.expand_as(x) 69 | 70 | 71 | 72 | class SpatialGroupEnhance(nn.Module): 73 | def __init__(self, groups = 64): 74 | super(SpatialGroupEnhance, self).__init__() 75 | self.groups = groups 76 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 77 | self.weight = Parameter(torch.zeros(1, groups, 1, 1)) 78 | self.bias = Parameter(torch.ones(1, groups, 1, 1)) 79 | self.sig = nn.Sigmoid() 80 | 81 | def forward(self, x): # (b, c, h, w) 82 | b, c, h, w = x.size() 83 | x = x.view(b * self.groups, -1, h, w) 84 | xn = x * self.avg_pool(x) 85 | xn = xn.sum(dim=1, keepdim=True) 86 | t = xn.view(b * self.groups, -1) 87 | t = t - t.mean(dim=1, keepdim=True) 88 | std = t.std(dim=1, keepdim=True) + 1e-5 89 | t = t / std 90 | t = t.view(b, self.groups, h, w) 91 | t = t * self.weight + self.bias 92 | t = t.view(b * self.groups, 1, h, w) 93 | x = x * self.sig(t) 94 | x = x.view(b, c, h, w) 95 | return x 96 | 97 | 98 | 99 | #scce注意力模块 100 | class cSE(nn.Module): # noqa: N801 101 | """ 102 | The channel-wise SE (Squeeze and Excitation) block from the 103 | `Squeeze-and-Excitation Networks`__ paper. 104 | Adapted from 105 | https://www.kaggle.com/c/tgs-salt-identification-challenge/discussion/65939 106 | and 107 | https://www.kaggle.com/c/tgs-salt-identification-challenge/discussion/66178 108 | Shape: 109 | - Input: (batch, channels, height, width) 110 | - Output: (batch, channels, height, width) (same shape as input) 111 | __ https://arxiv.org/abs/1709.01507 112 | """ 113 | 114 | def __init__(self, in_channels: int, r: int = 16): 115 | """ 116 | Args: 117 | in_channels: The number of channels 118 | in the feature map of the input. 119 | r: The reduction ratio of the intermediate channels. 120 | Default: 16. 121 | """ 122 | super().__init__() 123 | self.linear1 = nn.Linear(in_channels, in_channels // r) 124 | self.linear2 = nn.Linear(in_channels // r, in_channels) 125 | 126 | def forward(self, x: torch.Tensor): 127 | """Forward call.""" 128 | input_x = x 129 | 130 | x = x.view(*(x.shape[:-2]), -1).mean(-1) 131 | x = F.relu(self.linear1(x), inplace=True) 132 | x = self.linear2(x) 133 | x = x.unsqueeze(-1).unsqueeze(-1) 134 | x = torch.sigmoid(x) 135 | 136 | x = torch.mul(input_x, x) 137 | return x 138 | 139 | 140 | class sSE(nn.Module): # noqa: N801 141 | """ 142 | The sSE (Channel Squeeze and Spatial Excitation) block from the 143 | `Concurrent Spatial and Channel ‘Squeeze & Excitation’ 144 | in Fully Convolutional Networks`__ paper. 145 | Adapted from 146 | https://www.kaggle.com/c/tgs-salt-identification-challenge/discussion/66178 147 | Shape: 148 | - Input: (batch, channels, height, width) 149 | - Output: (batch, channels, height, width) (same shape as input) 150 | __ https://arxiv.org/abs/1803.02579 151 | """ 152 | 153 | def __init__(self, in_channels: int): 154 | """ 155 | Args: 156 | in_channels: The number of channels 157 | in the feature map of the input. 158 | """ 159 | super().__init__() 160 | self.conv = nn.Conv2d(in_channels, 1, kernel_size=1, stride=1) 161 | 162 | def forward(self, x: torch.Tensor): 163 | """Forward call.""" 164 | input_x = x 165 | 166 | x = self.conv(x) 167 | x = torch.sigmoid(x) 168 | 169 | x = torch.mul(input_x, x) 170 | return x 171 | 172 | 173 | class scSE(nn.Module): # noqa: N801 174 | """ 175 | The scSE (Concurrent Spatial and Channel Squeeze and Channel Excitation) 176 | block from the `Concurrent Spatial and Channel ‘Squeeze & Excitation’ 177 | in Fully Convolutional Networks`__ paper. 178 | Adapted from 179 | https://www.kaggle.com/c/tgs-salt-identification-challenge/discussion/66178 180 | Shape: 181 | - Input: (batch, channels, height, width) 182 | - Output: (batch, channels, height, width) (same shape as input) 183 | __ https://arxiv.org/abs/1803.02579 184 | """ 185 | 186 | def __init__(self, in_channels: int, r: int = 16): 187 | """ 188 | Args: 189 | in_channels: The number of channels 190 | in the feature map of the input. 191 | r: The reduction ratio of the intermediate channels. 192 | Default: 16. 193 | """ 194 | super().__init__() 195 | self.cse_block = cSE(in_channels, r) 196 | self.sse_block = sSE(in_channels) 197 | 198 | def forward(self, x: torch.Tensor): 199 | """Forward call.""" 200 | cse = self.cse_block(x) 201 | sse = self.sse_block(x) 202 | x = torch.add(cse, sse) 203 | return x 204 | 205 | 206 | ##This is a concise version of the BsiNet whose modules are better packaged 207 | 208 | class BsiNet(nn.Module): 209 | 210 | output_downscaled = 1 211 | module = NetModule 212 | 213 | def __init__( 214 | self, 215 | input_channels: int = 3, 216 | filters_base: int = 32, 217 | down_filter_factors=(1, 2, 4, 8, 16), 218 | up_filter_factors=(1, 2, 4, 8, 16), 219 | bottom_s=4, 220 | num_classes=1, 221 | add_output=True, 222 | ): 223 | super().__init__() 224 | self.num_classes = num_classes 225 | assert len(down_filter_factors) == len(up_filter_factors) 226 | assert down_filter_factors[-1] == up_filter_factors[-1] 227 | down_filter_sizes = [filters_base * s for s in down_filter_factors] 228 | up_filter_sizes = [filters_base * s for s in up_filter_factors] 229 | self.down, self.up = nn.ModuleList(), nn.ModuleList() 230 | self.down.append(self.module(input_channels, down_filter_sizes[0])) 231 | for prev_i, nf in enumerate(down_filter_sizes[1:]): 232 | self.down.append(self.module(down_filter_sizes[prev_i], nf)) 233 | for prev_i, nf in enumerate(up_filter_sizes[1:]): 234 | self.up.append( 235 | self.module(down_filter_sizes[prev_i] + nf, up_filter_sizes[prev_i]) 236 | ) 237 | 238 | pool = nn.MaxPool2d(2, 2) 239 | pool_bottom = nn.MaxPool2d(bottom_s, bottom_s) 240 | upsample = nn.Upsample(scale_factor=2) 241 | upsample_bottom = nn.Upsample(scale_factor=bottom_s) 242 | self.downsamplers = [None] + [pool] * (len(self.down) - 1) 243 | self.downsamplers[-1] = pool_bottom 244 | self.upsamplers = [upsample] * len(self.up) 245 | self.upsamplers[-1] = upsample_bottom 246 | self.add_output = add_output 247 | self.sge = SpatialGroupEnhance(32) 248 | 249 | # self.ca1 = ChannelAttention() 250 | # self.sa1 = SpatialAttention() 251 | 252 | if add_output: 253 | self.conv_final1 = nn.Conv2d(up_filter_sizes[0], num_classes, 1) 254 | self.conv_final2 = nn.Conv2d(up_filter_sizes[0], num_classes, 1) 255 | self.conv_final3 = nn.Conv2d(up_filter_sizes[0], 1, 1) 256 | 257 | def forward(self, x): 258 | xs = [] 259 | for downsample, down in zip(self.downsamplers, self.down): 260 | x_in = x if downsample is None else downsample(xs[-1]) 261 | x_out = down(x_in) 262 | xs.append(x_out) 263 | 264 | for x_skip, upsample, up in reversed( 265 | list(zip(xs[:-1], self.upsamplers, self.up)) 266 | ): 267 | 268 | x_out2 = upsample(x_out) 269 | x_out= (torch.cat([x_out2, x_skip], 1)) 270 | x_out = up(x_out) 271 | 272 | if self.add_output: 273 | 274 | x_out = self.sge(x_out) 275 | 276 | x_out1 = self.conv_final1(x_out) 277 | x_out2 = self.conv_final2(x_out) 278 | x_out3 = self.conv_final3(x_out) 279 | if self.num_classes > 1: 280 | x_out1 = F.log_softmax(x_out1,dim=1) 281 | x_out2 = F.log_softmax(x_out2,dim=1) 282 | x_out3 = torch.sigmoid(x_out3) 283 | 284 | return [x_out1, x_out2, x_out3] 285 | 286 | 287 | 288 | 289 | class LaplaceConv2d(nn.Module): 290 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, bias=True): 291 | super(LaplaceConv2d, self).__init__() 292 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias) 293 | 294 | # Generate Laplace kernel 295 | laplace_kernel = torch.tensor([[1, 1, 1], [1, -8, 1], [1, 1, 1]], dtype=torch.float32) ##8领域 296 | laplace_kernel = laplace_kernel.unsqueeze(0).unsqueeze(0) 297 | laplace_kernel = laplace_kernel.repeat((out_channels, in_channels, 1, 1)) 298 | self.conv.weight = nn.Parameter(laplace_kernel) 299 | self.conv.bias.data.fill_(0) 300 | self.bn = nn.BatchNorm2d(out_channels) 301 | self.relu = nn.ReLU(inplace=True) 302 | 303 | def forward(self, x): 304 | x1 = self.conv(x) 305 | x1 = self.relu(self.bn(x1)) 306 | 307 | return x1 308 | 309 | 310 | ######CBAM注意力 311 | class CBAM(nn.Module): 312 | def __init__(self, channel, reduction=16, spatial_kernel=7): 313 | super(CBAM, self).__init__() 314 | # channel attention 压缩H,W为1 315 | self.max_pool = nn.AdaptiveMaxPool2d(1) 316 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 317 | # shared MLP 318 | self.mlp = nn.Sequential( 319 | nn.Conv2d(channel, channel // reduction, 1, bias=False), 320 | nn.ReLU(inplace=True), 321 | nn.Conv2d(channel // reduction, channel, 1, bias=False) 322 | ) 323 | # spatial attention 324 | self.conv = nn.Conv2d(2, 1, kernel_size=spatial_kernel, 325 | padding=spatial_kernel // 2, bias=False) 326 | self.sigmoid = nn.Sigmoid() 327 | def forward(self, x): 328 | max_out = self.mlp(self.max_pool(x)) 329 | avg_out = self.mlp(self.avg_pool(x)) 330 | channel_out = self.sigmoid(max_out + avg_out) 331 | x = channel_out * x 332 | max_out, _ = torch.max(x, dim=1, keepdim=True) 333 | avg_out = torch.mean(x, dim=1, keepdim=True) 334 | spatial_out = self.sigmoid(self.conv(torch.cat([max_out, avg_out], dim=1))) 335 | x = spatial_out * x 336 | return x 337 | 338 | 339 | 340 | ###边界增强模块/边界引导模块(boundary guided module,BGM) 341 | class Boundary_guided_module(nn.Module): 342 | def __init__(self, in_channel1,in_channel2,out_channel): 343 | super(Boundary_guided_module, self).__init__() 344 | self.sigmoid = nn.Sigmoid() 345 | self.conv1 = nn.Conv2d(in_channel1, out_channel, 1) ##1x1卷积用来降低通道数 346 | self.conv2 = nn.Conv2d(in_channel2, out_channel, 1) ##1x1卷积用来降低通道数 347 | self.max_pool = nn.AdaptiveMaxPool2d(1) 348 | 349 | def forward(self,edge,semantic): 350 | x = self.conv1(edge) 351 | x, _ = torch.max(x, dim=1, keepdim=True) 352 | x = self.sigmoid(x) 353 | x = x*self.conv2(semantic) 354 | x = x + self.conv2(semantic) 355 | return x 356 | 357 | 358 | class Long_distance(nn.Module): 359 | '''Spatial reasoning module''' 360 | 361 | # codes from DANet 'Dual attention network for scene segmentation' 362 | def __init__(self, in_dim): 363 | super(Long_distance, self).__init__() 364 | self.chanel_in = in_dim 365 | 366 | self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1) 367 | self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1) 368 | self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1) 369 | self.gamma = nn.Parameter(torch.zeros(1)) # nn.Parameter() 将这个张量转化为一个可以在模型训练过程中进行梯度更新的参数。在神经网络中,这样的参数通常用于权重或偏置项。 370 | 371 | self.softmax = nn.Softmax(dim=-1) 372 | 373 | def forward(self, x): 374 | ''' inputs : 375 | x : input feature maps( B X C X H X W) 376 | returns : 377 | out : attention value + input feature 378 | attention: B X (HxW) X (HxW) ''' 379 | m_batchsize, C, height, width = x.size() 380 | proj_query = self.query_conv(x).view(m_batchsize, -1, width * height).permute(0, 2, 1) 381 | proj_key = self.key_conv(x).view(m_batchsize, -1, width * height) 382 | energy = torch.bmm(proj_query, proj_key) 383 | attention = self.softmax(energy) 384 | proj_value = self.value_conv(x).view(m_batchsize, -1, width * height) 385 | 386 | out = torch.bmm(proj_value, attention.permute(0, 2, 1)) 387 | out = out.view(m_batchsize, C, height, width) 388 | out = x + self.gamma * out 389 | 390 | return out 391 | 392 | ##局部空间和全局空间上下文(或近和远距离上下文提取) 393 | class near_and_long(nn.Module): 394 | def __init__(self, in_channel,out_channel): 395 | super(near_and_long, self).__init__() 396 | self.long = Long_distance(in_channel) 397 | self.near = Residual(in_channel,out_channel) 398 | self.conv1 = nn.Conv2d(in_channel+out_channel,out_channel,1) 399 | 400 | def forward(self,x): 401 | x1 = self.long(x) 402 | x2 = self.near(x) 403 | fuse = torch.cat([x1,x2], 1) 404 | fuse = self.conv1(fuse) 405 | 406 | return fuse 407 | 408 | 409 | class multi_scale_fuseion(nn.Module): 410 | def __init__(self, in_channel,out_channel): 411 | super(multi_scale_fuseion, self).__init__() 412 | self.c1 = Conv(in_channel,out_channel, kernel_size=1, padding=0) 413 | self.c2 = Conv(in_channel,out_channel, kernel_size=3, padding=1) 414 | self.c3 = Conv(in_channel,out_channel, kernel_size=7, padding=3) 415 | self.c4 = Conv(in_channel,out_channel, kernel_size=11, padding=5) 416 | self.s1 = Conv(out_channel*4,out_channel, kernel_size=1, padding=0) 417 | self.attention = CBAM(out_channel) 418 | 419 | def forward(self,x): 420 | x1 = self.c1(x) 421 | x2 = self.c2(x) 422 | x3 = self.c3(x) 423 | x4 = self.c4(x) 424 | x5 = torch.cat([x1,x2,x3,x4], 1) 425 | x5 = self.s1(x5) 426 | x6 = self.attention(x5) 427 | 428 | return x6 429 | 430 | 431 | class Residual(nn.Module): 432 | def __init__(self, input_dim, output_dim, stride=1, padding=1): 433 | super(Residual, self).__init__() 434 | 435 | self.conv_block = nn.Sequential( 436 | nn.BatchNorm2d(input_dim), 437 | nn.ReLU(), 438 | nn.Conv2d( 439 | input_dim, output_dim, kernel_size=3, stride=stride, padding=padding 440 | ), 441 | nn.BatchNorm2d(output_dim), 442 | nn.ReLU(), 443 | nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1), 444 | ) 445 | self.conv_skip = nn.Sequential( 446 | nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride, padding=1), 447 | nn.BatchNorm2d(output_dim), 448 | ) 449 | 450 | def forward(self, x): 451 | 452 | return self.conv_block(x) + self.conv_skip(x) 453 | 454 | 455 | class Conv(nn.Module): 456 | def __init__(self, inp_dim, out_dim, kernel_size=3, stride=1, bn=True, padding=1,relu=True, bias=True): 457 | super(Conv, self).__init__() 458 | self.inp_dim = inp_dim 459 | self.conv = nn.Conv2d(inp_dim, out_dim, kernel_size, stride, padding, bias=bias) 460 | self.relu = relu 461 | self.bn = bn 462 | if relu: 463 | self.relu = nn.ReLU(inplace=True) 464 | if bn: 465 | self.bn = nn.BatchNorm2d(out_dim) 466 | 467 | def forward(self, x): 468 | assert x.size()[1] == self.inp_dim, "{} {}".format(x.size()[1], self.inp_dim) 469 | 470 | x = self.conv(x) 471 | if self.bn is not None: 472 | x = self.bn(x) 473 | if self.relu is not None: 474 | x = self.relu(x) 475 | return x 476 | 477 | 478 | 479 | 480 | class Field(nn.Module): 481 | def __init__(self, channel=32,num_classes=2,drop_rate=0.4): 482 | super(Field, self).__init__() 483 | 484 | self.drop = nn.Dropout2d(drop_rate) 485 | self.backbone = pvt_v2_b2() # [64, 128, 320, 512] 486 | path = r"E:\zhaohang\DeepL\code\model_New\preweight\pvt_v2_b2.pth" 487 | save_model = torch.load(path) 488 | model_dict = self.backbone.state_dict() 489 | state_dict = {k: v for k, v in save_model.items() if k in model_dict.keys()} 490 | model_dict.update(state_dict) 491 | self.backbone.load_state_dict(model_dict) 492 | self.channel = channel 493 | self.num_classes = num_classes 494 | 495 | self.edge_lap = LaplaceConv2d(in_channels=3,out_channels=1) 496 | self.conv1 = Residual(1,32) 497 | self.conv2 = nn.Conv2d(64, 16, 1) 498 | self.attention1 = CBAM(32) 499 | self.attention2 = CBAM(64) 500 | self.fuse1 = near_and_long(512,256) 501 | self.fuse2 = near_and_long(320, 128) 502 | self.fuse3 = near_and_long(128, 64) 503 | self.fuse4 = near_and_long(64, 32) 504 | ### 505 | self.up1 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True) 506 | self.up2 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True) 507 | self.up3 = nn.Upsample(scale_factor=16, mode='bilinear', align_corners=True) 508 | self.up4 = nn.Upsample(scale_factor=32, mode='bilinear', align_corners=True) 509 | 510 | ## 511 | self.boundary1 = Boundary_guided_module(64,64,32) 512 | self.boundary2 = Boundary_guided_module(64,64, 16) 513 | self.boundary3 = Boundary_guided_module(64,128, 16) 514 | self.boundary4 = Boundary_guided_module(64, 256, 16) 515 | ## 516 | self.multi_fusion = multi_scale_fuseion(64,64) 517 | self.sigmoid = nn.Sigmoid() 518 | self.out_feature = nn.Conv2d(64,1,1) 519 | self.edge_feature = nn.Conv2d(64, num_classes, 1) 520 | #输出距离图 521 | self.dis_feature = nn.Conv2d(64,1,1) 522 | 523 | def forward(self, x): 524 | #---------------------------------------------------------------------# 525 | # 526 | 527 | edge = self.edge_lap(x) 528 | edge = self.conv1(edge) 529 | 530 | pvt = self.backbone(x) 531 | x1 = pvt[0] ##(b,64,64,64) 532 | x1 = self.drop(x1) 533 | x2 = pvt[1] ##(b,128,32,32) 534 | x2 = self.drop(x2) 535 | x3 = pvt[2] ##(b,320,16,16) 536 | x3 = self.drop(x3) 537 | x4 = pvt[3] ##(b,512,32,32) 538 | x4 = self.drop(x4) 539 | ##global and local context aggregation 540 | x1 = self.fuse4(x1) #32 541 | x2 = self.fuse3(x2) #64 542 | x3 = self.fuse2(x3) #128 543 | x4 = self.fuse1(x4) #256 544 | ### boundary guided module 545 | x1 = self.up1(x1) 546 | edge = torch.cat([edge,x1],1) #64 547 | edge = self.attention2(edge) 548 | edge1 = self.conv2(edge) 549 | # bs1 = self.boundary1(edge,x1) 550 | x2 = self.up2(x2) 551 | bs2 = self.boundary2(edge, x2) 552 | x3 = self.up3(x3) 553 | bs3 = self.boundary3(edge, x3) 554 | x4 = self.up4(x4) 555 | bs4 = self.boundary4(edge, x4) 556 | ###multi-scale feature fusion module 557 | ms = torch.cat([edge1,bs2,bs3,bs4],1) 558 | out = self.multi_fusion(ms) 559 | 560 | edge_out = self.edge_feature(edge) 561 | edge_out = F.log_softmax(edge_out, dim=1) 562 | mask_out = self.out_feature(out) 563 | # mask_out = F.log_softmax(mask_out,dim=1) 564 | #输出dist_out 565 | dis_out = self.dis_feature(out) 566 | 567 | 568 | return [mask_out,edge_out,dis_out] 569 | 570 | if __name__ == "__main__": 571 | tensor = torch.randn((8, 3, 512, 512)) 572 | net = Field() 573 | # 打印模型每一层的名字 574 | # for name, module in net.named_modules(): 575 | # print(f'Layer Name: {name}') 576 | outputs = net(tensor) 577 | for output in outputs: 578 | print(output.size()) 579 | 580 | 581 | # with SummaryWriter(logdir="network") as w: 582 | # w.add_graph(net, tensor) 583 | # 584 | # w.close() 585 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from torch.utils.data import DataLoader 4 | from dataset import DatasetImageMaskContourDist 5 | import glob 6 | from models import Field 7 | from tqdm import tqdm 8 | import numpy as np 9 | import cv2 10 | from utils import create_validation_arg_parser 11 | from torch import nn 12 | 13 | def build_model(model_type): 14 | 15 | if model_type == "field": 16 | model = Field(num_classes=2) 17 | 18 | return model 19 | 20 | 21 | if __name__ == "__main__": 22 | args = create_validation_arg_parser().parse_args() 23 | args.model_file = r"E:\zhaohang\DeepL\模型训练pt文件_lmx机子\JS_zh\80.pt" 24 | args.save_path = r"D:\DeepL\data\paper\JS\test\pre_mask1" 25 | args.model_type = 'field' 26 | args.test_path = r"D:\DeepL\data\paper\JS\test\train_image" 27 | 28 | # args = create_validation_arg_parser().parse_args() 29 | # args.model_file = r"D:\DeepL\data\paper\JS\ablation_study\model_F1\model_pt\50.pt" 30 | # args.save_path = r"D:\DeepL\data\paper\JS\ablation_study\model_F1\pre_mask1" 31 | # args.model_type = 'field' 32 | # args.test_path = r"D:\DeepL\data\paper\JS\ablation_study\model_F1\train_image" 33 | 34 | 35 | test_path = os.path.join(args.test_path, "*.tif") 36 | model_file = args.model_file 37 | save_path = args.save_path 38 | model_type = args.model_type 39 | 40 | cuda_no = args.cuda_no 41 | CUDA_SELECT = "cuda:{}".format(cuda_no) 42 | device = torch.device(CUDA_SELECT if torch.cuda.is_available() else "cpu") 43 | 44 | test_file_names = glob.glob(test_path) 45 | # print(test_file_names) 46 | # valLoader = DataLoader(DatasetImageMaskContourDist(test_file_names)) 47 | test_file_names = [filePath.split('.')[0] for filePath in test_file_names] 48 | valLoader = DataLoader(DatasetImageMaskContourDist(args.test_path, test_file_names)) 49 | # valLoader = DataLoader(DatasetImageMaskContourDist(test_file_names)) 50 | 51 | if not os.path.exists(save_path): 52 | os.mkdir(save_path) 53 | 54 | model = build_model(model_type) 55 | model = nn.DataParallel(model) # 自己加的 56 | model = model.to(device) 57 | model.load_state_dict(torch.load(model_file)) 58 | model.eval() 59 | 60 | for i, (img_file_name, inputs, targets1, targets2, targets3) in enumerate( 61 | tqdm(valLoader) 62 | ): 63 | 64 | inputs = inputs.to(device) 65 | outputs1, outputs2 ,outputs3= model(inputs) 66 | 67 | ## TTA 68 | # outputs4, outputs5, outputs6 = model(torch.flip(inputs, [-1])) 69 | # predict_2 = torch.flip(outputs4, [-1]) 70 | # outputs7, outputs8, outputs9 = model(torch.flip(inputs, [-2])) 71 | # predict_3 = torch.flip(outputs7, [-2]) 72 | # outputs10, outputs11, outputs12 = model(torch.flip(inputs, [-1, -2])) 73 | # predict_4 = torch.flip(outputs10, [-1, -2]) 74 | # predict_list = outputs1 + predict_2 + predict_3 + predict_4 75 | # pred1 = predict_list/4.0 76 | 77 | outputs1 = outputs1.detach().cpu().numpy().squeeze() 78 | 79 | 80 | 81 | res = np.zeros((256, 256)) 82 | res[outputs1>0.5] = 255 83 | res[outputs1<=0.5] = 0 84 | 85 | res = np.array(res, dtype='uint8') 86 | output_path = os.path.join( 87 | # save_path, os.path.basename(img_file_name[0]) 88 | save_path, os.path.basename(img_file_name[0] + ".tif") 89 | ) 90 | cv2.imwrite(output_path, res) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import logging 3 | import os 4 | import random 5 | import torch 6 | from dataset import DatasetImageMaskContourDist 7 | from losses import LossF 8 | from models import Field 9 | 10 | from tensorboardX import SummaryWriter 11 | from torch import nn 12 | from torch.utils.data import DataLoader 13 | from tqdm import tqdm 14 | from utils import visualize, create_train_arg_parser,evaluate 15 | # from torchsummary import summary 16 | from sklearn.model_selection import train_test_split 17 | 18 | # 19 | def define_loss(loss_type, weights=[1, 1, 1]): 20 | 21 | if loss_type == "field": 22 | criterion = LossF(weights) 23 | 24 | return criterion 25 | 26 | 27 | def build_model(model_type): 28 | 29 | if model_type == "field": 30 | model = Field(num_classes=2) 31 | 32 | return model 33 | 34 | 35 | def train_model(model, targets, model_type, criterion, optimizer): 36 | 37 | if model_type == "field": 38 | 39 | optimizer.zero_grad() 40 | 41 | with torch.set_grad_enabled(True): 42 | outputs = model(inputs) 43 | loss = criterion( 44 | outputs[0], outputs[1], outputs[2], targets[0], targets[1], targets[2] 45 | ) 46 | loss.backward() 47 | optimizer.step() 48 | 49 | return loss 50 | 51 | 52 | if __name__ == "__main__": 53 | 54 | args = create_train_arg_parser().parse_args() 55 | # args.pretrained_model_path = r"E:\zhaohang\data\onesoil_2m_pt\56.pt" 56 | # args.pretrained_model_path = r"E:\zhaohang\DeepL\模型训练pt文件_lmx机子\dikuai_all_ihave_ZH\60.pt" 57 | 58 | args.train_path = r"F:\PingAn\NX\Ningxia\output\train_image" 59 | args.model_type = 'field' 60 | args.save_path = r"F:\PingAn\NX\Ningxia\output\model_pt" 61 | 62 | CUDA_SELECT = "cuda:{}".format(args.cuda_no) 63 | log_path = args.save_path + "/summary" 64 | writer = SummaryWriter(log_dir=log_path) 65 | 66 | logging.basicConfig( 67 | filename="".format(args.object_type), 68 | filemode="a", 69 | format="%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s", 70 | datefmt="%Y-%m-%d %H:%M", 71 | level=logging.INFO, 72 | ) 73 | logging.info("") 74 | 75 | train_file_names = glob.glob(os.path.join(args.train_path, "*.tif")) 76 | random.shuffle(train_file_names) 77 | 78 | img_ids = [os.path.splitext(os.path.basename(p))[0] for p in train_file_names] 79 | train_file, val_file = train_test_split(img_ids, test_size=0.2, random_state=41) # 80 | 81 | device = torch.device(CUDA_SELECT if torch.cuda.is_available() else "cpu") 82 | print(device) 83 | model = build_model(args.model_type) 84 | 85 | if torch.cuda.device_count() > 2: #本来是0 86 | print("Let's use", torch.cuda.device_count(), "GPUs!") 87 | model = nn.DataParallel(model) 88 | 89 | model = model.to(device) 90 | # summary(model, input_size=(3, 256, 256)) 91 | 92 | epoch_start = "0" 93 | if args.use_pretrained: 94 | print("Loading Model {}".format(os.path.basename(args.pretrained_model_path))) 95 | model.load_state_dict(torch.load(args.pretrained_model_path)) 96 | epoch_start = os.path.basename(args.pretrained_model_path).split(".")[0] 97 | print(epoch_start) 98 | print('train',args.use_pretrained) 99 | 100 | trainLoader = DataLoader( 101 | DatasetImageMaskContourDist(args.train_path,train_file), 102 | batch_size=args.batch_size,drop_last=False, shuffle=True 103 | ) 104 | devLoader = DataLoader( 105 | DatasetImageMaskContourDist(args.train_path,val_file),drop_last=False, 106 | ) 107 | displayLoader = DataLoader( 108 | DatasetImageMaskContourDist(args.train_path,val_file), 109 | batch_size=args.val_batch_size,drop_last=False, shuffle=True 110 | ) 111 | 112 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) 113 | # optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) 114 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, int(1e10), eta_min=1e-5) 115 | # scheduler = optim.lr_scheduler.StepLR(optimizer, 50, 0.1) 116 | criterion = define_loss(args.model_type) 117 | 118 | 119 | for epoch in tqdm( 120 | range(int(epoch_start) + 1, int(epoch_start) + 1 + args.num_epochs) 121 | ): 122 | 123 | global_step = epoch * len(trainLoader) 124 | running_loss = 0.0 125 | 126 | for i, (img_file_name, inputs, targets1, targets2, targets3) in enumerate( 127 | tqdm(trainLoader) 128 | ): 129 | 130 | model.train() 131 | 132 | inputs = inputs.to(device) 133 | targets1 = targets1.to(device) 134 | targets2 = targets2.to(device) 135 | targets3 = targets3.to(device) 136 | 137 | targets = [targets1, targets2, targets3] 138 | 139 | 140 | loss = train_model(model, targets, args.model_type, criterion, optimizer) 141 | 142 | writer.add_scalar("loss", loss.item(), epoch) 143 | 144 | running_loss += loss.item() * inputs.size(0) 145 | scheduler.step() 146 | 147 | epoch_loss = running_loss / len(train_file_names) 148 | print(epoch_loss) 149 | 150 | if epoch % 1 == 0: 151 | 152 | dev_loss, dev_time = evaluate(device, epoch, model, devLoader, writer) 153 | writer.add_scalar("loss_valid", dev_loss, epoch) 154 | visualize(device, epoch, model, displayLoader, writer, args.val_batch_size) 155 | print("Global Loss:{} Val Loss:{}".format(epoch_loss, dev_loss)) 156 | else: 157 | print("Global Loss:{} ".format(epoch_loss)) 158 | 159 | 160 | logging.info("epoch:{} train_loss:{} ".format(epoch, epoch_loss)) 161 | if epoch % 5 == 0: 162 | torch.save( 163 | model.state_dict(), os.path.join(args.save_path, str(epoch) + ".pt") 164 | ) 165 | 166 | 167 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | import numpy as np 4 | import torchvision 5 | from torch.nn import functional as F 6 | import time 7 | import argparse 8 | from losses import * 9 | 10 | def evaluate(device, epoch, model, data_loader, writer): 11 | model.eval() 12 | losses = [] 13 | start = time.perf_counter() 14 | with torch.no_grad(): 15 | 16 | for iter, data in enumerate(tqdm(data_loader)): 17 | 18 | _, inputs, targets, _, _ = data 19 | # _, inputs, targets, _ = data 20 | inputs = inputs.to(device) 21 | targets = targets.to(device) 22 | outputs = model(inputs) 23 | crition = BCEDiceLoss() 24 | loss = crition(outputs[0],targets) 25 | # loss = F.nll_loss(outputs[0], targets.squeeze(1)) 26 | losses.append(loss.item()) 27 | 28 | writer.add_scalar("Dev_Loss", np.mean(losses), epoch) 29 | 30 | return np.mean(losses), time.perf_counter() - start 31 | 32 | 33 | def visualize(device, epoch, model, data_loader, writer, val_batch_size, train=True): 34 | def save_image(image, tag, val_batch_size): 35 | image -= image.min() 36 | image /= image.max() 37 | grid = torchvision.utils.make_grid( 38 | image, nrow=int(np.sqrt(val_batch_size)), pad_value=0, padding=25 39 | ) 40 | writer.add_image(tag, grid, epoch) 41 | 42 | model.eval() 43 | with torch.no_grad(): 44 | for iter, data in enumerate(tqdm(data_loader)): 45 | _, inputs, targets, _, _ = data 46 | # _, inputs, targets, _ = data 47 | 48 | inputs = inputs.to(device) 49 | 50 | targets = targets.to(device) 51 | outputs = model(inputs) 52 | 53 | output_mask = outputs[0].detach().cpu().numpy() 54 | output_mask[output_mask>0.5]= 255 55 | output_mask[output_mask <=0.5] = 0 56 | # output_final = np.argmax(output_mask, axis=1).astype(float) 57 | output_final = torch.from_numpy(output_mask) 58 | 59 | if train == "True": 60 | save_image(targets.float(), "Target_train",val_batch_size) 61 | save_image(output_final, "Prediction_train",val_batch_size) 62 | else: 63 | save_image(targets.float(), "Target", val_batch_size) 64 | save_image(output_final, "Prediction", val_batch_size) 65 | 66 | break 67 | 68 | 69 | def create_train_arg_parser(): 70 | 71 | parser = argparse.ArgumentParser(description="train setup for segmentation") 72 | parser.add_argument("--train_path", type=str, help="path to img tif files") 73 | parser.add_argument("--val_path", type=str, help="path to img tif files") 74 | parser.add_argument( 75 | "--model_type", 76 | type=str, 77 | help="select model type: bsinet", 78 | ) 79 | parser.add_argument("--object_type", type=str, help="Dataset.") 80 | parser.add_argument( 81 | "--distance_type", 82 | type=str, 83 | default="dist_contour", 84 | help="select distance transform type - dist_mask,dist_contour,dist_contour_tif", 85 | ) 86 | parser.add_argument("--batch_size", type=int, default=8, help="train batch size") 87 | parser.add_argument( 88 | "--val_batch_size", type=int, default=8, help="validation batch size" 89 | ) 90 | parser.add_argument("--num_epochs", type=int, default=100, help="number of epochs") 91 | parser.add_argument("--cuda_no", type=int, default=0, help="cuda number") 92 | parser.add_argument( 93 | "--use_pretrained", type=bool, default=False, help="Load pretrained checkpoint." 94 | ) 95 | parser.add_argument( 96 | "--pretrained_model_path", 97 | type=str, 98 | default=None, 99 | help="If use_pretrained is true, provide checkpoint.", 100 | ) 101 | parser.add_argument("--save_path", type=str, help="Model save path.") 102 | 103 | return parser 104 | 105 | 106 | def create_validation_arg_parser(): 107 | 108 | parser = argparse.ArgumentParser(description="train setup for segmentation") 109 | parser.add_argument( 110 | "--model_type", 111 | type=str, 112 | help="select model type: bsinet", 113 | ) 114 | parser.add_argument("--test_path", type=str, help="path to img tif files") 115 | parser.add_argument("--model_file", type=str, help="model_file") 116 | parser.add_argument("--save_path", type=str, help="results save path.") 117 | parser.add_argument("--cuda_no", type=int, default=0, help="cuda number") 118 | 119 | return parser 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | --------------------------------------------------------------------------------