├── README.md ├── data.py ├── metrics.py ├── models ├── GPPNN.py ├── modules.py ├── refine.py └── utils │ ├── CDC.py │ ├── Inv_modules.py │ ├── Inv_utils.py │ └── __pycache__ │ ├── CDC.cpython-38.pyc │ ├── Inv_modules.cpython-38.pyc │ └── Inv_utils.cpython-38.pyc ├── training └── train_GPPNN.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # MultulInform_PANsharpening 2 | The implementation of CVPR 2022 paper "mutual information-driven pan-sharpening" 3 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | ''' 4 | @Author: wjm 5 | @Date: 2019-10-23 14:57:22 6 | LastEditTime: 2021-01-19 20:57:29 7 | @Description: file content 8 | ''' 9 | import torch.utils.data as data 10 | import torch, random, os 11 | import numpy as np 12 | from os import listdir 13 | from os.path import join 14 | from PIL import Image, ImageOps 15 | from random import randrange 16 | import torch.nn.functional as F 17 | from torchvision.transforms import Compose, ToTensor 18 | 19 | 20 | def is_image_file(filename): 21 | return any(filename.endswith(extension) for extension in 22 | ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 'tif', 'TIF']) 23 | 24 | 25 | def load_img(filepath): 26 | img = Image.open(filepath) 27 | # img = Image.open(filepath) 28 | return img 29 | 30 | def transform(): 31 | return Compose([ 32 | ToTensor(), 33 | ]) 34 | 35 | def rescale_img(img_in, scale): 36 | size_in = img_in.size 37 | new_size_in = tuple([int(x * scale) for x in size_in]) 38 | img_in = img_in.resize(new_size_in, resample=Image.BICUBIC) 39 | return img_in 40 | 41 | 42 | def get_patch(ms_image, lms_image, pan_image, bms_image, patch_size, scale, ix=-1, iy=-1): 43 | (ih, iw) = lms_image.size 44 | (th, tw) = (scale * ih, scale * iw) 45 | 46 | patch_mult = scale # if len(scale) > 1 else 1 47 | tp = patch_mult * patch_size 48 | ip = tp // scale 49 | 50 | if ix == -1: 51 | ix = random.randrange(0, iw - ip + 1) 52 | if iy == -1: 53 | iy = random.randrange(0, ih - ip + 1) 54 | 55 | (tx, ty) = (scale * ix, scale * iy) 56 | 57 | lms_image = lms_image.crop((iy, ix, iy + ip, ix + ip)) 58 | ms_image = ms_image.crop((ty, tx, ty + tp, tx + tp)) 59 | pan_image = pan_image.crop((ty, tx, ty + tp, tx + tp)) 60 | bms_image = bms_image.crop((ty, tx, ty + tp, tx + tp)) 61 | 62 | info_patch = { 63 | 'ix': ix, 'iy': iy, 'ip': ip, 'tx': tx, 'ty': ty, 'tp': tp} 64 | 65 | return ms_image, lms_image, pan_image, bms_image, info_patch 66 | 67 | 68 | def augment(ms_image, lms_image, pan_image, bms_image, flip_h=True, rot=True): 69 | info_aug = {'flip_h': False, 'flip_v': False, 'trans': False} 70 | 71 | if random.random() < 0.5 and flip_h: 72 | ms_image = ImageOps.flip(ms_image) 73 | lms_image = ImageOps.flip(lms_image) 74 | pan_image = ImageOps.flip(pan_image) 75 | # bms_image = ImageOps.flip(bms_image) 76 | info_aug['flip_h'] = True 77 | 78 | if rot: 79 | if random.random() < 0.5: 80 | ms_image = ImageOps.mirror(ms_image) 81 | lms_image = ImageOps.mirror(lms_image) 82 | pan_image = ImageOps.mirror(pan_image) 83 | # bms_image = ImageOps.mirror(bms_image) 84 | info_aug['flip_v'] = True 85 | if random.random() < 0.5: 86 | ms_image = ms_image.rotate(180) 87 | lms_image = lms_image.rotate(180) 88 | pan_image = pan_image.rotate(180) 89 | # bms_image = pan_image.rotate(180) 90 | info_aug['trans'] = True 91 | 92 | return ms_image, lms_image, pan_image, info_aug 93 | 94 | 95 | class Data(data.Dataset): 96 | def __init__(self, data_dir_ms, data_dir_pan, transform=transform(), upscale = 4): 97 | super(Data, self).__init__() 98 | 99 | self.ms_image_filenames = [join(data_dir_ms, x) for x in listdir(data_dir_ms) if is_image_file(x)] 100 | self.pan_image_filenames = [join(data_dir_pan, x) for x in listdir(data_dir_pan) if is_image_file(x)] 101 | 102 | # self.patch_size = cfg['data']['patch_size'] 103 | self.upscale_factor = upscale 104 | self.transform = transform 105 | # self.data_augmentation = cfg['data']['data_augmentation'] 106 | # self.normalize = cfg['data']['normalize'] 107 | # self.cfg = cfg 108 | 109 | def __getitem__(self, index): 110 | 111 | ms_image = load_img(self.ms_image_filenames[index]) 112 | pan_image = load_img(self.pan_image_filenames[index]) 113 | _, file = os.path.split(self.ms_image_filenames[index]) 114 | ms_image = ms_image.crop((0, 0, ms_image.size[0] // self.upscale_factor * self.upscale_factor, 115 | ms_image.size[1] // self.upscale_factor * self.upscale_factor)) 116 | lms_image = ms_image.resize( 117 | (int(ms_image.size[0] / self.upscale_factor), int(ms_image.size[1] / self.upscale_factor)), Image.BICUBIC) 118 | pan_image = pan_image.crop((0, 0, pan_image.size[0] // self.upscale_factor * self.upscale_factor, 119 | pan_image.size[1] // self.upscale_factor * self.upscale_factor)) 120 | # bms_image = rescale_img(lms_image, self.upscale_factor) 121 | 122 | # ms_image, lms_image, pan_image, bms_image, _ = get_patch(ms_image, lms_image, pan_image, bms_image, 123 | # self.patch_size, scale=self.upscale_factor) 124 | 125 | # if self.data_augmentation: 126 | # ms_image, lms_image, pan_image, _ = augment(ms_image, lms_image, pan_image) 127 | 128 | if self.transform: 129 | ms_image = self.transform(ms_image) 130 | lms_image = self.transform(lms_image) 131 | pan_image = self.transform(pan_image) 132 | # bms_image = self.transform(bms_image) 133 | 134 | # if self.normalize: 135 | # ms_image = ms_image * 2 - 1 136 | # lms_image = lms_image * 2 - 1 137 | # pan_image = pan_image * 2 - 1 138 | # bms_image = bms_image * 2 - 1 139 | 140 | return lms_image, pan_image, ms_image 141 | 142 | def __len__(self): 143 | return len(self.ms_image_filenames) 144 | 145 | 146 | class Data_test(data.Dataset): 147 | def __init__(self, data_dir_ms, data_dir_pan, transform=transform(), upscale = 4): 148 | super(Data_test, self).__init__() 149 | 150 | self.ms_image_filenames = [join(data_dir_ms, x) for x in listdir(data_dir_ms) if is_image_file(x)] 151 | self.pan_image_filenames = [join(data_dir_pan, x) for x in listdir(data_dir_pan) if is_image_file(x)] 152 | 153 | self.upscale_factor = upscale 154 | self.transform = transform 155 | # self.data_augmentation = cfg['data']['data_augmentation'] 156 | # self.normalize = cfg['data']['normalize'] 157 | # self.cfg = cfg 158 | 159 | def __getitem__(self, index): 160 | 161 | ms_image = load_img(self.ms_image_filenames[index]) 162 | pan_image = load_img(self.pan_image_filenames[index]) 163 | _, file = os.path.split(self.ms_image_filenames[index]) 164 | ms_image = ms_image.crop((0, 0, ms_image.size[0] // self.upscale_factor * self.upscale_factor, 165 | ms_image.size[1] // self.upscale_factor * self.upscale_factor)) 166 | lms_image = ms_image.resize( 167 | (int(ms_image.size[0] / self.upscale_factor), int(ms_image.size[1] / self.upscale_factor)), Image.BICUBIC) 168 | pan_image = pan_image.crop((0, 0, pan_image.size[0] // self.upscale_factor * self.upscale_factor, 169 | pan_image.size[1] // self.upscale_factor * self.upscale_factor)) 170 | # bms_image = rescale_img(lms_image, self.upscale_factor) 171 | 172 | # if self.data_augmentation: 173 | # ms_image, lms_image, pan_image, _ = augment(ms_image, lms_image, pan_image) 174 | 175 | if self.transform: 176 | ms_image = self.transform(ms_image) 177 | lms_image = self.transform(lms_image) 178 | pan_image = self.transform(pan_image) 179 | # bms_image = self.transform(bms_image) 180 | 181 | # if self.normalize: 182 | # ms_image = ms_image * 2 - 1 183 | # lms_image = lms_image * 2 - 1 184 | # pan_image = pan_image * 2 - 1 185 | # bms_image = bms_image * 2 - 1 186 | 187 | return lms_image, pan_image,ms_image 188 | 189 | def __len__(self): 190 | return len(self.ms_image_filenames) 191 | 192 | 193 | # class Data_eval(data.Dataset): 194 | # def __init__(self, image_dir, upscale_factor, cfg, transform=None): 195 | # super(Data_eval, self).__init__() 196 | # 197 | # self.ms_image_filenames = [join(data_dir_ms, x) for x in listdir(data_dir_ms) if is_image_file(x)] 198 | # self.pan_image_filenames = [join(data_dir_pan, x) for x in listdir(data_dir_pan) if is_image_file(x)] 199 | # 200 | # self.upscale_factor = cfg['data']['upsacle'] 201 | # self.transform = transform 202 | # self.data_augmentation = cfg['data']['data_augmentation'] 203 | # # self.normalize = cfg['data']['normalize'] 204 | # self.cfg = cfg 205 | # 206 | # def __getitem__(self, index): 207 | # 208 | # lms_image = load_img(self.ms_image_filenames[index]) 209 | # pan_image = load_img(self.pan_image_filenames[index]) 210 | # _, file = os.path.split(self.ms_image_filenames[index]) 211 | # lms_image = lms_image.crop((0, 0, lms_image.size[0] // self.upscale_factor * self.upscale_factor, 212 | # lms_image.size[1] // self.upscale_factor * self.upscale_factor)) 213 | # pan_image = pan_image.crop((0, 0, pan_image.size[0] // self.upscale_factor * self.upscale_factor, 214 | # pan_image.size[1] // self.upscale_factor * self.upscale_factor)) 215 | # # bms_image = rescale_img(lms_image, self.upscale_factor) 216 | # 217 | # if self.data_augmentation: 218 | # lms_image, pan_image, bms_image, _ = augment(lms_image, pan_image, bms_image) 219 | # 220 | # if self.transform: 221 | # lms_image = self.transform(lms_image) 222 | # pan_image = self.transform(pan_image) 223 | # # bms_image = self.transform(bms_image) 224 | # 225 | # # if self.normalize: 226 | # # lms_image = lms_image * 2 - 1 227 | # # pan_image = pan_image * 2 - 1 228 | # # bms_image = bms_image * 2 - 1 229 | # 230 | # return lms_image, pan_image, file 231 | # 232 | # def __len__(self): 233 | # return len(self.ms_image_filenames) -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from utils import psnr_loss, ssim, sam 4 | 5 | eps = torch.finfo(torch.float32).eps 6 | 7 | def get_metrics_reduced(img1, img2): 8 | # input: img1 {the pan-sharpened image}, img2 {the ground-truth image} 9 | # return: (larger better) psnr, ssim, scc, (smaller better) sam, ergas 10 | m1 = psnr_loss(img1, img2, 1.) 11 | m2 = ssim(img1, img2, 11, 'mean', 1.) 12 | m3 = cc(img1, img2) 13 | m4 = sam(img1, img2) 14 | m5 = ergas(img1, img2) 15 | return [m1.item(), m2.item(), m3.item(), m4.item(), m5.item()] 16 | 17 | def ergas(img_fake, img_real, scale=4): 18 | """ERGAS for (N, C, H, W) image; torch.float32 [0.,1.]. 19 | scale = spatial resolution of PAN / spatial resolution of MUL, default 4.""" 20 | 21 | N,C,H,W = img_real.shape 22 | means_real = img_real.reshape(N,C,-1).mean(dim=-1) 23 | mses = ((img_fake - img_real)**2).reshape(N,C,-1).mean(dim=-1) 24 | # Warning: There is a small value in the denominator for numerical stability. 25 | # Since the default dtype of torch is float32, our result may be slightly different from matlab or numpy based ERGAS 26 | 27 | return 100 / scale * torch.sqrt((mses / (means_real**2 + eps)).mean()) 28 | 29 | def cc(img1, img2): 30 | """Correlation coefficient for (N, C, H, W) image; torch.float32 [0.,1.].""" 31 | N,C,_,_ = img1.shape 32 | img1 = img1.reshape(N,C,-1) 33 | img2 = img2.reshape(N,C,-1) 34 | img1 = img1 - img1.mean(dim=-1, keepdim=True) 35 | img2 = img2 - img2.mean(dim=-1, keepdim=True) 36 | cc = torch.sum(img1 * img2, dim=-1) / ( eps + torch.sqrt(torch.sum(img1**2, dim=-1)) * torch.sqrt(torch.sum(img2**2, dim=-1)) ) 37 | cc = torch.clamp(cc, -1., 1.) 38 | return cc.mean(dim=-1) -------------------------------------------------------------------------------- /models/GPPNN.py: -------------------------------------------------------------------------------- 1 | from math import exp 2 | import torch 3 | from models.utils.CDC import cdcconv 4 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import numpy as np 9 | from models.modules import InvertibleConv1x1 10 | from models.refine import Refine,CALayer 11 | import torch.nn.init as init 12 | 13 | 14 | 15 | def initialize_weights(net_l, scale=1): 16 | if not isinstance(net_l, list): 17 | net_l = [net_l] 18 | for net in net_l: 19 | for m in net.modules(): 20 | if isinstance(m, nn.Conv2d): 21 | init.kaiming_normal_(m.weight, a=0, mode='fan_in') 22 | m.weight.data *= scale # for residual block 23 | if m.bias is not None: 24 | m.bias.data.zero_() 25 | elif isinstance(m, nn.Linear): 26 | init.kaiming_normal_(m.weight, a=0, mode='fan_in') 27 | m.weight.data *= scale 28 | if m.bias is not None: 29 | m.bias.data.zero_() 30 | elif isinstance(m, nn.BatchNorm2d): 31 | init.constant_(m.weight, 1) 32 | init.constant_(m.bias.data, 0.0) 33 | 34 | 35 | 36 | def initialize_weights_xavier(net_l, scale=1): 37 | if not isinstance(net_l, list): 38 | net_l = [net_l] 39 | for net in net_l: 40 | for m in net.modules(): 41 | if isinstance(m, nn.Conv2d): 42 | init.xavier_normal_(m.weight) 43 | m.weight.data *= scale # for residual block 44 | if m.bias is not None: 45 | m.bias.data.zero_() 46 | elif isinstance(m, nn.Linear): 47 | init.xavier_normal_(m.weight) 48 | m.weight.data *= scale 49 | if m.bias is not None: 50 | m.bias.data.zero_() 51 | elif isinstance(m, nn.BatchNorm2d): 52 | init.constant_(m.weight, 1) 53 | init.constant_(m.bias.data, 0.0) 54 | 55 | 56 | class UNetConvBlock(nn.Module): 57 | def __init__(self, in_size, out_size, relu_slope=0.1, use_HIN=True): 58 | super(UNetConvBlock, self).__init__() 59 | self.identity = nn.Conv2d(in_size, out_size, 1, 1, 0) 60 | 61 | self.conv_1 = nn.Conv2d(in_size, out_size, kernel_size=3, padding=1, bias=True) 62 | self.relu_1 = nn.LeakyReLU(relu_slope, inplace=False) 63 | self.conv_2 = nn.Conv2d(out_size, out_size, kernel_size=3, padding=1, bias=True) 64 | self.relu_2 = nn.LeakyReLU(relu_slope, inplace=False) 65 | 66 | if use_HIN: 67 | self.norm = nn.InstanceNorm2d(out_size // 2, affine=True) 68 | self.use_HIN = use_HIN 69 | 70 | def forward(self, x): 71 | out = self.conv_1(x) 72 | if self.use_HIN: 73 | out_1, out_2 = torch.chunk(out, 2, dim=1) 74 | out = torch.cat([self.norm(out_1), out_2], dim=1) 75 | out = self.relu_1(out) 76 | out = self.relu_2(self.conv_2(out)) 77 | out += self.identity(x) 78 | 79 | return out 80 | 81 | 82 | class DenseBlock(nn.Module): 83 | def __init__(self, channel_in, channel_out, init='xavier', gc=16, bias=True): 84 | super(DenseBlock, self).__init__() 85 | self.conv1 = UNetConvBlock(channel_in, gc) 86 | self.conv2 = UNetConvBlock(gc, gc) 87 | self.conv3 = nn.Conv2d(channel_in + 2 * gc, channel_out, 3, 1, 1, bias=bias) 88 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 89 | 90 | if init == 'xavier': 91 | initialize_weights_xavier([self.conv1, self.conv2, self.conv3], 0.1) 92 | else: 93 | initialize_weights([self.conv1, self.conv2, self.conv3], 0.1) 94 | # initialize_weights(self.conv5, 0) 95 | 96 | def forward(self, x): 97 | x1 = self.lrelu(self.conv1(x)) 98 | x2 = self.lrelu(self.conv2(x1)) 99 | x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) 100 | 101 | return x3 102 | 103 | class DenseBlockMscale(nn.Module): 104 | def __init__(self, channel_in, channel_out, init='xavier'): 105 | super(DenseBlockMscale, self).__init__() 106 | self.ops = DenseBlock(channel_in, channel_out, init) 107 | self.fusepool = nn.Sequential(nn.AdaptiveAvgPool2d(1),nn.Conv2d(channel_out,channel_out,1,1,0),nn.LeakyReLU(0.1,inplace=True)) 108 | self.fc1 = nn.Sequential(nn.Conv2d(channel_out,channel_out,1,1,0),nn.LeakyReLU(0.1,inplace=True)) 109 | self.fc2 = nn.Sequential(nn.Conv2d(channel_out, channel_out, 1, 1, 0), nn.LeakyReLU(0.1, inplace=True)) 110 | self.fc3 = nn.Sequential(nn.Conv2d(channel_out, channel_out, 1, 1, 0), nn.LeakyReLU(0.1, inplace=True)) 111 | self.fuse = nn.Conv2d(3*channel_out,channel_out,1,1,0) 112 | 113 | def forward(self, x): 114 | x1 = x 115 | x2 = F.interpolate(x1, scale_factor=0.5, mode='bilinear') 116 | x3 = F.interpolate(x1, scale_factor=0.25, mode='bilinear') 117 | x1 = self.ops(x1) 118 | x2 = self.ops(x2) 119 | x3 = self.ops(x3) 120 | x2 = F.interpolate(x2, size=(x1.size()[2], x1.size()[3]), mode='bilinear') 121 | x3 = F.interpolate(x3, size=(x1.size()[2], x1.size()[3]), mode='bilinear') 122 | xattw = self.fusepool(x1+x2+x3) 123 | xattw1 = self.fc1(xattw) 124 | xattw2 = self.fc2(xattw) 125 | xattw3 = self.fc3(xattw) 126 | # x = x1*xattw1+x2*xattw2+x3*xattw3 127 | x = self.fuse(torch.cat([x1*xattw1,x2*xattw2,x3*xattw3],1)) 128 | 129 | return x 130 | 131 | 132 | 133 | def subnet(net_structure, init='xavier'): 134 | def constructor(channel_in, channel_out): 135 | if net_structure == 'DBNet': 136 | if init == 'xavier': 137 | return DenseBlockMscale(channel_in, channel_out, init) 138 | else: 139 | return DenseBlockMscale(channel_in, channel_out) 140 | # return UNetBlock(channel_in, channel_out) 141 | else: 142 | return None 143 | 144 | return constructor 145 | 146 | 147 | class InvBlock(nn.Module): 148 | def __init__(self, subnet_constructor, channel_num, channel_split_num, clamp=0.8): 149 | super(InvBlock, self).__init__() 150 | # channel_num: 3 151 | # channel_split_num: 1 152 | 153 | self.split_len1 = channel_split_num # 1 154 | self.split_len2 = channel_num - channel_split_num # 2 155 | 156 | self.clamp = clamp 157 | 158 | self.F = subnet_constructor(self.split_len2, self.split_len1) 159 | self.G = subnet_constructor(self.split_len1, self.split_len2) 160 | self.H = subnet_constructor(self.split_len1, self.split_len2) 161 | 162 | in_channels = channel_num 163 | self.invconv = InvertibleConv1x1(in_channels, LU_decomposed=True) 164 | self.flow_permutation = lambda z, logdet, rev: self.invconv(z, logdet, rev) 165 | 166 | def forward(self, x, rev=False): 167 | # if not rev: 168 | # invert1x1conv 169 | x, logdet = self.flow_permutation(x, logdet=0, rev=False) 170 | 171 | # split to 1 channel and 2 channel. 172 | x1, x2 = (x.narrow(1, 0, self.split_len1), x.narrow(1, self.split_len1, self.split_len2)) 173 | 174 | y1 = x1 + self.F(x2) # 1 channel 175 | self.s = self.clamp * (torch.sigmoid(self.H(y1)) * 2 - 1) 176 | y2 = x2.mul(torch.exp(self.s)) + self.G(y1) # 2 channel 177 | out = torch.cat((y1, y2), 1) 178 | 179 | 180 | return out 181 | 182 | 183 | 184 | class FeatureInteract(nn.Module): 185 | def __init__(self, channel_in, channel_split_num, subnet_constructor=subnet('DBNet'), block_num=4): 186 | super(FeatureInteract, self).__init__() 187 | operations = [] 188 | 189 | # current_channel = channel_in 190 | channel_num = channel_in 191 | 192 | for j in range(block_num): 193 | b = InvBlock(subnet_constructor, channel_num, channel_split_num) # one block is one flow step. 194 | operations.append(b) 195 | 196 | self.operations = nn.ModuleList(operations) 197 | self.fuse = nn.Conv2d((block_num-1)*channel_in,channel_in,1,1,0) 198 | 199 | self.initialize() 200 | 201 | def initialize(self): 202 | for m in self.modules(): 203 | if isinstance(m, nn.Conv2d): 204 | init.xavier_normal_(m.weight) 205 | m.weight.data *= 1. # for residual block 206 | if m.bias is not None: 207 | m.bias.data.zero_() 208 | elif isinstance(m, nn.Linear): 209 | init.xavier_normal_(m.weight) 210 | m.weight.data *= 1. 211 | if m.bias is not None: 212 | m.bias.data.zero_() 213 | elif isinstance(m, nn.BatchNorm2d): 214 | init.constant_(m.weight, 1) 215 | init.constant_(m.bias.data, 0.0) 216 | 217 | def forward(self, x, rev=False): 218 | out = x # x: [N,3,H,W] 219 | outfuse = out 220 | for i,op in enumerate(self.operations): 221 | out = op.forward(out, rev) 222 | if i == 1: 223 | outfuse = out 224 | elif i > 1: 225 | outfuse = torch.cat([outfuse,out],1) 226 | # if i < 4: 227 | # out = out+x 228 | outfuse = self.fuse(outfuse) 229 | 230 | return outfuse 231 | 232 | 233 | def upsample(x, h, w): 234 | return F.interpolate(x, size=[h, w], mode='bicubic', align_corners=True) 235 | 236 | 237 | class GPPNN(nn.Module): 238 | def __init__(self, 239 | ms_channels, 240 | pan_channels, 241 | n_feat): 242 | super(GPPNN, self).__init__() 243 | self.extract_pan = FeatureExtract(pan_channels,n_feat//2) 244 | self.extract_ms = FeatureExtract(ms_channels,n_feat//2) 245 | 246 | # self.mulfuse_pan = Multual_fuse(n_feat//2,n_feat//2) 247 | # self.mulfuse_ms = Multual_fuse(n_feat // 2, n_feat // 2) 248 | 249 | self.interact = FeatureInteract(n_feat, n_feat//2) 250 | self.refine = Refine(n_feat, ms_channels) 251 | 252 | def forward(self, ms, i, pan=None): 253 | # ms - low-resolution multi-spectral image [N,C,h,w] 254 | # pan - high-resolution panchromatic image [N,1,H,W] 255 | if type(pan) == torch.Tensor: 256 | pass 257 | elif pan == None: 258 | raise Exception('User does not provide pan image!') 259 | _, _, m, n = ms.shape 260 | _, _, M, N = pan.shape 261 | 262 | mHR = upsample(ms, M, N) 263 | 264 | panf = self.extract_pan(pan) 265 | mHRf = self.extract_ms(mHR) 266 | 267 | feature_save(panf, '/home/jieh/Projects/PAN_Sharp/PansharpingMul/GPPNN/training/logs/GPPNN2/panf', i) 268 | feature_save(mHRf, '/home/jieh/Projects/PAN_Sharp/PansharpingMul/GPPNN/training/logs/GPPNN2/mHRf', i) 269 | 270 | finput = torch.cat([panf, mHRf], dim=1) 271 | fmid = self.interact(finput) 272 | HR = self.refine(fmid)+mHR 273 | 274 | return HR, panf, mHRf 275 | 276 | 277 | 278 | import os 279 | import cv2 280 | 281 | def feature_save(tensor,name,i): 282 | # tensor = torchvision.utils.make_grid(tensor.transpose(1,0)) 283 | tensor = torch.mean(tensor,dim=1) 284 | inp = tensor.detach().cpu().numpy().transpose(1,2,0) 285 | inp = inp.squeeze(2) 286 | inp = (inp - np.min(inp)) / (np.max(inp) - np.min(inp)) 287 | if not os.path.exists(name): 288 | os.makedirs(name) 289 | # for i in range(tensor.shape[1]): 290 | # inp = tensor[:,i,:,:].detach().cpu().numpy().transpose(1,2,0) 291 | # inp = np.clip(inp,0,1) 292 | # # inp = (inp-np.min(inp))/(np.max(inp)-np.min(inp)) 293 | # 294 | # cv2.imwrite(str(name)+'/'+str(i)+'.png',inp*255.0) 295 | inp = cv2.applyColorMap(np.uint8(inp * 255.0),cv2.COLORMAP_JET) 296 | cv2.imwrite(name + '/' + str(i) + '.png', inp) 297 | 298 | 299 | class EdgeBlock(nn.Module): 300 | def __init__(self, channelin, channelout): 301 | super(EdgeBlock, self).__init__() 302 | self.process = nn.Conv2d(channelin,channelout,3,1,1) 303 | self.Res = nn.Sequential(nn.Conv2d(channelout,channelout,3,1,1), 304 | nn.ReLU(),nn.Conv2d(channelout, channelout, 3, 1, 1)) 305 | self.CDC = cdcconv(channelin, channelout) 306 | 307 | def forward(self,x): 308 | 309 | x = self.process(x) 310 | out = self.Res(x) + self.CDC(x) 311 | 312 | return out 313 | 314 | class FeatureExtract(nn.Module): 315 | def __init__(self, channelin, channelout): 316 | super(FeatureExtract, self).__init__() 317 | self.conv = nn.Conv2d(channelin,channelout,1,1,0) 318 | self.block1 = EdgeBlock(channelout,channelout) 319 | self.block2 = EdgeBlock(channelout, channelout) 320 | 321 | def forward(self,x): 322 | xf = self.conv(x) 323 | xf1 = self.block1(xf) 324 | xf2 = self.block2(xf1) 325 | 326 | return xf2 327 | 328 | 329 | from torch.distributions import Normal, Independent, kl 330 | from torch.autograd import Variable 331 | CE = torch.nn.BCELoss(reduction='sum') 332 | 333 | 334 | class Mutual_info_reg(nn.Module): 335 | def __init__(self, input_channels, channels, latent_size = 4): 336 | super(Mutual_info_reg, self).__init__() 337 | self.contracting_path = nn.ModuleList() 338 | self.input_channels = input_channels 339 | self.relu = nn.ReLU(inplace=True) 340 | self.layer1 = nn.Conv2d(input_channels, channels, kernel_size=4, stride=2, padding=1) 341 | # self.bn1 = nn.BatchNorm2d(channels) 342 | self.layer2 = nn.Conv2d(input_channels, channels, kernel_size=4, stride=2, padding=1) 343 | # self.bn2 = nn.BatchNorm2d(channels) 344 | self.layer3 = nn.Conv2d(channels, channels, kernel_size=4, stride=2, padding=1) 345 | self.layer4 = nn.Conv2d(channels, channels, kernel_size=4, stride=2, padding=1) 346 | 347 | self.channel = channels 348 | 349 | # self.fc1_rgb1 = nn.Linear(channels * 1 * 16 * 16, latent_size) 350 | # self.fc2_rgb1 = nn.Linear(channels * 1 * 16 * 16, latent_size) 351 | # self.fc1_depth1 = nn.Linear(channels * 1 * 16 * 16, latent_size) 352 | # self.fc2_depth1 = nn.Linear(channels * 1 * 16 * 16, latent_size) 353 | # 354 | # self.fc1_rgb2 = nn.Linear(channels * 1 * 22 * 22, latent_size) 355 | # self.fc2_rgb2 = nn.Linear(channels * 1 * 22 * 22, latent_size) 356 | # self.fc1_depth2 = nn.Linear(channels * 1 * 22 * 22, latent_size) 357 | # self.fc2_depth2 = nn.Linear(channels * 1 * 22 * 22, latent_size) 358 | 359 | self.fc1_rgb3 = nn.Linear(channels * 1 * 32 * 32, latent_size) 360 | self.fc2_rgb3 = nn.Linear(channels * 1 * 32 * 32, latent_size) 361 | self.fc1_depth3 = nn.Linear(channels * 1 * 32 * 32, latent_size) 362 | self.fc2_depth3 = nn.Linear(channels * 1 * 32 * 32, latent_size) 363 | 364 | self.leakyrelu = nn.LeakyReLU() 365 | self.tanh = torch.nn.Tanh() 366 | 367 | def kl_divergence(self, posterior_latent_space, prior_latent_space): 368 | kl_div = kl.kl_divergence(posterior_latent_space, prior_latent_space) 369 | return kl_div 370 | 371 | def reparametrize(self, mu, logvar): 372 | std = logvar.mul(0.5).exp_() 373 | eps = torch.cuda.FloatTensor(std.size()).normal_() 374 | eps = Variable(eps) 375 | return eps.mul(std).add_(mu) 376 | 377 | def forward(self, rgb_feat, depth_feat): 378 | rgb_feat = self.layer3(self.leakyrelu(self.layer1(rgb_feat))) 379 | depth_feat = self.layer4(self.leakyrelu(self.layer2(depth_feat))) 380 | # print(rgb_feat.size()) 381 | # print(depth_feat.size()) 382 | # if rgb_feat.shape[2] == 16: 383 | # rgb_feat = rgb_feat.view(-1, self.channel * 1 * 16 * 16) 384 | # depth_feat = depth_feat.view(-1, self.channel * 1 * 16 * 16) 385 | # 386 | # mu_rgb = self.fc1_rgb1(rgb_feat) 387 | # logvar_rgb = self.fc2_rgb1(rgb_feat) 388 | # mu_depth = self.fc1_depth1(depth_feat) 389 | # logvar_depth = self.fc2_depth1(depth_feat) 390 | # elif rgb_feat.shape[2] == 22: 391 | # rgb_feat = rgb_feat.view(-1, self.channel * 1 * 22 * 22) 392 | # depth_feat = depth_feat.view(-1, self.channel * 1 * 22 * 22) 393 | # mu_rgb = self.fc1_rgb2(rgb_feat) 394 | # logvar_rgb = self.fc2_rgb2(rgb_feat) 395 | # mu_depth = self.fc1_depth2(depth_feat) 396 | # logvar_depth = self.fc2_depth2(depth_feat) 397 | # else: 398 | rgb_feat = rgb_feat.view(-1, self.channel * 1 * 32 * 32) 399 | depth_feat = depth_feat.view(-1, self.channel * 1 * 32 * 32) 400 | mu_rgb = self.fc1_rgb3(rgb_feat) 401 | logvar_rgb = self.fc2_rgb3(rgb_feat) 402 | mu_depth = self.fc1_depth3(depth_feat) 403 | logvar_depth = self.fc2_depth3(depth_feat) 404 | 405 | mu_depth = self.tanh(mu_depth) 406 | mu_rgb = self.tanh(mu_rgb) 407 | logvar_depth = self.tanh(logvar_depth) 408 | logvar_rgb = self.tanh(logvar_rgb) 409 | z_rgb = self.reparametrize(mu_rgb, logvar_rgb) 410 | dist_rgb = Independent(Normal(loc=mu_rgb, scale=torch.exp(logvar_rgb)), 1) 411 | z_depth = self.reparametrize(mu_depth, logvar_depth) 412 | dist_depth = Independent(Normal(loc=mu_depth, scale=torch.exp(logvar_depth)), 1) 413 | bi_di_kld = torch.mean(self.kl_divergence(dist_rgb, dist_depth)) + torch.mean( 414 | self.kl_divergence(dist_depth, dist_rgb)) 415 | z_rgb_norm = torch.sigmoid(z_rgb) 416 | z_depth_norm = torch.sigmoid(z_depth) 417 | ce_rgb_depth = CE(z_rgb_norm,z_depth_norm.detach()) 418 | ce_depth_rgb = CE(z_depth_norm, z_rgb_norm.detach()) 419 | latent_loss = ce_rgb_depth+ce_depth_rgb-bi_di_kld 420 | # latent_loss = torch.abs(cos_sim(z_rgb,z_depth)).sum() 421 | 422 | return latent_loss 423 | 424 | 425 | 426 | 427 | ########################################################################################################### 428 | 429 | 430 | 431 | # class Multual_fuse(nn.Module): 432 | # def __init__(self, in_channels, channels): 433 | # super(Multual_fuse, self).__init__() 434 | # self.convx = nn.Conv2d(in_channels,channels,3,1,1) 435 | # self.fuse = CALayer(channels*2,4) 436 | # self.convout = nn.Conv2d(channels*2,channels,3,1,1) 437 | # 438 | # def tile(self, a, dim, n_title): 439 | # init_dim = a.size(dim) 440 | # repeat_idx = [1] * a.dim() 441 | # repeat_idx[dim] = n_title 442 | # a = a.repeat(*(repeat_idx)) 443 | # order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_title) + i for i in range(init_dim)])).cuda() 444 | # return torch.index_select(a,dim,order_index) 445 | # 446 | # 447 | # def forward(self,x,y): 448 | # x = self.convx(x) 449 | # y = torch.unsqueeze(y, 2) 450 | # y = self.tile(y, 2, x.shape[2]) 451 | # y = torch.unsqueeze(y, 3) 452 | # y = self.tile(y, 3, x.shape[3]) 453 | # fusef = self.fuse(torch.cat([x,y],1)) 454 | # out = self.convout(fusef) 455 | # 456 | # return out 457 | 458 | -------------------------------------------------------------------------------- /models/modules.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import math 6 | 7 | 8 | def compute_same_pad(kernel_size, stride): 9 | if isinstance(kernel_size, int): 10 | kernel_size = [kernel_size] 11 | 12 | if isinstance(stride, int): 13 | stride = [stride] 14 | 15 | assert len(stride) == len( 16 | kernel_size 17 | ), "Pass kernel size and stride both as int, or both as equal length iterable" 18 | 19 | return [((k - 1) * s + 1) // 2 for k, s in zip(kernel_size, stride)] 20 | 21 | 22 | def uniform_binning_correction(x, n_bits=8): 23 | """Replaces x^i with q^i(x) = U(x, x + 1.0 / 256.0). 24 | 25 | Args: 26 | x: 4-D Tensor of shape (NCHW) 27 | n_bits: optional. 28 | Returns: 29 | x: x ~ U(x, x + 1.0 / 256) 30 | objective: Equivalent to -q(x)*log(q(x)). 31 | """ 32 | b, c, h, w = x.size() 33 | n_bins = 2 ** n_bits 34 | chw = c * h * w 35 | x += torch.zeros_like(x).uniform_(0, 1.0 / n_bins) 36 | 37 | objective = -math.log(n_bins) * chw * torch.ones(b, device=x.device) 38 | return x, objective 39 | 40 | 41 | def split_feature(tensor, type="split"): 42 | """ 43 | type = ["split", "cross"] 44 | """ 45 | C = tensor.size(1) 46 | if type == "split": 47 | # return tensor[:, : C // 2, ...], tensor[:, C // 2 :, ...] 48 | return tensor[:, :1, ...], tensor[:,1:, ...] 49 | elif type == "cross": 50 | # return tensor[:, 0::2, ...], tensor[:, 1::2, ...] 51 | return tensor[:, 0::2, ...], tensor[:, 1::2, ...] 52 | 53 | 54 | 55 | def gaussian_p(mean, logs, x): 56 | """ 57 | lnL = -1/2 * { ln|Var| + ((X - Mu)^T)(Var^-1)(X - Mu) + kln(2*PI) } 58 | k = 1 (Independent) 59 | Var = logs ** 2 60 | """ 61 | c = math.log(2 * math.pi) 62 | return -0.5 * (logs * 2.0 + ((x - mean) ** 2) / torch.exp(logs * 2.0) + c) 63 | 64 | 65 | def gaussian_likelihood(mean, logs, x): 66 | p = gaussian_p(mean, logs, x) 67 | return torch.sum(p, dim=[1, 2, 3]) 68 | 69 | 70 | def gaussian_sample(mean, logs, temperature=1): 71 | # Sample from Gaussian with temperature 72 | z = torch.normal(mean, torch.exp(logs) * temperature) 73 | 74 | return z 75 | 76 | 77 | def squeeze2d(input, factor): 78 | if factor == 1: 79 | return input 80 | 81 | B, C, H, W = input.size() 82 | 83 | assert H % factor == 0 and W % factor == 0, "H or W modulo factor is not 0" 84 | 85 | x = input.view(B, C, H // factor, factor, W // factor, factor) 86 | x = x.permute(0, 1, 3, 5, 2, 4).contiguous() 87 | x = x.view(B, C * factor * factor, H // factor, W // factor) 88 | 89 | return x 90 | 91 | 92 | def unsqueeze2d(input, factor): 93 | if factor == 1: 94 | return input 95 | 96 | factor2 = factor ** 2 97 | 98 | B, C, H, W = input.size() 99 | 100 | assert C % (factor2) == 0, "C module factor squared is not 0" 101 | 102 | x = input.view(B, C // factor2, factor, factor, H, W) 103 | x = x.permute(0, 1, 4, 2, 5, 3).contiguous() 104 | x = x.view(B, C // (factor2), H * factor, W * factor) 105 | 106 | return x 107 | 108 | 109 | class _ActNorm(nn.Module): 110 | """ 111 | Activation Normalization 112 | Initialize the bias and scale with a given minibatch, 113 | so that the output per-channel have zero mean and unit variance for that. 114 | 115 | After initialization, `bias` and `logs` will be trained as parameters. 116 | """ 117 | 118 | def __init__(self, num_features, scale=1.0): 119 | super().__init__() 120 | # register mean and scale 121 | size = [1, num_features, 1, 1] 122 | self.bias = nn.Parameter(torch.zeros(*size)) 123 | self.logs = nn.Parameter(torch.zeros(*size)) 124 | self.num_features = num_features 125 | self.scale = scale 126 | self.inited = False 127 | 128 | def initialize_parameters(self, input): 129 | if not self.training: 130 | raise ValueError("In Eval mode, but ActNorm not inited") 131 | 132 | with torch.no_grad(): 133 | bias = -torch.mean(input.clone(), dim=[0, 2, 3], keepdim=True) 134 | vars = torch.mean((input.clone() + bias) ** 2, dim=[0, 2, 3], keepdim=True) 135 | logs = torch.log(self.scale / (torch.sqrt(vars) + 1e-6)) 136 | 137 | self.bias.data.copy_(bias.data) 138 | self.logs.data.copy_(logs.data) 139 | 140 | self.inited = True 141 | 142 | def _center(self, input, reverse=False): 143 | if reverse: 144 | return input - self.bias 145 | else: 146 | return input + self.bias 147 | 148 | def _scale(self, input, logdet=None, reverse=False): 149 | 150 | if reverse: 151 | input = input * torch.exp(-self.logs) 152 | else: 153 | input = input * torch.exp(self.logs) 154 | 155 | if logdet is not None: 156 | """ 157 | logs is log_std of `mean of channels` 158 | so we need to multiply by number of pixels 159 | """ 160 | b, c, h, w = input.shape 161 | 162 | dlogdet = torch.sum(self.logs) * h * w 163 | 164 | if reverse: 165 | dlogdet *= -1 166 | 167 | logdet = logdet + dlogdet 168 | 169 | return input, logdet 170 | 171 | def forward(self, input, logdet=None, reverse=False): 172 | self._check_input_dim(input) 173 | 174 | if not self.inited: 175 | self.initialize_parameters(input) 176 | 177 | if reverse: 178 | input, logdet = self._scale(input, logdet, reverse) 179 | input = self._center(input, reverse) 180 | else: 181 | input = self._center(input, reverse) 182 | input, logdet = self._scale(input, logdet, reverse) 183 | 184 | return input, logdet 185 | 186 | 187 | class ActNorm2d(_ActNorm): 188 | def __init__(self, num_features, scale=1.0): 189 | super().__init__(num_features, scale) 190 | 191 | def _check_input_dim(self, input): 192 | assert len(input.size()) == 4 193 | assert input.size(1) == self.num_features, ( 194 | "[ActNorm]: input should be in shape as `BCHW`," 195 | " channels should be {} rather than {}".format( 196 | self.num_features, input.size() 197 | ) 198 | ) 199 | 200 | 201 | class LinearZeros(nn.Module): 202 | def __init__(self, in_channels, out_channels, logscale_factor=3): 203 | super().__init__() 204 | 205 | self.linear = nn.Linear(in_channels, out_channels) 206 | self.linear.weight.data.zero_() 207 | self.linear.bias.data.zero_() 208 | 209 | self.logscale_factor = logscale_factor 210 | 211 | self.logs = nn.Parameter(torch.zeros(out_channels)) 212 | 213 | def forward(self, input): 214 | output = self.linear(input) 215 | return output * torch.exp(self.logs * self.logscale_factor) 216 | 217 | 218 | class Conv2d(nn.Module): 219 | def __init__( 220 | self, 221 | in_channels, 222 | out_channels, 223 | kernel_size=(3, 3), 224 | stride=(1, 1), 225 | padding="same", 226 | do_actnorm=True, 227 | weight_std=0.05, 228 | ): 229 | super().__init__() 230 | 231 | if padding == "same": 232 | padding = compute_same_pad(kernel_size, stride) 233 | elif padding == "valid": 234 | padding = 0 235 | 236 | self.conv = nn.Conv2d( 237 | in_channels, 238 | out_channels, 239 | kernel_size, 240 | stride, 241 | padding, 242 | bias=(not do_actnorm), 243 | ) 244 | 245 | # init weight with std 246 | self.conv.weight.data.normal_(mean=0.0, std=weight_std) 247 | 248 | if not do_actnorm: 249 | self.conv.bias.data.zero_() 250 | else: 251 | self.actnorm = ActNorm2d(out_channels) 252 | 253 | self.do_actnorm = do_actnorm 254 | 255 | def forward(self, input): 256 | x = self.conv(input) 257 | if self.do_actnorm: 258 | x, _ = self.actnorm(x) 259 | return x 260 | 261 | 262 | class Conv2dZeros(nn.Module): 263 | def __init__( 264 | self, 265 | in_channels, 266 | out_channels, 267 | kernel_size=(3, 3), 268 | stride=(1, 1), 269 | padding="same", 270 | logscale_factor=3, 271 | ): 272 | super().__init__() 273 | 274 | if padding == "same": 275 | padding = compute_same_pad(kernel_size, stride) 276 | elif padding == "valid": 277 | padding = 0 278 | 279 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding) 280 | 281 | self.conv.weight.data.zero_() 282 | self.conv.bias.data.zero_() 283 | 284 | self.logscale_factor = logscale_factor 285 | self.logs = nn.Parameter(torch.zeros(out_channels, 1, 1)) 286 | 287 | def forward(self, input): 288 | output = self.conv(input) 289 | return output * torch.exp(self.logs * self.logscale_factor) 290 | 291 | 292 | class Permute2d(nn.Module): 293 | def __init__(self, num_channels, shuffle): 294 | super().__init__() 295 | self.num_channels = num_channels 296 | self.indices = torch.arange(self.num_channels - 1, -1, -1, dtype=torch.long) 297 | self.indices_inverse = torch.zeros((self.num_channels), dtype=torch.long) 298 | 299 | for i in range(self.num_channels): 300 | self.indices_inverse[self.indices[i]] = i 301 | 302 | if shuffle: 303 | self.reset_indices() 304 | 305 | def reset_indices(self): 306 | shuffle_idx = torch.randperm(self.indices.shape[0]) 307 | self.indices = self.indices[shuffle_idx] 308 | 309 | for i in range(self.num_channels): 310 | self.indices_inverse[self.indices[i]] = i 311 | 312 | def forward(self, input, reverse=False): 313 | assert len(input.size()) == 4 314 | 315 | if not reverse: 316 | input = input[:, self.indices, :, :] 317 | return input 318 | else: 319 | return input[:, self.indices_inverse, :, :] 320 | 321 | 322 | class Split2d(nn.Module): 323 | def __init__(self, num_channels): 324 | super().__init__() 325 | self.conv = Conv2dZeros(num_channels // 2, num_channels) 326 | 327 | def split2d_prior(self, z): 328 | h = self.conv(z) 329 | return split_feature(h, "cross") 330 | 331 | def forward(self, input, logdet=0.0, reverse=False, temperature=None): 332 | if reverse: 333 | z1 = input 334 | mean, logs = self.split2d_prior(z1) 335 | z2 = gaussian_sample(mean, logs, temperature) 336 | z = torch.cat((z1, z2), dim=1) 337 | return z, logdet 338 | else: 339 | z1, z2 = split_feature(input, "split") 340 | mean, logs = self.split2d_prior(z1) 341 | logdet = gaussian_likelihood(mean, logs, z2) + logdet 342 | return z1, logdet 343 | 344 | 345 | class SqueezeLayer(nn.Module): 346 | def __init__(self, factor): 347 | super().__init__() 348 | self.factor = factor 349 | 350 | def forward(self, input, logdet=None, reverse=False): 351 | if reverse: 352 | output = unsqueeze2d(input, self.factor) 353 | else: 354 | output = squeeze2d(input, self.factor) 355 | 356 | return output, logdet 357 | 358 | 359 | class InvertibleConv1x1(nn.Module): 360 | def __init__(self, num_channels, LU_decomposed): 361 | super().__init__() 362 | w_shape = [num_channels, num_channels] 363 | w_init = torch.qr(torch.randn(*w_shape))[0] 364 | 365 | if not LU_decomposed: 366 | self.weight = nn.Parameter(torch.Tensor(w_init)) 367 | else: 368 | p, lower, upper = torch.lu_unpack(*torch.lu(w_init)) 369 | s = torch.diag(upper) 370 | sign_s = torch.sign(s) 371 | log_s = torch.log(torch.abs(s)) 372 | upper = torch.triu(upper, 1) 373 | l_mask = torch.tril(torch.ones(w_shape), -1) 374 | eye = torch.eye(*w_shape) 375 | 376 | self.register_buffer("p", p) 377 | self.register_buffer("sign_s", sign_s) 378 | self.lower = nn.Parameter(lower) 379 | self.log_s = nn.Parameter(log_s) 380 | self.upper = nn.Parameter(upper) 381 | self.l_mask = l_mask 382 | self.eye = eye 383 | 384 | self.w_shape = w_shape 385 | self.LU_decomposed = LU_decomposed 386 | 387 | def get_weight(self, input, reverse): 388 | b, c, h, w = input.shape 389 | 390 | if not self.LU_decomposed: 391 | dlogdet = torch.slogdet(self.weight)[1] * h * w 392 | if reverse: 393 | weight = torch.inverse(self.weight) 394 | else: 395 | weight = self.weight 396 | else: 397 | self.l_mask = self.l_mask.to(input.device) 398 | self.eye = self.eye.to(input.device) 399 | 400 | lower = self.lower * self.l_mask + self.eye 401 | 402 | u = self.upper * self.l_mask.transpose(0, 1).contiguous() 403 | u += torch.diag(self.sign_s * torch.exp(self.log_s)) 404 | 405 | dlogdet = torch.sum(self.log_s) * h * w 406 | 407 | if reverse: 408 | u_inv = torch.inverse(u) 409 | l_inv = torch.inverse(lower) 410 | p_inv = torch.inverse(self.p) 411 | 412 | weight = torch.matmul(u_inv, torch.matmul(l_inv, p_inv)) 413 | else: 414 | weight = torch.matmul(self.p, torch.matmul(lower, u)) 415 | 416 | return weight.view(self.w_shape[0], self.w_shape[1], 1, 1), dlogdet 417 | 418 | def forward(self, input, logdet=None, reverse=False): 419 | """ 420 | log-det = log|abs(|W|)| * pixels 421 | """ 422 | weight, dlogdet = self.get_weight(input, reverse) 423 | 424 | if not reverse: 425 | z = F.conv2d(input, weight) 426 | if logdet is not None: 427 | logdet = logdet + dlogdet 428 | return z, logdet 429 | else: 430 | z = F.conv2d(input, weight) 431 | if logdet is not None: 432 | logdet = logdet - dlogdet 433 | return z, logdet 434 | -------------------------------------------------------------------------------- /models/refine.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchvision.transforms as transforms 5 | from torch.autograd import Variable 6 | import numpy as np 7 | import math 8 | from torch.nn import init 9 | import os 10 | import torchvision.transforms.functional as tf 11 | 12 | 13 | class DenseModule(nn.Module): 14 | def __init__(self, channel): 15 | super(DenseModule, self).__init__() 16 | self.conv1 = nn.Conv2d(channel,channel,3, 1, 1, bias=True) 17 | self.conv2 = nn.Conv2d(channel, channel, 3, 1, 1, bias=True) 18 | self.conv3 = nn.Conv2d(channel, channel, 3, 1, 1, bias=True) 19 | self.conv4 = nn.Conv2d(channel*4, channel, 1, 1, 0) 20 | self.act = nn.LeakyReLU(0.2,inplace=True) 21 | 22 | def forward(self, x): 23 | x1 = self.act(self.conv1(x)) 24 | x2 = self.act(self.conv2(x1)) 25 | x3 = self.act(self.conv3(x2)) 26 | x_final = self.conv4(torch.cat([x,x1,x2,x3],1)) 27 | 28 | return x_final 29 | 30 | class CALayer(nn.Module): 31 | def __init__(self, channel, reduction): 32 | super(CALayer, self).__init__() 33 | # global average pooling: feature --> point 34 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 35 | # feature channel downscale and upscale --> channel weight 36 | self.conv_du = nn.Sequential( 37 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), 38 | nn.ReLU(inplace=True), 39 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), 40 | nn.Sigmoid() 41 | ) 42 | self.process = nn.Sequential( 43 | nn.Conv2d(channel, channel, 3, stride=1, padding=1), 44 | nn.ReLU(), 45 | nn.Conv2d(channel, channel, 3, stride=1, padding=1) 46 | ) 47 | 48 | def forward(self, x): 49 | y = self.process(x) 50 | y = self.avg_pool(y) 51 | z = self.conv_du(y) 52 | return z * y + x 53 | 54 | 55 | 56 | class Refine(nn.Module): 57 | 58 | def __init__(self,n_feat,out_channels): 59 | super(Refine, self).__init__() 60 | 61 | self.conv_in = nn.Conv2d(n_feat, n_feat, 3, stride=1, padding=1) 62 | self.process = nn.Sequential( 63 | CALayer(n_feat,4), 64 | CALayer(n_feat,4)) 65 | self.conv_last = nn.Conv2d(in_channels=n_feat, out_channels=out_channels, kernel_size=3, stride=1, padding=1) 66 | 67 | 68 | def forward(self, x): 69 | 70 | out = self.conv_in(x) 71 | out = self.process(out) 72 | out = self.conv_last(out) 73 | 74 | return out 75 | 76 | 77 | class Refine1(nn.Module): 78 | 79 | def __init__(self,in_channels,panchannels,n_feat): 80 | super(Refine1, self).__init__() 81 | 82 | self.conv_in = nn.Conv2d(n_feat, n_feat, 3, stride=1, padding=1) 83 | self.process = nn.Sequential( 84 | # CALayer(n_feat,4), 85 | # CALayer(n_feat,4), 86 | CALayer(n_feat,4)) 87 | self.conv_last = nn.Conv2d(in_channels=n_feat, out_channels=in_channels-panchannels, kernel_size=3, stride=1, padding=1) 88 | 89 | 90 | def forward(self, x): 91 | 92 | out = self.conv_in(x) 93 | out = self.process(out) 94 | out = self.conv_last(out) 95 | 96 | return out 97 | 98 | 99 | -------------------------------------------------------------------------------- /models/utils/CDC.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | import torch.utils.model_zoo as model_zoo 6 | from torch import nn 7 | from torch.nn import Parameter 8 | import pdb 9 | import numpy as np 10 | 11 | 12 | class cdc_vg(nn.Module): 13 | def __init__(self, mid_ch, theta=0.7): 14 | 15 | super(cdc_vg, self).__init__() 16 | 17 | self.cdc = Conv2d_cd(mid_ch, mid_ch, kernel_size=3, stride=1, padding=1, bias=False, theta= theta) 18 | self.cdc_bn = nn.BatchNorm2d(mid_ch) 19 | self.cdc_act = nn.PReLU() 20 | 21 | self.h_conv = Conv2d_Hori_Veri_Cross(in_channels=mid_ch, out_channels=mid_ch, kernel_size=3, stride=1, padding=1, bias=False, theta=theta) 22 | self.d_conv = Conv2d_Diag_Cross(in_channels=mid_ch, out_channels=mid_ch, kernel_size=3, stride=1, padding=1, bias=False, theta=theta) 23 | self.vg_bn = nn.BatchNorm2d(mid_ch) 24 | self.vg_act = nn.PReLU() 25 | 26 | # self.HP_branch = Parameter(torch.FloatTensor(1)) 27 | 28 | def forward(self, x): 29 | out_0 = self.cdc_act(self.cdc_bn(self.cdc(x))) 30 | 31 | out1 = self.h_conv(out_0) 32 | out2 = self.d_conv(out_0) 33 | out = self.vg_act(self.vg_bn(0.5 * out1 + 0.5 * out2)) 34 | # out = out1 + out2 35 | return out + x 36 | 37 | 38 | class ResBlock_cdc(nn.Module): 39 | def __init__( 40 | self, conv, n_feats, kernel_size, 41 | bias=True, bn=False, act=nn.PReLU(), res_scale=1, theta=0.8): 42 | 43 | super(ResBlock_cdc, self).__init__() 44 | m = [] 45 | for i in range(2): 46 | m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) 47 | if bn: 48 | m.append(nn.BatchNorm2d(n_feats)) 49 | if i == 0: 50 | m.append(act) 51 | 52 | self.body = nn.Sequential(*m) 53 | self.res_scale = res_scale 54 | 55 | self.h_conv = Conv2d_Hori_Veri_Cross(in_channels=n_feats, out_channels=n_feats, kernel_size=3, 56 | stride=1, padding=1, bias=False, theta=theta) 57 | self.d_conv = Conv2d_Diag_Cross(in_channels=n_feats, out_channels=n_feats, kernel_size=3, stride=1, 58 | padding=1, bias=False, theta=theta) 59 | # self.HP_branch = Parameter(torch.FloatTensor(1)) 60 | 61 | def forward(self, x): 62 | res = self.body(x).mul(self.res_scale) 63 | res += x 64 | 65 | out1 = self.h_conv(x) 66 | out2 = self.d_conv(x) 67 | # out = torch.sigmoid(self.HP_branch) * out1 + (1 - torch.sigmoid(self.HP_branch)) * out2 68 | out = out1 + out2 69 | 70 | res += x + out 71 | 72 | return res 73 | 74 | 75 | class cdcconv(nn.Module): 76 | def __init__(self, in_channels, out_channels, theta=0.8): 77 | 78 | super(cdcconv, self).__init__() 79 | 80 | self.h_conv = Conv2d_Hori_Veri_Cross(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1, bias=False, theta=theta) 81 | self.d_conv = Conv2d_Diag_Cross(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1, bias=False, theta=theta) 82 | 83 | self.HP_branch = Parameter(torch.FloatTensor(1)) 84 | 85 | def forward(self, x): 86 | out1 = self.h_conv(x) 87 | out2 = self.d_conv(x) 88 | out = torch.sigmoid(self.HP_branch) * out1 + (1 - torch.sigmoid(self.HP_branch)) * out2 + x 89 | # out = out1 + out2 90 | return out 91 | 92 | 93 | class Conv2d_cd(nn.Module): 94 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, 95 | padding=1, dilation=1, groups=1, bias=False, theta=0.7): 96 | 97 | super(Conv2d_cd, self).__init__() 98 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) 99 | self.theta = theta 100 | 101 | def forward(self, x): 102 | out_normal = self.conv(x) 103 | 104 | if math.fabs(self.theta - 0.0) < 1e-8: 105 | return out_normal 106 | else: 107 | #pdb.set_trace() 108 | [C_out,C_in, kernel_size,kernel_size] = self.conv.weight.shape 109 | kernel_diff = self.conv.weight.sum(2).sum(2) 110 | kernel_diff = kernel_diff[:, :, None, None] 111 | out_diff = F.conv2d(input=x, weight=kernel_diff, bias=self.conv.bias, stride=self.conv.stride, padding=0, groups=self.conv.groups) 112 | 113 | return out_normal - self.theta * out_diff 114 | 115 | 116 | 117 | class Conv2d_Hori_Veri_Cross(nn.Module): 118 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, 119 | padding=1, dilation=1, groups=1, bias=False, theta=0.7): 120 | 121 | super(Conv2d_Hori_Veri_Cross, self).__init__() 122 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=(1, 5), stride=stride, padding=padding, 123 | dilation=dilation, groups=groups, bias=bias) 124 | self.theta = theta 125 | 126 | def forward(self, x): 127 | [C_out, C_in, H_k, W_k] = self.conv.weight.shape 128 | tensor_zeros = torch.FloatTensor(C_out, C_in, 1).fill_(0).cuda() 129 | conv_weight = torch.cat((tensor_zeros, self.conv.weight[:, :, :, 0], tensor_zeros, self.conv.weight[:, :, :, 1], 130 | self.conv.weight[:, :, :, 2], self.conv.weight[:, :, :, 3], tensor_zeros, 131 | self.conv.weight[:, :, :, 4], tensor_zeros), 2) 132 | conv_weight = conv_weight.contiguous().view(C_out, C_in, 3, 3) 133 | 134 | out_normal = F.conv2d(input=x, weight=conv_weight, bias=self.conv.bias, stride=self.conv.stride, 135 | padding=self.conv.padding) 136 | 137 | if math.fabs(self.theta - 0.0) < 1e-8: 138 | return out_normal 139 | else: 140 | # pdb.set_trace() 141 | [C_out, C_in, kernel_size, kernel_size] = self.conv.weight.shape 142 | kernel_diff = self.conv.weight.sum(2).sum(2) 143 | kernel_diff = kernel_diff[:, :, None, None] 144 | out_diff = F.conv2d(input=x, weight=kernel_diff, bias=self.conv.bias, stride=self.conv.stride, padding=0, 145 | groups=self.conv.groups) 146 | 147 | return out_normal - self.theta * out_diff 148 | 149 | 150 | class Conv2d_Diag_Cross(nn.Module): 151 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, 152 | padding=1, dilation=1, groups=1, bias=False, theta=0.7): 153 | 154 | super(Conv2d_Diag_Cross, self).__init__() 155 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=(1, 5), stride=stride, padding=padding, 156 | dilation=dilation, groups=groups, bias=bias) 157 | self.theta = theta 158 | 159 | def forward(self, x): 160 | 161 | [C_out, C_in, H_k, W_k] = self.conv.weight.shape 162 | tensor_zeros = torch.FloatTensor(C_out, C_in, 1).fill_(0).cuda() 163 | conv_weight = torch.cat((self.conv.weight[:, :, :, 0], tensor_zeros, self.conv.weight[:, :, :, 1], tensor_zeros, 164 | self.conv.weight[:, :, :, 2], tensor_zeros, self.conv.weight[:, :, :, 3], tensor_zeros, 165 | self.conv.weight[:, :, :, 4]), 2) 166 | conv_weight = conv_weight.contiguous().view(C_out, C_in, 3, 3) 167 | 168 | out_normal = F.conv2d(input=x, weight=conv_weight, bias=self.conv.bias, stride=self.conv.stride, 169 | padding=self.conv.padding) 170 | 171 | if math.fabs(self.theta - 0.0) < 1e-8: 172 | return out_normal 173 | else: 174 | # pdb.set_trace() 175 | [C_out, C_in, kernel_size, kernel_size] = self.conv.weight.shape 176 | kernel_diff = self.conv.weight.sum(2).sum(2) 177 | kernel_diff = kernel_diff[:, :, None, None] 178 | out_diff = F.conv2d(input=x, weight=kernel_diff, bias=self.conv.bias, stride=self.conv.stride, padding=0, 179 | groups=self.conv.groups) 180 | 181 | return out_normal - self.theta * out_diff 182 | 183 | 184 | def default_conv(in_channels, out_channels, kernel_size,stride=1, bias=True): 185 | return nn.Conv2d( 186 | in_channels, out_channels, kernel_size, 187 | padding=(kernel_size//2),stride=stride, bias=bias) 188 | -------------------------------------------------------------------------------- /models/utils/Inv_modules.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from models.utils.Inv_utils import split_feature, compute_same_pad 7 | 8 | 9 | def gaussian_p(mean, logs, x): 10 | """ 11 | lnL = -1/2 * { ln|Var| + ((X - Mu)^T)(Var^-1)(X - Mu) + kln(2*PI) } 12 | k = 1 (Independent) 13 | Var = logs ** 2 14 | """ 15 | c = math.log(2 * math.pi) 16 | return -0.5 * (logs * 2.0 + ((x - mean) ** 2) / torch.exp(logs * 2.0) + c) 17 | 18 | 19 | def gaussian_likelihood(mean, logs, x): 20 | p = gaussian_p(mean, logs, x) 21 | return torch.sum(p, dim=[1, 2, 3]) 22 | 23 | 24 | def gaussian_sample(mean, logs, temperature=1): 25 | # Sample from Gaussian with temperature 26 | z = torch.normal(mean, torch.exp(logs) * temperature) 27 | 28 | return z 29 | 30 | 31 | def squeeze2d(input, factor): 32 | if factor == 1: 33 | return input 34 | 35 | B, C, H, W = input.size() 36 | 37 | assert H % factor == 0 and W % factor == 0, "H or W modulo factor is not 0" 38 | 39 | x = input.view(B, C, H // factor, factor, W // factor, factor) 40 | x = x.permute(0, 1, 3, 5, 2, 4).contiguous() 41 | x = x.view(B, C * factor * factor, H // factor, W // factor) 42 | 43 | return x 44 | 45 | 46 | def unsqueeze2d(input, factor): 47 | if factor == 1: 48 | return input 49 | 50 | factor2 = factor ** 2 51 | 52 | B, C, H, W = input.size() 53 | 54 | assert C % (factor2) == 0, "C module factor squared is not 0" 55 | 56 | x = input.view(B, C // factor2, factor, factor, H, W) 57 | x = x.permute(0, 1, 4, 2, 5, 3).contiguous() 58 | x = x.view(B, C // (factor2), H * factor, W * factor) 59 | 60 | return x 61 | 62 | 63 | class _ActNorm(nn.Module): 64 | """ 65 | Activation Normalization 66 | Initialize the bias and scale with a given minibatch, 67 | so that the output per-channel have zero mean and unit variance for that. 68 | 69 | After initialization, `bias` and `logs` will be trained as parameters. 70 | """ 71 | 72 | def __init__(self, num_features, scale=1.0): 73 | super().__init__() 74 | # register mean and scale 75 | size = [1, num_features, 1, 1] 76 | self.bias = nn.Parameter(torch.zeros(*size)) 77 | self.logs = nn.Parameter(torch.zeros(*size)) 78 | self.num_features = num_features 79 | self.scale = scale 80 | self.inited = False 81 | 82 | def initialize_parameters(self, input): 83 | if not self.training: 84 | raise ValueError("In Eval mode, but ActNorm not inited") 85 | 86 | with torch.no_grad(): 87 | bias = -torch.mean(input.clone(), dim=[0, 2, 3], keepdim=True) 88 | vars = torch.mean((input.clone() + bias) ** 2, dim=[0, 2, 3], keepdim=True) 89 | logs = torch.log(self.scale / (torch.sqrt(vars) + 1e-6)) 90 | 91 | self.bias.data.copy_(bias.data) 92 | self.logs.data.copy_(logs.data) 93 | 94 | self.inited = True 95 | 96 | def _center(self, input, reverse=False): 97 | if reverse: 98 | return input - self.bias 99 | else: 100 | return input + self.bias 101 | 102 | def _scale(self, input, logdet=None, reverse=False): 103 | 104 | if reverse: 105 | input = input * torch.exp(-self.logs) 106 | else: 107 | input = input * torch.exp(self.logs) 108 | 109 | if logdet is not None: 110 | """ 111 | logs is log_std of `mean of channels` 112 | so we need to multiply by number of pixels 113 | """ 114 | b, c, h, w = input.shape 115 | 116 | dlogdet = torch.sum(self.logs) * h * w 117 | 118 | if reverse: 119 | dlogdet *= -1 120 | 121 | logdet = logdet + dlogdet 122 | 123 | return input, logdet 124 | 125 | def forward(self, input, logdet=None, reverse=False): 126 | self._check_input_dim(input) 127 | 128 | if not self.inited: 129 | self.initialize_parameters(input) 130 | 131 | if reverse: 132 | input, logdet = self._scale(input, logdet, reverse) 133 | input = self._center(input, reverse) 134 | else: 135 | input = self._center(input, reverse) 136 | input, logdet = self._scale(input, logdet, reverse) 137 | 138 | return input, logdet 139 | 140 | 141 | class ActNorm2d(_ActNorm): 142 | def __init__(self, num_features, scale=1.0): 143 | super().__init__(num_features, scale) 144 | 145 | def _check_input_dim(self, input): 146 | assert len(input.size()) == 4 147 | assert input.size(1) == self.num_features, ( 148 | "[ActNorm]: input should be in shape as `BCHW`," 149 | " channels should be {} rather than {}".format( 150 | self.num_features, input.size() 151 | ) 152 | ) 153 | 154 | 155 | class LinearZeros(nn.Module): 156 | def __init__(self, in_channels, out_channels, logscale_factor=3): 157 | super().__init__() 158 | 159 | self.linear = nn.Linear(in_channels, out_channels) 160 | self.linear.weight.data.zero_() 161 | self.linear.bias.data.zero_() 162 | 163 | self.logscale_factor = logscale_factor 164 | 165 | self.logs = nn.Parameter(torch.zeros(out_channels)) 166 | 167 | def forward(self, input): 168 | output = self.linear(input) 169 | return output * torch.exp(self.logs * self.logscale_factor) 170 | 171 | 172 | class Conv2d(nn.Module): 173 | def __init__( 174 | self, 175 | in_channels, 176 | out_channels, 177 | kernel_size=(3, 3), 178 | stride=(1, 1), 179 | padding="same", 180 | do_actnorm=True, 181 | weight_std=0.05, 182 | ): 183 | super().__init__() 184 | 185 | if padding == "same": 186 | padding = compute_same_pad(kernel_size, stride) 187 | elif padding == "valid": 188 | padding = 0 189 | 190 | self.conv = nn.Conv2d( 191 | in_channels, 192 | out_channels, 193 | kernel_size, 194 | stride, 195 | padding, 196 | bias=(not do_actnorm), 197 | ) 198 | 199 | # init weight with std 200 | self.conv.weight.data.normal_(mean=0.0, std=weight_std) 201 | 202 | if not do_actnorm: 203 | self.conv.bias.data.zero_() 204 | else: 205 | self.actnorm = ActNorm2d(out_channels) 206 | 207 | self.do_actnorm = do_actnorm 208 | 209 | def forward(self, input): 210 | x = self.conv(input) 211 | if self.do_actnorm: 212 | x, _ = self.actnorm(x) 213 | return x 214 | 215 | 216 | class Conv2dZeros(nn.Module): 217 | def __init__( 218 | self, 219 | in_channels, 220 | out_channels, 221 | kernel_size=(3, 3), 222 | stride=(1, 1), 223 | padding="same", 224 | logscale_factor=3, 225 | ): 226 | super().__init__() 227 | 228 | if padding == "same": 229 | padding = compute_same_pad(kernel_size, stride) 230 | elif padding == "valid": 231 | padding = 0 232 | 233 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding) 234 | 235 | self.conv.weight.data.zero_() 236 | self.conv.bias.data.zero_() 237 | 238 | self.logscale_factor = logscale_factor 239 | self.logs = nn.Parameter(torch.zeros(out_channels, 1, 1)) 240 | 241 | def forward(self, input): 242 | output = self.conv(input) 243 | return output * torch.exp(self.logs * self.logscale_factor) 244 | 245 | 246 | class Permute2d(nn.Module): 247 | def __init__(self, num_channels, shuffle): 248 | super().__init__() 249 | self.num_channels = num_channels 250 | self.indices = torch.arange(self.num_channels - 1, -1, -1, dtype=torch.long) 251 | self.indices_inverse = torch.zeros((self.num_channels), dtype=torch.long) 252 | 253 | for i in range(self.num_channels): 254 | self.indices_inverse[self.indices[i]] = i 255 | 256 | if shuffle: 257 | self.reset_indices() 258 | 259 | def reset_indices(self): 260 | shuffle_idx = torch.randperm(self.indices.shape[0]) 261 | self.indices = self.indices[shuffle_idx] 262 | 263 | for i in range(self.num_channels): 264 | self.indices_inverse[self.indices[i]] = i 265 | 266 | def forward(self, input, reverse=False): 267 | assert len(input.size()) == 4 268 | 269 | if not reverse: 270 | input = input[:, self.indices, :, :] 271 | return input 272 | else: 273 | return input[:, self.indices_inverse, :, :] 274 | 275 | 276 | class Split2d(nn.Module): 277 | def __init__(self, num_channels): 278 | super().__init__() 279 | self.conv = Conv2dZeros(num_channels // 2, num_channels) 280 | 281 | def split2d_prior(self, z): 282 | h = self.conv(z) 283 | return split_feature(h, "cross") 284 | 285 | def forward(self, input, logdet=0.0, reverse=False, temperature=None): 286 | if reverse: 287 | z1 = input 288 | mean, logs = self.split2d_prior(z1) 289 | z2 = gaussian_sample(mean, logs, temperature) 290 | z = torch.cat((z1, z2), dim=1) 291 | return z, logdet 292 | else: 293 | z1, z2 = split_feature(input, "split") 294 | mean, logs = self.split2d_prior(z1) 295 | logdet = gaussian_likelihood(mean, logs, z2) + logdet 296 | return z1, logdet 297 | 298 | 299 | class SqueezeLayer(nn.Module): 300 | def __init__(self, factor): 301 | super().__init__() 302 | self.factor = factor 303 | 304 | def forward(self, input, logdet=None, reverse=False): 305 | if reverse: 306 | output = unsqueeze2d(input, self.factor) 307 | else: 308 | output = squeeze2d(input, self.factor) 309 | 310 | return output, logdet 311 | 312 | 313 | class InvertibleConv1x1(nn.Module): 314 | def __init__(self, num_channels, LU_decomposed): 315 | super().__init__() 316 | w_shape = [num_channels, num_channels] 317 | w_init = torch.qr(torch.randn(*w_shape))[0] 318 | 319 | if not LU_decomposed: 320 | self.weight = nn.Parameter(torch.Tensor(w_init)) 321 | else: 322 | p, lower, upper = torch.lu_unpack(*torch.lu(w_init)) 323 | s = torch.diag(upper) 324 | sign_s = torch.sign(s) 325 | log_s = torch.log(torch.abs(s)) 326 | upper = torch.triu(upper, 1) 327 | l_mask = torch.tril(torch.ones(w_shape), -1) 328 | eye = torch.eye(*w_shape) 329 | 330 | self.register_buffer("p", p) 331 | self.register_buffer("sign_s", sign_s) 332 | self.lower = nn.Parameter(lower) 333 | self.log_s = nn.Parameter(log_s) 334 | self.upper = nn.Parameter(upper) 335 | self.l_mask = l_mask 336 | self.eye = eye 337 | 338 | self.w_shape = w_shape 339 | self.LU_decomposed = LU_decomposed 340 | 341 | def get_weight(self, input, reverse): 342 | b, c, h, w = input.shape 343 | 344 | if not self.LU_decomposed: 345 | dlogdet = torch.slogdet(self.weight)[1] * h * w 346 | if reverse: 347 | weight = torch.inverse(self.weight) 348 | else: 349 | weight = self.weight 350 | else: 351 | self.l_mask = self.l_mask.to(input.device) 352 | self.eye = self.eye.to(input.device) 353 | 354 | lower = self.lower * self.l_mask + self.eye 355 | 356 | u = self.upper * self.l_mask.transpose(0, 1).contiguous() 357 | u += torch.diag(self.sign_s * torch.exp(self.log_s)) 358 | 359 | dlogdet = torch.sum(self.log_s) * h * w 360 | 361 | if reverse: 362 | u_inv = torch.inverse(u) 363 | l_inv = torch.inverse(lower) 364 | p_inv = torch.inverse(self.p) 365 | 366 | weight = torch.matmul(u_inv, torch.matmul(l_inv, p_inv)) 367 | else: 368 | weight = torch.matmul(self.p, torch.matmul(lower, u)) 369 | 370 | return weight.view(self.w_shape[0], self.w_shape[1], 1, 1), dlogdet 371 | 372 | def forward(self, input, logdet=None, reverse=False): 373 | """ 374 | log-det = log|abs(|W|)| * pixels 375 | """ 376 | weight, dlogdet = self.get_weight(input, reverse) 377 | 378 | if not reverse: 379 | z = F.conv2d(input, weight) 380 | if logdet is not None: 381 | logdet = logdet + dlogdet 382 | return z, logdet 383 | else: 384 | z = F.conv2d(input, weight) 385 | if logdet is not None: 386 | logdet = logdet - dlogdet 387 | return z, logdet 388 | -------------------------------------------------------------------------------- /models/utils/Inv_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | def compute_same_pad(kernel_size, stride): 5 | if isinstance(kernel_size, int): 6 | kernel_size = [kernel_size] 7 | 8 | if isinstance(stride, int): 9 | stride = [stride] 10 | 11 | assert len(stride) == len( 12 | kernel_size 13 | ), "Pass kernel size and stride both as int, or both as equal length iterable" 14 | 15 | return [((k - 1) * s + 1) // 2 for k, s in zip(kernel_size, stride)] 16 | 17 | 18 | def uniform_binning_correction(x, n_bits=8): 19 | """Replaces x^i with q^i(x) = U(x, x + 1.0 / 256.0). 20 | 21 | Args: 22 | x: 4-D Tensor of shape (NCHW) 23 | n_bits: optional. 24 | Returns: 25 | x: x ~ U(x, x + 1.0 / 256) 26 | objective: Equivalent to -q(x)*log(q(x)). 27 | """ 28 | b, c, h, w = x.size() 29 | n_bins = 2 ** n_bits 30 | chw = c * h * w 31 | x += torch.zeros_like(x).uniform_(0, 1.0 / n_bins) 32 | 33 | objective = -math.log(n_bins) * chw * torch.ones(b, device=x.device) 34 | return x, objective 35 | 36 | 37 | def split_feature(tensor, type="split"): 38 | """ 39 | type = ["split", "cross"] 40 | """ 41 | C = tensor.size(1) 42 | if type == "split": 43 | # return tensor[:, : C // 2, ...], tensor[:, C // 2 :, ...] 44 | return tensor[:, :1, ...], tensor[:,1:, ...] 45 | elif type == "cross": 46 | # return tensor[:, 0::2, ...], tensor[:, 1::2, ...] 47 | return tensor[:, 0::2, ...], tensor[:, 1::2, ...] 48 | 49 | 50 | 51 | 52 | -------------------------------------------------------------------------------- /models/utils/__pycache__/CDC.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manman1995/Mutual-Information-driven-Pan-sharpening/5915c65f3f24ddfcf5d528d80d7b35b7f2ae9de3/models/utils/__pycache__/CDC.cpython-38.pyc -------------------------------------------------------------------------------- /models/utils/__pycache__/Inv_modules.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manman1995/Mutual-Information-driven-Pan-sharpening/5915c65f3f24ddfcf5d528d80d7b35b7f2ae9de3/models/utils/__pycache__/Inv_modules.cpython-38.pyc -------------------------------------------------------------------------------- /models/utils/__pycache__/Inv_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manman1995/Mutual-Information-driven-Pan-sharpening/5915c65f3f24ddfcf5d528d80d7b35b7f2ae9de3/models/utils/__pycache__/Inv_utils.cpython-38.pyc -------------------------------------------------------------------------------- /training/train_GPPNN.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | ''' 4 | ------------------------------------------------------------------------------ 5 | Import packages 6 | ------------------------------------------------------------------------------ 7 | ''' 8 | 9 | import os 10 | # import xlwt 11 | import time 12 | import datetime 13 | import numpy as np 14 | 15 | import torch 16 | import torch.nn as nn 17 | 18 | from tensorboardX import SummaryWriter 19 | from torch.utils.data import DataLoader 20 | from scipy.io import savemat 21 | import cv2 22 | import sys 23 | 24 | sys.path.append("/home/jieh/Projects/PAN_Sharp/PansharpingMul/GPPNN/") 25 | from models import get_sat_param 26 | from models.GPPNN2 import GPPNN as GPPNN, Mutual_info_reg 27 | from metrics import get_metrics_reduced 28 | from utils import PSH5Datasetfu, PSDataset, prepare_data, normlization, save_param, psnr_loss, ssim, save_img 29 | from data import Data 30 | 31 | ''' 32 | ------------------------------------------------------------------------------ 33 | Configure our network 34 | ------------------------------------------------------------------------------ 35 | ''' 36 | 37 | model_str = 'GPPNN2' 38 | satellite_str = 'WV2' 39 | 40 | # . Get the parameters of your satellite 41 | # sat_param = get_sat_param(satellite_str) 42 | # if sat_param!=None: 43 | # ms_channels, pan_channels, scale = sat_param 44 | # else: 45 | # print('You should specify `ms_channels`, `pan_channels` and `scale`! ') 46 | ms_channels = 4 47 | pan_channels = 1 48 | scale = 4 49 | 50 | # . Set the hyper-parameters for training 51 | num_epochs = 1000 52 | lr = 1e-3 53 | weight_decay = 0 54 | batch_size = 4 55 | n_layer = 8 56 | n_feat = 8 57 | 58 | # . Get your model 59 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 60 | torch.cuda.set_device(0) 61 | net = GPPNN(ms_channels,pan_channels,n_feat).to(device) 62 | print(net) 63 | 64 | mutual = Mutual_info_reg(n_feat//2,n_feat//2).to(device) 65 | 66 | predir = '/home/jieh/Projects/PAN_Sharp/PansharpingMul/GPPNN/training/pretrain4' 67 | if os.path.exists(predir): 68 | net.load_state_dict(torch.load(os.path.join(predir, 'best_net.pth'))['net'],strict=False) 69 | print("load pretrained model successfully") 70 | 71 | # . Get your optimizer, scheduler and loss function 72 | optimizer = torch.optim.Adam(net.parameters(), lr=lr, weight_decay=weight_decay) 73 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=200, gamma=0.5) 74 | loss_fn = nn.L1Loss().to(device) 75 | 76 | # . Create your data loaders 77 | # prepare_data_flag = False # set it to False, if you have prepared dataset 78 | # train_path = '../PS_data/%s/%s_train.h5'%(satellite_str,satellite_str) 79 | # train_path = '/home/jieh/Projects/PAN_Sharp/GPPNN-main/PS_data/data/%s/train.mat'%(satellite_str) 80 | # #validation_path = '../PS_data/%s/validation'%(satellite_str) 81 | # # validation_path = '/home/jieh/Projects/PAN_Sharp/GPPNN-main/PS_data/data/%s/test.mat'%(satellite_str) 82 | # test_path = '/home/jieh/Projects/PAN_Sharp/GPPNN-main/PS_data/data/%s/test.mat'%(satellite_str) 83 | 84 | 85 | # if prepare_data_flag is True: 86 | # prepare_data(data_path = '../PS_data/%s'%(satellite_str), 87 | # patch_size=32, aug_times=1, stride=32, synthetic=False, scale=scale, 88 | # file_name = train_path) 89 | data_dir_ms_train = '/home/jieh/Projects/PAN_Sharp/yaogan/WV2_data/train128/ms/' 90 | data_dir_pan_train = '/home/jieh/Projects/PAN_Sharp/yaogan/WV2_data/train128/pan/' 91 | # trainloader = DataLoader(PSH5Datasetfu(train_path), 92 | # batch_size=batch_size, 93 | # shuffle=True) #[N,C,K,H,W] 94 | trainloader = DataLoader(Data(data_dir_ms=data_dir_ms_train, data_dir_pan=data_dir_pan_train), 95 | batch_size=batch_size, 96 | shuffle=True) 97 | 98 | # validationloader = DataLoader(PSH5Datasetfu(validation_path), 99 | # batch_size=1) #[N,C,K,H,W] 100 | data_dir_ms_test = '/home/jieh/Projects/PAN_Sharp/yaogan/WV2_data/test128/ms/' 101 | data_dir_pan_test = '/home/jieh/Projects/PAN_Sharp/yaogan/WV2_data/test128/pan/' 102 | testloader = DataLoader(Data(data_dir_ms=data_dir_ms_test, data_dir_pan=data_dir_pan_test), 103 | batch_size=1) # [N,C,K,H,W] 104 | 105 | # validationloader = DataLoader(PSDataset(validation_path,scale), 106 | # batch_size=1) 107 | # testloader = DataLoader(PSDataset(test_path, scale), 108 | # batch_size=1) 109 | 110 | loader = {'train': trainloader, 111 | 'validation': testloader} 112 | 113 | # . Creat logger 114 | timestamp = datetime.datetime.now().strftime("%m-%d-%H-%M") 115 | save_path = os.path.join( 116 | '/home/jieh/Projects/PAN_Sharp/PansharpingMul/GPPNN/training/logs/%s' % (model_str), 117 | timestamp + '_%s_layer%d_filter_%d' % (satellite_str, n_layer, n_feat) 118 | ) 119 | writer = SummaryWriter(save_path) 120 | params = {'model': model_str, 121 | 'satellite': satellite_str, 122 | 'epoch': num_epochs, 123 | 'lr': lr, 124 | 'batch_size': batch_size, 125 | 'n_feat': n_feat, 126 | 'n_layer': n_layer} 127 | save_param(params, 128 | os.path.join(save_path, 'param.json')) 129 | 130 | ''' 131 | ------------------------------------------------------------------------------ 132 | Train 133 | ------------------------------------------------------------------------------ 134 | ''' 135 | 136 | step = 0 137 | best_psnr_val, psnr_val, ssim_val = 0., 0., 0. 138 | torch.backends.cudnn.benchmark = True 139 | prev_time = time.time() 140 | 141 | def adjust(init, fin, step, fin_step): 142 | if fin_step == 0: 143 | return fin 144 | deta = fin - init 145 | adj = min(init + deta * step / fin_step, fin) 146 | return adj 147 | # 148 | for epoch in range(num_epochs): 149 | ''' train ''' 150 | for i, (ms, pan, gt) in enumerate(loader['train']): 151 | # 0. preprocess data 152 | ms, pan, gt = ms.cuda(), pan.cuda(), gt.cuda() 153 | # ms,_ = normlization(ms.cuda()) 154 | # pan,_ = normlization(pan.cuda()) 155 | # gt,_ = normlization(gt.cuda()) 156 | 157 | # 1. update 158 | net.train() 159 | net.zero_grad() 160 | optimizer.zero_grad() 161 | predHR, panf, mHRf = net(ms, pan) 162 | loss_forward = loss_fn(predHR, gt.detach())*300 163 | latentloss = torch.clip(mutual(panf,mHRf),-1,1) 164 | loss = loss_forward + latentloss * 0.2 * adjust(0, 1, epoch, num_epochs) 165 | loss.backward() 166 | optimizer.step() 167 | 168 | # 2. print 169 | # Determine approximate time left 170 | batches_done = epoch * len(loader['train']) + i 171 | batches_left = num_epochs * len(loader['train']) - batches_done 172 | time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time)) 173 | prev_time = time.time() 174 | sys.stdout.write( 175 | "\r[Epoch %d/%d] [Batch %d/%d] [loss: %f] [forward loss: %f] [PSNR/Best: %.4f/%.4f] ETA: %s" 176 | % ( 177 | epoch, 178 | num_epochs, 179 | i, 180 | len(loader['train']), 181 | loss.item(), 182 | loss_forward.item(), 183 | # loss_backward.item(), 184 | psnr_val, 185 | best_psnr_val, 186 | time_left, 187 | ) 188 | ) 189 | 190 | # 3. Log the scalar values 191 | writer.add_scalar('loss', loss.item(), step) 192 | writer.add_scalar('learning rate', optimizer.state_dict()['param_groups'][0]['lr'], step) 193 | step += 1 194 | 195 | ''' validation ''' 196 | current_psnr_val = psnr_val 197 | # psnr_val = 0. 198 | # ssim_val = 0. 199 | # with torch.no_grad(): 200 | # net.eval() 201 | # for i, (ms, pan, gt) in enumerate(loader['validation']): 202 | # ms, pan, gt = ms.cuda(), pan.cuda(), gt.cuda() 203 | # # ms,_ = normlization(ms.cuda()) 204 | # # pan,_ = normlization(pan.cuda()) 205 | # # gt,_ = normlization(gt.cuda()) 206 | # imgf = net(ms, pan) 207 | # psnr_val += psnr_loss(imgf, gt, 1.) 208 | # ssim_val += ssim(imgf, gt, 11, 'mean', 1.) 209 | # psnr_val = float(psnr_val/loader['validation'].__len__()) 210 | # ssim_val = float(ssim_val/loader['validation'].__len__()) 211 | # writer.add_scalar('PSNR/val', psnr_val, epoch) 212 | # writer.add_scalar('SSIM/val', ssim_val, epoch) 213 | 214 | psnr_val = 0. 215 | ssim_val = 0. 216 | metrics = torch.zeros(2, testloader.__len__()) 217 | with torch.no_grad(): 218 | net.eval() 219 | for i, (ms, pan, gt) in enumerate(testloader): 220 | ms, pan, gt = ms.cuda(), pan.cuda(), gt.cuda() 221 | # ms,_ = normlization(ms.cuda()) 222 | # pan,_ = normlization(pan.cuda()) 223 | # gt,_ = normlization(gt.cuda()) 224 | predHR, _,_ = net(ms, pan) 225 | metrics[:, i] = torch.Tensor(get_metrics_reduced(predHR, gt))[:2] 226 | psnr_val, ssim_val = metrics.mean(dim=1) 227 | writer.add_scalar('PSNR/test', psnr_val, epoch) 228 | writer.add_scalar('SSIM/test', ssim_val, epoch) 229 | 230 | ''' save model ''' 231 | # Save the best weight 232 | if best_psnr_val < psnr_val: 233 | best_psnr_val = psnr_val 234 | torch.save({'net': net.state_dict(), 235 | 'optimizer': optimizer.state_dict(), 236 | 'epoch': epoch}, 237 | os.path.join(save_path, 'best_net.pth')) 238 | # Save the current weight 239 | torch.save({'net': net.state_dict(), 240 | 'optimizer': optimizer.state_dict(), 241 | 'epoch': epoch}, 242 | os.path.join(save_path, 'last_net.pth')) 243 | 244 | ''' backtracking ''' 245 | if epoch > 0: 246 | if torch.isnan(loss): 247 | print(10 * '=' + 'Backtracking!' + 10 * '=') 248 | net.load_state_dict(torch.load(os.path.join(save_path, 'best_net.pth'))['net']) 249 | optimizer.load_state_dict(torch.load(os.path.join(save_path, 'best_net.pth'))['optimizer']) 250 | # 251 | # ''' 252 | # ------------------------------------------------------------------------------ 253 | # Test 254 | # ------------------------------------------------------------------------------ 255 | # ''' 256 | 257 | # 1. Load the best weight and create the dataloader for testing 258 | net.load_state_dict(torch.load(os.path.join(save_path,'best_net.pth'))['net']) 259 | 260 | # 2. Compute the metrics 261 | metrics = torch.zeros(5, testloader.__len__()) 262 | with torch.no_grad(): 263 | net.eval() 264 | for i, (ms, pan, gt) in enumerate(testloader): 265 | ms, pan, gt = ms.cuda(), pan.cuda(), gt.cuda() 266 | # ms,_ = normlization(ms.cuda()) 267 | # pan,_ = normlization(pan.cuda()) 268 | # gt,_ = normlization(gt.cuda()) 269 | predHR, _, _ = net(ms, i, pan) 270 | metrics[:, i] = torch.Tensor(get_metrics_reduced(predHR, gt)) 271 | # savemat(os.path.join(save_path, str(i)), 272 | # {'HR': imgf.squeeze().detach().cpu().numpy()}) 273 | # save_img(save_path, str(i)+'.tiff',predHR) 274 | 275 | list_PSNR = [] 276 | list_SSIM = [] 277 | list_CC = [] 278 | list_SAM = [] 279 | list_ERGAS = [] 280 | for n in range(testloader.__len__()): 281 | list_PSNR.append(metrics[0, n]) 282 | list_SSIM.append(metrics[1, n]) 283 | list_CC.append(metrics[2, n]) 284 | list_SAM.append(metrics[3, n]) 285 | list_ERGAS.append(metrics[4, n]) 286 | 287 | print("list_psnr_mean:", np.mean(list_PSNR)) 288 | print("list_ssim_mean:", np.mean(list_SSIM)) 289 | print("list_cc_mean:", np.mean(list_CC)) 290 | print("list_sam_mean:", np.mean(list_SAM)) 291 | print("list_ergas_mean:", np.mean(list_ERGAS)) 292 | 293 | # 3. Write the metrics 294 | # f = xlwt.Workbook() 295 | # sheet1 = f.add_sheet(u'sheet1',cell_overwrite_ok=True) 296 | # img_name = [i.split('\\')[-1].replace('.mat','') for i in testloader.dataset.files] 297 | # metric_name = ['PSNR','SSIM','CC','SAM','ERGAS'] 298 | # for i in range(len(metric_name)): 299 | # sheet1.write(i+1,0,metric_name[i]) 300 | # for j in range(len(img_name)): 301 | # sheet1.write(0,j+1,img_name[j]) 302 | # for i in range(len(metric_name)): 303 | # for j in range(len(img_name)): 304 | # sheet1.write(i+1,j+1,float(metrics[i,j])) 305 | # sheet1.write(0,len(img_name)+1,'Mean') 306 | # for i in range(len(metric_name)): 307 | # sheet1.write(i+1,len(img_name)+1,float(metrics.mean(1)[i])) 308 | # f.save(os.path.join(save_path,'test_result.xls')) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # ---------------------------------------------------------------------------- 4 | # Misc 5 | # ---------------------------------------------------------------------------- 6 | 7 | import os 8 | import json 9 | 10 | def mkdir(path): 11 | if os.path.exists(path) is False: 12 | os.makedirs(path) 13 | 14 | def save_param(input_dict, path): 15 | f = open(path, 'w') 16 | f.write(json.dumps(input_dict)) 17 | f.close() 18 | print("Hyper-Parameters have been saved!") 19 | 20 | 21 | # ---------------------------------------------------------------------------- 22 | # Dataset & Image Processing 23 | # ---------------------------------------------------------------------------- 24 | 25 | import os 26 | import h5py 27 | import torch 28 | 29 | from glob import glob 30 | import numpy as np 31 | import torch.utils.data as Data 32 | 33 | from scipy.io import loadmat 34 | 35 | 36 | def normlization(x): 37 | # x [N,C,H,W] 38 | N,C,H,W = x.shape 39 | m = [] 40 | for i in range(N): 41 | m.append(torch.max(x[i,:,:,:])) 42 | m = torch.stack(m, dim=0)[:,None,None,None] 43 | m = m+1e-10 44 | x = x/m 45 | return x,m 46 | 47 | def inverse_normlization(x, m): 48 | return x*m 49 | 50 | def im2double(img): 51 | if img.dtype=='uint8': 52 | img = img.astype(np.float32)/255. 53 | elif img.dtype=='uint16': 54 | img = img.astype(np.float32)/65535. 55 | else: 56 | img = img.astype(np.float32) 57 | return img 58 | 59 | def Im2Patch(img, win, stride=1): 60 | k = 0 61 | endc = img.shape[0] 62 | endw = img.shape[1] 63 | endh = img.shape[2] 64 | patch = img[:, 0:endw-win+0+1:stride, 0:endh-win+0+1:stride] 65 | TotalPatNum = patch.shape[1] * patch.shape[2] 66 | Y = np.zeros([endc, win*win,TotalPatNum], np.float32) 67 | for i in range(win): 68 | for j in range(win): 69 | patch = img[:,i:endw-win+i+1:stride,j:endh-win+j+1:stride] 70 | Y[:,k,:] = np.array(patch[:]).reshape(endc, TotalPatNum) 71 | k = k + 1 72 | return Y.reshape([endc, win, win, TotalPatNum]) 73 | 74 | def imresize(img, size=None, scale_factor=None): 75 | # img (np.array) - [C,H,W] 76 | imgT = torch.from_numpy(img).unsqueeze(0) #[1,C,H,W] 77 | if size is None and scale_factor is not None: 78 | imgT = torch.nn.functional.interpolate(imgT, scale_factor=scale_factor) 79 | elif size is not None and scale_factor is None: 80 | imgT = torch.nn.functional.interpolate(imgT, size=size) 81 | else: 82 | print('Neither size nor scale_factor is given.') 83 | imgT = imgT.squeeze(0).numpy() 84 | return imgT 85 | 86 | 87 | def prepare_data(data_path, 88 | patch_size, 89 | aug_times=4, 90 | stride=25, 91 | synthetic=True, 92 | scale=2, 93 | file_name='train.h5' 94 | ): 95 | # patch_size : the window size of low-resolution images 96 | # scale : the spatial ratio between low-resolution and guide images 97 | # train 98 | print('process training data') 99 | files = glob(os.path.join(data_path, 'train', '*.mat')) 100 | h5f = h5py.File(file_name, 'w') 101 | h5gt = h5f.create_group('GT') 102 | h5guide = h5f.create_group('PAN') 103 | h5lr = h5f.create_group('MS') 104 | train_num = 0 105 | for i in range(len(files)): 106 | img = loadmat(files[i]) 107 | lr = img['I_MS'].astype('float32') # [Height, Width, Channels] 108 | guide = img['I_PAN'].astype('float32') # [Height, Width] 109 | # print([lr.shape,guide.shape]) 110 | lr = np.transpose(lr, [2,0,1]) # [Channels, Height, Width] 111 | guide = guide[None,:,:] # [1, Height, Width] 112 | 113 | if synthetic: 114 | # if synthetic is True: the spatial resolutions of lr and guide are the same 115 | lr_patches = Im2Patch(lr, win=scale*patch_size, stride=stride) #[C,H,W,N] 116 | guide_patches = Im2Patch(guide, win=scale*patch_size, stride=stride) 117 | else: 118 | scale = int(guide.shape[-1]/lr.shape[-1]) 119 | # print(scale) 120 | guide = imresize(guide, size=lr.shape[1:]) 121 | lr_patches = Im2Patch(lr, win=scale*patch_size, stride=stride) #[C,H,W,N] 122 | guide_patches = Im2Patch(guide, win=scale*patch_size, stride=stride) 123 | 124 | print("file: %s # samples: %d" % (files[i], lr_patches.shape[3]*aug_times)) 125 | for n in range(lr_patches.shape[3]): 126 | gt_data = lr_patches[:,:,:,n].copy() 127 | guide_data = guide_patches[:,:,:,n].copy() 128 | lr_data = imresize(gt_data, scale_factor=1/scale) 129 | 130 | h5gt.create_dataset(str(train_num), 131 | data=gt_data, dtype=gt_data.dtype,shape=gt_data.shape) 132 | h5guide.create_dataset(str(train_num), 133 | data=guide_data, dtype=guide_data.dtype,shape=guide_data.shape) 134 | h5lr.create_dataset(str(train_num), 135 | data=lr_data, dtype=lr_data.dtype,shape=lr_data.shape) 136 | train_num += 1 137 | for m in range(aug_times-1): 138 | gt_data_aug = np.rot90(gt_data, m+1, axes=(1,2)) 139 | guide_data_aug = np.rot90(guide_data, m+1, axes=(1,2)) 140 | lr_data_aug = np.rot90(lr_data, m+1, axes=(1,2)) 141 | 142 | h5gt.create_dataset(str(train_num)+"_aug_%d" % (m+1), 143 | data=gt_data_aug, dtype=gt_data_aug.dtype,shape=gt_data_aug.shape) 144 | h5guide.create_dataset(str(train_num)+"_aug_%d" % (m+1), 145 | data=guide_data_aug, dtype=guide_data_aug.dtype,shape=guide_data_aug.shape) 146 | h5lr.create_dataset(str(train_num)+"_aug_%d" % (m+1), 147 | data=lr_data_aug, dtype=lr_data_aug.dtype,shape=lr_data_aug.shape) 148 | train_num += 1 149 | h5f.close() 150 | print('training set, # samples %d\n' % train_num) 151 | 152 | class CaveH5Dataset(Data.Dataset): 153 | def __init__(self, h5file_path): 154 | self.h5file_path = h5file_path 155 | h5f = h5py.File(h5file_path, 'r') 156 | self.keys = list(h5f['Guide'].keys()) 157 | h5f.close() 158 | 159 | def __len__(self): 160 | return len(self.keys) 161 | 162 | def __getitem__(self, index): 163 | h5f = h5py.File(self.h5file_path, 'r') 164 | key = self.keys[index] 165 | guide = np.array(h5f['Guide'][key]) 166 | gt = np.array(h5f['GT'][key]) 167 | lr = np.array(h5f['LR'][key]) 168 | h5f.close() 169 | return torch.Tensor(lr),torch.Tensor(guide),torch.Tensor(gt) 170 | 171 | class CaveDataset(Data.Dataset): 172 | def __init__(self, root, scale): 173 | self.scale = scale 174 | self.root = root 175 | self.files = glob(root+'/*.mat') 176 | 177 | def __len__(self): 178 | return len(self.files) 179 | 180 | def __getitem__(self, index): 181 | temp = h5py.File(self.files[index]) 182 | guide = im2double(temp['Guide'][:]) 183 | gt = im2double(temp['LR'][:]) 184 | lr = imresize(gt, scale_factor=1/self.scale) 185 | del temp 186 | return lr, guide, gt 187 | 188 | class PSDataset(Data.Dataset): 189 | def __init__(self, root, scale): 190 | self.scale = scale 191 | self.root = root 192 | self.files = glob(root+'/*.mat') 193 | 194 | def __len__(self): 195 | return len(self.files) 196 | 197 | def __getitem__(self, index): 198 | temp = loadmat(self.files[index]) 199 | gt = np.transpose(temp['I_MS'].astype('float32'), [2,0,1]) 200 | pan = temp['I_PAN'].astype('float32')[None,:,:] 201 | ms = imresize(gt, scale_factor=1/self.scale) 202 | pan = imresize(pan, scale_factor=1/self.scale) 203 | del temp 204 | return ms, pan, gt 205 | 206 | class PSH5Dataset(Data.Dataset): 207 | def __init__(self, h5file_path): 208 | self.h5file_path = h5file_path 209 | h5f = h5py.File(h5file_path, 'r') 210 | self.keys = list(h5f['PAN'].keys()) 211 | h5f.close() 212 | 213 | def __len__(self): 214 | return len(self.keys) 215 | 216 | def __getitem__(self, index): 217 | h5f = h5py.File(self.h5file_path, 'r') 218 | key = self.keys[index] 219 | pan = np.array(h5f['PAN'][key]) 220 | gt = np.array(h5f['GT'][key]) 221 | ms = np.array(h5f['MS'][key]) 222 | h5f.close() 223 | return torch.Tensor(ms),torch.Tensor(pan),torch.Tensor(gt) 224 | 225 | 226 | class PSH5Datasetfu(Data.Dataset): 227 | def __init__(self, h5file_path): 228 | self.h5file_path = h5file_path 229 | h5f = h5py.File(self.h5file_path, 'r') 230 | gt = h5f['gt'][...] ## ground truth N*H*W*C 231 | pan = h5f['pan'][...] #### Pan image N*H*W 232 | ms = h5f['ms'][...] ### low resolution MS image 233 | self.N = gt.shape[0] 234 | 235 | self.gt = np.array(gt, dtype=np.float32) / 2047. ### normalization, WorldView L = 11 236 | self.pan = np.array(pan, dtype=np.float32) / 2047. 237 | self.ms = np.array(ms, dtype=np.float32) / 2047. 238 | 239 | h5f.close() 240 | 241 | def __len__(self): 242 | return self.N 243 | 244 | def __getitem__(self, index): 245 | 246 | train_gt = self.gt[index, :, :, :] 247 | #train_gt = train_gt[np.newaxis,:,:,:] 248 | train_gt = np.transpose(train_gt, (2, 0, 1)) 249 | 250 | train_pan = self.pan[index, :, :] 251 | 252 | train_pan = train_pan[:, :, np.newaxis] # expand to N*H*W*1; new added! 253 | train_pan = np.transpose(train_pan, (2, 0, 1)) 254 | 255 | train_ms = self.ms[index, :, :, :] 256 | #train_ms = train_ms[np.newaxis, :, :, :] 257 | train_ms = np.transpose(train_ms, (2, 0, 1)) 258 | 259 | 260 | return torch.Tensor(train_ms), torch.Tensor(train_pan), torch.Tensor(train_gt) 261 | 262 | 263 | # ---------------------------------------------------------------------------- 264 | # Attention 265 | # ---------------------------------------------------------------------------- 266 | import torch 267 | import math 268 | import torch.nn as nn 269 | import torch.nn.functional as F 270 | 271 | class BasicConv(nn.Module): 272 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=False, bias=False): 273 | super(BasicConv, self).__init__() 274 | self.out_channels = out_planes 275 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) 276 | self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None 277 | self.relu = nn.ReLU() if relu else None 278 | 279 | def forward(self, x): 280 | x = self.conv(x) 281 | if self.bn is not None: 282 | x = self.bn(x) 283 | if self.relu is not None: 284 | x = self.relu(x) 285 | return x 286 | 287 | class Flatten(nn.Module): 288 | def forward(self, x): 289 | return x.view(x.size(0), -1) 290 | 291 | class ChannelGate(nn.Module): 292 | def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']): 293 | super(ChannelGate, self).__init__() 294 | self.gate_channels = gate_channels 295 | self.mlp = nn.Sequential( 296 | Flatten(), 297 | nn.Linear(gate_channels, gate_channels // reduction_ratio), 298 | nn.ReLU(), 299 | nn.Linear(gate_channels // reduction_ratio, gate_channels) 300 | ) 301 | self.pool_types = pool_types 302 | def forward(self, x): 303 | channel_att_sum = None 304 | for pool_type in self.pool_types: 305 | if pool_type=='avg': 306 | avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) 307 | channel_att_raw = self.mlp( avg_pool ) 308 | elif pool_type=='max': 309 | max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) 310 | channel_att_raw = self.mlp( max_pool ) 311 | elif pool_type=='lp': 312 | lp_pool = F.lp_pool2d( x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) 313 | channel_att_raw = self.mlp( lp_pool ) 314 | elif pool_type=='lse': 315 | # LSE pool only 316 | lse_pool = logsumexp_2d(x) 317 | channel_att_raw = self.mlp( lse_pool ) 318 | 319 | if channel_att_sum is None: 320 | channel_att_sum = channel_att_raw 321 | else: 322 | channel_att_sum = channel_att_sum + channel_att_raw 323 | 324 | scale = torch.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x) 325 | return x * scale 326 | 327 | def logsumexp_2d(tensor): 328 | tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1) 329 | s, _ = torch.max(tensor_flatten, dim=2, keepdim=True) 330 | outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log() 331 | return outputs 332 | 333 | class ChannelPool(nn.Module): 334 | def forward(self, x): 335 | return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 ) 336 | 337 | class SpatialGate(nn.Module): 338 | def __init__(self): 339 | super(SpatialGate, self).__init__() 340 | kernel_size = 7 341 | self.compress = ChannelPool() 342 | self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False) 343 | def forward(self, x): 344 | x_compress = self.compress(x) 345 | x_out = self.spatial(x_compress) 346 | scale = torch.sigmoid(x_out) # broadcasting 347 | return x * scale 348 | 349 | class CBAM(nn.Module): 350 | def __init__(self, gate_channels, reduction_ratio=2, pool_types=['avg', 'max'], no_spatial=False, no_channel=True): 351 | super(CBAM, self).__init__() 352 | self.no_channel = no_channel 353 | if not no_channel: 354 | self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types) 355 | self.no_spatial=no_spatial 356 | if not no_spatial: 357 | self.SpatialGate = SpatialGate() 358 | def forward(self, x): 359 | if not self.no_channel: 360 | x = self.ChannelGate(x) 361 | if not self.no_spatial: 362 | x = self.SpatialGate(x) 363 | return x 364 | 365 | # ---------------------------------------------------------------------------- 366 | # Losses 367 | # ---------------------------------------------------------------------------- 368 | import torch 369 | import torch.nn as nn 370 | 371 | eps = torch.finfo(torch.float32).eps 372 | 373 | # RAP loss 374 | class RAP(nn.Module): 375 | def __init__(self, lap_weight=1, angle_weight=1): 376 | super(RAP, self).__init__() 377 | self.lap_weight = lap_weight 378 | self.angle_weight = angle_weight 379 | 380 | def forward(self, img, gt): 381 | return nn.L1Loss()(img, gt) + self.lap_weight * lap_loss(img,gt) + self.angle_weight * sam(img, gt) 382 | 383 | def lap_loss(img, gt): 384 | img = laplacian(img, 3) 385 | gt = laplacian(gt, 3) 386 | return nn.L1Loss()(img, gt) 387 | 388 | def rmse(img, gt): 389 | """RMSE for (N, C, H, W) image; torch.float32 [0.,1.].""" 390 | N,C,_,_ = img.shape 391 | img = torch.reshape(img, [N,C,-1]) 392 | gt = torch.reshape(gt, [N,C,-1]) 393 | mse = (img-gt).pow(2).sum(dim=-1) 394 | rmse = mse/(gt.pow(2).sum(dim=-1)+eps) 395 | rmse = rmse.mean(dim=-1).sqrt() 396 | return rmse.mean() 397 | 398 | def sam(img1, img2): 399 | """SAM for (N, C, H, W) image; torch.float32 [0.,1.].""" 400 | inner_product = (img1 * img2).sum(dim=1) 401 | img1_spectral_norm = torch.sqrt((img1**2).sum(dim=1)) 402 | img2_spectral_norm = torch.sqrt((img2**2).sum(dim=1)) 403 | # numerical stability 404 | cos_theta = (inner_product / (img1_spectral_norm * img2_spectral_norm + eps)).clamp(min=0, max=1) 405 | cos_theta = cos_theta.reshape(cos_theta.shape[0], -1) 406 | return torch.mean(torch.acos(cos_theta), dim=-1).mean() 407 | 408 | # ---------------------------------------------------------------------------- 409 | # Guided Filter 410 | # ---------------------------------------------------------------------------- 411 | ''' 412 | This code is written by Huikai Wu. Original code is available at 413 | https://github.com/wuhuikai/DeepGuidedFilter/tree/master/GuidedFilteringLayer/GuidedFilter_PyTorch 414 | 415 | Please cite: 416 | Fast End-to-End Trainable Guided Filter 417 | Huikai Wu, Shuai Zheng, Junge Zhang, Kaiqi Huang 418 | CVPR 2018 419 | ''' 420 | 421 | import torch 422 | from torch import nn 423 | from torch.nn import functional as F 424 | from torch.autograd import Variable 425 | 426 | def diff_x(input, r): 427 | assert input.dim() == 4 428 | 429 | left = input[:, :, r:2 * r + 1] 430 | middle = input[:, :, 2 * r + 1: ] - input[:, :, :-2 * r - 1] 431 | right = input[:, :, -1: ] - input[:, :, -2 * r - 1: -r - 1] 432 | 433 | output = torch.cat([left, middle, right], dim=2) 434 | 435 | return output 436 | 437 | def diff_y(input, r): 438 | assert input.dim() == 4 439 | 440 | left = input[:, :, :, r:2 * r + 1] 441 | middle = input[:, :, :, 2 * r + 1: ] - input[:, :, :, :-2 * r - 1] 442 | right = input[:, :, :, -1: ] - input[:, :, :, -2 * r - 1: -r - 1] 443 | 444 | output = torch.cat([left, middle, right], dim=3) 445 | 446 | return output 447 | 448 | class BoxFilter(nn.Module): 449 | def __init__(self, r): 450 | super(BoxFilter, self).__init__() 451 | 452 | self.r = r 453 | 454 | def forward(self, x): 455 | assert x.dim() == 4 456 | 457 | return diff_y(diff_x(x.cumsum(dim=2), self.r).cumsum(dim=3), self.r) 458 | 459 | class FastGuidedFilter(nn.Module): 460 | def __init__(self, r, eps=1e-8): 461 | super(FastGuidedFilter, self).__init__() 462 | 463 | self.r = r 464 | self.eps = eps 465 | self.boxfilter = BoxFilter(r) 466 | 467 | 468 | def forward(self, lr_x, lr_y, hr_x): 469 | n_lrx, c_lrx, h_lrx, w_lrx = lr_x.size() 470 | n_lry, c_lry, h_lry, w_lry = lr_y.size() 471 | n_hrx, c_hrx, h_hrx, w_hrx = hr_x.size() 472 | 473 | assert n_lrx == n_lry and n_lry == n_hrx 474 | assert c_lrx == c_hrx and (c_lrx == 1 or c_lrx == c_lry) 475 | assert h_lrx == h_lry and w_lrx == w_lry 476 | assert h_lrx > 2*self.r+1 and w_lrx > 2*self.r+1 477 | 478 | ## N 479 | N = self.boxfilter(Variable(lr_x.data.new().resize_((1, 1, h_lrx, w_lrx)).fill_(1.0))) 480 | 481 | ## mean_x 482 | mean_x = self.boxfilter(lr_x) / N 483 | ## mean_y 484 | mean_y = self.boxfilter(lr_y) / N 485 | ## cov_xy 486 | cov_xy = self.boxfilter(lr_x * lr_y) / N - mean_x * mean_y 487 | ## var_x 488 | var_x = self.boxfilter(lr_x * lr_x) / N - mean_x * mean_x 489 | 490 | ## A 491 | A = cov_xy / (var_x + self.eps) 492 | ## b 493 | b = mean_y - A * mean_x 494 | 495 | ## mean_A; mean_b 496 | mean_A = F.interpolate(A, (h_hrx, w_hrx), mode='bilinear', align_corners=True) 497 | mean_b = F.interpolate(b, (h_hrx, w_hrx), mode='bilinear', align_corners=True) 498 | 499 | return mean_A*hr_x+mean_b 500 | 501 | 502 | class GuidedFilter(nn.Module): 503 | def __init__(self, r, eps=1e-8): 504 | super(GuidedFilter, self).__init__() 505 | 506 | self.r = r 507 | self.eps = eps 508 | self.boxfilter = BoxFilter(r) 509 | 510 | 511 | def forward(self, x, y): 512 | n_x, c_x, h_x, w_x = x.size() 513 | n_y, c_y, h_y, w_y = y.size() 514 | 515 | assert n_x == n_y 516 | assert c_x == 1 or c_x == c_y 517 | assert h_x == h_y and w_x == w_y 518 | assert h_x > 2 * self.r + 1 and w_x > 2 * self.r + 1 519 | 520 | # N 521 | N = self.boxfilter(Variable(x.data.new().resize_((1, 1, h_x, w_x)).fill_(1.0))) 522 | 523 | # mean_x 524 | mean_x = self.boxfilter(x) / N 525 | # mean_y 526 | mean_y = self.boxfilter(y) / N 527 | # cov_xy 528 | cov_xy = self.boxfilter(x * y) / N - mean_x * mean_y 529 | # var_x 530 | var_x = self.boxfilter(x * x) / N - mean_x * mean_x 531 | 532 | # A 533 | A = cov_xy / (var_x + self.eps) 534 | # b 535 | b = mean_y - A * mean_x 536 | 537 | # mean_A; mean_b 538 | mean_A = self.boxfilter(A) / N 539 | mean_b = self.boxfilter(b) / N 540 | 541 | return mean_A * x + mean_b 542 | 543 | 544 | 545 | 546 | 547 | # ---------------------------------------------------------------------------- 548 | # Kornia 549 | # ---------------------------------------------------------------------------- 550 | 551 | import torch 552 | import torch.nn as nn 553 | import torch.nn.functional as F 554 | from torch.nn.functional import mse_loss 555 | from typing import Tuple, List 556 | 557 | def compute_padding(kernel_size: Tuple[int, int]) -> List[int]: 558 | """Computes padding tuple.""" 559 | # 4 ints: (padding_left, padding_right,padding_top,padding_bottom) 560 | # https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad 561 | assert len(kernel_size) == 2, kernel_size 562 | computed = [(k - 1) // 2 for k in kernel_size] 563 | return [computed[1], computed[1], computed[0], computed[0]] 564 | 565 | 566 | def filter2D(input: torch.Tensor, kernel: torch.Tensor, 567 | border_type: str = 'reflect', 568 | normalized: bool = False) -> torch.Tensor: 569 | r"""Function that convolves a tensor with a kernel. 570 | 571 | The function applies a given kernel to a tensor. The kernel is applied 572 | independently at each depth channel of the tensor. Before applying the 573 | kernel, the function applies padding according to the specified mode so 574 | that the output remains in the same shape. 575 | 576 | Args: 577 | input (torch.Tensor): the input tensor with shape of 578 | :math:`(B, C, H, W)`. 579 | kernel (torch.Tensor): the kernel to be convolved with the input 580 | tensor. The kernel shape must be :math:`(1, kH, kW)`. 581 | border_type (str): the padding mode to be applied before convolving. 582 | The expected modes are: ``'constant'``, ``'reflect'``, 583 | ``'replicate'`` or ``'circular'``. Default: ``'reflect'``. 584 | normalized (bool): If True, kernel will be L1 normalized. 585 | 586 | Return: 587 | torch.Tensor: the convolved tensor of same size and numbers of channels 588 | as the input. 589 | """ 590 | if not isinstance(input, torch.Tensor): 591 | raise TypeError("Input type is not a torch.Tensor. Got {}" 592 | .format(type(input))) 593 | 594 | if not isinstance(kernel, torch.Tensor): 595 | raise TypeError("Input kernel type is not a torch.Tensor. Got {}" 596 | .format(type(kernel))) 597 | 598 | if not isinstance(border_type, str): 599 | raise TypeError("Input border_type is not string. Got {}" 600 | .format(type(kernel))) 601 | 602 | if not len(input.shape) == 4: 603 | raise ValueError("Invalid input shape, we expect BxCxHxW. Got: {}" 604 | .format(input.shape)) 605 | 606 | if not len(kernel.shape) == 3: 607 | raise ValueError("Invalid kernel shape, we expect 1xHxW. Got: {}" 608 | .format(kernel.shape)) 609 | 610 | borders_list: List[str] = ['constant', 'reflect', 'replicate', 'circular'] 611 | if border_type not in borders_list: 612 | raise ValueError("Invalid border_type, we expect the following: {0}." 613 | "Got: {1}".format(borders_list, border_type)) 614 | 615 | # prepare kernel 616 | b, c, h, w = input.shape 617 | tmp_kernel: torch.Tensor = kernel.to(input.device).to(input.dtype) 618 | tmp_kernel = tmp_kernel.repeat(c, 1, 1, 1) 619 | if normalized: 620 | tmp_kernel = normalize_kernel2d(tmp_kernel) 621 | # pad the input tensor 622 | height, width = tmp_kernel.shape[-2:] 623 | padding_shape: List[int] = compute_padding((height, width)) 624 | input_pad: torch.Tensor = F.pad(input, padding_shape, mode=border_type) 625 | 626 | # convolve the tensor with the kernel 627 | return F.conv2d(input_pad, tmp_kernel, padding=0, stride=1, groups=c) 628 | 629 | def normalize_kernel2d(input: torch.Tensor) -> torch.Tensor: 630 | r"""Normalizes both derivative and smoothing kernel. 631 | """ 632 | if len(input.size()) < 2: 633 | raise TypeError("input should be at least 2D tensor. Got {}" 634 | .format(input.size())) 635 | norm: torch.Tensor = input.abs().sum(dim=-1).sum(dim=-1) 636 | return input / (norm.unsqueeze(-1).unsqueeze(-1)) 637 | 638 | def get_box_kernel2d(kernel_size: Tuple[int, int]) -> torch.Tensor: 639 | r"""Utility function that returns a box filter.""" 640 | kx: float = float(kernel_size[0]) 641 | ky: float = float(kernel_size[1]) 642 | scale: torch.Tensor = torch.tensor(1.) / torch.tensor([kx * ky]) 643 | tmp_kernel: torch.Tensor = torch.ones(1, kernel_size[0], kernel_size[1]) 644 | return scale.to(tmp_kernel.dtype) * tmp_kernel 645 | 646 | class BoxBlur(nn.Module): 647 | r"""Blurs an image using the box filter. 648 | 649 | The function smooths an image using the kernel: 650 | 651 | .. math:: 652 | K = \frac{1}{\text{kernel_size}_x * \text{kernel_size}_y} 653 | \begin{bmatrix} 654 | 1 & 1 & 1 & \cdots & 1 & 1 \\ 655 | 1 & 1 & 1 & \cdots & 1 & 1 \\ 656 | \vdots & \vdots & \vdots & \ddots & \vdots & \vdots \\ 657 | 1 & 1 & 1 & \cdots & 1 & 1 \\ 658 | \end{bmatrix} 659 | 660 | Args: 661 | kernel_size (Tuple[int, int]): the blurring kernel size. 662 | border_type (str): the padding mode to be applied before convolving. 663 | The expected modes are: ``'constant'``, ``'reflect'``, 664 | ``'replicate'`` or ``'circular'``. Default: ``'reflect'``. 665 | normalized (bool): if True, L1 norm of the kernel is set to 1. 666 | 667 | Returns: 668 | torch.Tensor: the blurred input tensor. 669 | 670 | Shape: 671 | - Input: :math:`(B, C, H, W)` 672 | - Output: :math:`(B, C, H, W)` 673 | 674 | Example: 675 | >>> input = torch.rand(2, 4, 5, 7) 676 | >>> blur = kornia.filters.BoxBlur((3, 3)) 677 | >>> output = blur(input) # 2x4x5x7 678 | """ 679 | 680 | def __init__(self, kernel_size: Tuple[int, int], 681 | border_type: str = 'reflect', 682 | normalized: bool = True) -> None: 683 | super(BoxBlur, self).__init__() 684 | self.kernel_size: Tuple[int, int] = kernel_size 685 | self.border_type: str = border_type 686 | self.kernel: torch.Tensor = get_box_kernel2d(kernel_size) 687 | self.normalized: bool = normalized 688 | if self.normalized: 689 | self.kernel = normalize_kernel2d(self.kernel) 690 | 691 | def __repr__(self) -> str: 692 | return self.__class__.__name__ +\ 693 | '(kernel_size=' + str(self.kernel_size) + ', ' +\ 694 | 'normalized=' + str(self.normalized) + ', ' + \ 695 | 'border_type=' + self.border_type + ')' 696 | 697 | def forward(self, input: torch.Tensor): # type: ignore 698 | return filter2D(input, self.kernel, self.border_type) 699 | 700 | # functiona api 701 | def box_blur(input: torch.Tensor, 702 | kernel_size: Tuple[int, int], 703 | border_type: str = 'reflect', 704 | normalized: bool = True) -> torch.Tensor: 705 | r"""Blurs an image using the box filter. 706 | 707 | See :class:`~kornia.filters.BoxBlur` for details. 708 | """ 709 | return BoxBlur(kernel_size, border_type, normalized)(input) 710 | 711 | class PSNRLoss(nn.Module): 712 | r"""Creates a criterion that calculates the PSNR between 2 images. Given an m x n image, 713 | .. math:: 714 | \text{MSE}(I,T) = \frac{1}{m\,n}\sum_{i=0}^{m-1}\sum_{j=0}^{n-1} [I(i,j) - T(i,j)]^2 715 | 716 | Arguments: 717 | max_val (float): Maximum value of input 718 | 719 | Shape: 720 | - input: :math:`(*)` 721 | - approximation: :math:`(*)` same shape as input 722 | - output: :math:`()` a scalar 723 | 724 | Examples: 725 | >>> kornia.losses.psnr(torch.ones(1), 1.2*torch.ones(1), 2) 726 | tensor(20.0000) # 10 * log(4/((1.2-1)**2)) / log(10) 727 | 728 | reference: 729 | https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio#Definition 730 | """ 731 | 732 | def __init__(self, max_val: float) -> None: 733 | super(PSNRLoss, self).__init__() 734 | self.max_val: float = max_val 735 | 736 | def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: # type: ignore 737 | return psnr_loss(input, target, self.max_val) 738 | 739 | 740 | def psnr_loss(input: torch.Tensor, target: torch.Tensor, max_val: float) -> torch.Tensor: 741 | r"""Function that computes PSNR 742 | 743 | See :class:`~kornia.losses.PSNR` for details. 744 | """ 745 | if not torch.is_tensor(input) or not torch.is_tensor(target): 746 | raise TypeError(f"Expected 2 torch tensors but got {type(input)} and {type(target)}") 747 | 748 | if input.shape != target.shape: 749 | raise TypeError(f"Expected tensors of equal shapes, but got {input.shape} and {target.shape}") 750 | mse_val = mse_loss(input, target, reduction='mean') 751 | max_val_tensor: torch.Tensor = torch.tensor(max_val).to(input.device).to(input.dtype) 752 | return 10 * torch.log10(max_val_tensor * max_val_tensor / mse_val) 753 | 754 | 755 | def _compute_zero_padding(kernel_size: int) -> int: 756 | """Computes zero padding.""" 757 | return (kernel_size - 1) // 2 758 | 759 | def gaussian(window_size, sigma): 760 | x = torch.arange(window_size).float() - window_size // 2 761 | if window_size % 2 == 0: 762 | x = x + 0.5 763 | gauss = torch.exp((-x.pow(2.0) / float(2 * sigma ** 2))) 764 | return gauss / gauss.sum() 765 | 766 | def get_gaussian_kernel1d(kernel_size: int, 767 | sigma: float, 768 | force_even: bool = False) -> torch.Tensor: 769 | r"""Function that returns Gaussian filter coefficients. 770 | 771 | Args: 772 | kernel_size (int): filter size. It should be odd and positive. 773 | sigma (float): gaussian standard deviation. 774 | force_even (bool): overrides requirement for odd kernel size. 775 | 776 | Returns: 777 | Tensor: 1D tensor with gaussian filter coefficients. 778 | 779 | Shape: 780 | - Output: :math:`(\text{kernel_size})` 781 | 782 | Examples:: 783 | 784 | >>> kornia.image.get_gaussian_kernel(3, 2.5) 785 | tensor([0.3243, 0.3513, 0.3243]) 786 | 787 | >>> kornia.image.get_gaussian_kernel(5, 1.5) 788 | tensor([0.1201, 0.2339, 0.2921, 0.2339, 0.1201]) 789 | """ 790 | if (not isinstance(kernel_size, int) or ( 791 | (kernel_size % 2 == 0) and not force_even) or ( 792 | kernel_size <= 0)): 793 | raise TypeError( 794 | "kernel_size must be an odd positive integer. " 795 | "Got {}".format(kernel_size) 796 | ) 797 | window_1d: torch.Tensor = gaussian(kernel_size, sigma) 798 | return window_1d 799 | 800 | def get_gaussian_kernel2d( 801 | kernel_size: Tuple[int, int], 802 | sigma: Tuple[float, float], 803 | force_even: bool = False) -> torch.Tensor: 804 | r"""Function that returns Gaussian filter matrix coefficients. 805 | 806 | Args: 807 | kernel_size (Tuple[int, int]): filter sizes in the x and y direction. 808 | Sizes should be odd and positive. 809 | sigma (Tuple[int, int]): gaussian standard deviation in the x and y 810 | direction. 811 | force_even (bool): overrides requirement for odd kernel size. 812 | 813 | Returns: 814 | Tensor: 2D tensor with gaussian filter matrix coefficients. 815 | 816 | Shape: 817 | - Output: :math:`(\text{kernel_size}_x, \text{kernel_size}_y)` 818 | 819 | Examples:: 820 | 821 | >>> kornia.image.get_gaussian_kernel2d((3, 3), (1.5, 1.5)) 822 | tensor([[0.0947, 0.1183, 0.0947], 823 | [0.1183, 0.1478, 0.1183], 824 | [0.0947, 0.1183, 0.0947]]) 825 | 826 | >>> kornia.image.get_gaussian_kernel2d((3, 5), (1.5, 1.5)) 827 | tensor([[0.0370, 0.0720, 0.0899, 0.0720, 0.0370], 828 | [0.0462, 0.0899, 0.1123, 0.0899, 0.0462], 829 | [0.0370, 0.0720, 0.0899, 0.0720, 0.0370]]) 830 | """ 831 | if not isinstance(kernel_size, tuple) or len(kernel_size) != 2: 832 | raise TypeError( 833 | "kernel_size must be a tuple of length two. Got {}".format( 834 | kernel_size 835 | ) 836 | ) 837 | if not isinstance(sigma, tuple) or len(sigma) != 2: 838 | raise TypeError( 839 | "sigma must be a tuple of length two. Got {}".format(sigma) 840 | ) 841 | ksize_x, ksize_y = kernel_size 842 | sigma_x, sigma_y = sigma 843 | kernel_x: torch.Tensor = get_gaussian_kernel1d(ksize_x, sigma_x, force_even) 844 | kernel_y: torch.Tensor = get_gaussian_kernel1d(ksize_y, sigma_y, force_even) 845 | kernel_2d: torch.Tensor = torch.matmul( 846 | kernel_x.unsqueeze(-1), kernel_y.unsqueeze(-1).t() 847 | ) 848 | return kernel_2d 849 | 850 | class SSIM(nn.Module): 851 | r"""Creates a criterion that measures the Structural Similarity (SSIM) 852 | index between each element in the input `x` and target `y`. 853 | 854 | The index can be described as: 855 | 856 | .. math:: 857 | 858 | \text{SSIM}(x, y) = \frac{(2\mu_x\mu_y+c_1)(2\sigma_{xy}+c_2)} 859 | {(\mu_x^2+\mu_y^2+c_1)(\sigma_x^2+\sigma_y^2+c_2)} 860 | 861 | where: 862 | - :math:`c_1=(k_1 L)^2` and :math:`c_2=(k_2 L)^2` are two variables to 863 | stabilize the division with weak denominator. 864 | - :math:`L` is the dynamic range of the pixel-values (typically this is 865 | :math:`2^{\#\text{bits per pixel}}-1`). 866 | 867 | the loss, or the Structural dissimilarity (DSSIM) can be finally described 868 | as: 869 | 870 | .. math:: 871 | 872 | \text{loss}(x, y) = \frac{1 - \text{SSIM}(x, y)}{2} 873 | 874 | Arguments: 875 | window_size (int): the size of the kernel. 876 | max_val (float): the dynamic range of the images. Default: 1. 877 | reduction (str, optional): Specifies the reduction to apply to the 878 | output: 'none' | 'mean' | 'sum'. 'none': no reduction will be applied, 879 | 'mean': the sum of the output will be divided by the number of elements 880 | in the output, 'sum': the output will be summed. Default: 'none'. 881 | 882 | Returns: 883 | Tensor: the ssim index. 884 | 885 | Shape: 886 | - Input: :math:`(B, C, H, W)` 887 | - Target :math:`(B, C, H, W)` 888 | - Output: scale, if reduction is 'none', then :math:`(B, C, H, W)` 889 | 890 | Examples:: 891 | 892 | >>> input1 = torch.rand(1, 4, 5, 5) 893 | >>> input2 = torch.rand(1, 4, 5, 5) 894 | >>> ssim = kornia.losses.SSIM(5, reduction='none') 895 | >>> loss = ssim(input1, input2) # 1x4x5x5 896 | """ 897 | 898 | def __init__( 899 | self, 900 | window_size: int, 901 | reduction: str = "none", 902 | max_val: float = 1.0) -> None: 903 | super(SSIM, self).__init__() 904 | self.window_size: int = window_size 905 | self.max_val: float = max_val 906 | self.reduction: str = reduction 907 | 908 | self.window: torch.Tensor = get_gaussian_kernel2d( 909 | (window_size, window_size), (1.5, 1.5)) 910 | self.window = self.window.requires_grad_(False) # need to disable gradients 911 | 912 | self.padding: int = _compute_zero_padding(window_size) 913 | 914 | self.C1: float = (0.01 * self.max_val) ** 2 915 | self.C2: float = (0.03 * self.max_val) ** 2 916 | 917 | def forward( # type: ignore 918 | self, 919 | img1: torch.Tensor, 920 | img2: torch.Tensor) -> torch.Tensor: 921 | 922 | if not torch.is_tensor(img1): 923 | raise TypeError("Input img1 type is not a torch.Tensor. Got {}" 924 | .format(type(img1))) 925 | 926 | if not torch.is_tensor(img2): 927 | raise TypeError("Input img2 type is not a torch.Tensor. Got {}" 928 | .format(type(img2))) 929 | 930 | if not len(img1.shape) == 4: 931 | raise ValueError("Invalid img1 shape, we expect BxCxHxW. Got: {}" 932 | .format(img1.shape)) 933 | 934 | if not len(img2.shape) == 4: 935 | raise ValueError("Invalid img2 shape, we expect BxCxHxW. Got: {}" 936 | .format(img2.shape)) 937 | 938 | if not img1.shape == img2.shape: 939 | raise ValueError("img1 and img2 shapes must be the same. Got: {}" 940 | .format(img1.shape, img2.shape)) 941 | 942 | if not img1.device == img2.device: 943 | raise ValueError("img1 and img2 must be in the same device. Got: {}" 944 | .format(img1.device, img2.device)) 945 | 946 | if not img1.dtype == img2.dtype: 947 | raise ValueError("img1 and img2 must be in the same dtype. Got: {}" 948 | .format(img1.dtype, img2.dtype)) 949 | 950 | # prepare kernel 951 | b, c, h, w = img1.shape 952 | tmp_kernel: torch.Tensor = self.window.to(img1.device).to(img1.dtype) 953 | tmp_kernel = torch.unsqueeze(tmp_kernel, dim=0) 954 | 955 | # compute local mean per channel 956 | mu1: torch.Tensor = filter2D(img1, tmp_kernel) 957 | mu2: torch.Tensor = filter2D(img2, tmp_kernel) 958 | 959 | mu1_sq = mu1.pow(2) 960 | mu2_sq = mu2.pow(2) 961 | mu1_mu2 = mu1 * mu2 962 | 963 | # compute local sigma per channel 964 | sigma1_sq = filter2D(img1 * img1, tmp_kernel) - mu1_sq 965 | sigma2_sq = filter2D(img2 * img2, tmp_kernel) - mu2_sq 966 | sigma12 = filter2D(img1 * img2, tmp_kernel) - mu1_mu2 967 | 968 | ssim_map = ((2. * mu1_mu2 + self.C1) * (2. * sigma12 + self.C2)) / \ 969 | ((mu1_sq + mu2_sq + self.C1) * (sigma1_sq + sigma2_sq + self.C2)) 970 | 971 | loss = torch.clamp(ssim_map, min=0, max=1) 972 | 973 | if self.reduction == "mean": 974 | loss = torch.mean(loss) 975 | elif self.reduction == "sum": 976 | loss = torch.sum(loss) 977 | elif self.reduction == "none": 978 | pass 979 | return loss 980 | 981 | def ssim( 982 | img1: torch.Tensor, 983 | img2: torch.Tensor, 984 | window_size: int, 985 | reduction: str = "mean", 986 | max_val: float = 1.0) -> torch.Tensor: 987 | r"""Function that measures the Structural Similarity (SSIM) index between 988 | each element in the input `x` and target `y`. 989 | 990 | See :class:`~kornia.losses.SSIM` for details. 991 | """ 992 | return SSIM(window_size, reduction, max_val)(img1, img2) 993 | 994 | # from typing import Tuple 995 | 996 | def get_laplacian_kernel2d(kernel_size: int) -> torch.Tensor: 997 | r"""Function that returns Gaussian filter matrix coefficients. 998 | 999 | Args: 1000 | kernel_size (int): filter size should be odd. 1001 | 1002 | Returns: 1003 | Tensor: 2D tensor with laplacian filter matrix coefficients. 1004 | 1005 | Shape: 1006 | - Output: :math:`(\text{kernel_size}_x, \text{kernel_size}_y)` 1007 | 1008 | Examples:: 1009 | 1010 | >>> kornia.image.get_laplacian_kernel2d(3) 1011 | tensor([[ 1., 1., 1.], 1012 | [ 1., -8., 1.], 1013 | [ 1., 1., 1.]]) 1014 | 1015 | >>> kornia.image.get_laplacian_kernel2d(5) 1016 | tensor([[ 1., 1., 1., 1., 1.], 1017 | [ 1., 1., 1., 1., 1.], 1018 | [ 1., 1., -24., 1., 1.], 1019 | [ 1., 1., 1., 1., 1.], 1020 | [ 1., 1., 1., 1., 1.]]) 1021 | 1022 | """ 1023 | if not isinstance(kernel_size, int) or kernel_size % 2 == 0 or \ 1024 | kernel_size <= 0: 1025 | raise TypeError("ksize must be an odd positive integer. Got {}" 1026 | .format(kernel_size)) 1027 | 1028 | kernel = torch.ones((kernel_size, kernel_size)) 1029 | mid = kernel_size // 2 1030 | kernel[mid, mid] = 1 - kernel_size ** 2 1031 | kernel_2d: torch.Tensor = kernel 1032 | return kernel_2d 1033 | 1034 | class Laplacian(nn.Module): 1035 | r"""Creates an operator that returns a tensor using a Laplacian filter. 1036 | 1037 | The operator smooths the given tensor with a laplacian kernel by convolving 1038 | it to each channel. It suports batched operation. 1039 | 1040 | Arguments: 1041 | kernel_size (int): the size of the kernel. 1042 | border_type (str): the padding mode to be applied before convolving. 1043 | The expected modes are: ``'constant'``, ``'reflect'``, 1044 | ``'replicate'`` or ``'circular'``. Default: ``'reflect'``. 1045 | normalized (bool): if True, L1 norm of the kernel is set to 1. 1046 | 1047 | Returns: 1048 | Tensor: the tensor. 1049 | 1050 | Shape: 1051 | - Input: :math:`(B, C, H, W)` 1052 | - Output: :math:`(B, C, H, W)` 1053 | 1054 | Examples:: 1055 | 1056 | >>> input = torch.rand(2, 4, 5, 5) 1057 | >>> laplace = kornia.filters.Laplacian(5) 1058 | >>> output = laplace(input) # 2x4x5x5 1059 | """ 1060 | 1061 | def __init__(self, 1062 | kernel_size: int, border_type: str = 'reflect', 1063 | normalized: bool = True) -> None: 1064 | super(Laplacian, self).__init__() 1065 | self.kernel_size: int = kernel_size 1066 | self.border_type: str = border_type 1067 | self.normalized: bool = normalized 1068 | self.kernel: torch.Tensor = torch.unsqueeze( 1069 | get_laplacian_kernel2d(kernel_size), dim=0) 1070 | if self.normalized: 1071 | self.kernel = normalize_kernel2d(self.kernel) 1072 | 1073 | def __repr__(self) -> str: 1074 | return self.__class__.__name__ +\ 1075 | '(kernel_size=' + str(self.kernel_size) + ', ' +\ 1076 | 'normalized=' + str(self.normalized) + ', ' + \ 1077 | 'border_type=' + self.border_type + ')' 1078 | 1079 | def forward(self, input: torch.Tensor): # type: ignore 1080 | return filter2D(input, self.kernel, self.border_type) 1081 | 1082 | def laplacian( 1083 | input: torch.Tensor, 1084 | kernel_size: int, 1085 | border_type: str = 'reflect', 1086 | normalized: bool = True) -> torch.Tensor: 1087 | r"""Function that returns a tensor using a Laplacian filter. 1088 | 1089 | See :class:`~kornia.filters.Laplacian` for details. 1090 | """ 1091 | return Laplacian(kernel_size, border_type, normalized)(input) --------------------------------------------------------------------------------