├── predict.py ├── other_model ├── HSCNN_Plus.py ├── HDNet.py ├── MPRNet.py └── HRNET.py ├── README.md ├── exp.py ├── data.py ├── test.py ├── utils.py ├── batch_predict_single.py ├── miou.py ├── batch_predict.py ├── train.py └── model.py /predict.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import cv2 3 | import numpy as np 4 | from model import MST_Plus_Plus 5 | from utils import outi 6 | from exp import outi_pers 7 | 8 | def create_model(): 9 | model = MST_Plus_Plus(in_channels=3, out_channels=4, n_feat=4) 10 | return model 11 | 12 | 13 | def parse_args(): 14 | import argparse 15 | parser = argparse.ArgumentParser(description="Predict") 16 | parser.add_argument("--outf", type=str, default='./pred/', help='path MSI files') 17 | parser.add_argument('--pretrained_model_path', type=str, default='./model_zoo/net_12epoch.pth') 18 | parser.add_argument('--rgb_path', type=str, default='./hy_seg/1.jpg') 19 | 20 | args = parser.parse_args() 21 | 22 | return args 23 | 24 | 25 | def main(args): 26 | 27 | model = create_model() 28 | pretrained_model_path = args.pretrained_model_path 29 | 30 | 31 | if pretrained_model_path is not None: 32 | print(f'load model from {pretrained_model_path}') 33 | checkpoint = torch.load(pretrained_model_path, map_location=torch.device('cpu')) 34 | 35 | model.load_state_dict(checkpoint['state_dict']) 36 | 37 | rgb_path = args.rgb_path 38 | rgb_data = cv2.imread(rgb_path) # uint8 (340, 340, 3) 39 | rgb_data = np.transpose(rgb_data, [2, 0, 1]) # uint8 (3, 340, 340) 40 | rgb_data = np.float32(rgb_data) 41 | rgb_data = torch.tensor(rgb_data) # 转tensor 42 | rgb_data = rgb_data.unsqueeze(0) 43 | print(rgb_data.dtype) 44 | 45 | 46 | MSI = model(rgb_data) 47 | 48 | print(MSI.shape) 49 | outi(MSI, args.outf, "hy_1") 50 | 51 | outi_pers(MSI, args.outf, "hy_1") 52 | 53 | 54 | 55 | if __name__ == '__main__': 56 | args = parse_args() 57 | main(args) -------------------------------------------------------------------------------- /other_model/HSCNN_Plus.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | class dfus_block(nn.Module): 4 | def __init__(self, dim): 5 | super(dfus_block, self).__init__() 6 | self.conv1 = nn.Conv2d(dim, 128, 1, 1, 0, bias=False) 7 | 8 | self.conv_up1 = nn.Conv2d(128, 32, 3, 1, 1, bias=False) 9 | self.conv_up2 = nn.Conv2d(32, 16, 1, 1, 0, bias=False) 10 | 11 | self.conv_down1 = nn.Conv2d(128, 32, 3, 1, 1, bias=False) 12 | self.conv_down2 = nn.Conv2d(32, 16, 1, 1, 0, bias=False) 13 | 14 | self.conv_fution = nn.Conv2d(96, 32, 1, 1, 0, bias=False) 15 | 16 | #### activation function 17 | self.relu = nn.ReLU(inplace=True) 18 | 19 | def forward(self, x): 20 | """ 21 | x: [b,c,h,w] 22 | return out:[b,c,h,w] 23 | """ 24 | feat = self.relu(self.conv1(x)) 25 | feat_up1 = self.relu(self.conv_up1(feat)) 26 | feat_up2 = self.relu(self.conv_up2(feat_up1)) 27 | feat_down1 = self.relu(self.conv_down1(feat)) 28 | feat_down2 = self.relu(self.conv_down2(feat_down1)) 29 | feat_fution = torch.cat([feat_up1,feat_up2,feat_down1,feat_down2],dim=1) 30 | feat_fution = self.relu(self.conv_fution(feat_fution)) 31 | out = torch.cat([x, feat_fution], dim=1) 32 | return out 33 | 34 | class ddfn(nn.Module): 35 | def __init__(self, dim, num_blocks=78): 36 | super(ddfn, self).__init__() 37 | 38 | self.conv_up1 = nn.Conv2d(dim, 32, 3, 1, 1, bias=False) 39 | self.conv_up2 = nn.Conv2d(32, 32, 1, 1, 0, bias=False) 40 | 41 | self.conv_down1 = nn.Conv2d(dim, 32, 3, 1, 1, bias=False) 42 | self.conv_down2 = nn.Conv2d(32, 32, 1, 1, 0, bias=False) 43 | 44 | dfus_blocks = [dfus_block(dim=128+32*i) for i in range(num_blocks)] 45 | self.dfus_blocks = nn.Sequential(*dfus_blocks) 46 | 47 | #### activation function 48 | self.relu = nn.ReLU(inplace=True) 49 | 50 | def forward(self, x): 51 | """ 52 | x: [b,c,h,w] 53 | return out:[b,c,h,w] 54 | """ 55 | feat_up1 = self.relu(self.conv_up1(x)) 56 | feat_up2 = self.relu(self.conv_up2(feat_up1)) 57 | feat_down1 = self.relu(self.conv_down1(x)) 58 | feat_down2 = self.relu(self.conv_down2(feat_down1)) 59 | feat_fution = torch.cat([feat_up1,feat_up2,feat_down1,feat_down2],dim=1) 60 | out = self.dfus_blocks(feat_fution) 61 | return out 62 | 63 | class HSCNN_Plus(nn.Module): 64 | # Notice 65 | def __init__(self, in_channels=3, out_channels=4, num_blocks=3): 66 | super(HSCNN_Plus, self).__init__() 67 | 68 | self.ddfn = ddfn(dim=in_channels, num_blocks=num_blocks) 69 | self.conv_out = nn.Conv2d(128+32*num_blocks, out_channels, 1, 1, 0, bias=False) 70 | 71 | def forward(self, x): 72 | """ 73 | x: [b,c,h,w] 74 | return out:[b,c,h,w] 75 | """ 76 | fea = self.ddfn(x) 77 | out = self.conv_out(fea) 78 | return out -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VegSegment_SR 2 | *Phenotype segmentation method based on spectral reconstruction for UAV field vegetation* 3 | 4 | Field phenotyping can show plant growth in real-time, and segmented vegetation data can facilitate the visualization and analysis of vegetation information by reducing the influence of background noise. Field phenotype segmentation is considered a difficult task due to the extremely complex environment. Currently, spectral-based methods with deep learning for field vegetation segmentation are developing to overcome complex environments and poor generalization, but there are still two limitations. 5 | One is that the equipment for real-field data collection is extremely expensive. The other is that the datasets in the field are scarce, which is time-consuming and laborious for data annotation. To solve these problems, this study aims to propose a weakly supervised field vegetation segmentation method by introducing multispectral images as a priori information and referring to the theory of vegetation index (VI). In detail, we first adopt the visible light images to reconstruct corresponding multispectral images and then perform feature fusion based on the reconstructed image. Furthermore, a VI-based threshold segmentation is achieved on the reconstructed multispectral images to obtain the field vegetation segmentation map. 6 | In addition, we provide an unmanned aerial vehicle (UAV) RGB-multispectral image dataset, including 2358 pairs of RGB-multispectral images. We introduce a variety of spectral reconstruction (SR) methods as the baseline model and train different SR methods on our dataset, which are applied to different crop vegetations and environments, showing satisfactory experimental results. 7 | 8 | # Env 9 | 10 | ```shell 11 | conda create --name rtm python=3.8 -y 12 | conda activate rtm 13 | conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch 14 | conda install gdal=3.4.1 15 | conda install scikit-image=0.19.2 16 | conda install matplotlib=3.5.2 17 | 18 | pip install opencv-python==4.6.0.66 19 | pip install einops==0.4.1 20 | ``` 21 | 22 | 23 | 24 | # Data 25 | 26 | - Data will be released soon. 27 | - Download the dataset. 28 | - Place the dataset folder to `/dataset/`. 29 | 30 | ```shell 31 | ├─dataset 32 | ├─data_split.py 33 | └─MSI&RGB 34 | ``` 35 | 36 | 37 | 38 | - Run `python data_split.py`. 39 | 40 | ```shell 41 | ├─dataset 42 | ├─data_split.py 43 | ├─MSI&RGB 44 | ├─split_txt 45 | │ ├─train_list.txt 46 | │ └─val_list.txt 47 | ├─Train_MSI 48 | │ ├─xxx0001.tif 49 | │ ├─xxx0002.tif 50 | │ : 51 | │ └─xxx0009.tif 52 | ├─Train_RGB 53 | │ ├─xxx0001.jpg 54 | │ ├─xxx0002.jpg 55 | │ : 56 | │ └─xxx0009.jpg 57 | ├─Val_MSI 58 | │ ├─xxx1000.tif 59 | │ ├─xxx2000.tif 60 | │ : 61 | │ └─xxx9000.tif 62 | └─Val_RGB 63 | ├─xxx1000.jpg 64 | ├─xxx2000.jpg 65 | : 66 | └─xxx9000.jpg 67 | ``` 68 | -------------------------------------------------------------------------------- /exp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision import transforms 4 | 5 | from PIL import Image 6 | import matplotlib.pyplot as plt 7 | 8 | from osgeo import gdal 9 | import os 10 | 11 | import numpy as np 12 | from skimage.metrics import peak_signal_noise_ratio as psnr 13 | from skimage.metrics import structural_similarity as ssim 14 | 15 | 16 | import cv2 17 | 18 | 19 | def outi_pers(fakei, dir, name): 20 | fakei = fakei.detach().cpu().numpy() 21 | 22 | fakei = fakei[0, ...] 23 | 24 | fake_gre = fakei[0, ...] 25 | y = fake_gre.shape[0] 26 | x = fake_gre.shape[1] 27 | 28 | result = gdal.GetDriverByName('GTiff').Create(os.path.join(dir, f'out{name}_gre0.TIF'), xsize=x, ysize=y, bands=1, eType=gdal.GDT_Byte) 29 | result.GetRasterBand(1).WriteArray(fake_gre) 30 | 31 | 32 | fake_red = fakei[1, ...] 33 | result = gdal.GetDriverByName('GTiff').Create(os.path.join(dir, f'out{name}_red1.TIF'), xsize=x, ysize=y, bands=1, eType=gdal.GDT_Byte) 34 | result.GetRasterBand(1).WriteArray(fake_red) 35 | 36 | fake_reg = fakei[2, ...] 37 | result = gdal.GetDriverByName('GTiff').Create(os.path.join(dir, f'out{name}_reg2.TIF'), xsize=x, ysize=y, bands=1, eType=gdal.GDT_Byte) 38 | result.GetRasterBand(1).WriteArray(fake_reg) 39 | 40 | fake_nir = fakei[3, ...] 41 | result = gdal.GetDriverByName('GTiff').Create(os.path.join(dir, f'out{name}_nir3.TIF'), xsize=x, ysize=y, bands=1, eType=gdal.GDT_Byte) 42 | result.GetRasterBand(1).WriteArray(fake_nir) 43 | 44 | 45 | def gen_seg(): 46 | msi_data = gdal.Open("./pred/batch_pre_res/IMG_220611_073703_0038__01out.tif").ReadAsArray() # uint8 47 | msi_data = torch.tensor(msi_data) # 转tensor 48 | 49 | msi_data = msi_data.int() 50 | print(msi_data.dtype) 51 | 52 | s1 = 0 53 | 54 | # 0:gre 1:red 2:reg 3:nir 55 | x1, x2, x3, x4 = 1, 0, -1, -1 56 | s1 = x1 * msi_data[0,:,:] + x2 * msi_data[1,:,:] + x3 * msi_data[2,:,:] + x4 * msi_data[3,:,:] 57 | 58 | 59 | im_gray1 = np.array(s1.numpy()) 60 | # 利用图像像素均值二值化 61 | avg_gray = np.average(im_gray1) 62 | 63 | if avg_gray < 0 : 64 | avg_gray = avg_gray * 1.25 65 | else: 66 | avg_gray = avg_gray * 0.75 67 | 68 | print(avg_gray) 69 | 70 | im_gray2 = np.where(im_gray1 < avg_gray, 255, 0) 71 | 72 | 73 | data = np.array(im_gray2, dtype='uint8') 74 | plt.imshow(data) 75 | plt.show() 76 | # cv2.imwrite("./pred/IMG_220611_073743_0046__04seg.png", data) 77 | 78 | 79 | def show_seg(): 80 | # imgfile = './pred/avg_test1_12pth/test2.jpg' 81 | # pngfile = './pred/avg_test1_12pth/2+3.png' 82 | # 83 | # img = cv2.imread(imgfile, 1) 84 | # mask = cv2.imread(pngfile, 0) 85 | # 86 | # contours, _ = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 87 | # cv2.drawContours(img, contours, -1, (0, 0, 255), 1) 88 | # 89 | # img = img[:, :, ::-1] 90 | # img[..., 2] = np.where(mask == 1, 255, img[..., 2]) 91 | # 92 | # plt.imshow(img) 93 | # plt.savefig("./pred/avg_test1_12pth/result_2+3.png") 94 | # plt.show() 95 | image1 = Image.open("./pred/avg_out4/test.jpg") 96 | image2 = Image.open("./pred/avg_out4/0-2-3.png") 97 | 98 | plt.figure() 99 | 100 | plt.subplot(221) 101 | plt.imshow(image1) 102 | 103 | plt.subplot(222) 104 | plt.imshow(image2) 105 | 106 | plt.subplot(223) 107 | plt.imshow(image1) 108 | plt.imshow(image2, alpha=0.5) 109 | 110 | plt.savefig("./pred/avg_out4/3.png") 111 | plt.show() 112 | 113 | 114 | 115 | def compute_miou(pred, target, nclass): 116 | mini = 1 117 | 118 | pred = np.array(pred) 119 | target = np.array(target) 120 | # 计算公共区域 121 | intersection = pred * (pred == target) 122 | 123 | # 直方图 124 | area_inter, _ = np.histogram(intersection, bins=2, range=(mini, nclass)) 125 | area_pred, _ = np.histogram(pred, bins=2, range=(mini, nclass)) 126 | area_target, _ = np.histogram(target, bins=2, range=(mini, nclass)) 127 | area_union = area_pred + area_target - area_inter 128 | 129 | # 交集已经小于并集 130 | assert (area_inter <= area_union).all(), "Intersection area should be smaller than Union area" 131 | 132 | rate = max(area_inter) / max(area_union) 133 | return rate 134 | 135 | 136 | 137 | 138 | 139 | if __name__ == '__main__': 140 | # nclass = 1 141 | # # target 142 | # target = [[0,0,0], 143 | # [0,1,1], 144 | # [0,1,1]] 145 | # 146 | # # pred 147 | # pred = [[1,1,0], 148 | # [1,1,0], 149 | # [0,0,0]] 150 | # 151 | # # 计算miou 152 | # rate = compute_miou(pred, target, nclass) 153 | # print(rate) 154 | 155 | 156 | gen_seg() 157 | # show_seg() 158 | 159 | 160 | 161 | 162 | 163 | 164 | # 0 1 2 3 165 | # 01 02 03 166 | # 12 13 167 | # 23 168 | # 012 013 023 123 169 | # 0123 170 | 171 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import os 2 | from osgeo import gdal 3 | import cv2 4 | import random 5 | import numpy as np 6 | 7 | import torch 8 | from torch.utils.data import Dataset 9 | 10 | 11 | class TrainDataset(Dataset): 12 | def __init__(self, data_root, arg = True): 13 | self.root = data_root 14 | self.arg = arg 15 | self.rgb = [] 16 | self.msi = [] 17 | 18 | rgb_tr_path = f'{data_root}/Train_RGB/' 19 | msi_tr_path = f'{data_root}/Train_MSI/' 20 | 21 | with open(f'{data_root}/split_txt/train_list.txt', 'r') as li: 22 | self.rgb_list = [line.replace('\n', '.jpg') for line in li] 23 | self.msi_list = [line.replace('jpg', 'tif') for line in self.rgb_list] 24 | self.rgb_list.sort() 25 | self.msi_list.sort() 26 | print(f'len(train_rgb) dataset:{len(self.rgb_list)-1}') 27 | print(f'len(train_multispectral) dataset:{len(self.msi_list)-1}') 28 | 29 | for i in range(len(self.rgb_list)): 30 | rgb_path = os.path.join(rgb_tr_path, self.rgb_list[i]) 31 | rgb_data = cv2.imread(rgb_path) # uint8 (340, 340, 3) 32 | rgb_data = np.transpose(rgb_data, [2, 0, 1]) # uint8 (3, 340, 340) 33 | rgb_data = torch.tensor(rgb_data) # 转tensor 34 | self.rgb.append(rgb_data) 35 | print("\rRead [{}] processing [{}/{}]".format(self.rgb_list[i], i, len(self.rgb_list)-1), end="") # RGB 36 | print() 37 | for i in range(len(self.msi_list)): 38 | msi_path = os.path.join(msi_tr_path, self.msi_list[i]) 39 | msi_data = gdal.Open(msi_path).ReadAsArray() 40 | msi_data = torch.tensor(msi_data) # 转tensor 41 | self.msi.append(msi_data) # uint8 (4, 340, 340) 42 | print("\rRead [{}] processing [{}/{}]".format(self.msi_list[i], i, len(self.msi_list)-1), end="") # MSI 43 | print() 44 | 45 | def arguement(self, img, rotTimes, vFlip, hFlip): 46 | # tensor -> array 47 | img = img.numpy() 48 | 49 | # Random rotation 50 | for j in range(rotTimes): 51 | img = np.rot90(img.copy(), axes=(1, 2)) 52 | # Random vertical Flip 53 | for j in range(vFlip): 54 | img = img[:, :, ::-1].copy() 55 | # Random horizontal Flip 56 | for j in range(hFlip): 57 | img = img[:, ::-1, :].copy() 58 | 59 | # array -> tensor 60 | return torch.from_numpy(img.copy()) 61 | 62 | def __getitem__(self, idx): 63 | 64 | rgb = self.rgb[idx] 65 | msi = self.msi[idx] 66 | 67 | 68 | random.seed(0) 69 | 70 | rotTimes = random.randint(0, 3) 71 | vFlip = random.randint(0, 1) 72 | hFlip = random.randint(0, 1) 73 | if self.arg: 74 | rgb = self.arguement(rgb, rotTimes, vFlip, hFlip) 75 | msi = self.arguement(msi, rotTimes, vFlip, hFlip) 76 | 77 | # 查看转换后的图像 78 | # cv2.imshow("rgb", np.transpose(rgb.numpy(), [1, 2, 0])) 79 | # cv2.imshow("msi", np.transpose(msi.numpy(), [1, 2, 0])) 80 | # cv2.waitKey(0) 81 | return rgb, msi 82 | 83 | def __len__(self): 84 | return len(self.rgb_list) 85 | 86 | 87 | class ValidDataset(Dataset): 88 | def __init__(self, data_root): 89 | self.root = data_root 90 | self.rgb = [] 91 | self.msi = [] 92 | 93 | rgb_tr_path = f'{data_root}/Val_RGB/' 94 | msi_tr_path = f'{data_root}/Val_MSI/' 95 | 96 | with open(f'{data_root}/split_txt/val_list.txt', 'r') as li: 97 | self.rgb_list = [line.replace('\n', '.jpg') for line in li] 98 | self.msi_list = [line.replace('jpg', 'tif') for line in self.rgb_list] 99 | self.rgb_list.sort() 100 | self.msi_list.sort() 101 | print(f'len(val_rgb) dataset:{len(self.rgb_list)-1}') 102 | print(f'len(val_multispectral) dataset:{len(self.msi_list)-1}') 103 | 104 | for i in range(len(self.rgb_list)): 105 | rgb_path = os.path.join(rgb_tr_path, self.rgb_list[i]) 106 | rgb_data = cv2.imread(rgb_path) # uint8 (340, 340, 3) 107 | rgb_data = np.transpose(rgb_data, [2, 0, 1]) # uint8 (3, 340, 340) 108 | rgb_data = torch.tensor(rgb_data) 109 | self.rgb.append(rgb_data) 110 | print("\rRead [{}] processing [{}/{}]".format(self.rgb_list[i], i, len(self.rgb_list)-1), end="") # RGB 111 | print() 112 | for i in range(len(self.msi_list)): 113 | msi_path = os.path.join(msi_tr_path, self.msi_list[i]) 114 | msi_data = gdal.Open(msi_path).ReadAsArray() 115 | msi_data = torch.tensor(msi_data) 116 | 117 | self.msi.append(msi_data) # uint8 (4, 340, 340) 118 | print("\rRead [{}] processing [{}/{}]".format(self.msi_list[i], i, len(self.msi_list)-1), end="") # MSI 119 | print() 120 | 121 | def __getitem__(self, idx): 122 | 123 | rgb = self.rgb[idx] 124 | msi = self.msi[idx] 125 | 126 | return rgb, msi 127 | 128 | def __len__(self): 129 | return len(self.rgb_list) 130 | 131 | 132 | 133 | if __name__ == '__main__': 134 | t = TrainDataset('./dataset') 135 | v = ValidDataset('./dataset') 136 | # print(t[1]) 137 | # print(v[1]) -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | """ 2 | refer to https://github.com/jfzhang95/pytorch-deeplab-xception/blob/master/utils/metrics.py 3 | """ 4 | import numpy 5 | import numpy as np 6 | import os 7 | import cv2 8 | 9 | __all__ = ['SegmentationMetric'] 10 | 11 | """ 12 | confusionMetric # 注意:此处横着代表预测值,竖着代表真实值,与之前介绍的相反 13 | P\L P N 14 | P TP FP 15 | N FN TN 16 | """ 17 | 18 | 19 | class SegmentationMetric(object): 20 | def __init__(self, numClass): 21 | self.numClass = numClass 22 | self.confusionMatrix = np.zeros((self.numClass,) * 2) 23 | 24 | def pixelAccuracy(self): 25 | # return all class overall pixel accuracy 26 | # PA = acc = (TP + TN) / (TP + TN + FP + TN) 27 | acc = np.diag(self.confusionMatrix).sum() / self.confusionMatrix.sum() 28 | return acc 29 | 30 | def classPixelAccuracy(self): 31 | # return each category pixel accuracy(A more accurate way to call it precision) 32 | # acc = (TP) / TP + FP 33 | classAcc = np.diag(np.transpose(self.confusionMatrix)) / np.transpose(self.confusionMatrix).sum(axis=1) 34 | return classAcc # 返回的是一个列表值,如:[0.90, 0.80, 0.96],表示类别1 2 3各类别的预测准确率 35 | 36 | def meanPixelAccuracy(self): 37 | classAcc = self.classPixelAccuracy() 38 | meanAcc = np.nanmean(classAcc) # np.nanmean 求平均值,nan表示遇到Nan类型,其值取为0 39 | return meanAcc # 返回单个值,如:np.nanmean([0.90, 0.80, 0.96, nan, nan]) = (0.90 + 0.80 + 0.96) / 3 = 0.89 40 | 41 | def meanIntersectionOverUnion(self): 42 | # Intersection = TP Union = TP + FP + FN 43 | # IoU = TP / (TP + FP + FN) 44 | intersection = np.diag(self.confusionMatrix) # 取对角元素的值,返回列表 45 | union = np.sum(self.confusionMatrix, axis=1) + np.sum(self.confusionMatrix, axis=0) - np.diag( 46 | self.confusionMatrix) # axis = 1表示混淆矩阵行的值,返回列表; axis = 0表示取混淆矩阵列的值,返回列表 47 | IoU = intersection / union # 返回列表,其值为各个类别的IoU 48 | mIoU = np.nanmean(IoU) # 求各类别IoU的平均 49 | return mIoU 50 | 51 | def genConfusionMatrix(self, imgPredict, imgLabel): # 同FCN中score.py的fast_hist()函数 52 | # remove classes from unlabeled pixels in gt image and predict 53 | mask = (imgLabel >= 0) & (imgLabel < self.numClass) 54 | label = self.numClass * imgLabel[mask] + imgPredict[mask] 55 | # num_class * gt + pred 56 | # [ 0 4 10 0 5 11 10 5 15] 57 | count = np.bincount(label, minlength=self.numClass ** 2) 58 | confusionMatrix = count.reshape(self.numClass, self.numClass) 59 | return confusionMatrix 60 | 61 | def Frequency_Weighted_Intersection_over_Union(self): 62 | # FWIOU = [(TP+FN)/(TP+FP+TN+FN)] *[TP / (TP + FP + FN)] 63 | freq = np.sum(self.confusion_matrix, axis=1) / np.sum(self.confusion_matrix) 64 | iu = np.diag(self.confusion_matrix) / ( 65 | np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) - 66 | np.diag(self.confusion_matrix)) 67 | FWIoU = (freq[freq > 0] * iu[freq > 0]).sum() 68 | return FWIoU 69 | 70 | def addBatch(self, imgPredict, imgLabel): 71 | assert imgPredict.shape == imgLabel.shape 72 | self.confusionMatrix += self.genConfusionMatrix(imgPredict, imgLabel) 73 | 74 | def reset(self): 75 | self.confusionMatrix = np.zeros((self.numClass, self.numClass)) 76 | 77 | 78 | if __name__ == '__main__': 79 | 80 | 81 | 82 | metric = SegmentationMetric(2) # 类别数 83 | 84 | 85 | # imgPredict = np.array([[0,1,2,2], 86 | # [0,1,2,2], 87 | # [0,1,2,2], 88 | # [0,1,2,2]]) # 可直接换成预测图片 89 | # 90 | # imgLabel = np.array([[0,1,2,2], 91 | # [1,2,0,0], 92 | # [0,1,2,2], 93 | # [0,1,2,2]]) # 可直接换成标注图片 94 | 95 | imgPredict = cv2.imread("./1.png") 96 | imgPredict = np.transpose(imgPredict, [2, 0, 1]) # uint8 (3, 340, 340) 97 | imgPredict = imgPredict[:][:][0] 98 | # print(pred_data.dtype) 99 | 100 | imgLabel = cv2.imread("./2.png") 101 | imgLabel = np.transpose(imgLabel, [2, 0, 1]) # uint8 (3, 340, 340) 102 | imgLabel = imgLabel[:][:][0] 103 | 104 | imgPredict = np.where(imgPredict > 2, 1, 0) 105 | imgLabel = np.where(imgLabel > 2, 1, 0) 106 | 107 | 108 | metric.addBatch(imgPredict, imgLabel) 109 | 110 | print('ConfusionMatrix :') 111 | print(metric.confusionMatrix) # numpy.transpose() 矩阵转置 112 | 113 | print('Add:') 114 | print(numpy.sum(metric.confusionMatrix, axis=0)) 115 | 116 | print('%:') 117 | print(metric.confusionMatrix / numpy.sum(metric.confusionMatrix, axis=0)) 118 | # print('ConfusionMatrix :') 119 | # print(numpy.transpose(metric.confusionMatrix)) 120 | 121 | pa = metric.pixelAccuracy() 122 | cpa = metric.classPixelAccuracy() 123 | mpa = metric.meanPixelAccuracy() 124 | mIoU = metric.meanIntersectionOverUnion() 125 | print('pa is : %f' % pa) 126 | 127 | print('cpa is :') # 列表 128 | print(cpa) 129 | 130 | print('mpa is : %f' % mpa) 131 | print('mIoU is : %f' % mIoU) 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | # GT 141 | # 0 1 2 3 142 | # 0[[2. 1. 0. 0.] 143 | # 1 [0. 2. 0. 0.] 144 | # P 2 [0. 0. 2. 0.] 145 | # 3 [0. 0. 1. 1.]] 146 | # pa is : 0.777778 147 | # cpa is : 148 | # [1. 0.66666667 0.66666667 1. ] 149 | # mpa is : 0.833333 150 | # mIoU is : 0.625000 -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import logging 4 | import numpy as np 5 | import os 6 | 7 | from osgeo import gdal 8 | from skimage.metrics import peak_signal_noise_ratio as psnr 9 | from skimage.metrics import structural_similarity as ssim 10 | 11 | 12 | class AverageMeter(object): 13 | def __init__(self): 14 | self.reset() 15 | 16 | def reset(self): 17 | self.val = 0 18 | self.avg = 0 19 | self.sum = 0 20 | self.count = 0 21 | 22 | def update(self, val, n=1): 23 | self.val = val 24 | self.sum += val * n 25 | self.count += n 26 | self.avg = self.sum / self.count 27 | 28 | 29 | 30 | # 初始化log文件 31 | def initialize_logger(file_dir): 32 | logger = logging.getLogger() 33 | fhandler = logging.FileHandler(filename=file_dir, mode='a') 34 | formatter = logging.Formatter('%(asctime)s - %(message)s', "%Y-%m-%d %H:%M:%S") 35 | fhandler.setFormatter(formatter) 36 | logger.addHandler(fhandler) 37 | logger.setLevel(logging.INFO) 38 | return logger 39 | 40 | 41 | # 保存权重 42 | def save_checkpoint(model_path, epoch, model, optimizer): 43 | state = { 44 | 'epoch': epoch, 45 | 'state_dict': model.state_dict(), 46 | 'optimizer': optimizer.state_dict(), 47 | } 48 | 49 | torch.save(state, os.path.join(model_path, 'net_%depoch.pth' % epoch)) 50 | 51 | # 计算损失-平均绝对误差 52 | class Loss_MRAE(nn.Module): 53 | def __init__(self): 54 | super(Loss_MRAE, self).__init__() 55 | 56 | def forward(self, outputs, label): 57 | assert outputs.shape == label.shape 58 | error = torch.abs(outputs - label) / label 59 | mrae = torch.mean(error) # .contiguous().view(-1) 60 | return mrae 61 | 62 | # 计算损失-均方根误差 63 | class Loss_RMSE(nn.Module): 64 | def __init__(self): 65 | super(Loss_RMSE, self).__init__() 66 | 67 | def forward(self, outputs, label): 68 | assert outputs.shape == label.shape 69 | error = outputs-label 70 | sqrt_error = torch.pow(error,2) 71 | rmse = torch.sqrt(torch.mean(sqrt_error.contiguous().view(-1))) 72 | return rmse 73 | 74 | # 计算损失-峰值信噪比 75 | class Loss_PSNR(nn.Module): 76 | def __init__(self): 77 | super(Loss_PSNR, self).__init__() 78 | 79 | def forward(self, im_true, im_fake, data_range=255): 80 | N = im_true.size()[0] 81 | C = im_true.size()[1] 82 | H = im_true.size()[2] 83 | W = im_true.size()[3] 84 | Itrue = im_true.clamp(0., 1.).mul_(data_range).reshape(N, C * H * W) 85 | Ifake = im_fake.clamp(0., 1.).mul_(data_range).reshape(N, C * H * W) 86 | mse = nn.MSELoss(reduction='none') 87 | err = mse(Itrue, Ifake).sum(dim=1, keepdim=True).div_(C * H * W) 88 | psnr = 10. * torch.log((data_range ** 2) / err) / np.log(10.) 89 | return torch.mean(psnr) 90 | 91 | def time2file_name(time): 92 | year = time[0:4] 93 | month = time[5:7] 94 | day = time[8:10] 95 | hour = time[11:13] 96 | minute = time[14:16] 97 | second = time[17:19] 98 | time_filename = year + '_' + month + '_' + day + '_' + hour + '_' + minute + '_' + second 99 | return time_filename 100 | 101 | def record_loss(loss_csv, epoch, iteration, epoch_time, lr, train_loss, test_loss): 102 | """ Record many results.""" 103 | loss_csv.write('{},{},{},{},{},{}\n'.format(epoch, iteration, epoch_time, lr, train_loss, test_loss)) 104 | loss_csv.flush() 105 | loss_csv.close 106 | 107 | 108 | def try_gpu(i=0): 109 | if torch.cuda.device_count() >= i+1: 110 | return torch.device(f'cuda:{i}') 111 | return torch.device('cpu') 112 | 113 | 114 | def try_all_gpus(): 115 | devices = [torch.device(f'cuda:{i}') for i in range(torch.cuda.device_count())] 116 | return devices if devices else [torch.device('cpu')] 117 | 118 | # psnr 119 | class T_Loss_PSNR(nn.Module): 120 | def __init__(self): 121 | super(T_Loss_PSNR, self).__init__() 122 | 123 | def forward(self, im_true, im_fake): 124 | im_true = (im_true.detach().cpu().numpy().transpose(0, 2, 3, 1)).astype(np.uint8) # 转numpy (b,c,h,w) 125 | im_fake = (im_fake.detach().cpu().numpy().transpose(0, 2, 3, 1)).astype(np.uint8) # 转numpy (b,c,h,w) 126 | 127 | # print(im_true.dtype, im_fake.shape[0]) 128 | i_psnr = 0 129 | for i in range(im_true.shape[0]): 130 | i_t = im_true[i, ...] 131 | i_f = im_fake[i, ...] 132 | p = psnr(i_t, i_f) 133 | i_psnr += p 134 | 135 | m_psnr = i_psnr / im_true.shape[0] 136 | 137 | return m_psnr 138 | 139 | # ssim 140 | class T_Loss_SSIM(nn.Module): 141 | def __init__(self): 142 | super(T_Loss_SSIM, self).__init__() 143 | 144 | def forward(self, im_true, im_fake): 145 | im_true = (im_true.detach().cpu().numpy().transpose(0, 2, 3, 1)).astype(np.uint8) # 转numpy (b,c,h,w) 146 | im_fake = (im_fake.detach().cpu().numpy().transpose(0, 2, 3, 1)).astype(np.uint8) # 转numpy (b,c,h,w) 147 | 148 | # print(im_true.dtype, im_fake.shape[0]) 149 | 150 | i_ssim = 0 151 | for i in range(im_true.shape[0]): 152 | i_t = im_true[i,...] 153 | i_f = im_fake[i,...] 154 | s = ssim(i_t, i_f, multichannel=True) 155 | i_ssim += s 156 | 157 | m_ssim = i_ssim / im_true.shape[0] 158 | 159 | return m_ssim 160 | 161 | 162 | def outi(fakei, dir, name): 163 | fakei = fakei.detach().cpu().numpy() 164 | 165 | fakei = fakei[0, ...] 166 | 167 | fake_gre = fakei[0, ...] 168 | fake_red = fakei[1, ...] 169 | fake_reg = fakei[2, ...] 170 | fake_nir = fakei[3, ...] 171 | 172 | y = fake_gre.shape[0] 173 | x = fake_gre.shape[1] 174 | 175 | savepath = os.path.join(dir, f'{name}out.TIF') # 生成图信息 176 | result = gdal.GetDriverByName('GTiff').Create(savepath, xsize=x, ysize=y, bands=4, eType=gdal.GDT_Byte) 177 | result.GetRasterBand(1).WriteArray(fake_gre) 178 | result.GetRasterBand(2).WriteArray(fake_red) 179 | result.GetRasterBand(3).WriteArray(fake_reg) 180 | result.GetRasterBand(4).WriteArray(fake_nir) 181 | print("save: " + savepath) -------------------------------------------------------------------------------- /batch_predict_single.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import cv2 3 | import os 4 | import numpy as np 5 | from model import MST_Plus_Plus 6 | from other_model import AWAN, HRNET, HSCNN_Plus, MIRNet, HDNet, MPRNet 7 | from utils import outi 8 | from exp import outi_pers 9 | from PIL import Image 10 | import matplotlib.pyplot as plt 11 | from osgeo import gdal 12 | from utils import initialize_logger 13 | 14 | 15 | def create_model(name): 16 | if name == 'MST++': 17 | model = MST_Plus_Plus(in_channels=3, out_channels=4, n_feat=4) 18 | elif name == 'HRnet': 19 | model = HRNET.SGN() # 31689284 final // batch_size = 1 20 | elif name == 'HSCNN++': 21 | model = HSCNN_Plus.HSCNN_Plus() # 299584 final // batch_size = 2 22 | elif name == 'HDnet': 23 | model = HDNet.HDNet() # 2647552 final 24 | elif name == 'MPRnet': 25 | model = MPRNet.MPRNet() # 60349 final 26 | else: 27 | print(f'Method {name} is not defined !!!!') 28 | 29 | return model 30 | 31 | 32 | def parse_args(): 33 | import argparse 34 | parser = argparse.ArgumentParser(description="Predict") 35 | # parser.add_argument('--pretrained_model_path', type=str, default='./pred/single/HDnet/model/net_5epoch.pth') 36 | # parser.add_argument("--outf", type=str, default='./res/single/HDnet/005/', help='path MSI files') 37 | parser.add_argument('--rgb_dir', type=str, default='./pred/single/Val_RGB') 38 | 39 | # model_name : MST++ HRnet HSCNN++ HDnet MPRnet 40 | parser.add_argument("--model_name", type=str, default='MST++', help='model name') 41 | 42 | args = parser.parse_args() 43 | 44 | return args 45 | 46 | 47 | def gen_seg(tif_file, name, out_path): 48 | msi_data = gdal.Open(tif_file).ReadAsArray() 49 | msi_data = torch.tensor(msi_data) # 转tensor 50 | 51 | # total_sum = torch.sum(msi_data[2,...]) 52 | # print(total_sum) 53 | msi_data = msi_data.int() 54 | 55 | # ============fusion based on weight============ 56 | s1 = 0 57 | # 0:gre 1:red 2:reg 3:nir 58 | x1, x2, x3, x4 = -1.0, 0.0, 1.0, 1.0 59 | s1 = x1 * msi_data[0, :, :] + x2 * msi_data[1, :, :] + x3 * msi_data[2, :, :] + x4 * msi_data[3, :, :] 60 | im_gray1 = np.array(s1.numpy()) 61 | avg_gray = np.average(im_gray1) 62 | 63 | print(f'b{avg_gray}') 64 | 65 | # 聚合 66 | if avg_gray < 145: 67 | avg_gray = avg_gray * 1.15 68 | elif avg_gray > 150: 69 | avg_gray = avg_gray * 0.9 70 | 71 | print(avg_gray) 72 | im_gray2 = np.where(im_gray1 > avg_gray, 255, 0) 73 | 74 | # ============fusion based on VI============ 75 | # vi = (msi_data[0, :, :] - msi_data[3, :, :])/(msi_data[0, :, :] + msi_data[3, :, :]) 76 | # # vi = ((2 * msi_data[0, :, :]) - (msi_data[2, :, :] + msi_data[3, :, :]))/((2 * msi_data[0, :, :]) + (msi_data[2, :, :] + msi_data[3, :, :])) 77 | # # print(vi) 78 | # im_gray2 = np.where(vi < -0.2, 255, 0) 79 | 80 | # ============save result============ 81 | data = np.array(im_gray2, dtype='uint8') 82 | cv2.imwrite(os.path.join(out_path, f'{name}.png'), data) 83 | print("save: " + os.path.join(out_path, f'{name}.png')) 84 | 85 | 86 | def show_seg(rgb, seg, name, out_path): 87 | # ==============no.1============== 88 | # imgfile = './pred/avg_test1_12pth/test2.jpg' 89 | # pngfile = './pred/avg_test1_12pth/2+3.png' 90 | # 91 | # img = cv2.imread(imgfile, 1) 92 | # mask = cv2.imread(pngfile, 0) 93 | # 94 | # contours, _ = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 95 | # cv2.drawContours(img, contours, -1, (0, 0, 255), 1) 96 | # 97 | # img = img[:, :, ::-1] 98 | # img[..., 2] = np.where(mask == 1, 255, img[..., 2]) 99 | # 100 | # plt.imshow(img) 101 | # plt.savefig("./pred/avg_test1_12pth/result_2+3.png") 102 | # plt.show() 103 | 104 | # ==============no.2============== 105 | image1 = Image.open(rgb) 106 | image2 = Image.open(seg) 107 | 108 | plt.figure() 109 | 110 | plt.subplot(221) 111 | plt.imshow(image1) 112 | 113 | plt.subplot(222) 114 | plt.imshow(image2) 115 | 116 | plt.subplot(223) 117 | plt.imshow(image1) 118 | plt.imshow(image2, alpha=0.5) 119 | 120 | plt.savefig(os.path.join(out_path, f'{name}_c.png')) 121 | 122 | plt.close() 123 | # plt.show() 124 | 125 | # ==============no.3============== 126 | # image1 = Image.open(rgb) 127 | # image2 = Image.open(seg) 128 | # 129 | # image1 = image1.convert('RGBA') 130 | # image2 = image2.convert('RGBA') 131 | # 132 | # # 两幅图像进行合并时,按公式:blended_img = img1 * (1 – alpha) + img2* alpha 进行 133 | # image = Image.blend(image1, image2, 0.3) 134 | # # image.save(os.path.join(out_path, f'{name}_c.png')) 135 | # image.show() 136 | 137 | 138 | def main(args, pretrained_model_path, outf): 139 | model = create_model(args.model_name) 140 | 141 | if pretrained_model_path is not None: 142 | print(f'load model from {pretrained_model_path}') 143 | # If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') 144 | checkpoint = torch.load(pretrained_model_path, map_location=torch.device('cpu')) 145 | model.load_state_dict(checkpoint['state_dict']) 146 | 147 | # # logging 148 | # log_dir = os.path.join(outf, f'{args.model_name}{pretrained_model_path}.log') 149 | # logger = initialize_logger(log_dir) 150 | 151 | rgb_dir = args.rgb_dir 152 | 153 | for rgb in os.listdir(rgb_dir): 154 | name = rgb.split(".")[0] 155 | 156 | rgb_data = cv2.imread(os.path.join(rgb_dir, rgb)) # uint8 (340, 340, 3) 157 | rgb_data = np.transpose(rgb_data, [2, 0, 1]) # uint8 (3, 340, 340) 158 | rgb_data = np.float32(rgb_data) 159 | rgb_data = torch.tensor(rgb_data) # 转tensor 160 | rgb_data = rgb_data.unsqueeze(0) # (1, 3, 340, 340) 161 | # print(rgb_data.dtype) 162 | 163 | MSI = model(rgb_data) 164 | 165 | outi(MSI, outf, rgb.split('.')[0]) 166 | # outi_pers(MSI, outf, rgb.split('.')[0]) 167 | 168 | gen_seg(os.path.join(outf, f'{name}out.TIF'), name, outf) 169 | 170 | show_seg(os.path.join(rgb_dir, rgb), os.path.join(outf, f'{name}.png'), name, outf) 171 | 172 | 173 | if __name__ == '__main__': 174 | args = parse_args() 175 | 176 | for epoch in range(0, 100, 5): 177 | pdm_path = f'./pred/single/{args.model_name}/model/net_{epoch}epoch.pth' 178 | outf = f'./res/single/{args.model_name}/{epoch}/' 179 | 180 | if not os.path.exists(outf): 181 | os.makedirs(outf) 182 | 183 | main(args, pdm_path, outf) 184 | 185 | for epoch in range(100, 300, 20): 186 | pdm_path = f'./pred/single/{args.model_name}/model/net_{epoch}epoch.pth' 187 | outf = f'./res/single/{args.model_name}/{epoch}/' 188 | 189 | if not os.path.exists(outf): 190 | os.makedirs(outf) 191 | 192 | main(args, pdm_path, outf) -------------------------------------------------------------------------------- /miou.py: -------------------------------------------------------------------------------- 1 | """ 2 | refer to https://github.com/jfzhang95/pytorch-deeplab-xception/blob/master/utils/metrics.py 3 | """ 4 | import numpy 5 | import numpy as np 6 | import os 7 | import cv2 8 | 9 | __all__ = ['SegmentationMetric'] 10 | 11 | """ 12 | confusionMetric # 注意:此处横着代表预测值,竖着代表真实值,与之前介绍的相反 13 | P\L P N 14 | P TP FP 15 | N FN TN 16 | """ 17 | 18 | 19 | class SegmentationMetric(object): 20 | def __init__(self, numClass): 21 | self.numClass = numClass 22 | self.confusionMatrix = np.zeros((self.numClass,) * 2) 23 | 24 | def pixelAccuracy(self): 25 | # return all class overall pixel accuracy 26 | # PA = acc = (TP + TN) / (TP + TN + FP + TN) 27 | acc = np.diag(self.confusionMatrix).sum() / self.confusionMatrix.sum() 28 | return acc 29 | 30 | def classPixelAccuracy(self): 31 | # return each category pixel accuracy(A more accurate way to call it precision) 32 | # acc = (TP) / TP + FP 33 | classAcc = np.diag(self.confusionMatrix) / self.confusionMatrix.sum(axis=1) 34 | return classAcc # 返回的是一个列表值,如:[0.90, 0.80, 0.96],表示类别1 2 3各类别的预测准确率 35 | 36 | def meanPixelAccuracy(self): 37 | classAcc = self.classPixelAccuracy() 38 | meanAcc = np.nanmean(classAcc) # np.nanmean 求平均值,nan表示遇到Nan类型,其值取为0 39 | return meanAcc # 返回单个值,如:np.nanmean([0.90, 0.80, 0.96, nan, nan]) = (0.90 + 0.80 + 0.96) / 3 = 0.89 40 | 41 | def meanIntersectionOverUnion(self): 42 | # Intersection = TP Union = TP + FP + FN 43 | # IoU = TP / (TP + FP + FN) 44 | intersection = np.diag(self.confusionMatrix) # 取对角元素的值,返回列表 45 | union = np.sum(self.confusionMatrix, axis=1) + np.sum(self.confusionMatrix, axis=0) - np.diag( 46 | self.confusionMatrix) # axis = 1表示混淆矩阵行的值,返回列表; axis = 0表示取混淆矩阵列的值,返回列表 47 | IoU = intersection / union # 返回列表,其值为各个类别的IoU 48 | mIoU = np.nanmean(IoU) # 求各类别IoU的平均 49 | return mIoU 50 | 51 | def genConfusionMatrix(self, imgPredict, imgLabel): # 同FCN中score.py的fast_hist()函数 52 | # remove classes from unlabeled pixels in gt image and predict 53 | mask = (imgLabel >= 0) & (imgLabel < self.numClass) 54 | label = self.numClass * imgLabel[mask] + imgPredict[mask] 55 | # num_class * gt + pred 56 | # [ 0 4 10 0 5 11 10 5 15] 57 | count = np.bincount(label, minlength=self.numClass ** 2) 58 | confusionMatrix = count.reshape(self.numClass, self.numClass) 59 | return confusionMatrix 60 | 61 | def Frequency_Weighted_Intersection_over_Union(self): 62 | # FWIOU = [(TP+FN)/(TP+FP+TN+FN)] *[TP / (TP + FP + FN)] 63 | freq = np.sum(self.confusion_matrix, axis=1) / np.sum(self.confusion_matrix) 64 | iu = np.diag(self.confusion_matrix) / ( 65 | np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) - 66 | np.diag(self.confusion_matrix)) 67 | FWIoU = (freq[freq > 0] * iu[freq > 0]).sum() 68 | return FWIoU 69 | 70 | def addBatch(self, imgPredict, imgLabel): 71 | assert imgPredict.shape == imgLabel.shape 72 | self.confusionMatrix += self.genConfusionMatrix(imgPredict, imgLabel) 73 | 74 | def reset(self): 75 | self.confusionMatrix = np.zeros((self.numClass, self.numClass)) 76 | 77 | 78 | if __name__ == '__main__': 79 | 80 | with open('./all_vallist.txt', 'r') as li: 81 | data_list = [line.replace('\n', '.png') for line in li] 82 | data_list.sort() 83 | print(f'The number of val_data:{len(data_list)}') 84 | 85 | model = ['MST++', 'MPRnet', 'HSCNN++', 'HRnet', 'HDnet'] 86 | file_path = './res/vi low -0.2/all/' 87 | epoch = 20 88 | 89 | for mod in model: 90 | print(f'----------now: {mod}----------') 91 | metric = SegmentationMetric(2) # 类别数 92 | 93 | for i in range(len(data_list)): 94 | pred_data = cv2.imread(os.path.join(file_path, f'{mod}/{epoch}', data_list[i])) 95 | pred_data = np.transpose(pred_data, [2, 0, 1]) # uint8 (3, 340, 340) 96 | pred_data = pred_data[:][:][0] 97 | # print(pred_data.dtype) 98 | 99 | gt_data = cv2.imread(os.path.join(f'./miou/all/Val_GT', data_list[i])) 100 | gt_data = np.transpose(gt_data, [2, 0, 1]) # uint8 (3, 340, 340) 101 | gt_data = gt_data[:][:][0] 102 | # print(rgb_data.dtype) 103 | 104 | # print(f'{data_list[i]}') 105 | # imgPredict = np.array([[0,0,2], 106 | # [0,1,3], 107 | # [2,1,3]]) # 可直接换成预测图片 108 | # imgLabel = np.array([[0,1,2], 109 | # [0,1,2], 110 | # [2,1,3]]) # 可直接换成标注图片 111 | imgPredict = np.where(pred_data > 2, 1, 0) 112 | imgLabel = np.where(gt_data > 2, 1, 0) 113 | metric.addBatch(imgPredict, imgLabel) 114 | 115 | print('ConfusionMatrix :') 116 | print(numpy.transpose(metric.confusionMatrix)) # numpy.transpose() 矩阵转置 117 | 118 | print('Add:') 119 | print(numpy.sum(numpy.transpose(metric.confusionMatrix), axis=0)) 120 | 121 | print('%:') 122 | print(numpy.transpose(metric.confusionMatrix) / numpy.sum(numpy.transpose(metric.confusionMatrix), axis=0)) 123 | # print('ConfusionMatrix :') 124 | # print(numpy.transpose(metric.confusionMatrix)) 125 | 126 | pa = metric.pixelAccuracy() 127 | cpa = metric.classPixelAccuracy() 128 | mpa = metric.meanPixelAccuracy() 129 | mIoU = metric.meanIntersectionOverUnion() 130 | print('pa is : %f' % pa) 131 | 132 | print('cpa is :') # 列表 133 | print(cpa) 134 | 135 | print('mpa is : %f' % mpa) 136 | print('mIoU is : %f' % mIoU) 137 | 138 | with open(os.path.join(file_path, f'{mod}/{epoch}', "res.txt"), mode='w', encoding='utf-8') as f: 139 | f.write(os.path.join(file_path, f'{mod}/{epoch}')) 140 | f.write('\n') 141 | f.write('ConfusionMatrix :\n') 142 | f.write(str(numpy.transpose(metric.confusionMatrix))) 143 | f.write('\n') 144 | f.write('%:\n') 145 | f.write(str(numpy.transpose(metric.confusionMatrix) / numpy.sum(numpy.transpose(metric.confusionMatrix), axis=0))) 146 | f.write('\n') 147 | f.write('pa is : %f\n' % pa) 148 | f.write('cpa is :\n') # 列表 149 | f.write(str(cpa)) 150 | f.write('\n') 151 | f.write('mpa is : %f\n' % mpa) 152 | f.write('mIoU is : %f\n' % mIoU) 153 | 154 | 155 | 156 | 157 | 158 | 159 | # GT 160 | # 0 1 2 3 161 | # 0[[2. 1. 0. 0.] 162 | # 1 [0. 2. 0. 0.] 163 | # P 2 [0. 0. 2. 0.] 164 | # 3 [0. 0. 1. 1.]] 165 | # pa is : 0.777778 166 | # cpa is : 167 | # [1. 0.66666667 0.66666667 1. ] 168 | # mpa is : 0.833333 169 | # mIoU is : 0.625000 -------------------------------------------------------------------------------- /batch_predict.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import cv2 3 | import os 4 | import numpy as np 5 | from model import MST_Plus_Plus 6 | from other_model import HRNET, HSCNN_Plus, HDNet, MPRNet 7 | from utils import outi 8 | from exp import outi_pers 9 | from PIL import Image 10 | import matplotlib.pyplot as plt 11 | from osgeo import gdal 12 | from utils import initialize_logger 13 | 14 | 15 | def create_model(name): 16 | if name == 'MST++': 17 | model = MST_Plus_Plus(in_channels=3, out_channels=4, n_feat=4) 18 | elif name == 'HRnet': 19 | model = HRNET.SGN() # 31689284 final // batch_size = 1 20 | elif name == 'HSCNN++': 21 | model = HSCNN_Plus.HSCNN_Plus() # 299584 final // batch_size = 2 22 | elif name == 'HDnet': 23 | model = HDNet.HDNet() # 2647552 final 24 | elif name == 'MPRnet': 25 | model = MPRNet.MPRNet() # 60349 final 26 | else: 27 | print(f'Method {name} is not defined !!!!') 28 | 29 | return model 30 | 31 | 32 | def parse_args(): 33 | import argparse 34 | parser = argparse.ArgumentParser(description="Predict") 35 | # parser.add_argument('--pretrained_model_path', type=str, default='./pred/all/HDnet/model/net_5epoch.pth') 36 | # parser.add_argument("--outf", type=str, default='./res/all/HDnet/005/', help='path MSI files') 37 | parser.add_argument('--rgb_dir', type=str, default='./hy_seg1/') 38 | 39 | # model_name : MST++ HRnet HSCNN++ HDnet MPRnet 40 | parser.add_argument("--model_name", type=str, default='HSCNN++', help='model name') 41 | 42 | args = parser.parse_args() 43 | 44 | return args 45 | 46 | 47 | def gen_seg(tif_file, name, out_path): 48 | msi_data = gdal.Open(tif_file).ReadAsArray() 49 | msi_data = torch.tensor(msi_data) # 转tensor 50 | 51 | # total_sum = torch.sum(msi_data[2,...]) 52 | # print(total_sum) 53 | msi_data = msi_data.int() 54 | 55 | # ============fusion based on weight============ 56 | s1 = 0 57 | # 0:gre 1:red 2:reg 3:nir 58 | x1, x2, x3, x4 = -0.5, -1, 1, 1 59 | s1 = x1 * msi_data[0, :, :] + x2 * msi_data[1, :, :] + x3 * msi_data[2, :, :] + x4 * msi_data[3, :, :] 60 | im_gray1 = np.array(s1.numpy()) 61 | avg_gray = np.average(im_gray1) 62 | 63 | print(f'b{avg_gray}') 64 | 65 | # 聚合 66 | if avg_gray < 50: 67 | avg_gray = avg_gray * (-(1 + (55-avg_gray)/50) if avg_gray<0 else (1 + (55-avg_gray)/50)) 68 | elif avg_gray > 105: 69 | avg_gray = avg_gray * 0.95 70 | 71 | print(avg_gray) 72 | im_gray2 = np.where(im_gray1 > avg_gray, 255, 0) 73 | 74 | # ============fusion based on VI============ 75 | # vi = (msi_data[0, :, :] - msi_data[3, :, :])/(msi_data[0, :, :] + msi_data[3, :, :]) 76 | # # vi = ((2 * msi_data[0, :, :]) - (msi_data[2, :, :] + msi_data[3, :, :]))/((2 * msi_data[0, :, :]) + (msi_data[2, :, :] + msi_data[3, :, :])) 77 | # # print(vi) 78 | # im_gray2 = np.where(vi < -0.2, 255, 0) 79 | 80 | # ============save result============ 81 | data = np.array(im_gray2, dtype='uint8') 82 | cv2.imwrite(os.path.join(out_path, f'{name}.png'), data) 83 | print("save: " + os.path.join(out_path, f'{name}.png')) 84 | 85 | 86 | def show_seg(rgb, seg, name, out_path): 87 | # ==============no.1============== 88 | # imgfile = './pred/avg_test1_12pth/test2.jpg' 89 | # pngfile = './pred/avg_test1_12pth/2+3.png' 90 | # 91 | # img = cv2.imread(imgfile, 1) 92 | # mask = cv2.imread(pngfile, 0) 93 | # 94 | # contours, _ = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 95 | # cv2.drawContours(img, contours, -1, (0, 0, 255), 1) 96 | # 97 | # img = img[:, :, ::-1] 98 | # img[..., 2] = np.where(mask == 1, 255, img[..., 2]) 99 | # 100 | # plt.imshow(img) 101 | # plt.savefig("./pred/avg_test1_12pth/result_2+3.png") 102 | # plt.show() 103 | 104 | # ==============no.2============== 105 | image1 = Image.open(rgb) 106 | image2 = Image.open(seg) 107 | 108 | plt.figure() 109 | 110 | plt.subplot(221) 111 | plt.imshow(image1) 112 | 113 | plt.subplot(222) 114 | plt.imshow(image2) 115 | 116 | plt.subplot(223) 117 | plt.imshow(image1) 118 | plt.imshow(image2, alpha=0.5) 119 | 120 | plt.savefig(os.path.join(out_path, f'{name}_c.png')) 121 | 122 | plt.close() 123 | # plt.show() 124 | 125 | # ==============no.3============== 126 | # image1 = Image.open(rgb) 127 | # image2 = Image.open(seg) 128 | # 129 | # image1 = image1.convert('RGBA') 130 | # image2 = image2.convert('RGBA') 131 | # 132 | # # 两幅图像进行合并时,按公式:blended_img = img1 * (1 – alpha) + img2* alpha 进行 133 | # image = Image.blend(image1, image2, 0.3) 134 | # # image.save(os.path.join(out_path, f'{name}_c.png')) 135 | # image.show() 136 | 137 | 138 | def main(args, pretrained_model_path, outf): 139 | model = create_model(args.model_name) 140 | 141 | if pretrained_model_path is not None: 142 | print(f'load model from {pretrained_model_path}') 143 | # If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') 144 | checkpoint = torch.load(pretrained_model_path, map_location=torch.device('cpu')) 145 | model.load_state_dict(checkpoint['state_dict']) 146 | 147 | # # logging 148 | # log_dir = os.path.join(outf, f'{args.model_name}{pretrained_model_path}.log') 149 | # logger = initialize_logger(log_dir) 150 | 151 | rgb_dir = args.rgb_dir 152 | 153 | for rgb in os.listdir(rgb_dir): 154 | name = rgb.split(".")[0] 155 | 156 | rgb_data = cv2.imread(os.path.join(rgb_dir, rgb)) # uint8 (340, 340, 3) 157 | rgb_data = np.transpose(rgb_data, [2, 0, 1]) # uint8 (3, 340, 340) 158 | rgb_data = np.float32(rgb_data) 159 | rgb_data = torch.tensor(rgb_data) # 转tensor 160 | rgb_data = rgb_data.unsqueeze(0) # (1, 3, 340, 340) 161 | # print(rgb_data.dtype) 162 | 163 | MSI = model(rgb_data) 164 | 165 | outi(MSI, outf, rgb.split('.')[0]) 166 | # outi_pers(MSI, outf, rgb.split('.')[0]) 167 | 168 | gen_seg(os.path.join(outf, f'{name}out.TIF'), name, outf) 169 | 170 | show_seg(os.path.join(rgb_dir, rgb), os.path.join(outf, f'{name}.png'), name, outf) 171 | 172 | 173 | if __name__ == '__main__': 174 | args = parse_args() 175 | 176 | pdm_path = './model_zoo/hscnn+_10epoch.pth' 177 | outf = './hy_seg1/' 178 | 179 | if not os.path.exists(outf): 180 | os.makedirs(outf) 181 | 182 | main(args, pdm_path, outf) 183 | 184 | # for epoch in range(0, 100, 5): 185 | # pdm_path = f'./pred/all/{args.model_name}/model/net_{epoch}epoch.pth' 186 | # outf = f'./res/all/{args.model_name}test1/{epoch}/' 187 | # 188 | # if not os.path.exists(outf): 189 | # os.makedirs(outf) 190 | # 191 | # main(args, pdm_path, outf) 192 | 193 | # for epoch in range(100, 300, 20): 194 | # pdm_path = f'./pred/all/{args.model_name}/model/net_{epoch}epoch.pth' 195 | # outf = f'./res/all/{args.model_name}test1/{epoch}/' 196 | # 197 | # if not os.path.exists(outf): 198 | # os.makedirs(outf) 199 | # 200 | # main(args, pdm_path, outf) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from data import TrainDataset, ValidDataset 2 | from model import MST_Plus_Plus 3 | from utils import Loss_MRAE, Loss_RMSE, Loss_PSNR, AverageMeter, initialize_logger, time2file_name, save_checkpoint, T_Loss_PSNR, T_Loss_SSIM, outi 4 | from other_model import AWAN, HRNET, HSCNN_Plus, MIRNet, HDNet, MPRNet 5 | from other_model import MST_o, MPRNet_o, HSCNN_Plus_o, HDNet_o, HRNET_o 6 | 7 | import torch 8 | import torch.backends.cudnn as cudnn 9 | import datetime 10 | import os 11 | 12 | 13 | def create_model(name): 14 | 15 | if name == 'mst++': 16 | model = MST_Plus_Plus(in_channels=3, out_channels=4, n_feat=4) 17 | elif name == 'hrnet': 18 | model = HRNET.SGN() # 31689284 final // batch_size = 1 19 | elif name == 'hscnn_plus': 20 | model = HSCNN_Plus.HSCNN_Plus() # 299584 final // batch_size = 2 21 | elif name == 'hdnet': 22 | model = HDNet.HDNet() # 2647552 final 23 | elif name == 'mprnet': 24 | model = MPRNet.MPRNet() # 60349 final 25 | 26 | elif name == 'mprnet_o': 27 | model = MPRNet_o.MPRNet() # 60349 final 28 | elif name == 'hscnn_plus_o': 29 | model = HSCNN_Plus_o.HSCNN_Plus() # 60349 final 30 | elif name == 'hdnet_o': 31 | model = HDNet_o.HDNet() # 60349 final 32 | elif name == 'hrnet_o': 33 | model = HRNET_o.SGN() # 60349 final 34 | elif name == 'mst_o': 35 | model = MST_o.MST_Plus_Plus() # 60349 final 36 | else: 37 | print(f'Method {name} is not defined !!!!') 38 | 39 | return model 40 | 41 | 42 | def try_gpu(i=0): 43 | if torch.cuda.device_count() >= i+1: 44 | cudnn.benchmark = True # 匹配高效算法 增加运行效率 45 | return torch.device(f'cuda:{i}') 46 | return torch.device('cpu') 47 | 48 | 49 | def mk_outfile_dir(): 50 | # output path 51 | date_time = str(datetime.datetime.now()) 52 | date_time = time2file_name(date_time) 53 | args.outf = args.outf + date_time 54 | if not os.path.exists(args.outf): 55 | os.makedirs(args.outf) 56 | 57 | 58 | def main(args): 59 | 60 | # 初始化log 61 | mk_outfile_dir() 62 | # logging 63 | log_dir = os.path.join(args.outf, 'train.log') 64 | logger = initialize_logger(log_dir) 65 | 66 | # load数据 67 | train_dataset = TrainDataset(args.data_root) 68 | val_dataset = ValidDataset(args.data_root) 69 | 70 | epochs = args.end_epoch 71 | batch_size = args.batch_size 72 | num_workers = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) 73 | 74 | train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 75 | batch_size=batch_size, 76 | num_workers=0, 77 | shuffle=True, 78 | pin_memory=True # 提高数据从cpu到gpu的传输效率 79 | ) 80 | 81 | val_loader = torch.utils.data.DataLoader(dataset=val_dataset, 82 | batch_size=1, 83 | num_workers=0, 84 | pin_memory=True 85 | ) 86 | 87 | 88 | # create model 89 | model = create_model(args.model_name).to(try_gpu()) 90 | print('Parameters number is ', sum(param.numel() for param in model.parameters())) # 计算参数量 91 | 92 | # create optimizer and scheduler 93 | optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, betas=(0.9, 0.999)) 94 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, len(train_loader)*epochs, eta_min=1e-6) 95 | 96 | # recode general info 97 | logger.info(f'Model:{args.model_name}, Batch_size:{batch_size}, Epoch:{epochs}, Dataset:{args.data_root}, Parameters number is {sum(param.numel() for param in model.parameters())}') 98 | 99 | # recode criterion 100 | recode = 100 101 | 102 | for epoch in range(0, epochs): 103 | 104 | losses = AverageMeter() 105 | 106 | criterion_mrae = Loss_MRAE() 107 | criterion_psnr = T_Loss_PSNR() 108 | criterion_ssim = T_Loss_SSIM() 109 | 110 | # ============================= train ============================= 111 | model.train() 112 | print() 113 | for i, (image, target) in enumerate(train_loader): 114 | image = image.to(torch.float32).to(try_gpu()) 115 | target = target.to(torch.float32).to(try_gpu()) 116 | 117 | # 优化策略 Adam 118 | lr = optimizer.param_groups[0]['lr'] 119 | optimizer.zero_grad() # 重置梯度 120 | output = model(image) # input进模型 121 | loss_p = criterion_psnr(target, output) # 计算psnr 122 | loss_s = criterion_ssim(target, output) # 计算ssim 123 | 124 | loss = criterion_mrae(output, target) + (1 - loss_s) # 计算loss mrae + (1-ssim) 125 | loss.backward() # 反向传播 126 | 127 | optimizer.step() # 优化网络参数,如权重等 128 | scheduler.step() # 优化学习率等参数 129 | losses.update(loss.data) 130 | 131 | print(f'[train epoch:{epoch + 1}/{epochs}] batch:{i + 1}/{len(train_loader)}, device:{image.device}, lr:{lr:.9f}, loss:{losses.avg:.9f}, ssim:{loss_s:.9f}, psnr:{loss_p:.9f}') 132 | 133 | 134 | # ============================= val ============================= 135 | model.eval() 136 | print() 137 | losses_mrae = AverageMeter() 138 | losses_psnr = AverageMeter() 139 | losses_ssim = AverageMeter() 140 | for i, (image, target) in enumerate(val_loader): 141 | image = image.to(torch.float32).to(try_gpu()) 142 | target = target.to(torch.float32).to(try_gpu()) 143 | with torch.no_grad(): 144 | output = model(image) 145 | 146 | # save a output img in valid 147 | # if i == 0 : 148 | # outi(output, args.outf, epoch) 149 | 150 | loss_mrae = criterion_mrae(output, target) # 计算mrae 151 | loss_ssim = criterion_ssim(target, output) # 计算ssim 152 | loss_psnr = criterion_psnr(target, output) # 计算psnr 153 | 154 | losses_mrae.update(loss_mrae.data) 155 | losses_ssim.update(loss_ssim) 156 | losses_psnr.update(loss_psnr) 157 | 158 | # Save model 159 | if (epoch % 5 == 0) or (losses_mrae.avg < recode): 160 | print(f'Saving to {args.outf}') 161 | save_checkpoint(args.outf, epoch, model, optimizer) 162 | if losses_mrae.avg < recode: 163 | recode = losses_mrae.avg 164 | 165 | # logging loss 166 | print(f'[valid epoch:{epoch + 1}/{epochs}] device:{image.device}, lr:{lr:.9f},Train Loss:{losses.avg:.9f}, Test mrae:{losses_mrae.avg:.9f}, Test ssim:{losses_ssim.avg:.9f}, Test psnr:{losses_psnr.avg:.9f}') 167 | logger.info(f'[valid epoch:{epoch + 1}/{epochs}] device:{image.device}, lr:{lr:.9f},Train Loss:{losses.avg:.9f}, Test mrae:{losses_mrae.avg:.9f}, Test ssim:{losses_ssim.avg:.9f}, Test psnr:{losses_psnr.avg:.9f}') 168 | 169 | def parse_args(): 170 | import argparse 171 | parser = argparse.ArgumentParser(description="RGB to Multispectral") 172 | parser.add_argument("--batch_size", type=int, default=2, help="batch size") 173 | parser.add_argument("--end_epoch", type=int, default=50, help="number of epochs") 174 | parser.add_argument("--init_lr", type=float, default=4e-4, help="initial learning rate") 175 | parser.add_argument("--outf", type=str, default='./log/', help='path log files') 176 | parser.add_argument("--data_root", type=str, default='./dataset/') 177 | parser.add_argument("--model_name", type=str, default='mst_o', help='model name') 178 | 179 | args = parser.parse_args() 180 | 181 | return args 182 | 183 | 184 | # 'mprnet_o' 185 | # 'hscnn_plus_o'* not enough memory 186 | # 'hdnet_o' 187 | # 'hrnet_o' 188 | # 'mst_o' 189 | 190 | if __name__ == '__main__': 191 | args = parse_args() 192 | main(args) 193 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | from einops import rearrange 5 | import math 6 | import warnings 7 | from torch.nn.init import _calculate_fan_in_and_fan_out 8 | 9 | 10 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 11 | def norm_cdf(x): 12 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 13 | 14 | if (mean < a - 2 * std) or (mean > b + 2 * std): 15 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 16 | "The distribution of values may be incorrect.", 17 | stacklevel=2) 18 | with torch.no_grad(): 19 | l = norm_cdf((a - mean) / std) 20 | u = norm_cdf((b - mean) / std) 21 | tensor.uniform_(2 * l - 1, 2 * u - 1) 22 | tensor.erfinv_() 23 | tensor.mul_(std * math.sqrt(2.)) 24 | tensor.add_(mean) 25 | tensor.clamp_(min=a, max=b) 26 | return tensor 27 | 28 | 29 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 30 | # type: (Tensor, float, float, float, float) -> Tensor 31 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 32 | 33 | 34 | def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'): 35 | fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) 36 | if mode == 'fan_in': 37 | denom = fan_in 38 | elif mode == 'fan_out': 39 | denom = fan_out 40 | elif mode == 'fan_avg': 41 | denom = (fan_in + fan_out) / 2 42 | variance = scale / denom 43 | if distribution == "truncated_normal": 44 | trunc_normal_(tensor, std=math.sqrt(variance) / .87962566103423978) 45 | elif distribution == "normal": 46 | tensor.normal_(std=math.sqrt(variance)) 47 | elif distribution == "uniform": 48 | bound = math.sqrt(3 * variance) 49 | tensor.uniform_(-bound, bound) 50 | else: 51 | raise ValueError(f"invalid distribution {distribution}") 52 | 53 | 54 | def lecun_normal_(tensor): 55 | variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal') 56 | 57 | 58 | class PreNorm(nn.Module): 59 | def __init__(self, dim, fn): 60 | super().__init__() 61 | self.fn = fn 62 | self.norm = nn.LayerNorm(dim) 63 | 64 | def forward(self, x, *args, **kwargs): 65 | x = self.norm(x) 66 | return self.fn(x, *args, **kwargs) 67 | 68 | 69 | class GELU(nn.Module): 70 | def forward(self, x): 71 | return F.gelu(x) 72 | 73 | def conv(in_channels, out_channels, kernel_size, bias=False, padding = 1, stride = 1): 74 | return nn.Conv2d( 75 | in_channels, out_channels, kernel_size, 76 | padding=(kernel_size//2), bias=bias, stride=stride) 77 | 78 | 79 | def shift_back(inputs,step=2): # input [bs,28,256,310] output [bs, 28, 256, 256] 80 | [bs, nC, row, col] = inputs.shape # bs:batch_size nC:num_channel 81 | down_sample = 256//row # 256//256 82 | step = float(step)/float(down_sample*down_sample) 83 | out_col = row 84 | for i in range(nC): 85 | inputs[:,i,:,:out_col] = \ 86 | inputs[:,i,:,int(step*i):int(step*i)+out_col] 87 | return inputs[:, :, :, :out_col] 88 | 89 | # MS-MSA -> MST 90 | # S-MSA -> MST++ 91 | # 此方法已将MM去掉 92 | class MS_MSA(nn.Module): 93 | def __init__( 94 | self, 95 | dim, 96 | dim_head, 97 | heads, 98 | ): 99 | super().__init__() 100 | self.num_heads = heads 101 | self.dim_head = dim_head 102 | self.to_q = nn.Linear(dim, dim_head * heads, bias=False) 103 | self.to_k = nn.Linear(dim, dim_head * heads, bias=False) 104 | self.to_v = nn.Linear(dim, dim_head * heads, bias=False) 105 | self.rescale = nn.Parameter(torch.ones(heads, 1, 1)) 106 | self.proj = nn.Linear(dim_head * heads, dim, bias=True) 107 | self.pos_emb = nn.Sequential( 108 | nn.Conv2d(dim, dim, 3, 1, 1, bias=False, groups=dim), 109 | GELU(), 110 | nn.Conv2d(dim, dim, 3, 1, 1, bias=False, groups=dim), 111 | ) 112 | self.dim = dim 113 | 114 | def forward(self, x_in): 115 | """ 116 | x_in: [b,h,w,c] 117 | return out: [b,h,w,c] 118 | """ 119 | b, h, w, c = x_in.shape 120 | x = x_in.reshape(b,h*w,c) 121 | q_inp = self.to_q(x) 122 | k_inp = self.to_k(x) 123 | v_inp = self.to_v(x) 124 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.num_heads), 125 | (q_inp, k_inp, v_inp)) 126 | v = v 127 | # q: b,heads,hw,c 128 | q = q.transpose(-2, -1) 129 | k = k.transpose(-2, -1) 130 | v = v.transpose(-2, -1) 131 | q = F.normalize(q, dim=-1, p=2) 132 | k = F.normalize(k, dim=-1, p=2) 133 | attn = (k @ q.transpose(-2, -1)) # A = K^T*Q 134 | attn = attn * self.rescale 135 | attn = attn.softmax(dim=-1) 136 | x = attn @ v # b,heads,d,hw 137 | x = x.permute(0, 3, 1, 2) # Transpose 138 | x = x.reshape(b, h * w, self.num_heads * self.dim_head) 139 | out_c = self.proj(x).view(b, h, w, c) 140 | out_p = self.pos_emb(v_inp.reshape(b,h,w,c).permute(0, 3, 1, 2)).permute(0, 2, 3, 1) 141 | out = out_c + out_p 142 | 143 | return out 144 | 145 | 146 | class FeedForward(nn.Module): 147 | def __init__(self, dim, mult=4): 148 | super().__init__() 149 | self.net = nn.Sequential( 150 | nn.Conv2d(dim, dim * mult, 1, 1, bias=False), 151 | GELU(), 152 | nn.Conv2d(dim * mult, dim * mult, 3, 1, 1, bias=False, groups=dim * mult), 153 | GELU(), 154 | nn.Conv2d(dim * mult, dim, 1, 1, bias=False), 155 | ) 156 | 157 | def forward(self, x): 158 | """ 159 | x: [b,h,w,c] 160 | return out: [b,h,w,c] 161 | """ 162 | out = self.net(x.permute(0, 3, 1, 2)) 163 | return out.permute(0, 2, 3, 1) 164 | 165 | 166 | class MSAB(nn.Module): 167 | def __init__( 168 | self, 169 | dim, 170 | dim_head, 171 | heads, 172 | num_blocks, 173 | ): 174 | super().__init__() 175 | self.blocks = nn.ModuleList([]) 176 | for _ in range(num_blocks): 177 | self.blocks.append(nn.ModuleList([ 178 | MS_MSA(dim=dim, dim_head=dim_head, heads=heads), 179 | PreNorm(dim, FeedForward(dim=dim)) 180 | ])) 181 | 182 | def forward(self, x): 183 | """ 184 | x: [b,c,h,w] 185 | return out: [b,c,h,w] 186 | """ 187 | x = x.permute(0, 2, 3, 1) # b h w c 188 | for (attn, ff) in self.blocks: 189 | x = attn(x) + x 190 | x = ff(x) + x 191 | out = x.permute(0, 3, 1, 2) # b c h w 192 | return out 193 | 194 | # stage(N1,N2):SAB + DownSample // UpSample + SAB 195 | # stage = 2 : (SAB + DownSample)*2 // (UpSample + SAB)*2 196 | class MST(nn.Module): 197 | 198 | def __init__(self, in_dim=4, out_dim=4, dim=4, stage=2, num_blocks=[2,4,4]): # in_dim=31, out_dim=31, dim=31 199 | super(MST, self).__init__() 200 | self.dim = dim 201 | self.stage = stage 202 | 203 | # Input projection 204 | self.embedding = nn.Conv2d(in_dim, self.dim, 3, 1, 1, bias=False) 205 | 206 | # Encoder 207 | self.encoder_layers = nn.ModuleList([]) 208 | dim_stage = dim 209 | for i in range(stage): 210 | self.encoder_layers.append(nn.ModuleList([ 211 | MSAB( 212 | dim=dim_stage, num_blocks=num_blocks[i], dim_head=dim, heads=dim_stage // dim), 213 | nn.Conv2d(dim_stage, dim_stage * 2, 4, 2, 1, bias=False), 214 | ])) 215 | dim_stage *= 2 216 | 217 | # Bottleneck 218 | self.bottleneck = MSAB( 219 | dim=dim_stage, dim_head=dim, heads=dim_stage // dim, num_blocks=num_blocks[-1]) 220 | 221 | # Decoder 222 | self.decoder_layers = nn.ModuleList([]) 223 | for i in range(stage): 224 | self.decoder_layers.append(nn.ModuleList([ 225 | nn.ConvTranspose2d(dim_stage, dim_stage // 2, stride=2, kernel_size=2, padding=0, output_padding=0), 226 | nn.Conv2d(dim_stage, dim_stage // 2, 1, 1, bias=False), 227 | MSAB( 228 | dim=dim_stage // 2, num_blocks=num_blocks[stage - 1 - i], dim_head=dim, 229 | heads=(dim_stage // 2) // dim), 230 | ])) 231 | dim_stage //= 2 232 | 233 | # Output projection 234 | self.mapping = nn.Conv2d(self.dim, out_dim, 3, 1, 1, bias=False) 235 | 236 | #### activation function 237 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 238 | self.apply(self._init_weights) 239 | 240 | def _init_weights(self, m): 241 | if isinstance(m, nn.Linear): 242 | trunc_normal_(m.weight, std=.02) 243 | if isinstance(m, nn.Linear) and m.bias is not None: 244 | nn.init.constant_(m.bias, 0) 245 | elif isinstance(m, nn.LayerNorm): 246 | nn.init.constant_(m.bias, 0) 247 | nn.init.constant_(m.weight, 1.0) 248 | 249 | def forward(self, x): 250 | """ 251 | x: [b,c,h,w] 252 | return out:[b,c,h,w] 253 | """ 254 | 255 | # Embedding 256 | fea = self.embedding(x) 257 | 258 | # Encoder 259 | fea_encoder = [] 260 | for (MSAB, FeaDownSample) in self.encoder_layers: 261 | fea = MSAB(fea) 262 | fea_encoder.append(fea) 263 | fea = FeaDownSample(fea) 264 | 265 | # Bottleneck 266 | fea = self.bottleneck(fea) 267 | 268 | # Decoder 269 | for i, (FeaUpSample, Fution, LeWinBlcok) in enumerate(self.decoder_layers): 270 | fea = FeaUpSample(fea) 271 | fea = Fution(torch.cat([fea, fea_encoder[self.stage-1-i]], dim=1)) 272 | fea = LeWinBlcok(fea) 273 | 274 | # Mapping 275 | out = self.mapping(fea) + x 276 | 277 | return out 278 | 279 | 280 | class MST_Plus_Plus(nn.Module): 281 | def __init__(self, in_channels=3, out_channels=4, n_feat=4, stage=3): # in_channels=3, out_channels=31, n_feat=31 282 | super(MST_Plus_Plus, self).__init__() 283 | self.stage = stage # stage:num of SST 284 | 285 | self.conv_in = nn.Conv2d(in_channels, n_feat, kernel_size=3, padding=(3 - 1) // 2,bias=False) 286 | 287 | modules_body = [MST(dim=4, stage=2, num_blocks=[1,1,1]) for _ in range(stage)] # MST(dim=31, ...) 288 | self.body = nn.Sequential(*modules_body) 289 | 290 | self.conv_out = nn.Conv2d(n_feat, out_channels, kernel_size=3, padding=(3 - 1) // 2,bias=False) 291 | 292 | def forward(self, x): 293 | """ 294 | x: [b,c,h,w] 295 | return out:[b,c,h,w] 296 | """ 297 | b, c, h_inp, w_inp = x.shape 298 | hb, wb = 8, 8 299 | pad_h = (hb - h_inp % hb) % hb 300 | pad_w = (wb - w_inp % wb) % wb 301 | x = F.pad(x, [0, pad_w, 0, pad_h], mode='reflect') 302 | x = self.conv_in(x) 303 | h = self.body(x) 304 | h = self.conv_out(h) 305 | h += x 306 | return h[:, :, :h_inp, :w_inp] 307 | 308 | -------------------------------------------------------------------------------- /other_model/HDNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | def default_conv(in_channels, out_channels, kernel_size, bias=True): 4 | return nn.Conv2d( 5 | in_channels, out_channels, kernel_size, 6 | padding=(kernel_size//2), bias=bias) 7 | 8 | class MeanShift(nn.Conv2d): 9 | def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1): 10 | super(MeanShift, self).__init__(3, 3, kernel_size=1) 11 | std = torch.Tensor(rgb_std) 12 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) 13 | self.weight.data.div_(std.view(3, 1, 1, 1)) 14 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) 15 | self.bias.data.div_(std) 16 | self.requires_grad = False 17 | 18 | class BasicBlock(nn.Sequential): 19 | def __init__( 20 | self, in_channels, out_channels, kernel_size, stride=1, bias=False, 21 | bn=True, act=nn.ReLU(True)): 22 | 23 | m = [nn.Conv2d( 24 | in_channels, out_channels, kernel_size, 25 | padding=(kernel_size//2), stride=stride, bias=bias) 26 | ] 27 | if bn: m.append(nn.BatchNorm2d(out_channels)) 28 | if act is not None: m.append(act) 29 | super(BasicBlock, self).__init__(*m) 30 | 31 | class ResBlock(nn.Module): 32 | def __init__( 33 | self, conv=default_conv, n_feat=31, kernel_size=3, 34 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1): 35 | 36 | super(ResBlock, self).__init__() 37 | m = [] 38 | for i in range(2): 39 | m.append(conv(n_feat, n_feat, kernel_size, bias=bias)) 40 | if bn: m.append(nn.BatchNorm2d(n_feat)) 41 | if i == 0: m.append(act) 42 | 43 | self.body = nn.Sequential(*m) 44 | self.res_scale = res_scale 45 | 46 | def forward(self, x): 47 | res = self.body(x).mul(self.res_scale) 48 | res += x 49 | 50 | return res 51 | 52 | class Upsampler(nn.Sequential): 53 | def __init__(self, conv, scale, n_feat, bn=False, act=False, bias=True): 54 | 55 | m = [] 56 | if (scale & (scale - 1)) == 0: # Is scale = 2^n? 57 | for _ in range(int(math.log(scale, 2))): 58 | m.append(conv(n_feat, 4 * n_feat, 3, bias)) 59 | m.append(nn.PixelShuffle(2)) 60 | if bn: m.append(nn.BatchNorm2d(n_feat)) 61 | if act: m.append(act()) 62 | elif scale == 3: 63 | m.append(conv(n_feat, 9 * n_feat, 3, bias)) 64 | m.append(nn.PixelShuffle(3)) 65 | if bn: m.append(nn.BatchNorm2d(n_feat)) 66 | if act: m.append(act()) 67 | else: 68 | raise NotImplementedError 69 | 70 | super(Upsampler, self).__init__(*m) 71 | 72 | ## add SELayer 73 | class SELayer(nn.Module): 74 | def __init__(self, channel, reduction=16): 75 | super(SELayer, self).__init__() 76 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 77 | self.conv_du = nn.Sequential( 78 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), 79 | nn.ReLU(inplace=True), 80 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), 81 | nn.Sigmoid() 82 | ) 83 | 84 | def forward(self, x): 85 | y = self.avg_pool(x) 86 | y = self.conv_du(y) 87 | return x * y 88 | 89 | ## add SEResBlock 90 | class SEResBlock(nn.Module): 91 | def __init__( 92 | self, conv, n_feat, kernel_size, reduction, 93 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1): 94 | 95 | super(SEResBlock, self).__init__() 96 | modules_body = [] 97 | for i in range(2): 98 | modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) 99 | if bn: modules_body.append(nn.BatchNorm2d(n_feat)) 100 | if i == 0: modules_body.append(act) 101 | modules_body.append(SELayer(n_feat, reduction)) 102 | self.body = nn.Sequential(*modules_body) 103 | self.res_scale = res_scale 104 | 105 | def forward(self, x): 106 | res = self.body(x) 107 | #res = self.body(x).mul(self.res_scale) 108 | res += x 109 | 110 | return res 111 | 112 | 113 | _NORM_BONE = False 114 | 115 | def constant_init(module, val, bias=0): 116 | if hasattr(module, 'weight') and module.weight is not None: 117 | nn.init.constant_(module.weight, val) 118 | if hasattr(module, 'bias') and module.bias is not None: 119 | nn.init.constant_(module.bias, bias) 120 | 121 | 122 | def kaiming_init(module, 123 | a=0, 124 | mode='fan_out', 125 | nonlinearity='relu', 126 | bias=0, 127 | distribution='normal'): 128 | assert distribution in ['uniform', 'normal'] 129 | if distribution == 'uniform': 130 | nn.init.kaiming_uniform_( 131 | module.weight, a=a, mode=mode, nonlinearity=nonlinearity) 132 | else: 133 | nn.init.kaiming_normal_( 134 | module.weight, a=a, mode=mode, nonlinearity=nonlinearity) 135 | if hasattr(module, 'bias') and module.bias is not None: 136 | nn.init.constant_(module.bias, bias) 137 | 138 | # depthwise-separable convolution (DSC) 139 | class DSC(nn.Module): 140 | 141 | def __init__(self, nin: int) -> None: 142 | super(DSC, self).__init__() 143 | self.conv_dws = nn.Conv2d( 144 | nin, nin, kernel_size=1, stride=1, padding=0, groups=nin 145 | ) 146 | self.bn_dws = nn.BatchNorm2d(nin, momentum=0.9) 147 | self.relu_dws = nn.ReLU(inplace=False) 148 | 149 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=1, padding=1) 150 | 151 | self.conv_point = nn.Conv2d( 152 | nin, 1, kernel_size=1, stride=1, padding=0, groups=1 153 | ) 154 | self.bn_point = nn.BatchNorm2d(1, momentum=0.9) 155 | self.relu_point = nn.ReLU(inplace=False) 156 | 157 | self.softmax = nn.Softmax(dim=2) 158 | 159 | def forward(self, x: torch.Tensor) -> torch.Tensor: 160 | out = self.conv_dws(x) 161 | out = self.bn_dws(out) 162 | out = self.relu_dws(out) 163 | 164 | out = self.maxpool(out) 165 | 166 | out = self.conv_point(out) 167 | out = self.bn_point(out) 168 | out = self.relu_point(out) 169 | 170 | m, n, p, q = out.shape 171 | out = self.softmax(out.view(m, n, -1)) 172 | out = out.view(m, n, p, q) 173 | 174 | out = out.expand(x.shape[0], x.shape[1], x.shape[2], x.shape[3]) 175 | 176 | out = torch.mul(out, x) 177 | 178 | out = out + x 179 | 180 | return out 181 | 182 | # Efficient Feature Fusion(EFF) 183 | class EFF(nn.Module): 184 | def __init__(self, nin: int, nout: int, num_splits: int) -> None: 185 | super(EFF, self).__init__() 186 | 187 | assert nin % num_splits == 0 188 | 189 | self.nin = nin 190 | self.nout = nout 191 | self.num_splits = num_splits 192 | self.subspaces = nn.ModuleList( 193 | [DSC(int(self.nin / self.num_splits)) for i in range(self.num_splits)] 194 | ) 195 | 196 | def forward(self, x: torch.Tensor) -> torch.Tensor: 197 | sub_feat = torch.chunk(x, self.num_splits, dim=1) 198 | out = [] 199 | for idx, l in enumerate(self.subspaces): 200 | out.append(self.subspaces[idx](sub_feat[idx])) 201 | out = torch.cat(out, dim=1) 202 | 203 | return out 204 | 205 | 206 | # spatial-spectral domain attention learning(SDL) 207 | class SDL_attention(nn.Module): 208 | def __init__(self, inplanes, planes, kernel_size=1, stride=1): 209 | super(SDL_attention, self).__init__() 210 | 211 | self.inplanes = inplanes 212 | self.inter_planes = planes // 2 213 | self.planes = planes 214 | self.kernel_size = kernel_size 215 | self.stride = stride 216 | self.padding = (kernel_size-1)//2 217 | 218 | self.conv_q_right = nn.Conv2d(self.inplanes, 1, kernel_size=1, stride=stride, padding=0, bias=False) 219 | self.conv_v_right = nn.Conv2d(self.inplanes, self.inter_planes, kernel_size=1, stride=stride, padding=0, bias=False) 220 | self.conv_up = nn.Conv2d(self.inter_planes, self.planes, kernel_size=1, stride=1, padding=0, bias=False) 221 | self.softmax_right = nn.Softmax(dim=2) 222 | self.sigmoid = nn.Sigmoid() 223 | 224 | self.conv_q_left = nn.Conv2d(self.inplanes, self.inter_planes, kernel_size=1, stride=stride, padding=0, bias=False) #g 225 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 226 | self.conv_v_left = nn.Conv2d(self.inplanes, self.inter_planes, kernel_size=1, stride=stride, padding=0, bias=False) #theta 227 | self.softmax_left = nn.Softmax(dim=2) 228 | 229 | self.reset_parameters() 230 | 231 | def reset_parameters(self): 232 | kaiming_init(self.conv_q_right, mode='fan_in') 233 | kaiming_init(self.conv_v_right, mode='fan_in') 234 | kaiming_init(self.conv_q_left, mode='fan_in') 235 | kaiming_init(self.conv_v_left, mode='fan_in') 236 | 237 | self.conv_q_right.inited = True 238 | self.conv_v_right.inited = True 239 | self.conv_q_left.inited = True 240 | self.conv_v_left.inited = True 241 | # HR spatial attention 242 | def spatial_attention(self, x): 243 | input_x = self.conv_v_right(x) 244 | batch, channel, height, width = input_x.size() 245 | 246 | input_x = input_x.view(batch, channel, height * width) 247 | context_mask = self.conv_q_right(x) 248 | context_mask = context_mask.view(batch, 1, height * width) 249 | context_mask = self.softmax_right(context_mask) 250 | 251 | context = torch.matmul(input_x, context_mask.transpose(1,2)) 252 | context = context.unsqueeze(-1) 253 | context = self.conv_up(context) 254 | 255 | mask_ch = self.sigmoid(context) 256 | 257 | out = x * mask_ch 258 | 259 | return out 260 | # HR spectral attention 261 | def spectral_attention(self, x): 262 | 263 | g_x = self.conv_q_left(x) 264 | batch, channel, height, width = g_x.size() 265 | 266 | avg_x = self.avg_pool(g_x) 267 | batch, channel, avg_x_h, avg_x_w = avg_x.size() 268 | 269 | avg_x = avg_x.view(batch, channel, avg_x_h * avg_x_w).permute(0, 2, 1) 270 | theta_x = self.conv_v_left(x).view(batch, self.inter_planes, height * width) 271 | context = torch.matmul(avg_x, theta_x) 272 | context = self.softmax_left(context) 273 | context = context.view(batch, 1, height, width) 274 | 275 | mask_sp = self.sigmoid(context) 276 | 277 | out = x * mask_sp 278 | 279 | return out 280 | 281 | def forward(self, x): 282 | context_spectral = self.spectral_attention(x) 283 | context_spatial = self.spatial_attention(x) 284 | out = context_spatial + context_spectral 285 | return out 286 | 287 | 288 | class HDNet(nn.Module): 289 | 290 | def __init__(self, in_ch=3, out_ch=4, conv=default_conv): 291 | super(HDNet, self).__init__() 292 | 293 | n_resblocks = 4 294 | n_feats = 8 295 | kernel_size = 3 296 | act = nn.ReLU(True) 297 | 298 | # define head module 299 | m_head = [conv(in_ch, n_feats, kernel_size)] 300 | 301 | # define body module 302 | m_body = [ 303 | ResBlock( 304 | conv, n_feats, kernel_size, act=act, res_scale= 1 305 | ) for _ in range(n_resblocks) 306 | ] 307 | m_body.append(SDL_attention(inplanes = n_feats, planes = n_feats)) 308 | m_body.append(EFF(nin=n_feats, nout=n_feats, num_splits=4)) 309 | 310 | for i in range(1, n_resblocks): 311 | m_body.append(ResBlock( 312 | conv, n_feats, kernel_size, act=act, res_scale= 1 313 | )) 314 | 315 | m_body.append(conv(n_feats, n_feats, kernel_size)) 316 | 317 | m_tail = [conv(n_feats, out_ch, kernel_size)] 318 | 319 | self.head = nn.Sequential(*m_head) 320 | self.body = nn.Sequential(*m_body) 321 | self.tail = nn.Sequential(*m_tail) 322 | 323 | def forward(self, x): 324 | x = self.head(x) 325 | 326 | res = self.body(x) 327 | res += x 328 | 329 | x = self.tail(res) 330 | 331 | return x 332 | 333 | # frequency domain learning(FDL) 334 | class FDL(nn.Module): 335 | def __init__(self, loss_weight=1.0, alpha=1.0, patch_factor=1, ave_spectrum=False, log_matrix=False, batch_matrix=False): 336 | super(FDL, self).__init__() 337 | self.loss_weight = loss_weight 338 | self.alpha = alpha 339 | self.patch_factor = patch_factor 340 | self.ave_spectrum = ave_spectrum 341 | self.log_matrix = log_matrix 342 | self.batch_matrix = batch_matrix 343 | 344 | def tensor2freq(self, x): 345 | patch_factor = self.patch_factor 346 | _, _, h, w = x.shape 347 | assert h % patch_factor == 0 and w % patch_factor == 0, ( 348 | 'Patch factor should be divisible by image height and width') 349 | patch_list = [] 350 | patch_h = h // patch_factor 351 | patch_w = w // patch_factor 352 | for i in range(patch_factor): 353 | for j in range(patch_factor): 354 | patch_list.append(x[:, :, i * patch_h:(i + 1) * patch_h, j * patch_w:(j + 1) * patch_w]) 355 | 356 | y = torch.stack(patch_list, 1) 357 | 358 | return torch.rfft(y, 2, onesided=False, normalized=True) 359 | 360 | def loss_formulation(self, recon_freq, real_freq, matrix=None): 361 | if matrix is not None: 362 | weight_matrix = matrix.detach() 363 | else: 364 | matrix_tmp = (recon_freq - real_freq) ** 2 365 | matrix_tmp = torch.sqrt(matrix_tmp[..., 0] + matrix_tmp[..., 1]) ** self.alpha 366 | if self.log_matrix: 367 | matrix_tmp = torch.log(matrix_tmp + 1.0) 368 | 369 | if self.batch_matrix: 370 | matrix_tmp = matrix_tmp / matrix_tmp.max() 371 | else: 372 | matrix_tmp = matrix_tmp / matrix_tmp.max(-1).values.max(-1).values[:, :, :, None, None] 373 | 374 | matrix_tmp[torch.isnan(matrix_tmp)] = 0.0 375 | matrix_tmp = torch.clamp(matrix_tmp, min=0.0, max=1.0) 376 | weight_matrix = matrix_tmp.clone().detach() 377 | 378 | assert weight_matrix.min().item() >= 0 and weight_matrix.max().item() <= 1, ( 379 | 'The values of spectrum weight matrix should be in the range [0, 1], ' 380 | 'but got Min: %.10f Max: %.10f' % (weight_matrix.min().item(), weight_matrix.max().item())) 381 | 382 | tmp = (recon_freq - real_freq) ** 2 383 | freq_distance = tmp[..., 0] + tmp[..., 1] 384 | 385 | loss = weight_matrix * freq_distance 386 | return torch.mean(loss) 387 | 388 | def forward(self, pred, target, matrix=None, **kwargs): 389 | 390 | pred_freq = self.tensor2freq(pred) 391 | target_freq = self.tensor2freq(target) 392 | 393 | if self.ave_spectrum: 394 | pred_freq = torch.mean(pred_freq, 0, keepdim=True) 395 | target_freq = torch.mean(target_freq, 0, keepdim=True) 396 | 397 | return self.loss_formulation(pred_freq, target_freq, matrix) * self.loss_weight 398 | -------------------------------------------------------------------------------- /other_model/MPRNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | ########################################################################## 6 | def conv(in_channels, out_channels, kernel_size, bias=False, stride = 1): 7 | return nn.Conv2d( 8 | in_channels, out_channels, kernel_size, 9 | padding=(kernel_size//2), bias=bias, stride = stride) 10 | 11 | 12 | ########################################################################## 13 | ## Channel Attention Layer 14 | class CALayer(nn.Module): 15 | def __init__(self, channel, reduction=16, bias=False): 16 | super(CALayer, self).__init__() 17 | # global average pooling: feature --> point 18 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 19 | # feature channel downscale and upscale --> channel weight 20 | self.conv_du = nn.Sequential( 21 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=bias), 22 | nn.ReLU(inplace=True), 23 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=bias), 24 | nn.Sigmoid() 25 | ) 26 | 27 | def forward(self, x): 28 | y = self.avg_pool(x) 29 | y = self.conv_du(y) 30 | return x * y 31 | 32 | 33 | ########################################################################## 34 | ## Channel Attention Block (CAB) 35 | class CAB(nn.Module): 36 | def __init__(self, n_feat, kernel_size, reduction, bias, act): 37 | super(CAB, self).__init__() 38 | modules_body = [] 39 | modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) 40 | modules_body.append(act) 41 | modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) 42 | 43 | self.CA = CALayer(n_feat, reduction, bias=bias) 44 | self.body = nn.Sequential(*modules_body) 45 | 46 | def forward(self, x): 47 | res = self.body(x) 48 | res = self.CA(res) 49 | res += x 50 | return res 51 | 52 | ########################################################################## 53 | ## Supervised Attention Module 54 | class SAM(nn.Module): 55 | def __init__(self, n_feat, kernel_size, bias): 56 | super(SAM, self).__init__() 57 | self.conv1 = conv(n_feat, n_feat, kernel_size, bias=bias) 58 | self.conv2 = conv(n_feat, 4, kernel_size, bias=bias) 59 | self.conv3 = conv(4, n_feat, kernel_size, bias=bias) 60 | 61 | def forward(self, x, x_img): 62 | x1 = self.conv1(x) 63 | img = self.conv2(x) + x_img 64 | x2 = torch.sigmoid(self.conv3(img)) 65 | x1 = x1*x2 66 | x1 = x1+x 67 | return x1, img 68 | 69 | ########################################################################## 70 | ## U-Net 71 | 72 | class Encoder(nn.Module): 73 | def __init__(self, n_feat, kernel_size, reduction, act, bias, scale_unetfeats, csff): 74 | super(Encoder, self).__init__() 75 | 76 | self.encoder_level1 = [CAB(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(2)] 77 | self.encoder_level2 = [CAB(n_feat+scale_unetfeats, kernel_size, reduction, bias=bias, act=act) for _ in range(2)] 78 | self.encoder_level3 = [CAB(n_feat+(scale_unetfeats*2), kernel_size, reduction, bias=bias, act=act) for _ in range(2)] 79 | 80 | self.encoder_level1 = nn.Sequential(*self.encoder_level1) 81 | self.encoder_level2 = nn.Sequential(*self.encoder_level2) 82 | self.encoder_level3 = nn.Sequential(*self.encoder_level3) 83 | 84 | self.down12 = DownSample(n_feat, scale_unetfeats) 85 | self.down23 = DownSample(n_feat+scale_unetfeats, scale_unetfeats) 86 | 87 | # Cross Stage Feature Fusion (CSFF) 88 | if csff: 89 | self.csff_enc1 = nn.Conv2d(n_feat, n_feat, kernel_size=1, bias=bias) 90 | self.csff_enc2 = nn.Conv2d(n_feat+scale_unetfeats, n_feat+scale_unetfeats, kernel_size=1, bias=bias) 91 | self.csff_enc3 = nn.Conv2d(n_feat+(scale_unetfeats*2), n_feat+(scale_unetfeats*2), kernel_size=1, bias=bias) 92 | 93 | self.csff_dec1 = nn.Conv2d(n_feat, n_feat, kernel_size=1, bias=bias) 94 | self.csff_dec2 = nn.Conv2d(n_feat+scale_unetfeats, n_feat+scale_unetfeats, kernel_size=1, bias=bias) 95 | self.csff_dec3 = nn.Conv2d(n_feat+(scale_unetfeats*2), n_feat+(scale_unetfeats*2), kernel_size=1, bias=bias) 96 | 97 | def forward(self, x, encoder_outs=None, decoder_outs=None): 98 | enc1 = self.encoder_level1(x) 99 | if (encoder_outs is not None) and (decoder_outs is not None): 100 | enc1 = enc1 + self.csff_enc1(encoder_outs[0]) + self.csff_dec1(decoder_outs[0]) 101 | 102 | x = self.down12(enc1) 103 | 104 | enc2 = self.encoder_level2(x) 105 | if (encoder_outs is not None) and (decoder_outs is not None): 106 | enc2 = enc2 + self.csff_enc2(encoder_outs[1]) + self.csff_dec2(decoder_outs[1]) 107 | 108 | x = self.down23(enc2) 109 | 110 | enc3 = self.encoder_level3(x) 111 | if (encoder_outs is not None) and (decoder_outs is not None): 112 | enc3 = enc3 + self.csff_enc3(encoder_outs[2]) + self.csff_dec3(decoder_outs[2]) 113 | 114 | return [enc1, enc2, enc3] 115 | 116 | class Decoder(nn.Module): 117 | def __init__(self, n_feat, kernel_size, reduction, act, bias, scale_unetfeats): 118 | super(Decoder, self).__init__() 119 | 120 | self.decoder_level1 = [CAB(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(2)] 121 | self.decoder_level2 = [CAB(n_feat+scale_unetfeats, kernel_size, reduction, bias=bias, act=act) for _ in range(2)] 122 | self.decoder_level3 = [CAB(n_feat+(scale_unetfeats*2), kernel_size, reduction, bias=bias, act=act) for _ in range(2)] 123 | 124 | self.decoder_level1 = nn.Sequential(*self.decoder_level1) 125 | self.decoder_level2 = nn.Sequential(*self.decoder_level2) 126 | self.decoder_level3 = nn.Sequential(*self.decoder_level3) 127 | 128 | self.skip_attn1 = CAB(n_feat, kernel_size, reduction, bias=bias, act=act) 129 | self.skip_attn2 = CAB(n_feat+scale_unetfeats, kernel_size, reduction, bias=bias, act=act) 130 | 131 | self.up21 = SkipUpSample(n_feat, scale_unetfeats) 132 | self.up32 = SkipUpSample(n_feat+scale_unetfeats, scale_unetfeats) 133 | 134 | def forward(self, outs): 135 | enc1, enc2, enc3 = outs 136 | dec3 = self.decoder_level3(enc3) 137 | 138 | x = self.up32(dec3, self.skip_attn2(enc2)) 139 | dec2 = self.decoder_level2(x) 140 | 141 | x = self.up21(dec2, self.skip_attn1(enc1)) 142 | dec1 = self.decoder_level1(x) 143 | 144 | return [dec1,dec2,dec3] 145 | 146 | ########################################################################## 147 | ##---------- Resizing Modules ---------- 148 | class DownSample(nn.Module): 149 | def __init__(self, in_channels,s_factor): 150 | super(DownSample, self).__init__() 151 | self.down = nn.Sequential(nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=False), 152 | nn.Conv2d(in_channels, in_channels+s_factor, 1, stride=1, padding=0, bias=False)) 153 | 154 | def forward(self, x): 155 | x = self.down(x) 156 | return x 157 | 158 | class UpSample(nn.Module): 159 | def __init__(self, in_channels,s_factor): 160 | super(UpSample, self).__init__() 161 | self.up = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), 162 | nn.Conv2d(in_channels+s_factor, in_channels, 1, stride=1, padding=0, bias=False)) 163 | 164 | def forward(self, x): 165 | x = self.up(x) 166 | return x 167 | 168 | class SkipUpSample(nn.Module): 169 | def __init__(self, in_channels,s_factor): 170 | super(SkipUpSample, self).__init__() 171 | self.up = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), 172 | nn.Conv2d(in_channels+s_factor, in_channels, 1, stride=1, padding=0, bias=False)) 173 | 174 | def forward(self, x, y): 175 | x = self.up(x) 176 | x = x + y 177 | return x 178 | 179 | ########################################################################## 180 | ## Original Resolution Block (ORB) 181 | class ORB(nn.Module): 182 | def __init__(self, n_feat, kernel_size, reduction, act, bias, num_cab): 183 | super(ORB, self).__init__() 184 | modules_body = [] 185 | modules_body = [CAB(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(num_cab)] 186 | modules_body.append(conv(n_feat, n_feat, kernel_size)) 187 | self.body = nn.Sequential(*modules_body) 188 | 189 | def forward(self, x): 190 | res = self.body(x) 191 | res += x 192 | return res 193 | 194 | ########################################################################## 195 | class ORSNet(nn.Module): 196 | def __init__(self, n_feat, scale_orsnetfeats, kernel_size, reduction, act, bias, scale_unetfeats, num_cab): 197 | super(ORSNet, self).__init__() 198 | 199 | self.orb1 = ORB(n_feat+scale_orsnetfeats, kernel_size, reduction, act, bias, num_cab) 200 | self.orb2 = ORB(n_feat+scale_orsnetfeats, kernel_size, reduction, act, bias, num_cab) 201 | self.orb3 = ORB(n_feat+scale_orsnetfeats, kernel_size, reduction, act, bias, num_cab) 202 | 203 | self.up_enc1 = UpSample(n_feat, scale_unetfeats) 204 | self.up_dec1 = UpSample(n_feat, scale_unetfeats) 205 | 206 | self.up_enc2 = nn.Sequential(UpSample(n_feat+scale_unetfeats, scale_unetfeats), UpSample(n_feat, scale_unetfeats)) 207 | self.up_dec2 = nn.Sequential(UpSample(n_feat+scale_unetfeats, scale_unetfeats), UpSample(n_feat, scale_unetfeats)) 208 | 209 | self.conv_enc1 = nn.Conv2d(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias=bias) 210 | self.conv_enc2 = nn.Conv2d(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias=bias) 211 | self.conv_enc3 = nn.Conv2d(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias=bias) 212 | 213 | self.conv_dec1 = nn.Conv2d(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias=bias) 214 | self.conv_dec2 = nn.Conv2d(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias=bias) 215 | self.conv_dec3 = nn.Conv2d(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias=bias) 216 | 217 | def forward(self, x, encoder_outs, decoder_outs): 218 | x = self.orb1(x) 219 | x = x + self.conv_enc1(encoder_outs[0]) + self.conv_dec1(decoder_outs[0]) 220 | 221 | x = self.orb2(x) 222 | x = x + self.conv_enc2(self.up_enc1(encoder_outs[1])) + self.conv_dec2(self.up_dec1(decoder_outs[1])) 223 | 224 | x = self.orb3(x) 225 | x = x + self.conv_enc3(self.up_enc2(encoder_outs[2])) + self.conv_dec3(self.up_dec2(decoder_outs[2])) 226 | 227 | return x 228 | 229 | 230 | ########################################################################## 231 | class MPRNet(nn.Module): 232 | def __init__(self, in_c=4, out_c=4, n_feat=4, scale_unetfeats=4, scale_orsnetfeats=4, num_cab=4, kernel_size=3, reduction=1, bias=False): 233 | super(MPRNet, self).__init__() 234 | 235 | self.conv_in = nn.Conv2d(3, in_c, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, 236 | bias=bias) 237 | 238 | act=nn.PReLU() 239 | self.shallow_feat1 = nn.Sequential(conv(in_c, n_feat, kernel_size, bias=bias), CAB(n_feat,kernel_size, reduction, bias=bias, act=act)) 240 | self.shallow_feat2 = nn.Sequential(conv(in_c, n_feat, kernel_size, bias=bias), CAB(n_feat,kernel_size, reduction, bias=bias, act=act)) 241 | self.shallow_feat3 = nn.Sequential(conv(in_c, n_feat, kernel_size, bias=bias), CAB(n_feat,kernel_size, reduction, bias=bias, act=act)) 242 | 243 | # Cross Stage Feature Fusion (CSFF) 244 | self.stage1_encoder = Encoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats, csff=False) 245 | self.stage1_decoder = Decoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats) 246 | 247 | self.stage2_encoder = Encoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats, csff=True) 248 | self.stage2_decoder = Decoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats) 249 | 250 | self.stage3_orsnet = ORSNet(n_feat, scale_orsnetfeats, kernel_size, reduction, act, bias, scale_unetfeats, num_cab) 251 | 252 | self.sam12 = SAM(n_feat, kernel_size=1, bias=bias) 253 | self.sam23 = SAM(n_feat, kernel_size=1, bias=bias) 254 | 255 | self.concat12 = conv(n_feat*2, n_feat, kernel_size, bias=bias) 256 | self.concat23 = conv(n_feat*2, n_feat+scale_orsnetfeats, kernel_size, bias=bias) 257 | self.tail = conv(n_feat+scale_orsnetfeats, out_c, kernel_size, bias=bias) 258 | 259 | def forward(self, x3_img): 260 | b, c, h_inp, w_inp = x3_img.shape 261 | hb, wb = 8, 8 262 | pad_h = (hb - h_inp % hb) % hb 263 | pad_w = (wb - w_inp % wb) % wb 264 | x3_img = F.pad(x3_img, [0, pad_w, 0, pad_h], mode='reflect') 265 | x3_img = self.conv_in(x3_img) 266 | 267 | # Original-resolution Image for Stage 3 268 | H = x3_img.size(2) 269 | W = x3_img.size(3) 270 | 271 | # Multi-Patch Hierarchy: Split Image into four non-overlapping patches 272 | 273 | # Two Patches for Stage 2 274 | x2top_img = x3_img[:,:,0:int(H/2),:] 275 | x2bot_img = x3_img[:,:,int(H/2):H,:] 276 | 277 | # Four Patches for Stage 1 278 | x1ltop_img = x2top_img[:,:,:,0:int(W/2)] 279 | x1rtop_img = x2top_img[:,:,:,int(W/2):W] 280 | x1lbot_img = x2bot_img[:,:,:,0:int(W/2)] 281 | x1rbot_img = x2bot_img[:,:,:,int(W/2):W] 282 | 283 | ##------------------------------------------- 284 | ##-------------- Stage 1--------------------- 285 | ##------------------------------------------- 286 | ## Compute Shallow Features 287 | x1ltop = self.shallow_feat1(x1ltop_img) 288 | x1rtop = self.shallow_feat1(x1rtop_img) 289 | x1lbot = self.shallow_feat1(x1lbot_img) 290 | x1rbot = self.shallow_feat1(x1rbot_img) 291 | 292 | ## Process features of all 4 patches with Encoder of Stage 1 293 | feat1_ltop = self.stage1_encoder(x1ltop) 294 | feat1_rtop = self.stage1_encoder(x1rtop) 295 | feat1_lbot = self.stage1_encoder(x1lbot) 296 | feat1_rbot = self.stage1_encoder(x1rbot) 297 | 298 | ## Concat deep features 299 | feat1_top = [torch.cat((k,v), 3) for k,v in zip(feat1_ltop,feat1_rtop)] 300 | feat1_bot = [torch.cat((k,v), 3) for k,v in zip(feat1_lbot,feat1_rbot)] 301 | 302 | ## Pass features through Decoder of Stage 1 303 | res1_top = self.stage1_decoder(feat1_top) 304 | res1_bot = self.stage1_decoder(feat1_bot) 305 | 306 | ## Apply Supervised Attention Module (SAM) 307 | x2top_samfeats, stage1_img_top = self.sam12(res1_top[0], x2top_img) 308 | x2bot_samfeats, stage1_img_bot = self.sam12(res1_bot[0], x2bot_img) 309 | 310 | ## Output image at Stage 1 311 | stage1_img = torch.cat([stage1_img_top, stage1_img_bot],2) 312 | ##------------------------------------------- 313 | ##-------------- Stage 2--------------------- 314 | ##------------------------------------------- 315 | ## Compute Shallow Features 316 | x2top = self.shallow_feat2(x2top_img) 317 | x2bot = self.shallow_feat2(x2bot_img) 318 | 319 | ## Concatenate SAM features of Stage 1 with shallow features of Stage 2 320 | x2top_cat = self.concat12(torch.cat([x2top, x2top_samfeats], 1)) 321 | x2bot_cat = self.concat12(torch.cat([x2bot, x2bot_samfeats], 1)) 322 | 323 | ## Process features of both patches with Encoder of Stage 2 324 | feat2_top = self.stage2_encoder(x2top_cat, feat1_top, res1_top) 325 | feat2_bot = self.stage2_encoder(x2bot_cat, feat1_bot, res1_bot) 326 | 327 | ## Concat deep features 328 | feat2 = [torch.cat((k,v), 2) for k,v in zip(feat2_top,feat2_bot)] 329 | 330 | ## Pass features through Decoder of Stage 2 331 | res2 = self.stage2_decoder(feat2) 332 | 333 | ## Apply SAM 334 | x3_samfeats, stage2_img = self.sam23(res2[0], x3_img) 335 | 336 | 337 | ##------------------------------------------- 338 | ##-------------- Stage 3--------------------- 339 | ##------------------------------------------- 340 | ## Compute Shallow Features 341 | x3 = self.shallow_feat3(x3_img) 342 | 343 | ## Concatenate SAM features of Stage 2 with shallow features of Stage 3 344 | x3_cat = self.concat23(torch.cat([x3, x3_samfeats], 1)) 345 | 346 | x3_cat = self.stage3_orsnet(x3_cat, feat2, res2) 347 | 348 | stage3_img = self.tail(x3_cat) 349 | 350 | return (stage3_img + x3_img)[:, :, :h_inp, :w_inp] 351 | -------------------------------------------------------------------------------- /other_model/HRNET.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import Parameter 5 | # ---------------------------------------- 6 | # Conv2d Block 7 | # ---------------------------------------- 8 | class Conv2dLayer(nn.Module): 9 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, pad_type='zero', 10 | activation='lrelu', norm='none', sn=False): 11 | super(Conv2dLayer, self).__init__() 12 | # Initialize the padding scheme 13 | if pad_type == 'reflect': 14 | self.pad = nn.ReflectionPad2d(padding) 15 | elif pad_type == 'replicate': 16 | self.pad = nn.ReplicationPad2d(padding) 17 | elif pad_type == 'zero': 18 | self.pad = nn.ZeroPad2d(padding) 19 | else: 20 | assert 0, "Unsupported padding type: {}".format(pad_type) 21 | 22 | # Initialize the normalization type 23 | if norm == 'bn': 24 | self.norm = nn.BatchNorm2d(out_channels) 25 | elif norm == 'in': 26 | self.norm = nn.InstanceNorm2d(out_channels) 27 | elif norm == 'ln': 28 | self.norm = LayerNorm(out_channels) 29 | elif norm == 'none': 30 | self.norm = None 31 | else: 32 | assert 0, "Unsupported normalization: {}".format(norm) 33 | 34 | # Initialize the activation funtion 35 | if activation == 'relu': 36 | self.activation = nn.ReLU(inplace=True) 37 | elif activation == 'lrelu': 38 | self.activation = nn.LeakyReLU(0.2, inplace=True) 39 | elif activation == 'prelu': 40 | self.activation = nn.PReLU() 41 | elif activation == 'selu': 42 | self.activation = nn.SELU(inplace=True) 43 | elif activation == 'tanh': 44 | self.activation = nn.Tanh() 45 | elif activation == 'sigmoid': 46 | self.activation = nn.Sigmoid() 47 | elif activation == 'none': 48 | self.activation = None 49 | else: 50 | assert 0, "Unsupported activation: {}".format(activation) 51 | 52 | # Initialize the convolution layers 53 | if sn: 54 | self.conv2d = SpectralNorm( 55 | nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=0, dilation=dilation)) 56 | else: 57 | self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=0, dilation=dilation) 58 | 59 | def forward(self, x): 60 | x = self.pad(x) 61 | x = self.conv2d(x) 62 | if self.norm: 63 | x = self.norm(x) 64 | if self.activation: 65 | x = self.activation(x) 66 | return x 67 | 68 | 69 | class TransposeConv2dLayer(nn.Module): 70 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, pad_type='zero', 71 | activation='lrelu', norm='none', sn=False, scale_factor=2): 72 | super(TransposeConv2dLayer, self).__init__() 73 | # Initialize the conv scheme 74 | self.scale_factor = scale_factor 75 | self.conv2d = Conv2dLayer(in_channels, out_channels, kernel_size, stride, padding, dilation, pad_type, 76 | activation, norm, sn) 77 | 78 | def forward(self, x): 79 | x = F.interpolate(x, scale_factor=self.scale_factor, mode='nearest') 80 | x = self.conv2d(x) 81 | return x 82 | 83 | 84 | class ResConv2dLayer(nn.Module): 85 | def __init__(self, in_channels, kernel_size=3, stride=1, padding=1, dilation=1, pad_type='zero', activation='lrelu', 86 | norm='none', sn=False, scale_factor=2): 87 | super(ResConv2dLayer, self).__init__() 88 | # Initialize the conv scheme 89 | self.conv2d = nn.Sequential( 90 | Conv2dLayer(in_channels, in_channels, kernel_size, stride, padding, dilation, pad_type, activation, norm, 91 | sn), 92 | Conv2dLayer(in_channels, in_channels, kernel_size, stride, padding, dilation, pad_type, activation='none', 93 | norm=norm, sn=sn) 94 | ) 95 | 96 | def forward(self, x): 97 | residual = x 98 | out = self.conv2d(x) 99 | out = 0.1 * out + residual 100 | return out 101 | 102 | 103 | class DenseConv2dLayer_5C(nn.Module): 104 | def __init__(self, in_channels, latent_channels, kernel_size=3, stride=1, padding=1, dilation=1, pad_type='zero', 105 | activation='lrelu', norm='none', sn=False): 106 | super(DenseConv2dLayer_5C, self).__init__() 107 | # dense convolutions 108 | self.conv1 = Conv2dLayer(in_channels, latent_channels, kernel_size, stride, padding, dilation, pad_type, 109 | activation, norm, sn) 110 | self.conv2 = Conv2dLayer(in_channels + latent_channels, latent_channels, kernel_size, stride, padding, dilation, 111 | pad_type, activation, norm, sn) 112 | self.conv3 = Conv2dLayer(in_channels + latent_channels * 2, latent_channels, kernel_size, stride, padding, 113 | dilation, pad_type, activation, norm, sn) 114 | self.conv4 = Conv2dLayer(in_channels + latent_channels * 3, latent_channels, kernel_size, stride, padding, 115 | dilation, pad_type, activation, norm, sn) 116 | self.conv5 = Conv2dLayer(in_channels + latent_channels * 4, in_channels, kernel_size, stride, padding, dilation, 117 | pad_type, activation, norm, sn) 118 | 119 | def forward(self, x): 120 | x1 = self.conv1(x) 121 | x2 = self.conv2(torch.cat((x, x1), 1)) 122 | x3 = self.conv3(torch.cat((x, x1, x2), 1)) 123 | x4 = self.conv4(torch.cat((x, x1, x2, x3), 1)) 124 | x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) 125 | return x5 126 | 127 | 128 | class ResidualDenseBlock_5C(nn.Module): 129 | def __init__(self, in_channels, latent_channels, kernel_size=3, stride=1, padding=1, dilation=1, pad_type='zero', 130 | activation='lrelu', norm='none', sn=False): 131 | super(ResidualDenseBlock_5C, self).__init__() 132 | # dense convolutions 133 | self.conv1 = Conv2dLayer(in_channels, latent_channels, kernel_size, stride, padding, dilation, pad_type, 134 | activation, norm, sn) 135 | self.conv2 = Conv2dLayer(in_channels + latent_channels, latent_channels, kernel_size, stride, padding, dilation, 136 | pad_type, activation, norm, sn) 137 | self.conv3 = Conv2dLayer(in_channels + latent_channels * 2, latent_channels, kernel_size, stride, padding, 138 | dilation, pad_type, activation, norm, sn) 139 | self.conv4 = Conv2dLayer(in_channels + latent_channels * 3, latent_channels, kernel_size, stride, padding, 140 | dilation, pad_type, activation, norm, sn) 141 | self.conv5 = Conv2dLayer(in_channels + latent_channels * 4, in_channels, kernel_size, stride, padding, dilation, 142 | pad_type, activation, norm, sn) 143 | 144 | def forward(self, x): 145 | residual = x 146 | x1 = self.conv1(x) 147 | x2 = self.conv2(torch.cat((x, x1), 1)) 148 | x3 = self.conv3(torch.cat((x, x1, x2), 1)) 149 | x4 = self.conv4(torch.cat((x, x1, x2, x3), 1)) 150 | x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) 151 | x5 = 0.1 * x5 + residual 152 | return x5 153 | 154 | 155 | # ---------------------------------------- 156 | # Layer Norm 157 | # ---------------------------------------- 158 | class LayerNorm(nn.Module): 159 | def __init__(self, num_features, eps=1e-8, affine=True): 160 | super(LayerNorm, self).__init__() 161 | self.num_features = num_features 162 | self.affine = affine 163 | self.eps = eps 164 | 165 | if self.affine: 166 | self.gamma = Parameter(torch.Tensor(num_features).uniform_()) 167 | self.beta = Parameter(torch.zeros(num_features)) 168 | 169 | def forward(self, x): 170 | # layer norm 171 | shape = [-1] + [1] * (x.dim() - 1) # for 4d input: [-1, 1, 1, 1] 172 | if x.size(0) == 1: 173 | # These two lines run much faster in pytorch 0.4 than the two lines listed below. 174 | mean = x.view(-1).mean().view(*shape) 175 | std = x.view(-1).std().view(*shape) 176 | else: 177 | mean = x.view(x.size(0), -1).mean(1).view(*shape) 178 | std = x.view(x.size(0), -1).std(1).view(*shape) 179 | x = (x - mean) / (std + self.eps) 180 | # if it is learnable 181 | if self.affine: 182 | shape = [1, -1] + [1] * (x.dim() - 2) # for 4d input: [1, -1, 1, 1] 183 | x = x * self.gamma.view(*shape) + self.beta.view(*shape) 184 | return x 185 | 186 | 187 | # ---------------------------------------- 188 | # Spectral Norm Block 189 | # ---------------------------------------- 190 | def l2normalize(v, eps=1e-12): 191 | return v / (v.norm() + eps) 192 | 193 | 194 | class SpectralNorm(nn.Module): 195 | def __init__(self, module, name='weight', power_iterations=1): 196 | super(SpectralNorm, self).__init__() 197 | self.module = module 198 | self.name = name 199 | self.power_iterations = power_iterations 200 | if not self._made_params(): 201 | self._make_params() 202 | 203 | def _update_u_v(self): 204 | u = getattr(self.module, self.name + "_u") 205 | v = getattr(self.module, self.name + "_v") 206 | w = getattr(self.module, self.name + "_bar") 207 | 208 | height = w.data.shape[0] 209 | for _ in range(self.power_iterations): 210 | v.data = l2normalize(torch.mv(torch.t(w.view(height, -1).data), u.data)) 211 | u.data = l2normalize(torch.mv(w.view(height, -1).data, v.data)) 212 | 213 | # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data)) 214 | sigma = u.dot(w.view(height, -1).mv(v)) 215 | setattr(self.module, self.name, w / sigma.expand_as(w)) 216 | 217 | def _made_params(self): 218 | try: 219 | u = getattr(self.module, self.name + "_u") 220 | v = getattr(self.module, self.name + "_v") 221 | w = getattr(self.module, self.name + "_bar") 222 | return True 223 | except AttributeError: 224 | return False 225 | 226 | def _make_params(self): 227 | w = getattr(self.module, self.name) 228 | 229 | height = w.data.shape[0] 230 | width = w.view(height, -1).data.shape[1] 231 | 232 | u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False) 233 | v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False) 234 | u.data = l2normalize(u.data) 235 | v.data = l2normalize(v.data) 236 | w_bar = Parameter(w.data) 237 | 238 | del self.module._parameters[self.name] 239 | 240 | self.module.register_parameter(self.name + "_u", u) 241 | self.module.register_parameter(self.name + "_v", v) 242 | self.module.register_parameter(self.name + "_bar", w_bar) 243 | 244 | def forward(self, *args): 245 | self._update_u_v() 246 | return self.module.forward(*args) 247 | 248 | 249 | # ---------------------------------------- 250 | # Non-local Block 251 | # ---------------------------------------- 252 | class Self_Attn(nn.Module): 253 | """ Self attention Layer for Feature Map dimension""" 254 | 255 | def __init__(self, in_dim, latent_dim=8): 256 | super(Self_Attn, self).__init__() 257 | self.channel_in = in_dim 258 | self.channel_latent = in_dim // latent_dim 259 | self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // latent_dim, kernel_size=1) 260 | self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // latent_dim, kernel_size=1) 261 | self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1) 262 | self.gamma = nn.Parameter(torch.zeros(1)) 263 | self.softmax = nn.Softmax(dim=-1) 264 | 265 | def forward(self, x): 266 | """ 267 | inputs : 268 | x : input feature maps(B X C X H X W) 269 | returns : 270 | out : self attention value + input feature 271 | attention: B X N X N (N is Height * Width) 272 | """ 273 | batchsize, C, height, width = x.size() 274 | # proj_query: reshape to B x N x c, N = H x W 275 | proj_query = self.query_conv(x).view(batchsize, -1, height * width).permute(0, 2, 1) 276 | # proj_query: reshape to B x c x N, N = H x W 277 | proj_key = self.key_conv(x).view(batchsize, -1, height * width) 278 | # transpose check, energy: B x N x N, N = H x W 279 | energy = torch.bmm(proj_query, proj_key) 280 | # attention: B x N x N, N = H x W 281 | attention = self.softmax(energy) 282 | # proj_value is normal convolution, B x C x N 283 | proj_value = self.value_conv(x).view(batchsize, -1, height * width) 284 | # out: B x C x N 285 | out = torch.bmm(proj_value, attention.permute(0, 2, 1)) 286 | out = out.view(batchsize, C, height, width) 287 | 288 | out = self.gamma * out + x 289 | return out 290 | 291 | 292 | # ---------------------------------------- 293 | # Global Block 294 | # ---------------------------------------- 295 | class SELayer(nn.Module): 296 | def __init__(self, channel, reduction=16): 297 | super(SELayer, self).__init__() 298 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 299 | self.fc = nn.Sequential( 300 | nn.Linear(channel, channel // reduction, bias=False), 301 | nn.ReLU(inplace=True), 302 | nn.Linear(channel // reduction, channel // reduction, bias=False), 303 | nn.ReLU(inplace=True), 304 | nn.Linear(channel // reduction, channel, bias=False), 305 | nn.Sigmoid() 306 | ) 307 | 308 | def forward(self, x): 309 | b, c, _, _ = x.size() 310 | y = self.avg_pool(x).view(b, c) 311 | y = self.fc(y).view(b, c, 1, 1) 312 | return x * y.expand_as(x) 313 | 314 | 315 | class GlobalBlock(nn.Module): 316 | def __init__(self, in_channels, kernel_size, stride=1, padding=0, dilation=1, pad_type='zero', activation='lrelu', 317 | norm='none', sn=False, reduction=8): 318 | super(GlobalBlock, self).__init__() 319 | self.conv1 = Conv2dLayer(in_channels, in_channels, kernel_size, stride, padding, dilation, pad_type, activation, 320 | norm, sn) 321 | self.conv2 = Conv2dLayer(in_channels, in_channels, kernel_size, stride, padding, dilation, pad_type, activation, 322 | norm, sn) 323 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 324 | self.fc = nn.Sequential( 325 | nn.Linear(in_channels, in_channels // reduction, bias=False), 326 | nn.ReLU(inplace=True), 327 | nn.Linear(in_channels // reduction, in_channels // reduction, bias=False), 328 | nn.ReLU(inplace=True), 329 | nn.Linear(in_channels // reduction, in_channels, bias=False), 330 | nn.Sigmoid() 331 | ) 332 | 333 | def forward(self, x): 334 | # residual 335 | residual = x 336 | # Sequeeze-and-Excitation(SE) 337 | b, c, _, _ = x.size() 338 | x = self.conv1(x) 339 | y = self.avg_pool(x).view(b, c) 340 | y = self.fc(y).view(b, c, 1, 1) 341 | y = x * y.expand_as(x) 342 | y = self.conv2(x) 343 | # addition 344 | out = 0.1 * y + residual 345 | return out 346 | def pixel_unshuffle(input, downscale_factor): 347 | ''' 348 | input: batchSize * c * k*w * k*h 349 | downscale_factor: k 350 | batchSize * c * k*w * k*h -> batchSize * k*k*c * w * h 351 | ''' 352 | c = input.shape[1] 353 | kernel = torch.zeros(size = [downscale_factor * downscale_factor * c, 1, downscale_factor, downscale_factor], 354 | device = input.device) 355 | for y in range(downscale_factor): 356 | for x in range(downscale_factor): 357 | kernel[x + y * downscale_factor::downscale_factor * downscale_factor, 0, y, x] = 1 358 | return F.conv2d(input, kernel, stride = downscale_factor, groups = c) 359 | 360 | class PixelUnShuffle(nn.Module): 361 | def __init__(self, downscale_factor): 362 | super(PixelUnShuffle, self).__init__() 363 | self.downscale_factor = downscale_factor 364 | 365 | def forward(self, input): 366 | ''' 367 | input: batchSize * c * k*w * k*h 368 | downscale_factor: k 369 | batchSize * c * k*w * k*h -> batchSize * k*k*c * w * h 370 | ''' 371 | return pixel_unshuffle(input, self.downscale_factor) 372 | 373 | # ---------------------------------------- 374 | # Initialize the networks 375 | # ---------------------------------------- 376 | def weights_init(net, init_type = 'normal', init_gain = 0.02): 377 | """Initialize network weights. 378 | Parameters: 379 | net (network) -- network to be initialized 380 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 381 | init_gain (float) -- scaling factor for normal, xavier and orthogonal 382 | In our paper, we choose the default setting: zero mean Gaussian distribution with a standard deviation of 0.02 383 | """ 384 | def init_func(m): 385 | classname = m.__class__.__name__ 386 | if hasattr(m, 'weight') and classname.find('Conv') != -1: 387 | if init_type == 'normal': 388 | torch.nn.init.normal_(m.weight.data, 0.0, init_gain) 389 | elif init_type == 'xavier': 390 | torch.nn.init.xavier_normal_(m.weight.data, gain = init_gain) 391 | elif init_type == 'kaiming': 392 | torch.nn.init.kaiming_normal_(m.weight.data, a = 0, mode = 'fan_in') 393 | elif init_type == 'orthogonal': 394 | torch.nn.init.orthogonal_(m.weight.data, gain = init_gain) 395 | else: 396 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 397 | elif classname.find('BatchNorm2d') != -1: 398 | torch.nn.init.normal_(m.weight.data, 1.0, 0.02) 399 | torch.nn.init.constant_(m.bias.data, 0.0) 400 | 401 | # apply the initialization function 402 | print('initialize network with %s type' % init_type) 403 | net.apply(init_func) 404 | 405 | # ---------------------------------------- 406 | # Generator 407 | # ---------------------------------------- 408 | class SGN(nn.Module): 409 | # Notice 410 | def __init__(self, in_channels=3, out_channels=4, start_channels=8, pad='zero', activ='lrelu', norm='none', ): 411 | super(SGN, self).__init__() 412 | # Top subnetwork, K = 3 413 | self.top1 = Conv2dLayer(in_channels * (4 ** 3), start_channels * (2 ** 3), 3, 1, 1, pad_type = pad, activation = activ, norm = norm) 414 | self.top21 = ResidualDenseBlock_5C(start_channels * (2 ** 3), start_channels * (2 ** 2), 3, 1, 1, pad_type = pad, activation = activ, norm = norm) 415 | self.top22 = GlobalBlock(start_channels * (2 ** 3), 3, 1, 1, pad_type = pad, activation = activ, norm = norm, sn = False, reduction = 4) 416 | self.top3 = Conv2dLayer(start_channels * (2 ** 3), start_channels * (2 ** 3), 1, 1, 0, pad_type = pad, activation = activ, norm = norm) 417 | # Middle subnetwork, K = 2 418 | self.mid1 = Conv2dLayer(in_channels * (4 ** 2), start_channels * (2 ** 2), 3, 1, 1, pad_type = pad, activation = activ, norm = norm) 419 | self.mid2 = Conv2dLayer(int(start_channels * (2 ** 2 + 2 ** 3 / 4)), start_channels * (2 ** 2), 3, 1, 1, pad_type = pad, activation = activ, norm = norm) 420 | self.mid31 = ResidualDenseBlock_5C(start_channels * (2 ** 2), start_channels * (2 ** 1), 3, 1, 1, pad_type = pad, activation = activ, norm = norm) 421 | self.mid32 = GlobalBlock(start_channels * (2 ** 2), 3, 1, 1, pad_type = pad, activation = activ, norm = norm, sn = False, reduction = 4) 422 | self.mid4 = Conv2dLayer(start_channels * (2 ** 2), start_channels * (2 ** 2), 3, 1, 1, pad_type = pad, activation = activ, norm = norm) 423 | # Bottom subnetwork, K = 1 424 | self.bot1 = Conv2dLayer(in_channels * (4 ** 1), start_channels * (2 ** 1), 3, 1, 1, pad_type = pad, activation = activ, norm = norm) 425 | self.bot2 = Conv2dLayer(int(start_channels * (2 ** 1 + 2 ** 2 / 4)), start_channels * (2 ** 1), 3, 1, 1, pad_type = pad, activation = activ, norm = norm) 426 | self.bot31 = ResidualDenseBlock_5C(start_channels * (2 ** 1), start_channels * (2 ** 0), 3, 1, 1, pad_type = pad, activation = activ, norm = norm) 427 | self.bot32 = ResidualDenseBlock_5C(start_channels * (2 ** 1), start_channels * (2 ** 0), 3, 1, 1, pad_type = pad, activation = activ, norm = norm) 428 | self.bot33 = GlobalBlock(start_channels * (2 ** 1), 3, 1, 1, pad_type = pad, activation = activ, norm = norm, sn = False, reduction = 4) 429 | self.bot4 = Conv2dLayer(start_channels * (2 ** 1), start_channels * (2 ** 1), 3, 1, 1, pad_type = pad, activation = activ, norm = norm) 430 | # Mainstream 431 | self.main1 = Conv2dLayer(in_channels, start_channels, 3, 1, 1, pad_type = pad, activation = activ, norm = norm) 432 | self.main2 = Conv2dLayer(int(start_channels * (2 ** 0 + 2 ** 1 / 4)), start_channels, 3, 1, 1, pad_type = pad, activation = activ, norm = norm) 433 | self.main31 = ResidualDenseBlock_5C(start_channels, start_channels // 2, 3, 1, 1, pad_type = pad, activation = activ, norm = norm) 434 | self.main32 = ResidualDenseBlock_5C(start_channels, start_channels // 2, 3, 1, 1, pad_type = pad, activation = activ, norm = norm) 435 | self.main33 = ResidualDenseBlock_5C(start_channels, start_channels // 2, 3, 1, 1, pad_type = pad, activation = activ, norm = norm) 436 | self.main34 = ResidualDenseBlock_5C(start_channels, start_channels // 2, 3, 1, 1, pad_type = pad, activation = activ, norm = norm) 437 | self.main35 = GlobalBlock(start_channels, 3, 1, 1, pad_type = pad, activation = activ, norm = norm, sn = False, reduction = 4) 438 | self.main4 = Conv2dLayer(start_channels, out_channels, 3, 1, 1, pad_type = pad, activation = activ, norm = norm) 439 | 440 | def forward(self, x): 441 | b, c, h_inp, w_inp = x.shape 442 | hb, wb = 8, 8 443 | pad_h = (hb - h_inp % hb) % hb 444 | pad_w = (wb - w_inp % wb) % wb 445 | x = F.pad(x, [0, pad_w, 0, pad_h], mode='reflect') 446 | # PixelUnShuffle input: batch * 3 * 256 * 256 447 | x1 = pixel_unshuffle(x, 2) # out: batch * 12 * 128 * 128 448 | x2 = pixel_unshuffle(x, 4) # out: batch * 48 * 64 * 64 449 | x3 = pixel_unshuffle(x, 8) # out: batch * 192 * 32 * 32 450 | # Top subnetwork suppose the start_channels = 32 451 | x3 = self.top1(x3) # out: batch * 256 * 32 * 32 452 | x3 = self.top21(x3) # out: batch * 256 * 32 * 32 453 | x3 = self.top22(x3) # out: batch * 256 * 32 * 32 454 | x3 = self.top3(x3) # out: batch * 256 * 32 * 32 455 | x3 = F.pixel_shuffle(x3, 2) # out: batch * 64 * 64 * 64, ready to be concatenated 456 | # Middle subnetwork 457 | x2 = self.mid1(x2) # out: batch * 128 * 64 * 64 458 | x2 = torch.cat((x2, x3), 1) # out: batch * (128 + 64) * 64 * 64 459 | x2 = self.mid2(x2) # out: batch * 128 * 64 * 64 460 | x2 = self.mid31(x2) # out: batch * 128 * 64 * 64 461 | x2 = self.mid32(x2) # out: batch * 128 * 64 * 64 462 | x2 = self.mid4(x2) # out: batch * 128 * 64 * 64 463 | x2 = F.pixel_shuffle(x2, 2) # out: batch * 32 * 128 * 128, ready to be concatenated 464 | # Bottom subnetwork 465 | x1 = self.bot1(x1) # out: batch * 64 * 128 * 128 466 | x1 = torch.cat((x1, x2), 1) # out: batch * (64 + 32) * 128 * 128 467 | x1 = self.bot2(x1) # out: batch * 64 * 128 * 128 468 | x1 = self.bot31(x1) # out: batch * 64 * 128 * 128 469 | x1 = self.bot32(x1) # out: batch * 64 * 128 * 128 470 | x1 = self.bot33(x1) # out: batch * 64 * 128 * 128 471 | x1 = self.bot4(x1) # out: batch * 64 * 128 * 128 472 | x1 = F.pixel_shuffle(x1, 2) # out: batch * 16 * 256 * 256, ready to be concatenated 473 | # U-Net generator with skip connections from encoder to decoder 474 | x = self.main1(x) # out: batch * 32 * 256 * 256 475 | x = torch.cat((x, x1), 1) # out: batch * (32 + 16) * 256 * 256 476 | x = self.main2(x) # out: batch * 32 * 256 * 256 477 | x = self.main31(x) # out: batch * 32 * 256 * 256 478 | x = self.main32(x) # out: batch * 32 * 256 * 256 479 | x = self.main33(x) # out: batch * 32 * 256 * 256 480 | x = self.main34(x) # out: batch * 32 * 256 * 256 481 | x = self.main35(x) # out: batch * 32 * 256 * 256 482 | x = self.main4(x) # out: batch * 3 * 256 * 256 483 | 484 | return x[:, :, :h_inp, :w_inp] 485 | 486 | --------------------------------------------------------------------------------