├── README.md ├── data ├── IrVi_dataset.py ├── __init__.py ├── __pycache__ │ ├── IrVi_dataset.cpython-39.pyc │ ├── __init__.cpython-39.pyc │ └── common.cpython-39.pyc └── common.py ├── dataset ├── test │ ├── ir │ │ ├── 1.png │ │ ├── 10.png │ │ ├── 11.png │ │ ├── 12.png │ │ ├── 13.png │ │ ├── 14.png │ │ ├── 15.png │ │ ├── 16.png │ │ ├── 17.png │ │ ├── 18.png │ │ ├── 19.png │ │ ├── 2.png │ │ ├── 20.png │ │ ├── 3.png │ │ ├── 4.png │ │ ├── 5.png │ │ ├── 6.png │ │ ├── 7.png │ │ ├── 8.png │ │ └── 9.png │ └── vi │ │ ├── 1.png │ │ ├── 10.png │ │ ├── 11.png │ │ ├── 12.png │ │ ├── 13.png │ │ ├── 14.png │ │ ├── 15.png │ │ ├── 16.png │ │ ├── 17.png │ │ ├── 18.png │ │ ├── 19.png │ │ ├── 2.png │ │ ├── 20.png │ │ ├── 3.png │ │ ├── 4.png │ │ ├── 5.png │ │ ├── 6.png │ │ ├── 7.png │ │ ├── 8.png │ │ └── 9.png ├── train │ ├── ir │ │ ├── 1.png │ │ ├── 10.png │ │ ├── 11.png │ │ ├── 12.png │ │ ├── 13.png │ │ ├── 14.png │ │ ├── 15.png │ │ ├── 16.png │ │ ├── 17.png │ │ ├── 18.png │ │ ├── 19.png │ │ ├── 2.png │ │ ├── 20.png │ │ ├── 3.png │ │ ├── 4.png │ │ ├── 5.png │ │ ├── 6.png │ │ ├── 7.png │ │ ├── 8.png │ │ └── 9.png │ └── vi │ │ ├── 1.png │ │ ├── 10.png │ │ ├── 11.png │ │ ├── 12.png │ │ ├── 13.png │ │ ├── 14.png │ │ ├── 15.png │ │ ├── 16.png │ │ ├── 17.png │ │ ├── 18.png │ │ ├── 19.png │ │ ├── 2.png │ │ ├── 20.png │ │ ├── 3.png │ │ ├── 4.png │ │ ├── 5.png │ │ ├── 6.png │ │ ├── 7.png │ │ ├── 8.png │ │ └── 9.png └── val │ ├── ir │ ├── 1.png │ ├── 10.png │ ├── 11.png │ ├── 12.png │ ├── 13.png │ ├── 14.png │ ├── 15.png │ ├── 16.png │ ├── 17.png │ ├── 18.png │ ├── 19.png │ ├── 2.png │ ├── 20.png │ ├── 3.png │ ├── 4.png │ ├── 5.png │ ├── 6.png │ ├── 7.png │ ├── 8.png │ └── 9.png │ └── vi │ ├── 1.png │ ├── 10.png │ ├── 11.png │ ├── 12.png │ ├── 13.png │ ├── 14.png │ ├── 15.png │ ├── 16.png │ ├── 17.png │ ├── 18.png │ ├── 19.png │ ├── 2.png │ ├── 20.png │ ├── 3.png │ ├── 4.png │ ├── 5.png │ ├── 6.png │ ├── 7.png │ ├── 8.png │ └── 9.png ├── model └── best_ckp.pth ├── networks ├── ATFuse.py ├── __init__.py ├── __pycache__ │ ├── ATFuse.cpython-39.pyc │ ├── __init__.cpython-39.pyc │ ├── loss.cpython-39.pyc │ └── util.cpython-39.pyc ├── loss.py └── util.py ├── options ├── __pycache__ │ └── options.cpython-39.pyc ├── options.py ├── test │ └── test_ATFuse.json └── train │ └── train_ATFuse.json ├── solvers ├── FuSolver.py ├── __init__.py ├── __pycache__ │ ├── FuSolver.cpython-39.pyc │ ├── __init__.cpython-39.pyc │ └── base_solver.cpython-39.pyc └── base_solver.py ├── testViIr.py ├── train_ViIr.py ├── utils ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-39.pyc │ └── util.cpython-39.pyc ├── hist_adjust.py ├── hyper_plot.py ├── lr_scheduler.py ├── metrics.py ├── plot_subfigs_colorbar.py ├── pyCartooTexture.py ├── reamdme.md ├── test.png ├── unionNormImg.py └── util.py └── visualization ├── MsVisualizer.py ├── __init__.py ├── __pycache__ ├── __init__.cpython-39.pyc └── visualizer.cpython-39.pyc ├── attaVisualizer.py └── visualizer.py /README.md: -------------------------------------------------------------------------------- 1 | # ATFuse 2 | Code for: Rethinking Cross-Attention for Infrared and Visible Image Fusion. 3 | 4 | Training and Testing datasets: https://pan.baidu.com/s/1p-2Y5x6uYSwRKfPmVgid6g?pwd=p6rt 5 | -------------------------------------------------------------------------------- /data/IrVi_dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import torch 4 | import torch.utils.data as data 5 | import cv2 6 | import torchvision.transforms as transforms 7 | import numpy as np 8 | from data import common 9 | 10 | 11 | class IrViDataset(data.Dataset): 12 | ''' 13 | Read LR and HR images in train and eval phases. 14 | 可见光代替LR,红外代替Pan 15 | ''' 16 | 17 | def name(self): 18 | return self.dataset_name # 返回使用的数据集名称 19 | 20 | def __init__(self, opt): 21 | super(IrViDataset, self).__init__() 22 | self.opt = opt 23 | self.train = (opt['phase'] == 'train') 24 | self.split = 'train' if self.train else 'test' 25 | self.scale = self.opt['scale'] # 放大倍数 26 | self.paths_Vi = None 27 | 28 | # read image list from image/binary files 29 | if self.opt["useContinueLearning"]: 30 | self.dataset_name = self.opt['dataroot_Vi'][int(self.opt["dataset_index"])].split("/")[1] 31 | # self.paths_Fu = common.get_image_paths(self.opt['data_type'], 32 | # self.opt['dataroot_Fu'][int(self.opt["dataset_index"])]) 33 | self.paths_Vi = common.get_image_paths(self.opt['data_type'], 34 | self.opt['dataroot_Vi'][int(self.opt["dataset_index"])]) 35 | self.paths_Ir = common.get_image_paths(self.opt['data_type'], 36 | self.opt['dataroot_Ir'][int(self.opt["dataset_index"])]) 37 | else: 38 | # self.paths_Fu = common.get_image_paths(self.opt['data_type'], self.opt['dataroot_Fu']) 39 | self.paths_Vi = common.get_image_paths(self.opt['data_type'], self.opt['dataroot_Vi']) 40 | self.paths_Ir = common.get_image_paths(self.opt['data_type'], self.opt['dataroot_Ir']) 41 | self.dataset_name = self.opt['dataroot_Vi'].split("/")[1] 42 | 43 | # assert self.paths_Fu, '[Error] Fu paths are empty.' 44 | if self.paths_Vi and self.paths_Ir: 45 | assert len(self.paths_Vi) == len(self.paths_Ir), \ 46 | '[Error] Vi: [%d] and Ir: [%d] have different number of images.' % ( 47 | len(self.paths_Vi), len(self.paths_Ir)) 48 | 49 | def __getitem__(self, idx): 50 | if self.train: 51 | vi, ir, vi_path, ir_path = self._load_file(idx) 52 | vi = vi[:, :, 0:1] 53 | 54 | ir = ir[:, :, 0:1] 55 | 56 | vi, ir = self.get_patch1(vi, ir) 57 | # self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.5], std=[0.5])]) 58 | # if self.transform: 59 | # vi_tensor = self.transform(vi) 60 | # ir_tensor = self.transform(ir) 61 | 62 | vi_tensor, ir_tensor = common.np2Tensor([vi, ir], self.opt['rgb_range']) 63 | return {'Vi': vi_tensor, 'Ir': ir_tensor, 'vi_path': vi_path, 64 | 'ir_path': ir_path} 65 | else: 66 | EPS = 1e-8 67 | 68 | vi, ir, vi_path, ir_path = self._load_file(idx) 69 | 70 | # vi_t = torch.tensor(vi) 71 | # ir_t = torch.tensor(ir) 72 | 73 | # vi_t = vi_t.unsqueeze(dim=3) 74 | # ir_t = ir_t.unsqueeze(dim=3) 75 | # imgs = torch.concat([vi_t,ir_t], dim=3) 76 | # imgs = np.transpose(imgs, (3,2,0,1)) 77 | # 78 | # img_cr = imgs[:, 1:2, :, :] 79 | # img_cb = imgs[:, 2:3, :, :] 80 | # w_cr = (torch.abs(img_cr) + EPS) / torch.sum(torch.abs(img_cr) + EPS, dim=0) 81 | # w_cb = (torch.abs(img_cb) + EPS) / torch.sum(torch.abs(img_cb) + EPS, dim=0) 82 | # fused_img_cr = torch.sum(w_cr * img_cr, dim=0, keepdim=True) 83 | # fused_img_cb = torch.sum(w_cb * img_cb, dim=0, keepdim=True) 84 | # fused_img_cr = fused_img_cr.squeeze(0) 85 | # fused_img_cb = fused_img_cb.squeeze(0) 86 | 87 | 88 | vi = vi[:, :, 0:1] 89 | ir = ir[:, :, 0:1] 90 | 91 | # self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.5], std=[0.5])]) 92 | # vi_tensor = self.transform(vi) 93 | # ir_tensor = self.transform(ir) 94 | 95 | vi_tensor, ir_tensor = common.np2Tensor([vi, ir], self.opt['rgb_range']) # 'fused_img_cr': fused_img_cr, 'fused_img_cb': fused_img_cb, 96 | return {'Vi': vi_tensor, 'Ir': ir_tensor, 'vi_path': vi_path, 97 | 'ir_path': ir_path} 98 | 99 | def get_patch1(self, over, under): 100 | h, w = over.shape[:2] 101 | stride = 128 102 | 103 | x = random.randint(0, w - stride) 104 | y = random.randint(0, h - stride) 105 | 106 | over = over[y:y + stride, x:x + stride, :] 107 | under = under[y:y + stride, x:x + stride, :] 108 | 109 | return over, under 110 | # if self.train: 111 | # vi, fu, ir = self._get_patch(vi, fu, ir) 112 | # if self.opt["shift_pace"]: 113 | # if random.random() > 0.5: 114 | # pan_shift = self.dealign_ms_pan(ir.copy()) 115 | # else: 116 | # pan_shift = ir.copy() 117 | # else: 118 | # pan_shift = ir.copy() 119 | # else: 120 | # if self.opt["shift_pace"]: 121 | # pan_shift = self.dealign_ms_pan(pan.copy()) 122 | # else: 123 | # pan_shift = pan.copy() 124 | 125 | 126 | def __len__(self): 127 | if self.train: 128 | return 20 * len(self.paths_Vi) 129 | # return 2 130 | else: 131 | return len(self.paths_Vi) 132 | # return 1 133 | # return len(self.paths_Vi) * 10 134 | 135 | 136 | def _get_index(self, idx): 137 | if self.train: 138 | return idx % len(self.paths_Vi) 139 | else: 140 | return idx 141 | 142 | def _load_file(self, idx): 143 | idx = self._get_index(idx) 144 | vi_path = self.paths_Vi[idx] 145 | # fu_path = self.paths_Fu[idx] 146 | ir_path = self.paths_Ir[idx] 147 | vi = common.read_img(vi_path, self.opt['data_type']) 148 | # fu = common.read_img(fu_path, self.opt['data_type']) 149 | ir = common.read_img(ir_path, self.opt['data_type']) 150 | return vi, ir, vi_path, ir_path 151 | 152 | def _get_patch(self, vi, fu, ir): 153 | 154 | vi_size = self.opt['vi_size'] 155 | # random crop and augment 156 | vi, fu = common.get_patch(vi, fu, 157 | vi_size, self.scale) 158 | vi, fu, ir = common.augment([vi, fu, ir]) 159 | vi = common.add_noise(vi, self.opt['noise']) 160 | 161 | return vi, fu, ir 162 | 163 | def dealign_ms_pan(self, pan): 164 | """ 165 | Artificially created unregistered pan and MS images 166 | :return: 167 | """ 168 | shift_pace = random.randint(1, 250) 169 | h, w, c = pan.shape 170 | pan = np.vstack((pan[(h - shift_pace):, :], pan[:(h - shift_pace), :])) 171 | pan = np.hstack((pan[:, (w - shift_pace):], pan[:, :(w - shift_pace)])) 172 | return pan 173 | 174 | def random_shuffle_patch(self, pan): 175 | patch_width = random.randint(10, 50) 176 | patch_height = random.randint(10, 50) 177 | row_num = pan.shape[0] // patch_height 178 | col_num = pan.shape[1] // patch_width 179 | 180 | li = [] 181 | for row in range(row_num): 182 | for col in range(col_num): 183 | li.append(pan[row * patch_height: row * patch_height + patch_width, 184 | col * patch_width:col * patch_width + patch_height, :]) 185 | np.random.shuffle(li) 186 | 187 | li2 = [] 188 | for row in range(row_num): 189 | li2.append(li[row * col_num:row * col_num + col_num]) 190 | 191 | li3 = [] 192 | for item in li2: 193 | li3.append(np.concatenate(item, axis=1)) 194 | li3 = np.concatenate(li3, axis=0) 195 | return li3 196 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data 2 | 3 | 4 | def create_dataloader(dataset, dataset_opt): 5 | phase = dataset_opt['phase'] 6 | if phase == 'train': 7 | batch_size = dataset_opt['batch_size'] 8 | shuffle = True 9 | num_workers = dataset_opt['n_workers'] #这是什么参数 10 | else: 11 | batch_size = 1 12 | shuffle = False 13 | num_workers = 1 14 | return torch.utils.data.DataLoader( 15 | dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True) 16 | 17 | 18 | def create_dataset(dataset_opt): 19 | mode = dataset_opt['mode'].upper() 20 | if mode == 'LRHR': 21 | from data.LRHR_dataset import LRHRDataset as D 22 | elif mode == 'IRVI': 23 | from data.IrVi_dataset import IrViDataset as D 24 | else: 25 | raise NotImplementedError("Dataset [%s] is not recognized." % mode) 26 | dataset = D(dataset_opt) 27 | print('===> [%s] Dataset is created.' % (mode)) 28 | return dataset 29 | 30 | 31 | # def create_dataset(dataset_opt): 32 | # mode = dataset_opt['mode'].upper() 33 | # if mode == 'IrVi': 34 | # from data.IrVi_dataset import IrViDataset as D 35 | # else: 36 | # raise NotImplementedError("Dataset [%s] is not recognized." % mode) 37 | # dataset = D(dataset_opt) 38 | # print('===> [%s] Dataset is created.' % (mode)) 39 | # return dataset -------------------------------------------------------------------------------- /data/__pycache__/IrVi_dataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/data/__pycache__/IrVi_dataset.cpython-39.pyc -------------------------------------------------------------------------------- /data/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/data/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /data/__pycache__/common.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/data/__pycache__/common.cpython-39.pyc -------------------------------------------------------------------------------- /data/common.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import cv2 5 | import numpy as np 6 | import scipy.misc as misc 7 | import imageio 8 | from tqdm import tqdm 9 | 10 | import torch 11 | 12 | IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 'tif'] 13 | BINARY_EXTENSIONS = ['.npy'] 14 | BENCHMARK = ['Set5', 'Set14', 'B100', 'Urban100', 'Manga109', 'DIV2K', 'DF2K'] 15 | 16 | 17 | #################### 18 | # Files & IO 19 | #################### 20 | def is_image_file(filename): 21 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 22 | 23 | 24 | def is_binary_file(filename): 25 | return any(filename.endswith(extension) for extension in BINARY_EXTENSIONS) 26 | 27 | 28 | def _get_paths_from_images(path): 29 | assert os.path.isdir(path), '[Error] [%s] is not a valid directory' % path 30 | images = [] 31 | for dirpath, _, fnames in sorted(os.walk(path)): 32 | for fname in sorted(fnames): 33 | if is_image_file(fname): 34 | img_path = os.path.join(dirpath, fname) 35 | images.append(img_path) 36 | assert images, '[%s] has no valid image file' % path 37 | return images 38 | 39 | 40 | def _get_paths_from_binary(path): 41 | assert os.path.isdir(path), '[Error] [%s] is not a valid directory' % path 42 | files = [] 43 | for dirpath, _, fnames in sorted(os.walk(path)): 44 | for fname in sorted(fnames): 45 | if is_binary_file(fname): 46 | binary_path = os.path.join(dirpath, fname) 47 | files.append(binary_path) 48 | assert files, '[%s] has no valid binary file' % path 49 | return files 50 | 51 | 52 | def get_image_paths(data_type, dataroot): 53 | paths = None 54 | if dataroot is not None: 55 | if data_type == 'img': 56 | paths = sorted(_get_paths_from_images(dataroot)) 57 | elif data_type == 'png': 58 | paths = sorted(_get_paths_from_images(dataroot)) 59 | elif data_type == 'rgb': 60 | paths = sorted(_get_paths_from_images(dataroot)) 61 | elif data_type == 'tif': 62 | paths = sorted(_get_paths_from_images(dataroot)) 63 | elif data_type == 'jpg': 64 | paths = sorted(_get_paths_from_images(dataroot)) 65 | elif data_type == 'npy': 66 | if dataroot.find('_npy') < 0 : 67 | old_dir = dataroot 68 | dataroot = dataroot + '_npy' 69 | if not os.path.exists(dataroot): 70 | print('===> Creating binary files in [%s]' % dataroot) 71 | os.makedirs(dataroot) 72 | img_paths = sorted(_get_paths_from_images(old_dir)) 73 | path_bar = tqdm(img_paths) 74 | for v in path_bar: 75 | img = imageio.imread(v, pilmode='RGB') 76 | ext = os.path.splitext(os.path.basename(v))[-1] 77 | name_sep = os.path.basename(v.replace(ext, '.npy')) 78 | np.save(os.path.join(dataroot, name_sep), img) 79 | else: 80 | print('===> Binary files already exists in [%s]. Skip binary files generation.' % dataroot) 81 | 82 | paths = sorted(_get_paths_from_binary(dataroot)) 83 | 84 | else: 85 | raise NotImplementedError("[Error] Data_type [%s] is not recognized." % data_type) 86 | return paths 87 | 88 | 89 | def find_benchmark(dataroot): 90 | bm_list = [dataroot.find(bm)>=0 for bm in BENCHMARK] 91 | if not sum(bm_list) == 0: 92 | bm_idx = bm_list.index(True) 93 | bm_name = BENCHMARK[bm_idx] 94 | else: 95 | bm_name = 'MyImage' 96 | return bm_name 97 | 98 | 99 | def read_img(path, data_type): 100 | # read image by misc or from .npy 101 | # return: Numpy float32, HWC, RGB, [0,255] 102 | if data_type == 'img': 103 | img = imageio.imread(path, pilmode='RGB') 104 | elif data_type.find('npy') >= 0: 105 | img = np.load(path) 106 | elif data_type == 'png': 107 | # img = cv2.imread(path, cv2.IMREAD_GRAYSCALE) 108 | img = cv2.imread(path) 109 | img = cv2.cvtColor(img, cv2.COLOR_BGR2YCrCb) 110 | elif data_type == 'tif': 111 | img = cv2.imread(path, cv2.IMREAD_GRAYSCALE) 112 | elif data_type == 'rgb': 113 | img = cv2.imread(path, cv2.IMREAD_COLOR) 114 | elif data_type == 'jpg': 115 | img = cv2.imread(path) 116 | img = cv2.cvtColor(img, cv2.COLOR_BGR2YCrCb) 117 | else: 118 | raise NotImplementedError 119 | 120 | if img.ndim == 2: 121 | img = np.expand_dims(img, axis=2) 122 | return img 123 | 124 | 125 | #################### 126 | # image processing 127 | # process on numpy image 128 | #################### 129 | def np2Tensor(l, rgb_range): 130 | def _np2Tensor(img): 131 | np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1))) #函数将一个内存不连续存储的数组转换为内存连续存储的数组,使得运行速度更快。 132 | tensor = torch.from_numpy(np_transpose).float() 133 | tensor.mul_(rgb_range / 2047.) 134 | return tensor 135 | 136 | return [_np2Tensor(_l) for _l in l] 137 | 138 | 139 | def get_patch(img_in, img_tar, patch_size, scale): 140 | ih, iw = img_in.shape[:2] 141 | oh, ow = img_tar.shape[:2] 142 | 143 | ip = patch_size 144 | 145 | if ih == oh: 146 | tp = ip 147 | ix = random.randrange(0, iw - ip + 1) 148 | iy = random.randrange(0, ih - ip + 1) 149 | tx, ty = ix, iy 150 | else: 151 | tp = ip * scale 152 | ix = random.randrange(0, iw - ip + 1) 153 | iy = random.randrange(0, ih - ip + 1) 154 | tx, ty = scale * ix, scale * iy 155 | 156 | img_in = img_in[iy:iy + ip, ix:ix + ip, :] 157 | img_tar = img_tar[ty:ty + tp, tx:tx + tp, :] 158 | 159 | return img_in, img_tar 160 | 161 | 162 | def add_noise(x, noise='.'): 163 | if noise is not '.': 164 | noise_type = noise[0] 165 | # noise_value = int(noise[1:]) 166 | noise_value = int(noise[1:]) 167 | 168 | if noise_type == 'G': 169 | noises = np.random.normal(scale=noise_value, size=x.shape) 170 | noises = noises.round() 171 | elif noise_type == 'S': 172 | noises = np.random.poisson(x * noise_value) / noise_value 173 | noises = noises - noises.mean(axis=0).mean(axis=0) 174 | 175 | x_noise = x.astype(np.int16) + noises.astype(np.int16) 176 | x_noise = x_noise.clip(0, 255).astype(np.uint8) 177 | return x_noise 178 | else: 179 | return x 180 | 181 | 182 | def augment(img_list, hflip=True, rot=True): 183 | # horizontal flip OR rotate 184 | hflip = hflip and random.random() < 0.5 185 | vflip = rot and random.random() < 0.5 186 | rot90 = rot and random.random() < 0.5 187 | 188 | def _augment(img): 189 | if hflip: img = img[:, ::-1, :] 190 | if vflip: img = img[::-1, :, :] 191 | if rot90: img = img.transpose(1, 0, 2) 192 | return img 193 | 194 | return [_augment(img) for img in img_list] 195 | 196 | 197 | def modcrop(img_in, scale): 198 | img = np.copy(img_in) 199 | if img.ndim == 2: 200 | H, W = img.shape 201 | H_r, W_r = H % scale, W % scale 202 | img = img[:H - H_r, :W - W_r] 203 | elif img.ndim == 3: 204 | H, W, C = img.shape 205 | H_r, W_r = H % scale, W % scale 206 | img = img[:H - H_r, :W - W_r, :] 207 | else: 208 | raise ValueError('Wrong img ndim: [%d].' % img.ndim) 209 | return img 210 | -------------------------------------------------------------------------------- /dataset/test/ir/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/test/ir/1.png -------------------------------------------------------------------------------- /dataset/test/ir/10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/test/ir/10.png -------------------------------------------------------------------------------- /dataset/test/ir/11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/test/ir/11.png -------------------------------------------------------------------------------- /dataset/test/ir/12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/test/ir/12.png -------------------------------------------------------------------------------- /dataset/test/ir/13.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/test/ir/13.png -------------------------------------------------------------------------------- /dataset/test/ir/14.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/test/ir/14.png -------------------------------------------------------------------------------- /dataset/test/ir/15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/test/ir/15.png -------------------------------------------------------------------------------- /dataset/test/ir/16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/test/ir/16.png -------------------------------------------------------------------------------- /dataset/test/ir/17.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/test/ir/17.png -------------------------------------------------------------------------------- /dataset/test/ir/18.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/test/ir/18.png -------------------------------------------------------------------------------- /dataset/test/ir/19.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/test/ir/19.png -------------------------------------------------------------------------------- /dataset/test/ir/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/test/ir/2.png -------------------------------------------------------------------------------- /dataset/test/ir/20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/test/ir/20.png -------------------------------------------------------------------------------- /dataset/test/ir/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/test/ir/3.png -------------------------------------------------------------------------------- /dataset/test/ir/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/test/ir/4.png -------------------------------------------------------------------------------- /dataset/test/ir/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/test/ir/5.png -------------------------------------------------------------------------------- /dataset/test/ir/6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/test/ir/6.png -------------------------------------------------------------------------------- /dataset/test/ir/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/test/ir/7.png -------------------------------------------------------------------------------- /dataset/test/ir/8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/test/ir/8.png -------------------------------------------------------------------------------- /dataset/test/ir/9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/test/ir/9.png -------------------------------------------------------------------------------- /dataset/test/vi/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/test/vi/1.png -------------------------------------------------------------------------------- /dataset/test/vi/10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/test/vi/10.png -------------------------------------------------------------------------------- /dataset/test/vi/11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/test/vi/11.png -------------------------------------------------------------------------------- /dataset/test/vi/12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/test/vi/12.png -------------------------------------------------------------------------------- /dataset/test/vi/13.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/test/vi/13.png -------------------------------------------------------------------------------- /dataset/test/vi/14.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/test/vi/14.png -------------------------------------------------------------------------------- /dataset/test/vi/15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/test/vi/15.png -------------------------------------------------------------------------------- /dataset/test/vi/16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/test/vi/16.png -------------------------------------------------------------------------------- /dataset/test/vi/17.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/test/vi/17.png -------------------------------------------------------------------------------- /dataset/test/vi/18.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/test/vi/18.png -------------------------------------------------------------------------------- /dataset/test/vi/19.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/test/vi/19.png -------------------------------------------------------------------------------- /dataset/test/vi/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/test/vi/2.png -------------------------------------------------------------------------------- /dataset/test/vi/20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/test/vi/20.png -------------------------------------------------------------------------------- /dataset/test/vi/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/test/vi/3.png -------------------------------------------------------------------------------- /dataset/test/vi/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/test/vi/4.png -------------------------------------------------------------------------------- /dataset/test/vi/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/test/vi/5.png -------------------------------------------------------------------------------- /dataset/test/vi/6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/test/vi/6.png -------------------------------------------------------------------------------- /dataset/test/vi/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/test/vi/7.png -------------------------------------------------------------------------------- /dataset/test/vi/8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/test/vi/8.png -------------------------------------------------------------------------------- /dataset/test/vi/9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/test/vi/9.png -------------------------------------------------------------------------------- /dataset/train/ir/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/train/ir/1.png -------------------------------------------------------------------------------- /dataset/train/ir/10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/train/ir/10.png -------------------------------------------------------------------------------- /dataset/train/ir/11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/train/ir/11.png -------------------------------------------------------------------------------- /dataset/train/ir/12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/train/ir/12.png -------------------------------------------------------------------------------- /dataset/train/ir/13.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/train/ir/13.png -------------------------------------------------------------------------------- /dataset/train/ir/14.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/train/ir/14.png -------------------------------------------------------------------------------- /dataset/train/ir/15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/train/ir/15.png -------------------------------------------------------------------------------- /dataset/train/ir/16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/train/ir/16.png -------------------------------------------------------------------------------- /dataset/train/ir/17.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/train/ir/17.png -------------------------------------------------------------------------------- /dataset/train/ir/18.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/train/ir/18.png -------------------------------------------------------------------------------- /dataset/train/ir/19.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/train/ir/19.png -------------------------------------------------------------------------------- /dataset/train/ir/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/train/ir/2.png -------------------------------------------------------------------------------- /dataset/train/ir/20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/train/ir/20.png -------------------------------------------------------------------------------- /dataset/train/ir/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/train/ir/3.png -------------------------------------------------------------------------------- /dataset/train/ir/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/train/ir/4.png -------------------------------------------------------------------------------- /dataset/train/ir/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/train/ir/5.png -------------------------------------------------------------------------------- /dataset/train/ir/6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/train/ir/6.png -------------------------------------------------------------------------------- /dataset/train/ir/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/train/ir/7.png -------------------------------------------------------------------------------- /dataset/train/ir/8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/train/ir/8.png -------------------------------------------------------------------------------- /dataset/train/ir/9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/train/ir/9.png -------------------------------------------------------------------------------- /dataset/train/vi/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/train/vi/1.png -------------------------------------------------------------------------------- /dataset/train/vi/10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/train/vi/10.png -------------------------------------------------------------------------------- /dataset/train/vi/11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/train/vi/11.png -------------------------------------------------------------------------------- /dataset/train/vi/12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/train/vi/12.png -------------------------------------------------------------------------------- /dataset/train/vi/13.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/train/vi/13.png -------------------------------------------------------------------------------- /dataset/train/vi/14.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/train/vi/14.png -------------------------------------------------------------------------------- /dataset/train/vi/15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/train/vi/15.png -------------------------------------------------------------------------------- /dataset/train/vi/16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/train/vi/16.png -------------------------------------------------------------------------------- /dataset/train/vi/17.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/train/vi/17.png -------------------------------------------------------------------------------- /dataset/train/vi/18.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/train/vi/18.png -------------------------------------------------------------------------------- /dataset/train/vi/19.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/train/vi/19.png -------------------------------------------------------------------------------- /dataset/train/vi/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/train/vi/2.png -------------------------------------------------------------------------------- /dataset/train/vi/20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/train/vi/20.png -------------------------------------------------------------------------------- /dataset/train/vi/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/train/vi/3.png -------------------------------------------------------------------------------- /dataset/train/vi/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/train/vi/4.png -------------------------------------------------------------------------------- /dataset/train/vi/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/train/vi/5.png -------------------------------------------------------------------------------- /dataset/train/vi/6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/train/vi/6.png -------------------------------------------------------------------------------- /dataset/train/vi/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/train/vi/7.png -------------------------------------------------------------------------------- /dataset/train/vi/8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/train/vi/8.png -------------------------------------------------------------------------------- /dataset/train/vi/9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/train/vi/9.png -------------------------------------------------------------------------------- /dataset/val/ir/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/val/ir/1.png -------------------------------------------------------------------------------- /dataset/val/ir/10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/val/ir/10.png -------------------------------------------------------------------------------- /dataset/val/ir/11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/val/ir/11.png -------------------------------------------------------------------------------- /dataset/val/ir/12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/val/ir/12.png -------------------------------------------------------------------------------- /dataset/val/ir/13.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/val/ir/13.png -------------------------------------------------------------------------------- /dataset/val/ir/14.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/val/ir/14.png -------------------------------------------------------------------------------- /dataset/val/ir/15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/val/ir/15.png -------------------------------------------------------------------------------- /dataset/val/ir/16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/val/ir/16.png -------------------------------------------------------------------------------- /dataset/val/ir/17.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/val/ir/17.png -------------------------------------------------------------------------------- /dataset/val/ir/18.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/val/ir/18.png -------------------------------------------------------------------------------- /dataset/val/ir/19.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/val/ir/19.png -------------------------------------------------------------------------------- /dataset/val/ir/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/val/ir/2.png -------------------------------------------------------------------------------- /dataset/val/ir/20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/val/ir/20.png -------------------------------------------------------------------------------- /dataset/val/ir/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/val/ir/3.png -------------------------------------------------------------------------------- /dataset/val/ir/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/val/ir/4.png -------------------------------------------------------------------------------- /dataset/val/ir/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/val/ir/5.png -------------------------------------------------------------------------------- /dataset/val/ir/6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/val/ir/6.png -------------------------------------------------------------------------------- /dataset/val/ir/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/val/ir/7.png -------------------------------------------------------------------------------- /dataset/val/ir/8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/val/ir/8.png -------------------------------------------------------------------------------- /dataset/val/ir/9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/val/ir/9.png -------------------------------------------------------------------------------- /dataset/val/vi/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/val/vi/1.png -------------------------------------------------------------------------------- /dataset/val/vi/10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/val/vi/10.png -------------------------------------------------------------------------------- /dataset/val/vi/11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/val/vi/11.png -------------------------------------------------------------------------------- /dataset/val/vi/12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/val/vi/12.png -------------------------------------------------------------------------------- /dataset/val/vi/13.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/val/vi/13.png -------------------------------------------------------------------------------- /dataset/val/vi/14.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/val/vi/14.png -------------------------------------------------------------------------------- /dataset/val/vi/15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/val/vi/15.png -------------------------------------------------------------------------------- /dataset/val/vi/16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/val/vi/16.png -------------------------------------------------------------------------------- /dataset/val/vi/17.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/val/vi/17.png -------------------------------------------------------------------------------- /dataset/val/vi/18.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/val/vi/18.png -------------------------------------------------------------------------------- /dataset/val/vi/19.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/val/vi/19.png -------------------------------------------------------------------------------- /dataset/val/vi/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/val/vi/2.png -------------------------------------------------------------------------------- /dataset/val/vi/20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/val/vi/20.png -------------------------------------------------------------------------------- /dataset/val/vi/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/val/vi/3.png -------------------------------------------------------------------------------- /dataset/val/vi/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/val/vi/4.png -------------------------------------------------------------------------------- /dataset/val/vi/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/val/vi/5.png -------------------------------------------------------------------------------- /dataset/val/vi/6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/val/vi/6.png -------------------------------------------------------------------------------- /dataset/val/vi/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/val/vi/7.png -------------------------------------------------------------------------------- /dataset/val/vi/8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/val/vi/8.png -------------------------------------------------------------------------------- /dataset/val/vi/9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/dataset/val/vi/9.png -------------------------------------------------------------------------------- /model/best_ckp.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/model/best_ckp.pth -------------------------------------------------------------------------------- /networks/ATFuse.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch import einsum, nn 5 | import numpy as np 6 | from functools import partial 7 | import torch.nn.functional as F 8 | from torch.nn import Softmin 9 | 10 | 11 | class Conv2d_BN(nn.Module): 12 | """Convolution with BN module.""" 13 | 14 | def __init__( 15 | self, 16 | in_ch, 17 | out_ch, 18 | kernel_size=1, 19 | stride=1, 20 | pad=0, 21 | dilation=1, 22 | groups=1, 23 | bn_weight_init=1, 24 | norm_layer=nn.BatchNorm2d, 25 | act_layer=None, 26 | ): 27 | super().__init__() 28 | 29 | self.conv = torch.nn.Conv2d(in_ch, 30 | out_ch, 31 | kernel_size, 32 | stride, 33 | pad, 34 | dilation, 35 | groups, 36 | bias=False) 37 | self.bn = norm_layer(out_ch) 38 | torch.nn.init.constant_(self.bn.weight, bn_weight_init) 39 | torch.nn.init.constant_(self.bn.bias, 0) 40 | for m in self.modules(): 41 | if isinstance(m, nn.Conv2d): 42 | # Note that there is no bias due to BN 43 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 44 | m.weight.data.normal_(mean=0.0, std=np.sqrt(2.0 / fan_out)) 45 | 46 | self.act_layer = act_layer() if act_layer is not None else nn.Identity( 47 | ) 48 | 49 | def forward(self, x): 50 | """foward function""" 51 | x = self.conv(x) 52 | x = self.bn(x) 53 | x = self.act_layer(x) 54 | 55 | return x 56 | 57 | 58 | class DWConv2d_BN(nn.Module): 59 | """Depthwise Separable Convolution with BN module.""" 60 | 61 | def __init__( 62 | self, 63 | in_ch, 64 | out_ch, 65 | kernel_size=1, 66 | stride=1, 67 | norm_layer=nn.BatchNorm2d, 68 | act_layer=nn.Hardswish, 69 | bn_weight_init=1, 70 | ): 71 | super().__init__() 72 | 73 | # dw 74 | self.dwconv = nn.Conv2d( 75 | in_ch, 76 | out_ch, 77 | kernel_size, 78 | stride, 79 | (kernel_size - 1) // 2, 80 | groups=out_ch, 81 | bias=False, 82 | ) 83 | # pw-linear 84 | self.pwconv = nn.Conv2d(out_ch, out_ch, 1, 1, 0, bias=False) 85 | self.bn = norm_layer(out_ch) 86 | self.act = act_layer() if act_layer is not None else nn.Identity() 87 | 88 | # initialize parameters 89 | for m in self.modules(): 90 | if isinstance(m, nn.Conv2d): 91 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 92 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 93 | if m.bias is not None: 94 | m.bias.data.zero_() 95 | elif isinstance(m, nn.BatchNorm2d): 96 | m.weight.data.fill_(bn_weight_init) 97 | m.bias.data.zero_() 98 | 99 | def forward(self, x): 100 | """ 101 | foward function 102 | """ 103 | x = self.dwconv(x) 104 | x = self.pwconv(x) 105 | x = self.bn(x) 106 | x = self.act(x) 107 | 108 | return x 109 | 110 | 111 | class DWCPatchEmbed(nn.Module): 112 | """Depthwise Convolutional Patch Embedding layer Image to Patch 113 | Embedding.""" 114 | 115 | def __init__(self, 116 | in_chans=3, 117 | embed_dim=768, 118 | patch_size=16, 119 | stride=1, 120 | act_layer=nn.Hardswish): 121 | super().__init__() 122 | 123 | self.patch_conv = DWConv2d_BN( 124 | in_chans, 125 | embed_dim, 126 | kernel_size=patch_size, 127 | stride=stride, 128 | act_layer=act_layer, 129 | ) 130 | 131 | def forward(self, x): 132 | """foward function""" 133 | x = self.patch_conv(x) 134 | 135 | return x 136 | 137 | 138 | class Patch_Embed_stage(nn.Module): 139 | """Depthwise Convolutional Patch Embedding stage comprised of 140 | `DWCPatchEmbed` layers.""" 141 | 142 | def __init__(self, embed_dim, num_path=4, isPool=False): 143 | super(Patch_Embed_stage, self).__init__() 144 | 145 | self.patch_embeds = nn.ModuleList([ 146 | DWCPatchEmbed( 147 | in_chans=embed_dim, 148 | embed_dim=embed_dim, 149 | patch_size=3, 150 | stride=2 if isPool and idx == 0 else 1, 151 | ) for idx in range(num_path) 152 | ]) 153 | 154 | def forward(self, inputs): 155 | """foward function""" 156 | att_inputs = [] 157 | for x, pe in zip(inputs, self.patch_embeds): 158 | x = pe(x) 159 | att_inputs.append(x) 160 | 161 | return att_inputs 162 | 163 | 164 | class FactorAtt_ConvRelPosEnc(nn.Module): 165 | """Factorized attention with convolutional relative position encoding 166 | class.""" 167 | 168 | def __init__( 169 | self, 170 | dim, 171 | num_heads=8, 172 | qkv_bias=False, 173 | qk_scale=None, 174 | attn_drop=0.0, 175 | proj_drop=0.0, 176 | shared_crpe=None, 177 | ): 178 | super().__init__() 179 | self.num_heads = num_heads 180 | head_dim = dim // num_heads 181 | self.scale = qk_scale or head_dim ** -0.5 182 | 183 | self.q = nn.Linear(dim, dim, bias=qkv_bias) 184 | self.k = nn.Linear(dim, dim, bias=qkv_bias) 185 | self.v = nn.Linear(dim, dim, bias=qkv_bias) 186 | self.attn_drop = nn.Dropout(attn_drop) 187 | self.proj = nn.Linear(dim, dim) 188 | self.proj_drop = nn.Dropout(proj_drop) 189 | 190 | # Shared convolutional relative position encoding. 191 | self.crpe = shared_crpe 192 | 193 | def forward(self, q, k, v, minus=True): 194 | B, N, C = q.shape 195 | 196 | # Generate Q, K, V. 197 | q = self.q(q).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 198 | k = self.k(k).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 199 | v = self.v(v).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 200 | 201 | # Factorized attention. 202 | use_efficient = minus 203 | if use_efficient: 204 | 205 | k_softmax = k.softmax(dim=2) 206 | 207 | k_softmax_T_dot_v = einsum("b h n k, b h n v -> b h k v", k_softmax, v) 208 | factor_att = einsum("b h n k, b h k v -> b h n v", q, k_softmax_T_dot_v) 209 | else: 210 | # minus = Softmin(dim=2) 211 | # k_softmax = minus(k) 212 | # k_softmax_T_dot_v = einsum("b h n k, b h n v -> b h k v", k_softmax, v) 213 | # factor_att = einsum("b h n k, b h k v -> b h n v", q, k_softmax_T_dot_v) 214 | k_softmax = k.softmax(dim=2) 215 | 216 | k_softmax_T_dot_v = einsum("b h n k, b h n v -> b h k v", k_softmax, v) 217 | factor_att = einsum("b h n k, b h k v -> b h n v", q, k_softmax_T_dot_v) 218 | # else: 219 | # q_dot_k = einsum("b h n k, b h n v -> b h k v", q, k) 220 | # q_dot_k_softmax = q_dot_k.softmax(dim=2) 221 | # factor_att = einsum("b h n v, b h n v -> b h n v", q_dot_k_softmax, v) 222 | 223 | # Convolutional relative position encoding. 224 | # if self.crpe: 225 | # crpe = self.crpe(q, v, size=size) 226 | # else: 227 | # crpe = 0 228 | 229 | # Merge and reshape. 230 | if use_efficient: 231 | x = factor_att # + crpe ViIr2用的0.5 232 | else: 233 | x = v - factor_att 234 | x = x.transpose(1, 2).reshape(B, N, C) 235 | 236 | # Output projection. 237 | 238 | x = self.proj(x) 239 | x = self.proj_drop(x) 240 | 241 | return x 242 | 243 | 244 | class Mlp(nn.Module): 245 | """Feed-forward network (FFN, a.k.a. 246 | 247 | MLP) class. 248 | """ 249 | 250 | def __init__( 251 | self, 252 | in_features, 253 | hidden_features=None, 254 | out_features=None, 255 | act_layer=nn.GELU, 256 | drop=0.0, 257 | ): 258 | super().__init__() 259 | out_features = out_features or in_features 260 | hidden_features = hidden_features or in_features 261 | self.fc1 = nn.Linear(in_features, hidden_features) 262 | self.act = act_layer() 263 | self.fc2 = nn.Linear(hidden_features, out_features) 264 | self.drop = nn.Dropout(drop) 265 | 266 | def forward(self, x): 267 | """foward function""" 268 | x = self.fc1(x) 269 | x = self.act(x) 270 | x = self.drop(x) 271 | x = self.fc2(x) 272 | x = self.drop(x) 273 | return x 274 | 275 | 276 | class MHCABlock(nn.Module): 277 | """Multi-Head Convolutional self-Attention block.""" 278 | 279 | def __init__( 280 | self, 281 | dim, 282 | num_heads, 283 | mlp_ratio=3, 284 | qkv_bias=True, 285 | qk_scale=None, 286 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 287 | shared_cpe=None, 288 | shared_crpe=None, 289 | ): 290 | super().__init__() 291 | 292 | self.cpe = shared_cpe 293 | self.crpe = shared_crpe 294 | self.fuse = nn.Linear(dim * 2, dim) 295 | self.factoratt_crpe = FactorAtt_ConvRelPosEnc( 296 | dim, 297 | num_heads=num_heads, 298 | qkv_bias=qkv_bias, 299 | qk_scale=qk_scale, 300 | shared_crpe=shared_crpe, 301 | ) 302 | self.mlp = Mlp(in_features=dim, hidden_features=dim * mlp_ratio) 303 | 304 | self.norm2 = norm_layer(dim) 305 | 306 | def forward(self, q, k, v, minus=True): 307 | """foward function""" 308 | """foward function""" 309 | b, c, h, w = q.size(0), q.size(1), q.size(2), q.size(3) 310 | q = q.flatten(2).transpose(1, 2) 311 | k = k.flatten(2).transpose(1, 2) 312 | v = v.flatten(2).transpose(1, 2) 313 | x = q + self.factoratt_crpe(q, k, v, minus) 314 | cur = self.norm2(x) 315 | x = x + self.mlp(cur) 316 | x = x.reshape(b, h, w, -1).permute(0, 3, 1, 2).contiguous() 317 | return x 318 | 319 | 320 | class UpScale(nn.Module): 321 | def __init__(self, is_feature_sum, embed_dim, ): 322 | super(UpScale, self).__init__() 323 | self.is_feature_sum = is_feature_sum 324 | if is_feature_sum: 325 | self.conv11_headSum = nn.Conv2d(embed_dim, embed_dim, kernel_size=3, 326 | stride=1, padding=1, bias=True) 327 | else: 328 | self.conv11_head = nn.Conv2d(embed_dim * 2, embed_dim, kernel_size=3, 329 | stride=1, padding=1, bias=True) 330 | self.conv12 = nn.Conv2d(embed_dim, embed_dim * 2, kernel_size=3, 331 | stride=1, padding=1, bias=True) 332 | self.ps12 = nn.PixelShuffle(2) 333 | self.conv11_tail = nn.Conv2d(embed_dim // 2, embed_dim // 2, kernel_size=3, 334 | stride=1, padding=1, bias=True) 335 | 336 | def forward(self, x, x_res): 337 | x11 = x # B, C, H, W 338 | if self.is_feature_sum: 339 | x = x + x_res # B, C, H, W 340 | x = self.conv11_headSum(x) # B, C, H, W 341 | else: 342 | x = torch.cat([x, x_res], dim=1) # B, 2*C, H, W 343 | x = self.conv11_head(x) # B, C, H, W 344 | x = x + x11 # B, C, H, W 345 | x22 = self.conv12(x) # B, 2*C, H, W 346 | x = F.relu(self.ps12(x22)) # B, C, 2*H, 2*W 347 | x = self.conv11_tail(x) # 这里考虑一下是否使用relu函数 B, C, 2*H, 2*W 348 | return x 349 | 350 | 351 | def conv3x3(in_channels, out_channels, stride=1): 352 | return nn.Conv2d(in_channels, out_channels, kernel_size=3, 353 | stride=stride, padding=1, bias=True) 354 | 355 | 356 | class ResBlock(nn.Module): 357 | def __init__(self, in_channels, out_channels, stride=1, res_scale=1): 358 | super(ResBlock, self).__init__() 359 | self.res_scale = res_scale 360 | self.conv1 = conv3x3(in_channels, out_channels, stride) 361 | self.relu = nn.ReLU(inplace=True) 362 | self.conv2 = conv3x3(out_channels, out_channels) 363 | 364 | def forward(self, x): 365 | x1 = x 366 | out = self.conv1(x) 367 | out = self.relu(out) 368 | out = self.conv2(out) 369 | out = out * self.res_scale + self.conv1(x1) 370 | return out 371 | 372 | 373 | class ATF(nn.Module): 374 | def __init__(self, opt): 375 | super(ATF, self).__init__() 376 | self.num_stage = opt["networks"]["num_stage"] 377 | self.factor = opt["scale"] 378 | self.use_aggregate = opt["networks"]["use_aggregate"] 379 | in_chans = opt["networks"]["in_channels"] 380 | out_chans = opt["networks"]["out_channels"] 381 | embed_dims = opt["networks"]["embed_dims"] 382 | num_paths = opt["networks"]["num_paths"] 383 | num_heads = opt["networks"]["num_heads"] 384 | mlp_ratio = opt["networks"]["mlp_ratio"] 385 | self.num_paths = num_paths 386 | 387 | self.stem1 = nn.ModuleList([ 388 | Conv2d_BN( 389 | in_chans, 390 | embed_dims[0], 391 | kernel_size=3, 392 | stride=2, 393 | pad=1, 394 | act_layer=nn.Hardswish, 395 | ) for _ in range(num_paths[0]) 396 | ]) # B,C,H/2,W/2, 对所有的通道进行处理 397 | 398 | self.patch_embed_stages1 = Patch_Embed_stage( 399 | embed_dims[0], 400 | num_path=num_paths[0], 401 | isPool=False 402 | ) # B,C,H/2,W/2, 对所有的通道进行处理 403 | 404 | self.mhca_stage = MHCABlock( 405 | embed_dims[0], 406 | num_heads=num_heads, 407 | mlp_ratio=mlp_ratio, 408 | qk_scale=None, 409 | shared_cpe=None, 410 | shared_crpe=None, 411 | ) 412 | # B,C,H/2,W/2, 对所有的通道进行处理 413 | 414 | self.ir1_attn1 = MHCABlock(embed_dims[0], 415 | num_heads=num_heads, 416 | mlp_ratio=mlp_ratio, 417 | qk_scale=None, 418 | shared_cpe=None, 419 | shared_crpe=None, ) 420 | 421 | self.ir2_attn1 = MHCABlock(embed_dims[0], 422 | num_heads=num_heads, 423 | mlp_ratio=mlp_ratio, 424 | qk_scale=None, 425 | shared_cpe=None, 426 | shared_crpe=None, ) 427 | 428 | self.vi1_attn1 = MHCABlock(embed_dims[0], 429 | num_heads=num_heads, 430 | mlp_ratio=mlp_ratio, 431 | qk_scale=None, 432 | shared_cpe=None, 433 | shared_crpe=None, ) 434 | 435 | self.vi2_attn1 = MHCABlock(embed_dims[0], 436 | num_heads=num_heads, 437 | mlp_ratio=mlp_ratio, 438 | qk_scale=None, 439 | shared_cpe=None, 440 | shared_crpe=None, ) 441 | 442 | self.mhca_stage1_2 = MHCABlock( 443 | embed_dims[0], 444 | num_heads=num_heads, 445 | mlp_ratio=mlp_ratio, 446 | qk_scale=None, 447 | shared_cpe=None, 448 | shared_crpe=None, 449 | ) # B,C,H/2,W/2, 对所有的通道进行处理 450 | 451 | self.ir1_attn1_2 = MHCABlock(embed_dims[0], 452 | num_heads=num_heads, 453 | mlp_ratio=mlp_ratio, 454 | qk_scale=None, 455 | shared_cpe=None, 456 | shared_crpe=None, ) 457 | 458 | self.ir2_attn1_2 = MHCABlock(embed_dims[0], 459 | num_heads=num_heads, 460 | mlp_ratio=mlp_ratio, 461 | qk_scale=None, 462 | shared_cpe=None, 463 | shared_crpe=None, ) 464 | 465 | self.stem2 = nn.ModuleList([ 466 | Conv2d_BN( 467 | embed_dims[0], 468 | embed_dims[1], 469 | kernel_size=3, 470 | stride=2, 471 | pad=1, 472 | act_layer=nn.Hardswish, 473 | ) for _ in range(num_paths[0] + 1) 474 | ]) # B,C,H/4,W/4, 对所有的通道进行处理 475 | 476 | self.patch_embed_stages2 = Patch_Embed_stage( 477 | embed_dims[1], 478 | num_path=num_paths[0] + 1, 479 | isPool=False 480 | ) # B,C,H/4,W/4, 对所有的通道进行处理 481 | 482 | self.mhca_stage2 = MHCABlock( 483 | embed_dims[0], 484 | num_heads=num_heads, 485 | mlp_ratio=mlp_ratio, 486 | qk_scale=None, 487 | shared_cpe=None, 488 | shared_crpe=None, 489 | ) # B,C,H/4,W/4, 对所有的通道进行处理 490 | 491 | self.mhca_stage2_2 = MHCABlock( 492 | embed_dims[0], 493 | num_heads=num_heads, 494 | mlp_ratio=mlp_ratio, 495 | qk_scale=None, 496 | shared_cpe=None, 497 | shared_crpe=None, 498 | ) # B,C,H/4,W/4, 对所有的通道进行处理 499 | 500 | if self.use_aggregate: 501 | self.up_scale1_aggregate = UpScale(opt["networks"]["feature_sum"], embed_dim=embed_dims[1] * num_paths[1]) 502 | self.resBlock = ResBlock(in_channels=embed_dims[0] * num_paths[1], 503 | out_channels=embed_dims[0] * num_paths[1]) 504 | # self.resBlock2 = ResBlock(in_channels=embed_dims[0] * num_paths[1], 505 | # out_channels=embed_dims[0] * num_paths[1]) # add for ewc 506 | self.up_scale2_aggregate = UpScale(opt["networks"]["feature_sum"], embed_dim=embed_dims[0] * num_paths[1]) 507 | # self.head = ResBlock(embed_dims[0] * 2, embed_dims[0]) 508 | self.head = ResBlock(embed_dims[0] // 2, embed_dims[0] // 4) 509 | self.head_final = ResBlock(embed_dims[0] // 4, out_chans) 510 | else: 511 | # upscale 部分 512 | self.up_scale1 = nn.ModuleList([ 513 | UpScale(opt["networks"]["feature_sum"], embed_dims[1]) for _ in range(num_paths[1]) 514 | ]) 515 | 516 | self.up_scale2 = nn.ModuleList([ 517 | UpScale(opt["networks"]["feature_sum"], embed_dims[0]) for _ in range(num_paths[1]) 518 | ]) 519 | 520 | # head 部分 521 | self.head_aggregate = True 522 | if self.head_aggregate: 523 | self.head = ResBlock(embed_dims[0] * 2, embed_dims[0]) 524 | self.head_final = ResBlock(embed_dims[0], 4) 525 | else: 526 | self.head = nn.ModuleList([ 527 | ResBlock(embed_dims[0] // 2, out_chans) for _ in range(num_paths[1]) 528 | ]) 529 | 530 | def inject_fusion(self, ms, pan): 531 | pass 532 | 533 | def forward(self, vi, ir): 534 | horizontal = 0 535 | perpendicular = 0 536 | outPre = vi 537 | with torch.no_grad(): 538 | a = vi.shape[2] % 4 539 | b = vi.shape[3] % 4 540 | left = nn.ReplicationPad2d((1, 0, 0, 0)) 541 | upper = nn.ReplicationPad2d((0, 0, 1, 0)) 542 | 543 | while a % 4 != 0: 544 | vi = upper(vi) 545 | ir = upper(ir) 546 | horizontal += 1 547 | a += 1 548 | while b % 4 != 0: 549 | vi = left(vi) 550 | ir = left(ir) 551 | perpendicular += 1 552 | b += 1 553 | 554 | 555 | inputs = [ir,vi] # pan(V), pan_up(K), ms(Q) 556 | 557 | att_outputs = [] 558 | for x, model in zip(inputs, self.stem1): 559 | att_outputs.append(model(x)) 560 | 561 | att_outputs = self.patch_embed_stages1(att_outputs) 562 | 563 | ir, vi = att_outputs[0], att_outputs[1] 564 | 565 | att_outputs2 = [] # 这里的结果后面需要使用 566 | 567 | att_outputs2.append(self.mhca_stage(att_outputs[1], ir, ir, minus=False)) 568 | 569 | att_outputs2_1 = [] # 这里的结果后面需要使用 570 | 571 | att_outputs2_1.append(self.mhca_stage1_2(att_outputs2[0], vi, vi, minus=True)) 572 | 573 | att_outputs3 = [] # 这里的结果后面需要使用 574 | 575 | att_outputs3.append(self.mhca_stage2(att_outputs2_1[0], ir, ir, minus=True)) 576 | 577 | outv1 = ir 578 | outk1 = vi 579 | 580 | if self.use_aggregate: 581 | x11 = att_outputs3[0] # (4,512,64,64) 582 | x11_skip = att_outputs2[0] # (4,512,64,64) (4,128,64,64) 583 | 584 | x22 = self.up_scale1_aggregate(x11, x11_skip) # (4,256,128,128) (4,64,128,128) 585 | 586 | x = self.head(x22) 587 | 588 | x = self.head_final(x) 589 | 590 | if horizontal != 0: 591 | x = x[:, :, horizontal:, :] 592 | if perpendicular != 0: 593 | x = x[:, :, :, perpendicular:] 594 | 595 | x = x.clamp(min=0, max=255) 596 | output = {"pred": x, 597 | "k": outk1, 598 | "v": outv1, 599 | "outPre": outPre 600 | } 601 | else: 602 | x = [] 603 | 604 | att_outputs3 = [] 605 | for i in range(self.num_paths[1]): 606 | att_outputs3.append(self.up_scale2[i](x[i], att_outputs2[i])) 607 | 608 | if self.head_aggregate: 609 | att_outputs3 = torch.cat(att_outputs3, dim=1) 610 | x = self.head(att_outputs3) 611 | x = self.head_final(x) 612 | output = {"pred": x, } 613 | else: 614 | x = [] 615 | for i in range(self.num_paths[1]): 616 | x.append(self.head[i](att_outputs3[i])) 617 | output = {"pred": torch.cat(x, dim=1)} 618 | 619 | return output 620 | -------------------------------------------------------------------------------- /networks/__init__.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import init 6 | 7 | 8 | #################### 9 | # initialize 10 | #################### 11 | 12 | def weights_init_normal(m, std=0.02): 13 | classname = m.__class__.__name__ 14 | if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): 15 | if classname != "MeanShift": 16 | print('initializing [%s] ...' % classname) 17 | init.normal_(m.weight.data, 0.0, std) 18 | if m.bias is not None: 19 | m.bias.data.zero_() 20 | elif isinstance(m, (nn.Linear)): 21 | init.normal_(m.weight.data, 0.0, std) 22 | if m.bias is not None: 23 | m.bias.data.zero_() 24 | elif isinstance(m, (nn.BatchNorm2d)): 25 | init.normal_(m.weight.data, 1.0, std) 26 | init.constant_(m.bias.data, 0.0) 27 | 28 | 29 | def weights_init_kaiming(m, scale=1): 30 | classname = m.__class__.__name__ 31 | if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): 32 | if classname != "MeanShift": 33 | print('initializing [%s] ...' % classname) 34 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 35 | m.weight.data *= scale 36 | if m.bias is not None: 37 | m.bias.data.zero_() 38 | elif isinstance(m, (nn.Linear)): 39 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 40 | m.weight.data *= scale 41 | if m.bias is not None: 42 | m.bias.data.zero_() 43 | elif isinstance(m, (nn.BatchNorm2d)): 44 | init.constant_(m.weight.data, 1.0) 45 | m.weight.data *= scale 46 | init.constant_(m.bias.data, 0.0) 47 | 48 | 49 | def weights_init_orthogonal(m): 50 | classname = m.__class__.__name__ 51 | if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): 52 | if classname != "MeanShift": 53 | print('initializing [%s] ...' % classname) 54 | init.orthogonal_(m.weight.data, gain=1) 55 | if m.bias is not None: 56 | m.bias.data.zero_() 57 | elif isinstance(m, (nn.Linear)): 58 | init.orthogonal_(m.weight.data, gain=1) 59 | if m.bias is not None: 60 | m.bias.data.zero_() 61 | elif isinstance(m, (nn.BatchNorm2d)): 62 | init.normal_(m.weight.data, 1.0, 0.02) 63 | init.constant_(m.bias.data, 0.0) 64 | 65 | 66 | def init_weights(net, init_type='kaiming', scale=1, std=0.02): 67 | # scale for 'kaiming', std for 'normal'. 68 | print('initialization method [%s]' % init_type) 69 | if init_type == 'normal': 70 | weights_init_normal_ = functools.partial(weights_init_normal, std=std) 71 | net.apply(weights_init_normal_) 72 | elif init_type == 'kaiming': 73 | weights_init_kaiming_ = functools.partial(weights_init_kaiming, scale=scale) 74 | net.apply(weights_init_kaiming_) 75 | elif init_type == 'orthogonal': 76 | net.apply(weights_init_orthogonal) 77 | else: 78 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 79 | 80 | 81 | #################### 82 | # define network 83 | #################### 84 | 85 | def create_model(opt): 86 | if opt['mode'] == 'sr': 87 | net = define_net(opt) 88 | return net 89 | elif opt['mode'] == 'fu': 90 | net = define_net(opt) 91 | return net 92 | else: 93 | raise NotImplementedError("The mode [%s] of networks is not recognized." % opt['mode']) 94 | 95 | # def create_model(opt): 96 | # if opt['mode'] == 'fu': 97 | # net = define_net(opt) 98 | # return net 99 | # else: 100 | # raise NotImplementedError("The mode [%s] of networks is not recognized." % opt['mode']) 101 | 102 | 103 | # choose one network 104 | def define_net(opt): 105 | which_model = opt["networks"]['which_model'].upper() 106 | print('===> Building network [%s]...' % which_model) 107 | 108 | if which_model.find("ATFUSE") >= 0: 109 | from .ATFuse import ATF 110 | net = ATF(opt) 111 | else: 112 | raise NotImplementedError("Network [%s] is not recognized." % which_model) 113 | 114 | if torch.cuda.is_available(): 115 | net = nn.DataParallel(net).cuda() 116 | 117 | return net 118 | -------------------------------------------------------------------------------- /networks/__pycache__/ATFuse.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/networks/__pycache__/ATFuse.cpython-39.pyc -------------------------------------------------------------------------------- /networks/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/networks/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /networks/__pycache__/loss.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/networks/__pycache__/loss.cpython-39.pyc -------------------------------------------------------------------------------- /networks/__pycache__/util.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/networks/__pycache__/util.cpython-39.pyc -------------------------------------------------------------------------------- /networks/loss.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import math 6 | from torchvision.models.vgg import vgg16 7 | import numpy as np 8 | import torchvision.transforms.functional as TF 9 | 10 | 11 | 12 | 13 | class L_color(nn.Module): 14 | 15 | def __init__(self): 16 | super(L_color, self).__init__() 17 | 18 | def forward(self, x ): 19 | 20 | b,c,h,w = x.shape 21 | 22 | mean_rgb = torch.mean(x,[2,3],keepdim=True) 23 | mr,mg, mb = torch.split(mean_rgb, 1, dim=1) 24 | Drg = torch.pow(mr-mg,2) 25 | Drb = torch.pow(mr-mb,2) 26 | Dgb = torch.pow(mb-mg,2) 27 | k = torch.pow(torch.pow(Drg,2) + torch.pow(Drb,2) + torch.pow(Dgb,2),0.5) 28 | return k 29 | 30 | 31 | class L_Grad(nn.Module): 32 | def __init__(self): 33 | super(L_Grad, self).__init__() 34 | self.sobelconv=Sobelxy() 35 | 36 | def forward(self, image_A, image_B, image_fused, thresholds): 37 | image_A_Y = image_A[:, :1, :, :] 38 | image_B_Y = image_B[:, :1, :, :] 39 | image_fused_Y = image_fused[:, :1, :, :] 40 | gradient_A = self.sobelconv(image_A_Y) 41 | # gradient_A = TF.gaussian_blur(gradient_A, 3, [1, 1]) 42 | gradient_B = self.sobelconv(image_B_Y) 43 | # gradient_B = TF.gaussian_blur(gradient_B, 3, [1, 1]) 44 | gradient_fused = self.sobelconv(image_fused_Y) 45 | # gradient_joint = torch.max(gradient_A, gradient_B) 46 | grant_joint = torch.concat([gradient_A, gradient_B], dim=1) 47 | grant_joint_max, index = grant_joint.max(dim=1) 48 | 49 | a, b, c, d = gradient_A.size(0), gradient_A.size(1), gradient_A.size(2), gradient_A.size(3) 50 | grant_joint_max = grant_joint_max.reshape(a, b, c, d) 51 | 52 | gradient_A_Mask = threshold_tensor(gradient_A, dim=2, k=thresholds) 53 | aaa = gradient_A_Mask.argmax(dim=1).shape 54 | gradient_B_Mask = threshold_tensor(gradient_B, dim=2, k=thresholds) 55 | bbb = gradient_B_Mask.argmax(dim=1).shape 56 | 57 | 58 | 59 | Loss_gradient = F.l1_loss(gradient_fused, grant_joint_max) 60 | return Loss_gradient, gradient_A_Mask, gradient_B_Mask 61 | 62 | 63 | 64 | def gradWeightBlockIntenLoss(image_A_Y, image_B_Y, image_fused_Y, gradient_A, gradient_B, L_Inten_loss, percent, mask_pre = None): 65 | """ 66 | percent:百分比,大于百分之多少的像素点 67 | L_Inten_loss:计算像素损失的函数 68 | gradient_A:A图像的梯度 69 | mask_pre:前一次的掩膜,第一次前百分之20,第二次取60,就是中间的四十 70 | """ 71 | thresholds = torch.round(torch.tensor(percent * image_A_Y.shape[2] * image_A_Y.shape[3])).int() 72 | clone_grand_A = gradient_A.clone().detach() 73 | gradient_A_Mask = threshold_tensor(clone_grand_A, dim=2, k=thresholds) 74 | 75 | 76 | clone_grand_B = gradient_B.clone().detach() 77 | gradient_B_Mask = threshold_tensor(clone_grand_B, dim=2, k=thresholds) 78 | 79 | if mask_pre == None: 80 | grand_Mask = gradient_A_Mask + gradient_B_Mask 81 | grand_Mask = grand_Mask.clamp(min=0, max=1) 82 | 83 | else: 84 | grand_Mask = gradient_A_Mask + gradient_B_Mask 85 | grand_Mask = grand_Mask.clamp(min=0, max=1) 86 | 87 | grand_Mask -= mask_pre 88 | grand_IntenLoss = L_Inten_loss(image_A_Y * grand_Mask, image_B_Y * grand_Mask, image_fused_Y * grand_Mask) 89 | return grand_IntenLoss, grand_Mask 90 | 91 | 92 | def testNum(grand_Mask): 93 | grand_Mask_1Wei = torch.flatten(grand_Mask) 94 | num = 0 95 | for i in range(grand_Mask_1Wei.shape[0]): 96 | if grand_Mask_1Wei[i] == 1: 97 | num += 1 98 | return num 99 | 100 | class L_Grad_Inte(nn.Module): 101 | """ 102 | 按梯度分块求像素损失并计算梯度损失 103 | """ 104 | def __init__(self): 105 | super(L_Grad_Inte, self).__init__() 106 | self.sobelconv=Sobelxy() 107 | self.L_Inten_aver = L_IntensityAver() 108 | self.L_Inten_Max = L_Intensity() 109 | self.L_Inten_Once = L_IntensityOnce() 110 | def forward(self, image_A, image_B, image_fused): 111 | image_A_Y = image_A[:, :1, :, :] 112 | image_B_Y = image_B[:, :1, :, :] 113 | image_fused_Y = image_fused[:, :1, :, :] 114 | gradient_A = self.sobelconv(image_A_Y) 115 | gradient_B = self.sobelconv(image_B_Y) 116 | gradient_fused = self.sobelconv(image_fused_Y) 117 | grant_joint = torch.concat([gradient_A, gradient_B], dim=1) 118 | grant_joint_max, index = grant_joint.max(dim=1) 119 | 120 | a, b, c, d = gradient_A.size(0), gradient_A.size(1), gradient_A.size(2), gradient_A.size(3) 121 | grant_joint_max = grant_joint_max.reshape(a, b, c, d) 122 | 123 | #梯度乘以像素强度来对图像进行分等级求强度loss 124 | gradient_A_Att = image_A_Y * gradient_A 125 | gradient_B_Att = image_B_Y * gradient_B 126 | 127 | 128 | #前百分之20的梯度的像素点用max像素损失 129 | grand_IntenLoss_one, grand_Mask_one = gradWeightBlockIntenLoss(image_A_Y, image_B_Y, image_fused_Y, gradient_A_Att, gradient_B_Att, self.L_Inten_Max, 0.8, mask_pre = None) 130 | # #百分之20-70的用平均 131 | # grand_IntenLoss_two, grand_Mask_two = gradWeightBlockIntenLoss(image_A_Y, image_B_Y, image_fused_Y, gradient_A_Att, gradient_B_Att, self.L_Inten_aver, 0.3, mask_pre = grand_Mask_one) 132 | # 最后30用vi的像素点 133 | grand_Mask_three = 1 - grand_Mask_one 134 | grand_IntenLoss_three = self.L_Inten_aver(image_A_Y * grand_Mask_three, image_B_Y * grand_Mask_three, image_fused_Y * grand_Mask_three) 135 | 136 | grand_IntenLoss = grand_IntenLoss_one + grand_IntenLoss_three 137 | 138 | Loss_gradient = F.l1_loss(gradient_fused, grant_joint_max) 139 | return Loss_gradient, grand_IntenLoss 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | def threshold_tensor(input_tensor, dim, k): 149 | """ 150 | 将输入的Tensor按维度dim取第k大的元素作为阈值,大于等于阈值的元素置为1,其余元素置为0。 151 | 152 | Args: 153 | - input_tensor: 输入的Tensor 154 | - dim: 取第k大元素的维度 155 | - k: 取第k大元素 156 | 157 | Returns: 158 | - 输出的Tensor,形状与输入的Tensor相同 159 | """ 160 | # kth_value, _ = torch.kthvalue(input_tensor, k, dim=dim, keepdim=True) # 按维度dim取第k大的元素 161 | B, N, C ,D = input_tensor.shape 162 | input_tensor = input_tensor.reshape(B,N,C*D) 163 | for i in range(B): 164 | kth_value, _ = torch.kthvalue(input_tensor[i:i+1, :, :], k, dim=dim, keepdim=True) 165 | kth_value = torch.flatten(kth_value) 166 | input_tensor[i:i+1,: , :] = torch.where(input_tensor[i:i+1, :, :] >= kth_value[0], torch.tensor(1.0).cuda(), torch.tensor(0.0).cuda()) 167 | input_tensor = input_tensor.reshape(B, N, C ,D) 168 | return input_tensor 169 | 170 | 171 | class Sobelxy(nn.Module): 172 | def __init__(self): 173 | super(Sobelxy, self).__init__() 174 | kernelx = [[-1, 0, 1], 175 | [-2,0 , 2], 176 | [-1, 0, 1]] 177 | kernely = [[1, 2, 1], 178 | [0,0 , 0], 179 | [-1, -2, -1]] 180 | kernelx = torch.FloatTensor(kernelx).unsqueeze(0).unsqueeze(0) 181 | kernely = torch.FloatTensor(kernely).unsqueeze(0).unsqueeze(0) 182 | self.weightx = nn.Parameter(data=kernelx, requires_grad=False).cuda() 183 | self.weighty = nn.Parameter(data=kernely, requires_grad=False).cuda() 184 | def forward(self,x): 185 | sobelx=F.conv2d(x, self.weightx, padding=1) 186 | sobely=F.conv2d(x, self.weighty, padding=1) 187 | return torch.abs(sobelx)+torch.abs(sobely) 188 | 189 | class L_Intensity(nn.Module): 190 | def __init__(self): 191 | super(L_Intensity, self).__init__() 192 | 193 | def forward(self, image_A, image_B, image_fused): 194 | intensity_joint = torch.max(image_A, image_B) 195 | Loss_intensity = F.l1_loss(intensity_joint, image_fused) 196 | return Loss_intensity 197 | 198 | 199 | class L_IntensityAver(nn.Module): 200 | def __init__(self): 201 | super(L_IntensityAver, self).__init__() 202 | 203 | def forward(self, image_A, image_B, image_fused): 204 | Loss_intensity_A = F.l1_loss(image_A, image_fused) 205 | Loss_intensity_B = F.l1_loss(image_B, image_fused) 206 | Loss_intensity = 0.5 * Loss_intensity_A + 0.5 * Loss_intensity_B 207 | return Loss_intensity 208 | 209 | 210 | class L_IntensityOnce(nn.Module): 211 | def __init__(self): 212 | super(L_IntensityOnce, self).__init__() 213 | 214 | def forward(self, image_A, image_fused): 215 | 216 | Loss_intensity = F.l1_loss(image_A, image_fused) 217 | return Loss_intensity 218 | 219 | 220 | class L_Intensity_GrandFu(nn.Module): 221 | def __init__(self): 222 | super(L_Intensity_GrandFu, self).__init__() 223 | 224 | def forward(self,image_A, image_B, image_fused, gradient_A_Mask, gradient_B_Mask): 225 | 226 | Fu_image_maskA_A = image_A * gradient_A_Mask 227 | Loss_intensity_maskA = F.l1_loss(image_fused * gradient_A_Mask, Fu_image_maskA_A) 228 | 229 | 230 | Fu_image_maskB_B = image_B * gradient_B_Mask 231 | Loss_intensity_maskB = F.l1_loss(image_fused * gradient_B_Mask, Fu_image_maskB_B) 232 | 233 | return Loss_intensity_maskA + Loss_intensity_maskB 234 | 235 | 236 | 237 | class fusion_loss_med(nn.Module): 238 | def __init__(self): 239 | super(fusion_loss_med, self).__init__() 240 | self.L_GradInte = L_Grad_Inte() 241 | 242 | # print(1) 243 | def forward(self, image_fused, image_A, image_B): 244 | 245 | image_fused = image_fused["pred"] 246 | 247 | loss_gradient, grand_IntenLoss = self.L_GradInte(image_A, image_B, image_fused) 248 | 249 | fusion_loss = loss_gradient * 20 + grand_IntenLoss * 20 250 | 251 | return fusion_loss 252 | -------------------------------------------------------------------------------- /networks/util.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | def resize(input, 8 | size=None, 9 | scale_factor=None, 10 | mode='nearest', 11 | align_corners=None, 12 | warning=True): 13 | if warning: 14 | if size is not None and align_corners: 15 | input_h, input_w = tuple(int(x) for x in input.shape[2:]) 16 | output_h, output_w = tuple(int(x) for x in size) 17 | if output_h > input_h or output_w > output_h: 18 | if ((output_h > 1 and output_w > 1 and input_h > 1 19 | and input_w > 1) and (output_h - 1) % (input_h - 1) 20 | and (output_w - 1) % (input_w - 1)): 21 | warnings.warn( 22 | f'When align_corners={align_corners}, ' 23 | 'the output would more aligned if ' 24 | f'input size {(input_h, input_w)} is `x+1` and ' 25 | f'out size {(output_h, output_w)} is `nx+1`') 26 | if isinstance(size, torch.Size): 27 | size = tuple(int(x) for x in size) 28 | return F.interpolate(input, size, scale_factor, mode, align_corners) 29 | 30 | 31 | if __name__ == '__main__': 32 | B, C, W, H = 2, 3, 1024, 1024 33 | x = torch.randn(B, C, H, W) 34 | 35 | kernel_size = 128 36 | stride = 64 37 | patches = x.unfold(3, kernel_size, stride).unfold(2, kernel_size, stride) 38 | print(patches.shape) # [B, C, nb_patches_h, nb_patches_w, kernel_size, kernel_size] 39 | 40 | # perform the operations on each patch 41 | # ... 42 | 43 | # reshape output to match F.fold input 44 | patches = patches.contiguous().view(B, C, -1, kernel_size * kernel_size) 45 | print(patches.shape) # [B, C, nb_patches_all, kernel_size*kernel_size] 46 | patches = patches.permute(0, 1, 3, 2) 47 | print(patches.shape) # [B, C, kernel_size*kernel_size, nb_patches_all] 48 | patches = patches.contiguous().view(B, C * kernel_size * kernel_size, -1) 49 | print(patches.shape) # [B, C*prod(kernel_size), L] as expected by Fold 50 | # https://pytorch.org/docs/stable/nn.html#torch.nn.Fold 51 | 52 | output = F.fold( 53 | patches, output_size=(H, W), kernel_size=kernel_size, stride=stride) 54 | print(output.shape) # [B, C, H, W] 55 | -------------------------------------------------------------------------------- /options/__pycache__/options.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/options/__pycache__/options.cpython-39.pyc -------------------------------------------------------------------------------- /options/options.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import OrderedDict 3 | from datetime import datetime 4 | import json 5 | 6 | import torch 7 | 8 | from utils import util 9 | 10 | 11 | def get_timestamp(): 12 | return datetime.now().strftime('%y%m%d-%H%M%S') 13 | 14 | 15 | def parse(opt_path): 16 | # remove comments starting with '//' 17 | json_str = '' 18 | with open(opt_path, 'r', encoding='utf-8') as f: 19 | for line in f: 20 | line = line.split('//')[0] + '\n' 21 | json_str += line 22 | opt = json.loads(json_str, object_pairs_hook=OrderedDict) 23 | 24 | opt['timestamp'] = get_timestamp() 25 | scale = opt['scale'] 26 | rgb_range = opt['rgb_range'] 27 | 28 | # export CUDA_VISIBLE_DEVICES 29 | if torch.cuda.is_available(): 30 | gpu_list = ','.join(str(x) for x in opt['gpu_ids']) 31 | os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list 32 | print('===> Export CUDA_VISIBLE_DEVICES = [' + gpu_list + ']') 33 | else: 34 | print('===> CPU mode is set (NOTE: GPU is recommended)') 35 | 36 | # datasets 37 | for phase, dataset in opt['datasets'].items(): 38 | phase = phase.split('_')[0] 39 | dataset['phase'] = phase 40 | dataset['scale'] = scale 41 | dataset['rgb_range'] = rgb_range 42 | 43 | # for network initialize 44 | opt['networks']['scale'] = opt['scale'] 45 | network_opt = opt['networks'] 46 | 47 | config_str = f'{network_opt["which_model"].upper()}_in{network_opt["in_channels"]}_x{opt["scale"]}' 48 | exp_path = os.path.join(os.getcwd(), 'experiments', config_str) 49 | 50 | if opt['is_train'] and opt['solver']['pretrain']: 51 | if 'pretrained_path' not in list(opt['solver'].keys()): raise ValueError( 52 | "[Error] The 'pretrained_path' does not declarate in *.json") 53 | exp_path = os.path.dirname(os.path.dirname(opt['solver']['pretrained_path'])) 54 | if opt['solver']['pretrain'] == 'finetune': exp_path += '_finetune' 55 | 56 | exp_path = os.path.relpath(exp_path) 57 | 58 | path_opt = OrderedDict() 59 | path_opt['exp_root'] = exp_path 60 | path_opt['epochs'] = os.path.join(exp_path, 'epochs') 61 | path_opt['visual'] = os.path.join(exp_path, 'visual') 62 | path_opt['records'] = os.path.join(exp_path, 'records') 63 | opt['path'] = path_opt 64 | 65 | if opt['is_train']: 66 | # create folders 67 | if opt['solver']['pretrain'] == 'resume': 68 | opt = dict_to_nonedict(opt) 69 | else: 70 | util.mkdir_and_rename(opt['path']['exp_root']) # rename old experiments if exists 71 | util.mkdirs((path for key, path in opt['path'].items() if not key == 'exp_root')) 72 | save(opt) 73 | opt = dict_to_nonedict(opt) 74 | 75 | print("===> Experimental DIR: [%s]" % exp_path) 76 | 77 | return opt 78 | 79 | 80 | def save(opt): 81 | dump_dir = opt['path']['exp_root'] 82 | dump_path = os.path.join(dump_dir, 'options.json') 83 | with open(dump_path, 'w') as dump_file: 84 | json.dump(opt, dump_file, indent=2) 85 | 86 | 87 | class NoneDict(dict): 88 | def __missing__(self, key): 89 | return None 90 | 91 | 92 | # convert to NoneDict, which return None for missing key. 93 | def dict_to_nonedict(opt): 94 | if isinstance(opt, dict): 95 | new_opt = dict() 96 | for key, sub_opt in opt.items(): 97 | new_opt[key] = dict_to_nonedict(sub_opt) 98 | return NoneDict(**new_opt) 99 | elif isinstance(opt, list): 100 | return [dict_to_nonedict(sub_opt) for sub_opt in opt] 101 | else: 102 | return opt 103 | -------------------------------------------------------------------------------- /options/test/test_ATFuse.json: -------------------------------------------------------------------------------- 1 | { 2 | "mode": "fu", 3 | "use_cl": false, 4 | "gpu_ids": [0], 5 | 6 | "scale": 4, 7 | "degradation": "BI", 8 | "is_train": false, 9 | "use_chop": false, 10 | "rgb_range": 2047, 11 | "self_ensemble": false, 12 | 13 | 14 | 15 | "datasets": { 16 | "test_set1": { 17 | "mode": "IrVi", 18 | "dataroot_Vi": "./dataset/train/vi", 19 | "dataroot_Ir": "./dataset/train/ir", 20 | "data_type": "png", 21 | "useContinueLearning": false, 22 | "shift_pace": 36 23 | 24 | } 25 | }, 26 | 27 | "networks": { 28 | "which_model": "ATFuse", 29 | "in_channels": 1, 30 | "out_channels": 1, 31 | "img_size": 4, 32 | "num_heads": 4, 33 | "n_feats": 256, 34 | "linear_dim": 256, 35 | 36 | "num_stage": 4, 37 | "embed_dims": [64, 64, 256, 512], 38 | "num_paths": [4, 1, 3, 3], 39 | "mlp_ratio": 3, 40 | "use_aggregate": true, 41 | "feature_sum": true 42 | }, 43 | 44 | "solver": { 45 | "q": "vi", 46 | "pretrained_path": "./model/best_ckp.pth" 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /options/train/train_ATFuse.json: -------------------------------------------------------------------------------- 1 | { 2 | "mode": "fu", 3 | "gpu_ids": [0], 4 | "scale": 4, 5 | "is_train": true, 6 | "rgb_range": 2047, 7 | "save_image": true, 8 | 9 | "datasets": { 10 | "train": { 11 | "mode": "IrVi", 12 | "dataroot_Vi": "./dataset/train/vi", 13 | "dataroot_Ir": "./dataset/train/ir", 14 | "data_type": "png", 15 | "n_workers": 4, 16 | "batch_size": 16, 17 | "LR_size": 64, 18 | "use_flip": true, 19 | "use_rot": true, 20 | "noise": ".", 21 | "useContinueLearning": false, 22 | "shift_pace": 16 23 | }, 24 | "val": { 25 | "mode": "IrVi", 26 | "dataroot_Vi": "./dataset/val/vi", 27 | "dataroot_Ir": "./dataset/val/ir", 28 | "data_type": "png", 29 | "useContinueLearning": false, 30 | "shift_pace": 16 31 | } 32 | }, 33 | 34 | "networks": { 35 | "which_model": "ATFuse", 36 | "in_channels": 1, 37 | "out_channels": 1, 38 | "img_size": 4, 39 | "num_heads": 4, 40 | "n_feats": 256, 41 | "linear_dim": 256, 42 | 43 | "num_stage": 4, 44 | "embed_dims": [64, 64, 256, 512], 45 | "num_paths": [4, 1, 3, 3], 46 | "mlp_ratio": 3, 47 | "use_aggregate": true, 48 | "feature_sum": true 49 | }, 50 | 51 | "solver": { 52 | "type": "ADAMW", 53 | "learning_rate": 0.0001, 54 | "weight_decay": 0, 55 | "lr_scheme": "MultiStepLR", 56 | "lr_steps": [50,100,200,400], 57 | "lr_gamma": 0.5, 58 | "loss_type": "loss", 59 | "q": "vi", 60 | "spatial_loss": true, 61 | "PerceptualLoss": true, 62 | "manual_seed": 0, 63 | "num_epochs": 2000, 64 | "skip_threshold": 4, 65 | "split_batch": 1, 66 | "save_ckp_step": 50, 67 | "save_vis_step": 50, 68 | "pretrain": null, 69 | "pretrained_path": "./experiments/HYP_in4_x4/epochs/last_ckp.pth", 70 | "cl_weights": [0.1, 0.1, 0.4, 0.4] 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /solvers/FuSolver.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import OrderedDict 3 | 4 | import cv2 5 | import numpy as np 6 | import pandas as pd 7 | import scipy 8 | import spectral as spy 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.optim as optim 13 | import torchvision.utils as thutil 14 | 15 | 16 | # from models.vgg import vgg16 17 | from networks import create_model 18 | 19 | from .base_solver import BaseSolver 20 | from networks import init_weights 21 | import torch.nn.functional as F 22 | from utils import util 23 | 24 | 25 | 26 | class FuSolver(BaseSolver): 27 | def __init__(self, opt): 28 | super(FuSolver, self).__init__(opt) 29 | self.train_opt = opt['solver'] 30 | self.Vi = self.Tensor() 31 | 32 | # self.HR = self.Tensor() 33 | self.Ir = self.Tensor() 34 | 35 | self.fused_img_cr = self.Tensor() 36 | self.fused_img_cb = self.Tensor() 37 | # self.PAN_unalign = self.Tensor() 38 | self.lossWeight = None 39 | self.Fu = None 40 | 41 | self.records = {'train_loss': [], 42 | 'val_loss': [], 43 | 'psnr': [], 44 | 'ssim': [], 45 | 'lr': []} 46 | 47 | self.model = create_model(opt) 48 | self.print_network() 49 | 50 | if self.is_train: 51 | self.model.train() 52 | # set loss 53 | self._set_loss() 54 | # set optimizer 55 | self._set_optimizer() 56 | # set lr_scheduler 57 | self._set_scheduler() 58 | 59 | self._load() 60 | 61 | print(f'===> Solver Initialized : [{self.__class__.__name__}] || Use GPU : [{self.use_gpu}]') 62 | 63 | def feed_data(self, batch): 64 | input = batch['Vi'] 65 | input_ = batch['Ir'] 66 | # a = input.shape[2] % 4 67 | # b = input.shape[3] % 4 68 | # left = nn.ReplicationPad2d((1, 0, 0, 0)) 69 | # upper = nn.ReplicationPad2d((0, 0, 1, 0)) 70 | # double = nn.ReplicationPad2d((0, 1, 1, 0)) 71 | # while a % 4 != 0 and b % 4 != 0: 72 | # input_ = double(input_) 73 | # input = double(input) 74 | # a += 1 75 | # b += 1 76 | # while a % 4 != 0: 77 | # input_ = upper(input_) 78 | # input = upper(input) 79 | # a += 1 80 | # while b % 4 != 0: 81 | # input_ = left(input_) 82 | # input = left(input) 83 | # b += 1 84 | # input_pan_unalign = batch["pan_unalign"] 85 | self.Vi.resize_(input.size()).copy_(input) # 将数据复制到对应的设备上 86 | self.Ir.resize_(input_.size()).copy_(input_) 87 | # self.fused_img_cr = batch['fused_img_cr'] 88 | # self.fused_img_cb = batch['fused_img_cb'] 89 | # self.PAN_unalign.resize_(input_pan_unalign.size()).copy_(input_pan_unalign) 90 | 91 | def train_step(self): 92 | # self.weight_list = self.generatLossWeight() 93 | 94 | self.model.train() 95 | self.optimizer.zero_grad() 96 | 97 | loss_batch = 0.0 98 | sub_batch_size = int(self.Vi.size(0) / self.split_batch) 99 | for i in range(self.split_batch): 100 | loss_sbatch = 0.0 101 | split_Vi = self.Vi.narrow(0, i * sub_batch_size, sub_batch_size) #切片。和没切没变化 102 | # split_HR = self.HR.narrow(0, i * sub_batch_size, sub_batch_size) 103 | split_Ir = self.Ir.narrow(0, i * sub_batch_size, sub_batch_size) 104 | # split_PAN_unalign = self.PAN_unalign.narrow(0, i * sub_batch_size, sub_batch_size) 105 | 106 | q = self.train_opt['q'] 107 | if q == 'vi': 108 | output = self.model(split_Vi, split_Ir) 109 | elif q == 'ir': 110 | output = self.model(split_Ir,split_Vi) 111 | else: 112 | output = self.model(split_Ir, split_Vi) 113 | 114 | loss_sbatch = self.criterion_pix(output, split_Vi, split_Ir) #output, split_HR, split_LR, split_PAN 115 | 116 | loss_sbatch /= self.split_batch 117 | loss_sbatch.backward() 118 | 119 | loss_batch += (loss_sbatch.item()) 120 | 121 | # for stable training 122 | if loss_batch < self.skip_threshold * self.last_epoch_loss: 123 | self.optimizer.step() 124 | self.last_epoch_loss = loss_batch 125 | else: 126 | print(f'[Warning] Skip this batch! (Loss: {loss_batch})') 127 | 128 | self.model.eval() 129 | return loss_batch 130 | 131 | # def generatLossWeight(self): 132 | # self.feature_model = vgg16().cuda() 133 | # self.feature_model.load_state_dict(torch.load('./models/vgg16.pth')) 134 | # c = 3500 135 | # vi_a = (self.Vi + 1) / 2 136 | # vi_a = vi_a.cuda() 137 | # Ir_a = (self.Ir + 1) / 2 138 | # Ir_a = Ir_a.cuda() 139 | # with torch.no_grad(): 140 | # feat_1 = torch.cat((vi_a, vi_a, vi_a), dim=1) 141 | # feat_1 = self.feature_model(feat_1) 142 | # feat_2 = torch.cat((Ir_a, Ir_a, Ir_a), dim=1) 143 | # feat_2 = self.feature_model(feat_2) 144 | # 145 | # for i in range(len(feat_1)): 146 | # m1 = torch.mean(features_grad(feat_1[i]).pow(2), dim=[1, 2, 3]) 147 | # m2 = torch.mean(features_grad(feat_2[i]).pow(2), dim=[1, 2, 3]) 148 | # if i == 0: 149 | # w1 = torch.unsqueeze(m1, dim=-1) 150 | # w2 = torch.unsqueeze(m2, dim=-1) 151 | # else: 152 | # w1 = torch.cat((w1, torch.unsqueeze(m1, dim=-1)), dim=-1) 153 | # w2 = torch.cat((w2, torch.unsqueeze(m2, dim=-1)), dim=-1) 154 | # weight_1 = torch.mean(w1, dim=-1) / c 155 | # weight_2 = torch.mean(w2, dim=-1) / c 156 | # weight_list = torch.cat((weight_1.unsqueeze(-1), weight_2.unsqueeze(-1)), -1) 157 | # weight_list = F.softmax(weight_list, dim=-1) 158 | # return weight_list 159 | def get_LossWeight(self,output): 160 | min_v = np.min(output["weight"]) 161 | max_v = np.max(output["weight"]) 162 | img_v = (output["weight"] - min_v) / (max_v - min_v) 163 | return img_v 164 | 165 | def test(self): 166 | self.model.eval() 167 | with torch.no_grad(): 168 | q = self.train_opt['q'] 169 | if q == 'vi': 170 | self.Fu = self.model(self.Vi, self.Ir) 171 | elif q == 'ir': 172 | self.Fu = self.model(self.Ir, self.Vi) 173 | else: 174 | self.Fu = self.model(self.Ir, self.Vi) 175 | 176 | self.model.train() 177 | if self.is_train: 178 | loss_pix = self.criterion_pix(self.Fu, self.Vi, self.Ir) 179 | return loss_pix.item() 180 | 181 | def visualization(self): 182 | return 183 | 184 | def save_checkpoint(self, epoch, is_best): 185 | """ 186 | save checkpoint to experimental dir 187 | """ 188 | filename = os.path.join(self.checkpoint_dir, 'last_ckp.pth') 189 | print('===> Saving last checkpoint to [%s] ...]' % filename) 190 | ckp = { 191 | 'epoch': epoch, 192 | 'state_dict': self.model.state_dict(), 193 | 'optimizer': self.optimizer.state_dict(), 194 | 'best_pred': self.best_pred, 195 | 'best_epoch': self.best_epoch, 196 | 'records': self.records 197 | } 198 | torch.save(ckp, filename) 199 | if is_best: 200 | print('===> Saving best checkpoint to [%s] ...]' % filename.replace('last_ckp', 'best_ckp')) 201 | torch.save(ckp, filename.replace('last_ckp', 'best_ckp')) 202 | 203 | if epoch % self.train_opt['save_ckp_step'] == 0: 204 | print('===> Saving checkpoint [%d] to [%s] ...]' % (epoch, 205 | filename.replace('last_ckp', 206 | 'epoch_%d_ckp.pth' % epoch))) 207 | 208 | torch.save(ckp, filename.replace('last_ckp', 'epoch_%d_ckp.pth' % epoch)) 209 | 210 | def _load(self): 211 | """ 212 | load or initialize network 213 | """ 214 | 215 | if self.is_train and not self.opt['solver']['pretrain']: # 不是续训练 216 | self._net_init() 217 | elif self.is_train and self.opt['solver']['pretrain'] == 'resume': # 是续训练 218 | model_path = self.opt['solver']['pretrained_path'] 219 | if not os.path.exists(model_path): 220 | raise ValueError("[Error] The 'pretrained_path' does not declarate in *.json") 221 | checkpoint = torch.load(model_path) 222 | self.model.load_state_dict(checkpoint['state_dict']) 223 | self.cur_epoch = checkpoint['epoch'] + 1 224 | self.optimizer.load_state_dict(checkpoint['optimizer']) 225 | self.best_pred = checkpoint['best_pred'] 226 | self.best_epoch = checkpoint['best_epoch'] 227 | self.records = checkpoint['records'] 228 | else: 229 | model_path = self.opt['solver']['pretrained_path'] 230 | if not os.path.exists(model_path): raise ValueError( 231 | "[Error] The 'pretrained_path' does not declarate in *.json") 232 | checkpoint = torch.load(model_path) 233 | if 'state_dict' in checkpoint.keys(): checkpoint = checkpoint['state_dict'] 234 | load_func = self.model.load_state_dict if isinstance(self.model, nn.DataParallel) \ 235 | else self.model.module.load_state_dict 236 | load_func(checkpoint) 237 | 238 | def get_K_V(self): 239 | kv_dict = OrderedDict() 240 | kv_dict['k'] = self.Fu["k"].data[0].float().cpu() 241 | kv_dict['v'] = self.Fu["v"].data[0].float().cpu() 242 | return kv_dict 243 | 244 | def get_pre(self): 245 | kv_dict = OrderedDict() 246 | kv_dict['outPre'] = self.Fu["outPre"].data[0].float().cpu() 247 | return kv_dict 248 | 249 | def save_K_V(self,visuals,save_kv_path,index): 250 | # visuals = self.get_K_V() 251 | length = len(visuals['k'].numpy()) 252 | vis_k = visuals['k'].numpy().transpose(1, 2, 0) 253 | vis_v = visuals['v'].numpy().transpose(1, 2, 0) 254 | if not os.path.exists(save_kv_path + f"/K/{index}"): os.makedirs(save_kv_path + f"/K/{index}") 255 | if not os.path.exists(save_kv_path + f"/V/{index}"): os.makedirs(save_kv_path + f"/V/{index}") 256 | for i in range(length): 257 | k = vis_k[:,:,i] 258 | v = vis_v[:,:,i] 259 | k_path = save_kv_path + f"/K/{index}/{i}k.png" 260 | v_path = save_kv_path + f"/V/{index}/{i}v.png" 261 | cv2.imwrite(k_path, k) 262 | cv2.imwrite(v_path, v) 263 | 264 | def get_current_visual(self, need_np=False, need_HR=True): 265 | """ 266 | return LR SR (HR) images 267 | """ 268 | out_dict = OrderedDict() 269 | out_dict['Vi'] = self.Vi.data[0].float().cpu() 270 | out_dict['Ir'] = self.Ir.data[0].float().cpu() 271 | out_dict["Fu"] = self.Fu["pred"].data[0].float().cpu() 272 | # out_dict['fused_img_cr'] = self.fused_img_cr.data[0].float().cpu() 273 | # out_dict['fused_img_cb'] = self.fused_img_cb.data[0].float().cpu() 274 | 275 | # out_dict["weight"] = self.Fu["weight"].data[0].float().cpu() 276 | # out_dict['SR'] = self.SR.data[0].float().cpu() 277 | # out_dict['HR'] = self.HR.data[0].float().cpu() 278 | 279 | return out_dict 280 | 281 | def save_current_visual(self, epoch, img_num): 282 | """ 283 | save visual results for comparison 284 | """ 285 | # if epoch % self.save_vis_step == 0: 286 | visuals = self.get_current_visual(need_np=False) 287 | Vi = visuals['Vi'].numpy().transpose(1, 2, 0) 288 | Fu = visuals['Fu'].numpy().transpose(1, 2, 0) 289 | Ir = visuals['Ir'].numpy().transpose(1, 2, 0) 290 | # HR = visuals['HR'].numpy().transpose(1, 2, 0) 291 | 292 | 293 | # Vi = (Vi + 1) * 127.5 294 | # Ir = (Ir + 1) * 127.5 295 | # Fu = (Fu + 1) * 127.5 296 | 297 | if not os.path.exists(self.visual_dir + "/"+str(epoch)+"/Vi/"): os.makedirs(self.visual_dir + "/"+str(epoch)+"/Vi/") 298 | if not os.path.exists(self.visual_dir + "/"+str(epoch)+"/Fu/"): os.makedirs(self.visual_dir + "/"+str(epoch)+"/Fu/") 299 | if not os.path.exists(self.visual_dir + "/"+str(epoch)+"/Ir/"): os.makedirs(self.visual_dir + "/"+str(epoch)+"/Ir/") 300 | 301 | Vi_path = self.visual_dir + "/"+str(epoch)+"/Vi/"+str(img_num)+".png" 302 | # spy.save_rgb(Vi_path, Vi, bands=[2, 1, 0]) 303 | # spy.save_rgb(Vi_path, Vi, bands=[0]) 304 | cv2.imwrite(Vi_path, Vi) 305 | 306 | Fu_path = self.visual_dir + "/"+str(epoch)+"/Fu/"+str(img_num)+".png" 307 | # spy.save_rgb(Fu_path, Fu, bands=[0]) 308 | cv2.imwrite(Fu_path, Fu) 309 | 310 | Ir_path = self.visual_dir + "/"+str(epoch)+"/Ir/"+str(img_num)+".png" 311 | # spy.save_rgb(Ir_path, Ir, bands=[0]) 312 | cv2.imwrite(Ir_path, Ir) 313 | # HR_path = self.visual_dir + f"/{epoch}/{img_num}/HR.bmp" 314 | # spy.save_rgb(HR_path, HR, bands=[2, 1, 0]) 315 | 316 | def get_current_learning_rate(self): 317 | return self.optimizer.param_groups[0]['lr'] 318 | 319 | def update_learning_rate(self, epoch): 320 | self.scheduler.step(epoch) 321 | 322 | def get_current_log(self): 323 | log = OrderedDict() 324 | log['epoch'] = self.cur_epoch 325 | log['best_pred'] = self.best_pred 326 | log['best_epoch'] = self.best_epoch 327 | log['records'] = self.records 328 | return log 329 | 330 | def set_current_log(self, log): 331 | self.cur_epoch = log['epoch'] 332 | self.best_pred = log['best_pred'] 333 | self.best_epoch = log['best_epoch'] 334 | self.records = log['records'] 335 | 336 | def save_current_log(self): 337 | data_frame = pd.DataFrame( 338 | data={'train_loss': self.records['train_loss'][-1] 339 | , 'val_loss': self.records['val_loss'][-1] 340 | , 'psnr': self.records['psnr'][-1].item() 341 | , 'ssim': self.records['ssim'][-1].item() 342 | , 'lr': self.records['lr'][-1] 343 | }, 344 | index=range(self.cur_epoch, self.cur_epoch + 1) 345 | ) 346 | is_need_header = True if self.cur_epoch == 1 else False 347 | data_frame.to_csv(os.path.join(self.records_dir, 'train_records.csv'), mode="a", 348 | index_label='epoch', header=is_need_header) 349 | 350 | def print_network(self): 351 | """ 352 | print network summary including module and number of parameters 353 | """ 354 | s, n = self.get_network_description(self.model) 355 | if isinstance(self.model, nn.DataParallel): 356 | net_struc_str = '{} - {}'.format(self.model.__class__.__name__, 357 | self.model.module.__class__.__name__) 358 | else: 359 | net_struc_str = '{}'.format(self.model.__class__.__name__) 360 | 361 | print("==================================================") 362 | print("===> Network Summary\n") 363 | net_lines = [] 364 | line = s + '\n' 365 | print(line) 366 | net_lines.append(line) 367 | line = 'Network structure: [{}], with parameters: [{:,d}]'.format(net_struc_str, n) 368 | print(line) 369 | net_lines.append(line) 370 | 371 | if self.is_train: 372 | with open(os.path.join(self.exp_root, 'network_summary.txt'), 'w') as f: 373 | f.writelines(net_lines) 374 | 375 | print("==================================================") 376 | 377 | def _set_loss(self): 378 | loss_type = self.train_opt['loss_type'] 379 | if loss_type == 'l1': 380 | self.criterion_pix = nn.L1Loss() 381 | elif loss_type == "loss": 382 | from networks.loss import fusion_loss_med 383 | self.criterion_pix = fusion_loss_med() 384 | 385 | else: 386 | raise NotImplementedError('Loss type [%s] is not implemented!' % loss_type) 387 | 388 | if self.use_gpu: 389 | self.criterion_pix = self.criterion_pix.cuda() 390 | 391 | def _set_optimizer(self): 392 | weight_decay = self.train_opt['weight_decay'] if self.train_opt['weight_decay'] else 0 393 | optim_type = self.train_opt['type'].upper() 394 | if optim_type == "ADAM": 395 | self.optimizer = optim.Adam(self.model.parameters(), 396 | lr=self.train_opt['learning_rate'], weight_decay=weight_decay) 397 | elif optim_type == "ADAMW": 398 | self.optimizer = optim.AdamW(self.model.parameters(), lr=self.train_opt['learning_rate'], 399 | weight_decay=weight_decay) 400 | else: 401 | raise NotImplementedError('Optimizer type [%s] is not implemented!' % optim_type) 402 | 403 | def _set_scheduler(self): 404 | if self.train_opt['lr_scheme'].lower() == 'multisteplr': 405 | self.scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer, 406 | self.train_opt['lr_steps'], 407 | self.train_opt['lr_gamma']) 408 | 409 | else: 410 | raise NotImplementedError('Only MultiStepLR scheme is supported!') 411 | print("optimizer: ", self.optimizer) 412 | print(f"lr_scheduler milestones: {self.scheduler.milestones} gamma: {self.scheduler.gamma:.3f}") 413 | 414 | def _net_init(self, init_type='normal'): 415 | print('==> Initializing the network using [%s]' % init_type) 416 | init_weights(self.model, init_type) 417 | -------------------------------------------------------------------------------- /solvers/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .FuSolver import FuSolver 4 | 5 | 6 | 7 | def create_solver(opt): 8 | if opt['mode'] == 'fu': 9 | solver = FuSolver(opt) 10 | else: 11 | raise NotImplementedError 12 | 13 | return solver 14 | 15 | 16 | # def create_solver(opt): 17 | # if opt['mode'] == 'fu': 18 | # solver = FuSolver(opt) 19 | # else: 20 | # raise NotImplementedError 21 | # 22 | # return solver 23 | -------------------------------------------------------------------------------- /solvers/__pycache__/FuSolver.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/solvers/__pycache__/FuSolver.cpython-39.pyc -------------------------------------------------------------------------------- /solvers/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/solvers/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /solvers/__pycache__/base_solver.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/solvers/__pycache__/base_solver.cpython-39.pyc -------------------------------------------------------------------------------- /solvers/base_solver.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class BaseSolver(object): 6 | def __init__(self, opt): 7 | self.opt = opt 8 | self.scale = opt['scale'] 9 | self.is_train = opt['is_train'] 10 | 11 | # GPU verify 12 | self.use_gpu = torch.cuda.is_available() 13 | # self.use_gpu = False 14 | self.Tensor = torch.cuda.FloatTensor if self.use_gpu else torch.FloatTensor 15 | 16 | # for better training (stablization and less GPU memory usage) 17 | self.last_epoch_loss = 1e8 18 | self.skip_threshold = opt['solver']['skip_threshold'] 19 | # save GPU memory during training 20 | self.split_batch = opt['solver']['split_batch'] 21 | 22 | # experimental dirs 23 | self.exp_root = opt['path']['exp_root'] 24 | self.checkpoint_dir = opt['path']['epochs'] 25 | self.records_dir = opt['path']['records'] 26 | self.visual_dir = opt['path']['visual'] 27 | 28 | # log and vis scheme 29 | self.save_ckp_step = opt['solver']['save_ckp_step'] 30 | self.save_vis_step = opt['solver']['save_vis_step'] 31 | 32 | self.best_epoch = 0 33 | self.cur_epoch = 1 34 | self.best_pred = 0.0 35 | 36 | def feed_data(self, batch): 37 | pass 38 | 39 | def train_step(self): 40 | pass 41 | 42 | def test(self): 43 | pass 44 | 45 | def _forward_x8(self, x, forward_function): 46 | pass 47 | 48 | def _overlap_crop_forward(self, upscale): 49 | pass 50 | 51 | def get_current_log(self): 52 | pass 53 | 54 | def get_current_visual(self): 55 | pass 56 | 57 | def get_current_learning_rate(self): 58 | pass 59 | 60 | def set_current_log(self, log): 61 | pass 62 | 63 | def update_learning_rate(self, epoch): 64 | pass 65 | 66 | def save_checkpoint(self, epoch, is_best): 67 | pass 68 | 69 | def load(self): 70 | pass 71 | 72 | def save_current_visual(self, epoch, iter): 73 | pass 74 | 75 | def save_current_log(self): 76 | pass 77 | 78 | def print_network(self): 79 | pass 80 | 81 | def get_network_description(self, network): 82 | '''Get the string and total parameters of the network''' 83 | if isinstance(network, nn.DataParallel): 84 | network = network.module 85 | s = str(network) 86 | n = sum(map(lambda x: x.numel(), network.parameters())) 87 | 88 | return s, n 89 | -------------------------------------------------------------------------------- /testViIr.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" 5 | 6 | 7 | import argparse, time 8 | 9 | import options.options as option 10 | from utils import util 11 | from solvers import create_solver 12 | from data import create_dataloader 13 | from data import create_dataset 14 | from PIL import Image 15 | 16 | 17 | def main(): 18 | 19 | parser = argparse.ArgumentParser(description='Test Super Resolution Models') 20 | parser.add_argument('-opt', type=str, required=True, help='Path to options JSON file.') 21 | # opt = option.parse(parser.parse_args().opt) 22 | opt = option.parse('options/test/test_ATFuse.json') 23 | opt = option.dict_to_nonedict(opt) 24 | 25 | # initial configure 26 | scale = opt['scale'] 27 | degrad = opt['degradation'] 28 | network_opt = opt['networks'] 29 | model_name = network_opt['which_model'].upper() 30 | if opt['self_ensemble']: model_name += 'plus' 31 | 32 | # create test dataloader 33 | bm_names = [] 34 | test_loaders = [] 35 | for _, dataset_opt in sorted(opt['datasets'].items()): 36 | test_set = create_dataset(dataset_opt) 37 | test_loader = create_dataloader(test_set, dataset_opt) 38 | test_loaders.append(test_loader) 39 | print(f'===> Test Dataset: [{test_set.name()}] Number of images: [{len(test_set)}]') 40 | bm_names.append(test_set.name()) 41 | 42 | # create solver (and load model) 43 | solver = create_solver(opt) ### load train and test model 44 | # Test phase 45 | print('===> Start Test') 46 | print("==================================================") 47 | print(f"Method: {model_name} || Scale: {scale} || Degradation: {degrad}") 48 | 49 | for bm, test_loader in zip(bm_names, test_loaders): 50 | print(f"Test set : [{bm}]") 51 | 52 | Fu_list = [] 53 | # KV_list = [] 54 | path_list = [] 55 | 56 | total_psnr = [] 57 | total_ssim = [] 58 | total_time = [] 59 | Img_cr=[] 60 | Img_cb=[] 61 | need_HR = False 62 | 63 | for iter, batch in enumerate(test_loader): 64 | solver.feed_data(batch) 65 | 66 | # calculate forward time 67 | t0 = time.time() 68 | solver.test() 69 | t1 = time.time() 70 | total_time.append((t1 - t0)) 71 | 72 | visuals = solver.get_current_visual(need_HR=need_HR) 73 | Fu_list.append(visuals['Fu']) 74 | # Img_cr.append(visuals['fused_img_cr']) 75 | # Img_cb.append(visuals['fused_img_cb']) 76 | # KV = solver.get_K_V() 77 | # KV_list.append(KV) 78 | # calculate PSNR/SSIM metrics on Python 79 | if need_HR: 80 | cc, ssim = util.calc_metrics(visuals['Fu'], visuals['HR'], crop_border=scale) 81 | total_psnr.append(cc) 82 | total_ssim.append(ssim) 83 | path_list.append(os.path.basename(batch['HR_path'][0]).replace('HR', model_name)) 84 | print(f"[{iter + 1}/{len(test_loader)}] {os.path.basename(batch['vi_path'][0])} " 85 | f"|| CC/SSIM: {cc:.4f}/{ssim:.4f} || Timer: {t1 - t0:.4f} sec .") 86 | else: 87 | path_list.append(os.path.basename(batch['vi_path'][0])) 88 | print(f"[{iter + 1}/{len(test_loader)}] {os.path.basename(batch['vi_path'][0])} " 89 | f"|| Timer: {t1 - t0:.4f} sec.") 90 | 91 | if need_HR: 92 | print(f"---- Average PSNR(dB) /SSIM /Speed(s) for [{bm}] ----") 93 | print(f"CC: {sum(total_psnr) / len(total_psnr):.4f} PSNR: {sum(total_ssim) / len(total_ssim):.4f}" 94 | f" Speed: {sum(total_time) / len(total_time):.4f}.") 95 | else: 96 | print(f"---- Average Speed(s) for [{bm}] is {sum(total_time) / len(total_time)} sec ----") 97 | 98 | # save SR results for further evaluation on MATLAB 99 | if need_HR: 100 | save_img_path = os.path.join('./results/Fu/' + degrad, model_name, bm, "x%d" % scale) 101 | else: 102 | save_img_path = os.path.join('./results/Fu/' + bm, model_name, "x%d/" % scale) 103 | # save_kv_path = os.path.join('./results/KV/' + bm, model_name, "x%d" % scale) 104 | print(f"===> Saving Fu images of [{bm}]... Save Path: [{save_img_path}]\n") 105 | 106 | # if not os.path.exists(save_kv_path): os.makedirs(save_kv_path) 107 | if not os.path.exists(save_img_path): os.makedirs(save_img_path) 108 | index = 1 109 | for img, name in zip(Fu_list, path_list): 110 | # img = img.numpy().transpose(1, 2, 0) 111 | # cv2.imwrite(save_img_path + name, img) 112 | for i in range(img.shape[0]): 113 | img = img[i, :, :].numpy() 114 | # min_v = np.min(img_v) 115 | # max_v = np.max(img_v) 116 | # img_v = (img_v - min_v) / (max_v - min_v) 117 | 118 | # img = (img + 1) * 127.5 119 | # img_v = img_v * 255 120 | 121 | im = Image.fromarray(img) 122 | im = im.convert("L") 123 | im.save(save_img_path + name) 124 | 125 | # im.save(save_img_path + str(index)+'.png') 126 | # index += 1 127 | 128 | # for img, name,kv in zip(Fu_list, path_list,KV_list): 129 | # img = img.numpy().transpose(1, 2, 0) 130 | # cv2.imwrite(save_img_path+name, img) 131 | # solver.save_K_V(kv,save_kv_path,index) 132 | # index += 1 133 | # 134 | 135 | print("==================================================") 136 | print("===> Finished !") 137 | 138 | 139 | if __name__ == '__main__': 140 | main() 141 | -------------------------------------------------------------------------------- /train_ViIr.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import argparse, random 3 | from tqdm import tqdm 4 | 5 | import torch 6 | 7 | from visualization import get_local 8 | 9 | get_local.activate() 10 | 11 | import options.options as option 12 | from utils import util 13 | from solvers import create_solver 14 | from data import create_dataloader 15 | from data import create_dataset 16 | import os 17 | 18 | os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" 19 | 20 | 21 | def parse_options(option_file_path): 22 | parser = argparse.ArgumentParser(description='Train Super Resolution Models') 23 | parser.add_argument('-opt', type=str, required=True, help='Path to options JSON file.') 24 | opt = option.parse(option_file_path) 25 | return opt 26 | 27 | 28 | def set_random_seed(opt): 29 | seed = opt['solver']['manual_seed'] 30 | if seed is None: seed = random.randint(1, 10000) # 这里尽量直接把seed写死 31 | print(f"===> Random Seed: [{seed}]") 32 | random.seed(seed) 33 | torch.manual_seed(seed) 34 | 35 | 36 | def create_Dataloader(opt): 37 | train_set, train_loader, val_set, val_loader = None, None, None, None 38 | for phase, dataset_opt in sorted(opt['datasets'].items()): 39 | if phase == 'train': 40 | train_set = create_dataset(dataset_opt) 41 | train_loader = create_dataloader(train_set, dataset_opt) 42 | print(f'===> Train Dataset: {train_set.name()} Number of images: [{len(train_set)}]') 43 | if train_loader is None: raise ValueError("[Error] The training data does not exist") 44 | 45 | elif phase == 'val': 46 | val_set = create_dataset(dataset_opt) 47 | val_loader = create_dataloader(val_set, dataset_opt) 48 | print(f'===> Val Dataset: {val_set.name()} Number of images: [{len(val_set)}]') 49 | 50 | else: 51 | raise NotImplementedError(f"[Error] Dataset phase [{phase}] in *.json is not recognized.") 52 | return train_set, train_loader, val_set, val_loader 53 | 54 | 55 | def train_step(): 56 | pass 57 | 58 | 59 | def main(): 60 | option_file_path = 'options/train/train_ATFuse.json' 61 | opt = parse_options(option_file_path) 62 | # random seed 63 | set_random_seed(opt) 64 | 65 | # create train and val dataloader 66 | train_set, train_loader, val_set, val_loader = create_Dataloader(opt) 67 | 68 | solver = create_solver(opt) 69 | 70 | scale = opt['scale'] 71 | model_name = opt['networks']['which_model'].upper() 72 | solver_log = solver.get_current_log() 73 | 74 | NUM_EPOCH = int(opt['solver']['num_epochs']) 75 | start_epoch = solver_log['epoch'] 76 | 77 | print('===> Start Train') 78 | print("==================================================") 79 | print(f"Method: {model_name} || Scale: {scale} || Epoch Range: ({start_epoch} ~ {NUM_EPOCH})") 80 | solver_log['best_pred'] = 10000000000 81 | for epoch in range(start_epoch, NUM_EPOCH + 1): 82 | print( 83 | f'\n===> Training Epoch: [{epoch}/{NUM_EPOCH}]... Learning Rate: {solver.get_current_learning_rate():.2e}') 84 | 85 | # Initialization 86 | solver_log['epoch'] = epoch 87 | 88 | # Train model 89 | get_local.clear() 90 | 91 | train_loss_list = [] 92 | with tqdm(total=len(train_loader), desc=f'Epoch: [{epoch}/{NUM_EPOCH}]', miniters=1) as t: 93 | for iter, batch in enumerate(train_loader): 94 | solver.feed_data(batch) 95 | iter_loss = solver.train_step() 96 | batch_size = batch['Vi'].size(0) 97 | train_loss_list.append(iter_loss * batch_size) 98 | t.set_postfix_str("Batch Loss: %.4f" % iter_loss) 99 | t.update() 100 | cache = get_local.cache 101 | # print(cache) 102 | 103 | solver_log['records']['train_loss'].append(sum(train_loss_list) / len(train_set)) 104 | solver_log['records']['lr'].append(solver.get_current_learning_rate()) 105 | 106 | print(f'\nEpoch: [{epoch}/{NUM_EPOCH}] Avg Train Loss: {sum(train_loss_list) / len(train_set):.6f}') 107 | 108 | print('===> Validating...', ) 109 | 110 | psnr_list = [] 111 | ssim_list = [] 112 | val_loss_list = [] 113 | 114 | for iter, batch in enumerate(val_loader): 115 | solver.feed_data(batch) 116 | iter_loss = solver.test() 117 | val_loss_list.append(iter_loss) 118 | 119 | # calculate evaluation metrics 120 | visuals = solver.get_current_visual() 121 | # 评估指标需要改,这是监督学习的 122 | psnr, ssim = util.calc_metrics(visuals['Fu'], visuals['Vi'], visuals['Ir'],crop_border=scale) 123 | # psnrIr, ssimIr = util.calc_metrics(visuals['Fu'], visuals['Ir'], crop_border=scale) 124 | # psnr_list.append((psnrVi+psnrIr)/2) 125 | psnr_list.append(psnr) 126 | ssim_list.append(ssim) 127 | 128 | if opt["save_image"] and epoch % 50 == 0: 129 | solver.save_current_visual(epoch, iter) 130 | 131 | solver_log['records']['val_loss'].append(sum(val_loss_list) / len(val_loss_list)) 132 | solver_log['records']['psnr'].append(sum(psnr_list) / len(psnr_list)) 133 | solver_log['records']['ssim'].append(sum(ssim_list) / len(ssim_list)) 134 | 135 | # record the best epoch 136 | # epoch_is_best = False 137 | # if solver_log['best_pred'] < (sum(psnr_list) / len(psnr_list)): 138 | # solver_log['best_pred'] = (sum(psnr_list) / len(psnr_list)) 139 | # epoch_is_best = True 140 | # solver_log['best_epoch'] = epoch 141 | #best,损失函数最小 142 | epoch_is_best = False 143 | if solver_log['best_pred'] > (sum(val_loss_list) / len(val_loss_list)): 144 | solver_log['best_pred'] = (sum(val_loss_list) / len(val_loss_list)) 145 | epoch_is_best = True 146 | solver_log['best_epoch'] = epoch 147 | 148 | # print( 149 | # f"[{val_set.name()}] CC: {sum(psnr_list) / len(psnr_list):.4f} RMSE: {sum(ssim_list) / len(ssim_list):.4f}" 150 | # f"Loss: {sum(val_loss_list) / len(val_loss_list):.6f} Best CC: {solver_log['best_pred']:.4f} in Epoch: " 151 | # f"[{solver_log['best_epoch']:d}]" 152 | # ) 153 | print( 154 | f"[{val_set.name()}] CC: {sum(psnr_list) / len(psnr_list):.4f} Qabf: {sum(ssim_list) / len(ssim_list):.4f}" 155 | f"Loss: {sum(val_loss_list) / len(val_loss_list):.6f} Best : {solver_log['best_pred']:.6f} in Epoch: " 156 | f"[{solver_log['best_epoch']:d}]" 157 | ) 158 | 159 | solver.set_current_log(solver_log) 160 | solver.save_checkpoint(epoch, epoch_is_best) 161 | solver.save_current_log() 162 | # update lr 163 | solver.update_learning_rate(epoch) 164 | 165 | print('===> Finished !') 166 | 167 | 168 | if __name__ == '__main__': 169 | torch.backends.cudnn.enabled = False 170 | main() 171 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/utils/__init__.py -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/utils/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/util.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/utils/__pycache__/util.cpython-39.pyc -------------------------------------------------------------------------------- /utils/hist_adjust.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from skimage.exposure import histogram 3 | 4 | 5 | def hist_line_stretch(img, nbins, bound=[0.01, 0.99]): 6 | def _line_strectch(img): 7 | # img = img.astype(np.uint16) 8 | ori = img 9 | img = img.reshape(-1) 10 | hist1, bins1 = histogram(img, nbins=nbins, normalize=True) 11 | cumhist = np.cumsum(hist1) 12 | lowThreshold = np.where(cumhist >= bound[0])[0][0] 13 | highThreshold = np.where(cumhist >= bound[1])[0][0] 14 | lowThreshold = bins1[lowThreshold] 15 | highThreshold = bins1[highThreshold] 16 | ori[np.where(ori < lowThreshold)] = lowThreshold 17 | ori[np.where(ori > highThreshold)] = highThreshold 18 | ori = (ori - lowThreshold) / (highThreshold - lowThreshold + np.finfo(np.float).eps) 19 | return ori, lowThreshold, highThreshold 20 | 21 | if img.ndim > 2: 22 | lowThreshold = np.zeros(img.shape[2]) 23 | highThreshold = np.zeros(img.shape[2]) 24 | for i in range(img.shape[2]): 25 | img[:, :, i], lowThreshold[i], highThreshold[i] = _line_strectch(img[:, :, i].squeeze()) 26 | else: 27 | img, lowThreshold, highThreshold = _line_strectch(img) 28 | return img, lowThreshold, highThreshold 29 | 30 | 31 | def hist_line_stretchv2(img, nbins, bound=[0.01, 0.99]): 32 | def _line_strectch(img): 33 | max_img = img.max() 34 | min_img = img.min() 35 | gap_img = max_img - min_img 36 | gap_img = max(np.finfo(np.float).eps, gap_img) 37 | return (img - min_img) / gap_img, min_img, max_img 38 | 39 | if img.ndim > 2: 40 | lowThreshold = np.zeros(img.shape[2]) 41 | highThreshold = np.zeros(img.shape[2]) 42 | for i in range(img.shape[2]): 43 | img[:, :, i], lowThreshold[i], highThreshold[i] = _line_strectch(img[:, :, i].squeeze()) 44 | else: 45 | img, lowThreshold, highThreshold = _line_strectch(img) 46 | return img, lowThreshold, highThreshold 47 | -------------------------------------------------------------------------------- /utils/hyper_plot.py: -------------------------------------------------------------------------------- 1 | from cProfile import label 2 | import matplotlib.pyplot as plt 3 | # impor 4 | 5 | named_data = { 6 | 'P': {'label': [1, 4, 8, 16, 18], 7 | 'metrics': 8 | {'ERGAS': [1.342, 1.233, 1.231, 1.205, 1.243], 9 | 'SAM': [2.148, 2.053, 2.095, 2.049, 2.11], 10 | 'Q2n': [0.963, 0.962, 0.962, 0.965, 0.962], 11 | '#Params': [3.252, 3.295, 3.917, 13.779, 20.1]} 12 | }, 13 | "H": {'label': [1, 2, 4, 6, 8, 120], 14 | "metrics": { 15 | "ERGAS": [1.264, 1.341, 1.207, 1.306, 1.233, 7.614], 16 | "SAM": [2.197, 2.273, 2.102, 2.092, 2.053, 4.022], 17 | "Q2n": [0.962, 0.958, 0.963, 0.962, 0.962, 0.757], 18 | '#Params': [3.257, 3.263, 3.274, 3.285, 3.295, 3.905]} 19 | }, 20 | "K": {'label': [3, 5, 7, 9, 11], 21 | "metrics": { 22 | "ERGAS": [1.253, 1.233, 1.352, 1.233, 1.254], 23 | "SAM": [2.069, 2.105, 2.187, 2.053, 2.142], 24 | "Q2n": [0.961, 0.961, 0.962, 0.962, 0.957], 25 | "#Params": [3.209, 3.228, 3.257, 3.295, 3.343], 26 | } 27 | }, 28 | 29 | "D": {'label': [30, 60, 90, 120, 180], 30 | 'metrics': { 31 | "ERGAS": [3.379, 1.566, 1.303, 1.233, 1.325], 32 | "SAM": [3.527, 2.385, 2.12, 2.053, 2.138], 33 | 'Q2n': [0.918, 0.957, 0.963, 0.962, 0.962], 34 | '#Params': [0.292, 0.917, 1.918, 3.295, 7.179] 35 | } 36 | } 37 | } 38 | 39 | for metric in ('ERGAS', 'SAM', 'Q2n'): 40 | for hyper_para, value in named_data.items(): 41 | plt.plot(value[metric], label=hyper_para) 42 | plt.text(range(1, len(value[metric]))) -------------------------------------------------------------------------------- /utils/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from timm.scheduler.cosine_lr import CosineLRScheduler 3 | from timm.scheduler.step_lr import StepLRScheduler 4 | from timm.scheduler.scheduler import Scheduler 5 | 6 | 7 | def build_scheduler(config, optimizer, n_iter_per_epoch): 8 | num_steps = int(config.TRAIN.EPOCHS * n_iter_per_epoch) 9 | warmup_steps = int(config.TRAIN.WARMUP_EPOCHS * n_iter_per_epoch) 10 | decay_steps = int(config.TRAIN.LR_SCHEDULER.DECAY_EPOCHS * n_iter_per_epoch) 11 | 12 | lr_scheduler = None 13 | if config.TRAIN.LR_SCHEDULER.NAME == 'cosine': 14 | lr_scheduler = CosineLRScheduler( 15 | optimizer, 16 | t_initial=num_steps, 17 | t_mul=1., 18 | lr_min=config.TRAIN.MIN_LR, 19 | warmup_lr_init=config.TRAIN.WARMUP_LR, 20 | warmup_t=warmup_steps, 21 | cycle_limit=1, 22 | t_in_epochs=False, 23 | ) 24 | elif config.TRAIN.LR_SCHEDULER.NAME == 'linear': 25 | lr_scheduler = LinearLRScheduler( 26 | optimizer, 27 | t_initial=num_steps, 28 | lr_min_rate=0.01, 29 | warmup_lr_init=config.TRAIN.WARMUP_LR, 30 | warmup_t=warmup_steps, 31 | t_in_epochs=False, 32 | ) 33 | elif config.TRAIN.LR_SCHEDULER.NAME == 'step': 34 | lr_scheduler = StepLRScheduler( 35 | optimizer, 36 | decay_t=decay_steps, 37 | decay_rate=config.TRAIN.LR_SCHEDULER.DECAY_RATE, 38 | warmup_lr_init=config.TRAIN.WARMUP_LR, 39 | warmup_t=warmup_steps, 40 | t_in_epochs=False, 41 | ) 42 | 43 | return lr_scheduler 44 | 45 | 46 | class LinearLRScheduler(Scheduler): 47 | def __init__(self, 48 | optimizer: torch.optim.Optimizer, 49 | t_initial: int, 50 | lr_min_rate: float, 51 | warmup_t=0, 52 | warmup_lr_init=0., 53 | t_in_epochs=True, 54 | noise_range_t=None, 55 | noise_pct=0.67, 56 | noise_std=1.0, 57 | noise_seed=42, 58 | initialize=True, 59 | ) -> None: 60 | super().__init__( 61 | optimizer, param_group_field="lr", 62 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, 63 | initialize=initialize) 64 | 65 | self.t_initial = t_initial 66 | self.lr_min_rate = lr_min_rate 67 | self.warmup_t = warmup_t 68 | self.warmup_lr_init = warmup_lr_init 69 | self.t_in_epochs = t_in_epochs 70 | if self.warmup_t: 71 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 72 | super().update_groups(self.warmup_lr_init) 73 | else: 74 | self.warmup_steps = [1 for _ in self.base_values] 75 | 76 | def _get_lr(self, t): 77 | if t < self.warmup_t: 78 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 79 | else: 80 | t = t - self.warmup_t 81 | total_t = self.t_initial - self.warmup_t 82 | lrs = [v - ((v - v * self.lr_min_rate) * (t / total_t)) for v in self.base_values] 83 | return lrs 84 | 85 | def get_epoch_values(self, epoch: int): 86 | if self.t_in_epochs: 87 | return self._get_lr(epoch) 88 | else: 89 | return None 90 | 91 | def get_update_values(self, num_updates: int): 92 | if not self.t_in_epochs: 93 | return self._get_lr(num_updates) 94 | else: 95 | return None -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: LihuiChen 3 | E-mail: lihuichen@126.com 4 | Note: The metrics for reduced-rolution is the same with the matlat codes opened by [Vivone20]. 5 | Metrics for full-resolution have a little different results from the codes opened by [Vivone20]. 6 | 7 | Refercence: PansharpeningToolver1.3 and Pansharpening Toolbox for Distribution 8 | 9 | Pansharpening metrics: The same implementation of CC, SAM, ERGAS, Q2n as the one in Matlab codes publised by: 10 | [Vivone15] G. Vivone, L. Alparone, J. Chanussot, M. Dalla Mura, A. Garzelli, G. Licciardi, R. Restaino, and L. Wald, 11 | "A Critical Comparison Among Pansharpening Algorithms", IEEE Transactions on Geoscience and Remote Sensing, vol. 53, no. 5, pp. 2565�2586, May 2015. 12 | [Vivone20] G. Vivone, M. Dalla Mura, A. Garzelli, R. Restaino, G. Scarpa, M.O. Ulfarsson, L. Alparone, and J. Chanussot, 13 | "A New Benchmark Based on Recent Advances in Multispectral Pansharpening: Revisiting pansharpening with classical and 14 | emerging pansharpening methods",IEEE Geoscience and Remote Sensing Magazine, doi: 10.1109/MGRS.2020.3019315. 15 | """ 16 | from scipy.ndimage import sobel 17 | import numpy as np 18 | from scipy import signal, ndimage, misc 19 | import cv2 20 | from numpy.linalg import norm 21 | from PIL import Image 22 | from skimage.metrics import peak_signal_noise_ratio 23 | 24 | ########################################################## 25 | # Full Reference metrics for Reduced Resolution Assesment 26 | ########################################################## 27 | 28 | def SAM(ms,ps,degs = True): 29 | result = np.double(ps) 30 | target = np.double(ms) 31 | if result.shape != target.shape: 32 | raise ValueError('Result and target arrays must have the same shape!') 33 | 34 | bands = target.shape[2] 35 | rnorm = np.sqrt((result ** 2).sum(axis=2)) 36 | tnorm = np.sqrt((target ** 2).sum(axis=2)) 37 | dotprod = (result * target).sum(axis=2) 38 | cosines = (dotprod / (rnorm * tnorm)) 39 | sam2d = np.arccos(cosines) 40 | sam2d[np.invert(np.isfinite(sam2d))] = 0. # arccos(1.) -> NaN 41 | if degs: 42 | sam2d = np.rad2deg(sam2d) 43 | return sam2d[np.isfinite(sam2d)].mean() 44 | 45 | 46 | def CC(img1, img2): 47 | """SCC for 2D (H, W)or 3D (H, W, C) image; uint or float[0, 1]""" 48 | if not img1.shape == img2.shape: 49 | raise ValueError('Input images must have the same dimensions.') 50 | img1_ = img1.astype(np.float64) 51 | img2_ = img2.astype(np.float64) 52 | if img1_.ndim == 2: 53 | return np.corrcoef(img1_.reshape(1, -1), img2_.rehshape(1, -1))[0, 1] 54 | elif img1_.ndim == 3: 55 | ccs = [np.corrcoef(img1_[..., i].reshape(1, -1), img2_[..., i].reshape(1, -1))[0, 1] 56 | for i in range(img1_.shape[2])] 57 | return np.mean(ccs) 58 | else: 59 | raise ValueError('Wrong input image dimensions.') 60 | 61 | def sCC(ms, ps): 62 | ps_sobel = sobel(ps, mode='constant') 63 | ms_sobel = sobel(ms, mode='constant') 64 | return (np.sum(ps_sobel*ms_sobel)/np.sqrt(np.sum(ps_sobel*ps_sobel))/np.sqrt(np.sum(ms_sobel*ms_sobel))) 65 | 66 | 67 | def _qindex(img1, img2, block_size=8): 68 | """Q-index for 2D (one-band) image, shape (H, W); uint or float [0, 1]""" 69 | assert block_size > 1, 'block_size shold be greater than 1!' 70 | img1_ = img1.astype(np.float64) 71 | img2_ = img2.astype(np.float64) 72 | window = np.ones((block_size, block_size)) / (block_size**2) 73 | # window_size = block_size**2 74 | # filter, valid 75 | pad_topleft = int(np.floor(block_size/2)) 76 | pad_bottomright = block_size - 1 - pad_topleft 77 | mu1 = cv2.filter2D(img1_, -1, window)[pad_topleft:-pad_bottomright, pad_topleft:-pad_bottomright] 78 | mu2 = cv2.filter2D(img2_, -1, window)[pad_topleft:-pad_bottomright, pad_topleft:-pad_bottomright] 79 | mu1_sq = mu1**2 80 | mu2_sq = mu2**2 81 | mu1_mu2 = mu1 * mu2 82 | 83 | sigma1_sq = cv2.filter2D(img1_**2, -1, window)[pad_topleft:-pad_bottomright, pad_topleft:-pad_bottomright] - mu1_sq 84 | sigma2_sq = cv2.filter2D(img2_**2, -1, window)[pad_topleft:-pad_bottomright, pad_topleft:-pad_bottomright] - mu2_sq 85 | sigma12 = cv2.filter2D(img1_ * img2_, -1, window)[pad_topleft:-pad_bottomright, pad_topleft:-pad_bottomright] - mu1_mu2 86 | 87 | # all = 1, include the case of simga == mu == 0 88 | qindex_map = np.ones(sigma12.shape) 89 | # sigma == 0 and mu != 0 90 | idx = ((sigma1_sq + sigma2_sq) == 0) * ((mu1_sq + mu2_sq) != 0) 91 | qindex_map[idx] = 2 * mu1_mu2[idx] / (mu1_sq + mu2_sq)[idx] 92 | # sigma !=0 and mu == 0 93 | idx = ((sigma1_sq + sigma2_sq) != 0) * ((mu1_sq + mu2_sq) == 0) 94 | qindex_map[idx] = 2 * sigma12[idx] / (sigma1_sq + sigma2_sq)[idx] 95 | # sigma != 0 and mu != 0 96 | idx = ((sigma1_sq + sigma2_sq)!=0) * ((mu1_sq + mu2_sq) != 0) 97 | qindex_map[idx] =((2 * mu1_mu2[idx]) * (2 * sigma12[idx])) / ( 98 | (mu1_sq + mu2_sq)[idx] * (sigma1_sq + sigma2_sq)[idx]) 99 | return np.mean(qindex_map) 100 | 101 | 102 | def Q_AVE(img1, img2, block_size=8): 103 | """Q-index for 2D (H, W) or 3D (H, W, C) image; uint or float [0, 1]""" 104 | if not img1.shape == img2.shape: 105 | raise ValueError('Input images must have the same dimensions.') 106 | if img1.ndim == 2: 107 | return _qindex(img1, img2, block_size) 108 | elif img1.ndim == 3: 109 | qindexs = [_qindex(img1[..., i], img2[..., i], block_size) for i in range(img1.shape[2])] 110 | return np.array(qindexs).mean() 111 | else: 112 | raise ValueError('Wrong input image dimensions.') 113 | 114 | def ERGAS(img_fake, img_real, scale=4): 115 | """ERGAS for 2D (H, W) or 3D (H, W, C) image; uint or float [0, 1]. 116 | scale = spatial resolution of PAN / spatial resolution of MUL, default 4.""" 117 | if not img_fake.shape == img_real.shape: 118 | raise ValueError('Input images must have the same dimensions.') 119 | img_fake_ = img_fake.astype(np.float64) 120 | img_real_ = img_real.astype(np.float64) 121 | if img_fake_.ndim == 2: 122 | mean_real = img_real_.mean() 123 | mse = np.mean((img_fake_ - img_real_)**2) 124 | return 100 / scale * np.sqrt(mse / (mean_real**2 + np.finfo(np.float64).eps)) 125 | elif img_fake_.ndim == 3: 126 | means_real = img_real_.reshape(-1, img_real_.shape[2]).mean(axis=0) 127 | means_real = means_real**2 128 | means_real[np.where(means_real==0)] = np.finfo(np.float64).eps 129 | mses = ((img_fake_ - img_real_)**2).reshape(-1, img_fake_.shape[2]).mean(axis=0) 130 | return 100 / scale * np.sqrt((mses/means_real).mean()) 131 | else: 132 | raise ValueError('Wrong input image dimensions.') 133 | 134 | def Q2n(I_GT, I_F, Q_blocks_size=32, Q_shift=32): 135 | N1,N2,N3 = I_GT.shape 136 | ori_N3 = N3 137 | size2 = Q_blocks_size 138 | stepx = int(np.ceil(float(N1)/Q_shift)) 139 | stepy = int(np.ceil(float(N2)/Q_shift)) 140 | # stepy = N2//Q_shift 141 | if stepy<=0: stepx, stepy = 1, 1 142 | est1 = (stepx-1)*Q_shift+Q_blocks_size-N1 143 | est2 = (stepy-1)*Q_shift+Q_blocks_size-N2 144 | 145 | if sum([est1!=0, est2!=0])>0: 146 | refref = np.zeros((N1+est1, N2+est2, N3)) 147 | fusfus = np.zeros((N1+est1, N2+est2, N3)) 148 | refref[:N1, :N2,:] = I_GT 149 | refref[:N1, N2:,:] = I_GT[:,N2-1:N2-est2-1:-1,:] 150 | refref[N1:,:,:] = refref[N1-1:N1-est1-1:-1,:,:] 151 | 152 | fusfus[:N1, :N2,:] = I_F 153 | fusfus[:N1,N2:,:] = I_F[:,N2-1:N2-est2-1:-1,:] 154 | fusfus[N1:,:,:] = fusfus[N1-1:N1-est1-1:-1,:,:] 155 | I_GT, I_F = refref, fusfus 156 | I_GT, I_F = I_GT.astype(np.uint16), I_F.astype(np.uint16) 157 | N1,N2,N3 = I_GT.shape 158 | if (np.ceil(np.log2(np.array(N3)))-np.log2(np.array(N3)))!=0: 159 | Ndif = np.power(2,np.ceil(np.log2(np.array(N3)))) - N3 160 | dif = np.zeros((N1,N2,int(Ndif))) 161 | I_GT = np.concatenate((I_GT, dif), axis=2) 162 | I_F = np.concatenate((I_F, dif), axis=2) 163 | N3 = I_GT.shape[2] 164 | 165 | valori = np.zeros((stepx, stepy, N3)) 166 | for j in range(stepx): 167 | for i in range(stepy): 168 | tmp_gt = I_GT[j*Q_shift:j*Q_shift+Q_blocks_size,i*Q_shift:i*Q_shift+size2,:] 169 | tmp_f = I_F[j*Q_shift:j*Q_shift+Q_blocks_size,i*Q_shift:i*Q_shift+size2,:] 170 | o = onions_quality(tmp_gt, tmp_f, Q_blocks_size) 171 | valori[j,i,:] = o 172 | valori = valori[:,:,:ori_N3] 173 | Q2n_index_map = np.sqrt((valori*valori).sum(axis=2)) 174 | Q2n_index = Q2n_index_map.mean() 175 | return Q2n_index 176 | 177 | def onions_quality(dat1, dat2, size1): 178 | dat1, dat2 = dat1.astype(np.double), dat2.astype(np.double) 179 | dat2[:,:,1:] = -dat2[:,:,1:] 180 | N3 = dat1.shape[2] 181 | size2 = size1 182 | # Block normalization 183 | for i in range(N3): 184 | tmp = dat1[:,:,i] 185 | s, t = tmp.mean(), tmp.std() 186 | if t==0: t=np.finfo(np.float64).eps 187 | dat1[:,:,i] = (tmp-s)/t + 1 188 | if s==0: 189 | if i==0: 190 | dat2[:,:,i] = dat2[:,:,i]-s+1 191 | else: 192 | dat2[:,:,i]=-(-dat2[:,:,i]-s+1) 193 | else: 194 | if i==0: 195 | dat2[:,:,i] = (dat2[:,:,i]-s)/t + 1 196 | else: 197 | dat2[:,:,i]=-((-dat2[:,:,i]-s)/t+1) 198 | 199 | m1 = dat1.reshape(-1, N3).mean(axis=0, keepdims=True) 200 | mod_q1m =((m1*m1).sum()) 201 | m2 = dat2.reshape(-1, N3).mean(axis=0, keepdims=True) 202 | mod_q2m = ((m2*m2).sum()) 203 | 204 | mod_q1= np.sqrt((dat1*dat1).sum(axis=2)) 205 | mod_q2= np.sqrt((dat2*dat2).sum(axis=2)) 206 | 207 | 208 | 209 | mod_q1m = np.sqrt(mod_q1m) 210 | mod_q2m = np.sqrt(mod_q2m) 211 | termine2 = (mod_q1m*mod_q2m) 212 | termine4 = ((mod_q1m**2)+(mod_q2m**2)) 213 | 214 | int1=(size1*size2)/((size1*size2)-1)*(mod_q1*mod_q1).mean() 215 | int2=(size1*size2)/((size1*size2)-1)*(mod_q2*mod_q2).mean() 216 | 217 | termine3=int1+int2-(size1*size2)/((size1*size2)-1)*((mod_q1m**2)+(mod_q2m**2)) 218 | 219 | mean_bias = 2*termine2/termine4 220 | if termine3==0: 221 | # q = np.zeros(1, 1, N3) 222 | # q[:,:,N3] = mean_bias 223 | q = mean_bias 224 | else: 225 | cbm = 2/termine3 226 | qu = onion_mult2D(dat1, dat2) 227 | qm = onion_mult(m1, m2) 228 | qv = (size1*size2)/((size1*size2)-1)*(qu.reshape(-1, N3).mean(axis=0, keepdims=True)) 229 | q = qv-(size1*size2)/((size1*size2)-1)*qm 230 | q = q*mean_bias*cbm 231 | return q 232 | 233 | def onion_mult2D(onion1, onion2): 234 | while onion1.ndim<3: 235 | onion1 = np.expand_dims(onion1, axis=0) 236 | onion2 = np.expand_dims(onion2, axis=0) 237 | N3 = onion1.shape[2] 238 | if N3>1: 239 | L = N3//2 240 | a=onion1[:,:,:L] 241 | b=onion1[:,:,L:] 242 | b[:,:,1:] = -b[:,:,1:] 243 | 244 | c=onion2[:,:,:L] 245 | d=onion2[:,:,L:] 246 | d[:,:,1:] = -d[:,:,1:] 247 | if N3==2: 248 | ris = np.concatenate((a*c-d*b, a*d+c*b), axis=2) 249 | else: 250 | ris1=onion_mult2D(a,c) 251 | ris2=onion_mult2D(d,np.concatenate((b[:,:,0:1],-b[:,:,1:]), axis=2)) 252 | ris3=onion_mult2D(np.concatenate((a[:,:,0:1],-a[:,:,1:]), axis=2),d) 253 | ris4=onion_mult2D(c,b) 254 | 255 | aux1=ris1-ris2 256 | aux2=ris3+ris4 257 | ris = np.concatenate((aux1, aux2), axis=2) 258 | 259 | else: 260 | ris = onion1*onion2 261 | return ris 262 | 263 | def onion_mult(onion1,onion2): 264 | 265 | N=(onion1.shape[1]) 266 | 267 | if N>1: 268 | L=N//2 269 | a=onion1[:,:L] 270 | b=onion1[:, L:] 271 | b[:,1:] = -b[:, 1:] 272 | c=onion2[:,:L] 273 | d=onion2[:, L:] 274 | d[:,1:] = -d[:,1:] 275 | if N==2: 276 | ris=np.concatenate((a*c-d*b,a*d+c*b), axis=1) 277 | else: 278 | ris1=onion_mult(a,c) 279 | ris2=onion_mult(d,np.concatenate((b[:,0:1],-b[:,1:]), axis=1)) 280 | ris3=onion_mult(np.concatenate((a[:,0:1],-a[:,1:]), axis=1),d) 281 | ris4=onion_mult(c,b) 282 | aux1=ris1-ris2 283 | aux2=ris3+ris4 284 | ris=np.concatenate([aux1,aux2], axis=1) 285 | else: 286 | ris = onion1*onion2 287 | return ris 288 | ########################################################## 289 | # 23-taps interpolation 290 | ########################################################## 291 | ''' 292 | interpolation with 23-taps 293 | ''' 294 | 295 | from scipy import ndimage 296 | def upsample_mat_interp23(image, ratio=4): 297 | '''2 pixel shift compare with original matlab version''' 298 | shift=2 299 | h,w,c = image.shape 300 | basecoeff = np.array([[-4.63495665e-03, -3.63442646e-03, 3.84904063e-18, 301 | 5.76678319e-03, 1.08358664e-02, 1.01980790e-02, 302 | -9.31747402e-18, -1.75033181e-02, -3.17660068e-02, 303 | -2.84531643e-02, 1.85181518e-17, 4.42450253e-02, 304 | 7.71733386e-02, 6.70554910e-02, -2.85299239e-17, 305 | -1.01548683e-01, -1.78708388e-01, -1.60004642e-01, 306 | 3.61741232e-17, 2.87940558e-01, 6.25431459e-01, 307 | 8.97067600e-01, 1.00107877e+00, 8.97067600e-01, 308 | 6.25431459e-01, 2.87940558e-01, 3.61741232e-17, 309 | -1.60004642e-01, -1.78708388e-01, -1.01548683e-01, 310 | -2.85299239e-17, 6.70554910e-02, 7.71733386e-02, 311 | 4.42450253e-02, 1.85181518e-17, -2.84531643e-02, 312 | -3.17660068e-02, -1.75033181e-02, -9.31747402e-18, 313 | 1.01980790e-02, 1.08358664e-02, 5.76678319e-03, 314 | 3.84904063e-18, -3.63442646e-03, -4.63495665e-03]]) 315 | coeff = np.dot(basecoeff.T, basecoeff) 316 | I1LRU = np.zeros((ratio*h, ratio*w, c)) 317 | I1LRU[shift::ratio, shift::ratio, :]=image 318 | for i in range(c): 319 | temp = I1LRU[:, :, i] 320 | temp = ndimage.convolve(temp, coeff, mode='wrap') 321 | I1LRU[:, :, i]=temp 322 | return I1LRU 323 | 324 | 325 | ########################################################## 326 | # Using Gaussian filter matched MTF to degrade HRMS images 327 | ########################################################## 328 | def MTF_Filter(hrms, scale, sensor, GNyq=None): 329 | # while hrms.ndim<4: 330 | # hrms = np.expand_dims(hrms, axis=0) 331 | h,w,c = hrms.shape 332 | if GNyq is not None: 333 | GNyq = GNyq 334 | elif sensor == 'random': 335 | GNyq = np.random.normal(loc=0.3, scale=0.03, size=c) 336 | elif sensor=='QB': 337 | GNyq = [0.34, 0.32, 0.30, 0.22] # Band Order: B,G,R,NIR 338 | elif sensor=='IK': 339 | GNyq = [0.26,0.28,0.29,0.28] # Band Order: B,G,R,NIR 340 | elif sensor=='GE' or sensor == 'WV4': 341 | GNyq = [0.23,0.23,0.23,0.23] # Band Order: B,G,R,NIR 342 | elif sensor=='WV3': 343 | GNyq = [0.325, 0.355, 0.360, 0.350, 0.365, 0.360, 0.335, 0.315] 344 | elif sensor=='WV2': 345 | GNyq = ([0.35]*7+[0.27]) 346 | else: 347 | GNyq = [0.3]*c 348 | mtf = [GNyq2win(GNyq=tmp) for tmp in GNyq] 349 | ms_lr = [ndimage.convolve(hrms[:,:,idx], tmp_mtf, mode='wrap') for idx, tmp_mtf in enumerate(mtf)] 350 | ms_lr = np.stack(ms_lr, axis=2) 351 | return ms_lr 352 | 353 | 354 | ########################################################## 355 | # No reference metrics for Full Resolution Assesment. 356 | ########################################################## 357 | 358 | def HQNR(ps_ms, ms, pan, S=32, sensor=None, ratio=4): 359 | msexp = upsample_mat_interp23(ms, ratio) 360 | Dl = D_lambda_K(ps_ms, msexp, ratio, sensor, S) 361 | Ds = D_s(ps_ms, ms, pan, ratio, S, 1) 362 | HQNR_value = (1-Dl)*(1-Ds) 363 | return Dl, Ds, HQNR_value 364 | 365 | 366 | def D_lambda_K(fused, msexp, ratio, sensor, S): 367 | if fused.shape != msexp.shape: 368 | raise('The two images must have the same dimensions') 369 | N, M, _ = fused.shape 370 | if N % S != 0 or N % S != 0: 371 | raise('numbers of rows and columns must be multiple of the block size.') 372 | 373 | fused_degraded = MTF_Filter(fused, sensor, ratio, GNyq=None) 374 | q2n = Q2n(msexp, fused_degraded, S, S) 375 | return 1-q2n 376 | 377 | 378 | def D_s(img_fake, img_lm, pan,ratio, S, q=1): 379 | """Spatial distortion 380 | img_fake, generated HRMS 381 | img_lm, LRMS 382 | pan, HRPan""" 383 | # fake and lm 384 | assert img_fake.ndim == img_lm.ndim == 3, 'MS images must be 3D!' 385 | H_f, W_f, C_f = img_fake.shape 386 | H_r, W_r, C_r = img_lm.shape 387 | assert H_f // H_r == W_f // W_r == ratio, 'Spatial resolution should be compatible with scale' 388 | assert C_f == C_r, 'Fake and lm should have the same number of bands!' 389 | # fake and pan 390 | # if pan.ndim == 2: pan = np.expand_dims(pan, axis=2) 391 | if pan.ndim==3: pan = pan.squeeze(2) 392 | H_p, W_p = pan.shape 393 | assert H_f == H_p and W_f == W_p, "Pan's and fake's spatial resolution should be the same" 394 | # get LRPan, 2D 395 | pan_lr = Image.fromarray(pan).resize((int(1/ratio*H_p),int(1/ratio*W_p)), resample=Image.BICUBIC) 396 | pan_lr = np.array(pan_lr) 397 | 398 | Q_hr = [] 399 | Q_lr = [] 400 | for i in range(C_f): 401 | # for HR fake 402 | band1 = img_fake[..., i] 403 | Q_hr.append(_qindex(band1, pan, block_size=S)) 404 | band1 = img_lm[..., i] 405 | Q_lr.append(_qindex(band1, pan_lr, block_size=S)) 406 | Q_hr = np.array(Q_hr) 407 | Q_lr = np.array(Q_lr) 408 | D_s_index = (np.abs(Q_hr - Q_lr) ** q).mean() 409 | return D_s_index ** (1/q) 410 | 411 | 412 | 413 | def pan_calc_metrics_rr(PS, GT, scale, img_range): 414 | GT = np.array(GT).astype(np.float) 415 | PS = np.array(PS).astype(np.float) 416 | RMSE = (GT - PS)/img_range 417 | RMSE = np.sqrt((RMSE*RMSE).mean()) 418 | # cc = CC(GT,PS) 419 | sam = SAM(GT, PS) 420 | ergas = ERGAS(PS, GT, scale=scale) 421 | # Qave = Q_AVE(GT, PS) 422 | # scc = sCC(GT, PS) 423 | q2n = Q2n(GT,PS) 424 | psnr = peak_signal_noise_ratio(GT, PS, data_range=img_range) 425 | # return {'SAM':sam, 'ERGAS':ergas, 'Q2n':q2n, 'CC': cc, 'RMSE':RMSE} 426 | return {'PSNR': psnr, 'SAM':sam, 'ERGAS':ergas, 'Q2n': q2n} 427 | 428 | if __name__ == '__main__': 429 | 430 | # import numpy as np 431 | # import os 432 | # ArbRPN_dir = '/home/ser606/Documents/LihuiChen/ArbRPN_20200916/results/SR/RNN_RESIDUAL_BI_PAN_FB_MASK/QB-FIX-4/x4/' 433 | # GT_dir = '/home/ser606/Documents/LihuiChen/PanSharp_dataset/QB/test/MTF/4bands/HRMS_npy' 434 | # LRMS_dir = '/home/ser606/Documents/LihuiChen/PanSharp_dataset/QB/test/MTF/4bands/LRMS_npy' 435 | # PAN_dir = '/home/ser606/Documents/LihuiChen/PanSharp_dataset/QB/test/MTF/4bands/LRPAN_npy' 436 | # # save_root = '/home/ser606/Documents/LihuiChen/compare/extra/reduced_resolution/test' 437 | 438 | # ArbRPN_files = os.listdir(ArbRPN_dir) 439 | # ArbRPN_files.sort() 440 | # GT_files = os.listdir(GT_dir) 441 | # GT_files.sort() 442 | # LRMS_files = os.listdir(LRMS_dir) 443 | # LRMS_files.sort() 444 | # PAN_files = os.listdir((PAN_dir)) 445 | # PAN_files.sort() 446 | # cc = [] 447 | # sam = [] 448 | # ergas = [] 449 | # q_ave = [] 450 | # q2n = [] 451 | # for i in range(len(ArbRPN_files)): 452 | # ps = np.load(os.path.join(ArbRPN_dir, ArbRPN_files[i])) 453 | # ps = ps.astype(np.float) 454 | # gt = np.load(os.path.join(GT_dir, GT_files[i])).astype(np.float) 455 | # pan = np.load(os.path.join(PAN_dir, PAN_files[i])).astype(np.float) 456 | # # print('%s || %s \n'%(ArbRPN_files[i], PAN_files[i])) 457 | # # print((ps.dtype)) 458 | # cc.append(CC(ps, gt)) 459 | # sam.append(SAM(gt, ps)) 460 | # ergas.append(ERGAS(ps, gt)) 461 | # q_ave.append(Q_AVE(ps, gt)) 462 | # q2n.append(Q2n(gt, ps, 32, 32)) 463 | 464 | # print(mean(q2n)) 465 | 466 | # import os 467 | # TFNET_dir = '/home/ser606/Documents/LihuiChen/ArbRPN_20200916/results/SR/PARABIRNN/QB-Vanilla-BiRNN-FR/x4' 468 | # LRMS_dir = '/home/ser606/Documents/LihuiChen/PanSharp_dataset/QB/test/MS_full_resolution/MS_npy' 469 | # PAN_dir = '/home/ser606/Documents/LihuiChen/PanSharp_dataset/QB/test/PAN_full_resolution/PAN_npy' 470 | # # save_root = '/home/ser606/Documents/LihuiChen/compare/extra/full_resolution/test'ArbRPN_files = os.listdir(ArbRPN_dir) 471 | # TFNET_files = os.listdir(TFNET_dir) 472 | # TFNET_files.sort() 473 | # # GT_files = os.listdir(GT_dir) 474 | # # GT_files.sort() 475 | # LRMS_files = os.listdir(LRMS_dir) 476 | # LRMS_files.sort() 477 | # PAN_files = os.listdir((PAN_dir)) 478 | # PAN_files.sort() 479 | # Dl_results = [] 480 | # Ds_results = [] 481 | # Qnr_results = [] 482 | # for i in range(len(TFNET_files)): 483 | # print('processing the %d-th image.\n'%i) 484 | # ps = np.load(os.path.join(TFNET_dir, TFNET_files[i])) 485 | # ps = ps.astype(np.float) 486 | # lrms = np.load(os.path.join(LRMS_dir, LRMS_files[i])).astype(np.float) 487 | # pan = np.load(os.path.join(PAN_dir, PAN_files[i])).astype(np.float) 488 | # # print('%s || %s \n'%(ArbRPN_files[i], PAN_files[i])) 489 | # # print((ps.dtype)) 490 | # msexp = upsample_mat_interp23(lrms, 4) 491 | # dl, ds, hqnr = HQNR(ps, lrms, msexp, pan, 32, 'QB', 4) 492 | # Dl_results.append(dl) 493 | # Ds_results.append(ds) 494 | # Qnr_results.append(hqnr) 495 | # print(sum(Qnr_results)/len(Qnr_results)) 496 | a = np.random.randn(240, 240, 128) 497 | b = np.random.randn(240, 240, 128) 498 | Q2n(a,b) 499 | -------------------------------------------------------------------------------- /utils/plot_subfigs_colorbar.py: -------------------------------------------------------------------------------- 1 | ''' 2 | author: Lihui Chen 3 | copywrite: Lihui Chen 4 | email: lihuichenscu@foxmail.com 5 | ''' 6 | from matplotlib import cm 7 | from matplotlib.colors import LinearSegmentedColormap 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | ################ { for colorbar } ################ 11 | from matplotlib import cm 12 | from matplotlib.colors import LinearSegmentedColormap 13 | color_list = ['#0000FF', '#00FF33', '#FFFF33', '#FF0000', '#FF00FF'] 14 | my_cmap = LinearSegmentedColormap.from_list('rain', color_list) 15 | cm.register_cmap(cmap=my_cmap) 16 | ################ { for colorbar } ################ 17 | 18 | def plot_subfigs_colorbar(imgDict, plotsize=[1, 6], bound=[0, 1], save_name=None, ifpdf=False, fontsize=6): 19 | ''' 20 | inputs: 21 | imgDict: dict of images 22 | plotsize: size of subplots 23 | bound: the low- and up- bound for image 24 | return: 25 | ''' 26 | plt.rcParams['font.family'] = 'Times' 27 | fig = plt.figure() 28 | for idx, key in enumerate(imgDict.keys()): 29 | ax = fig.add_subplot(plotsize[0],plotsize[1], idx+1) 30 | data = imgDict[key] 31 | if 'res' in key.lower(): 32 | im = ax.imshow(data , vmin = bound[0], vmax = bound[1], cmap='rain') 33 | else: 34 | data = np.clip(data, 0, 1) 35 | ax.imshow(data) 36 | ax.set_title(key, fontsize=fontsize) 37 | ax.set_xticks([]) 38 | ax.set_yticks([]) 39 | cax = fig.add_axes([ax.get_position().x1+0.01,ax.get_position().y0,0.02,ax.get_position().height]) 40 | cb = plt.colorbar(im, cax=cax) 41 | # cb = plt.colorbar(im, ax = axes.ravel().tolist()) 42 | # cb = plt.colorbar(im, ax = ax) 43 | cb.ax.tick_params(labelsize=fontsize) 44 | # plt.show() 45 | if save_name is not None: 46 | plt.savefig(save_name) 47 | 48 | if ifpdf: 49 | plt.savefig(save_name.replace('.png', '.svg'), dpi=300) 50 | # plt.close() 51 | return fig 52 | 53 | -------------------------------------------------------------------------------- /utils/pyCartooTexture.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from imageio import imread 3 | import matplotlib.pyplot as plt 4 | from scipy.ndimage.filters import gaussian_filter1d as Gf1 5 | 6 | from train import create_Dataloader, parse_options 7 | 8 | 9 | def ComputeGradient(img): 10 | gd = np.gradient(img) 11 | g1 = np.sqrt(np.power(gd[0], 2) + np.power(gd[1], 2)) 12 | return g1 13 | 14 | 15 | def SepConvol(grad, sigma): 16 | v = Gf1(grad, sigma, axis=-1) 17 | v = Gf1(v, sigma, axis=0) 18 | return v 19 | 20 | 21 | def low_pass_filter(img, sigma, niter): 22 | gconvolved = SepConvol(img, sigma) 23 | imdifference = img - gconvolved 24 | 25 | for i in range(0, niter): 26 | imconvolved = SepConvol(imdifference, sigma) 27 | imdifference = imdifference - imconvolved 28 | gconvolved = img - imdifference 29 | return gconvolved 30 | 31 | 32 | def WeightingFunction(r1, r2): 33 | difference = r1 - r2 34 | ar1 = np.abs(r1) 35 | 36 | mask_ar = np.argwhere(ar1 <= 1) 37 | difference /= ar1 38 | difference[mask_ar[:, 0], mask_ar[:, 1]] = 0.0 39 | 40 | cmin = 0.25 41 | cmax = 0.5 42 | 43 | weight = (difference - cmin) / (cmax - cmin) 44 | mask_min = np.argwhere(difference < cmin) 45 | weight[mask_min[:, 0], mask_min[:, 1]] = 0.0 46 | mask_max = np.argwhere(difference > cmax) 47 | weight[mask_max[:, 0], mask_max[:, 1]] = 1 48 | 49 | return weight 50 | 51 | 52 | def fast_CT(img, sigma, niter): 53 | grad = ComputeGradient(img) 54 | ratio1 = SepConvol(grad, sigma) 55 | gconvolved = low_pass_filter(img, sigma, niter) 56 | grad = ComputeGradient(gconvolved) 57 | ratio2 = SepConvol(grad, sigma) 58 | weight = WeightingFunction(ratio1, ratio2) 59 | return weight * gconvolved + (1 - weight) * img 60 | 61 | 62 | if __name__ == '__main__': 63 | # img = imread('test.png', pilmode='F') 64 | # 65 | # sigma = 3 66 | # niter = 5 67 | # 68 | # material = fast_CT(img, sigma, niter) 69 | # 70 | # plt.figure(1) 71 | # plt.subplot(1, 3, 1) 72 | # plt.imshow(img, cmap="gray") 73 | # plt.title('image') 74 | # 75 | # plt.subplot(1, 3, 2) 76 | # plt.imshow(material, cmap="gray") 77 | # plt.title('material') 78 | # 79 | # texture = (img - material) 80 | # plt.subplot(1, 3, 3) 81 | # plt.imshow(texture, cmap='gray') 82 | # plt.title('texture') 83 | # 84 | # plt.show() 85 | 86 | option_file_path = '../options/train/train_MFPS_example.json' 87 | opt = parse_options(option_file_path) 88 | train_set, train_loader, val_set, val_loader = create_Dataloader(opt) 89 | sigma = 1 90 | niter = 5 91 | 92 | for batch in train_loader: 93 | images = batch 94 | lr = images['LR'].numpy()[0] / 2047 95 | hr = images['HR'].numpy()[0] / 2047 96 | pan = images['PAN'].numpy()[0] / 2047 97 | 98 | img = pan[0] 99 | 100 | material = fast_CT(img, sigma, niter) 101 | 102 | plt.figure(1) 103 | plt.subplot(1, 3, 1) 104 | plt.imshow(img, cmap="gray") 105 | plt.title('image') 106 | 107 | plt.subplot(1, 3, 2) 108 | plt.imshow(material, cmap="gray") 109 | plt.title('material') 110 | 111 | texture = (img - material) 112 | plt.subplot(1, 3, 3) 113 | plt.imshow(texture, cmap='gray') 114 | plt.title('texture') 115 | 116 | plt.show() 117 | -------------------------------------------------------------------------------- /utils/reamdme.md: -------------------------------------------------------------------------------- 1 | Copyright@ Lihui Chen, email: lihuichenscu@foxmail.com 2 | 3 | # -------------------------------------------------------------------------------- /utils/test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/utils/test.png -------------------------------------------------------------------------------- /utils/unionNormImg.py: -------------------------------------------------------------------------------- 1 | ''' 2 | author: Lihui Chen 3 | copywrite: Lihui Chen 4 | email: lihuichenscu@foxmail.com 5 | ''' 6 | import numpy as np 7 | import glob 8 | import sys 9 | import scipy.io as sio 10 | # sys.path.append('../dataProc4manuscript/') 11 | from hist_adjust import hist_line_stretch, hist_line_stretchv2 12 | from plot_subfigs_colorbar import plot_subfigs_colorbar 13 | from metrics import pan_calc_metrics_rr 14 | # %matplotlib inline'GSA': '/Users/qianqian/Documents/codeWorkspace/result/result/GSA/*.npy', 15 | import matplotlib.pyplot as plt 16 | import matplotlib 17 | matplotlib.use('TkAgg') 18 | import os 19 | # from collections import Order`edDict 20 | def main(): 21 | results_dir = './SOTA_HS_Sharpen' 22 | dataset = 'result_chikusei' #'result_HyperALi' 'result_chikusei' 23 | method_dir = { 24 | # 'Ours': '', 25 | # 'GSA': '%s/%s/GSA/*.npy'%(results_dir, dataset), 26 | 'SFIMHS': '%s/%s/SFIMHS/*.npy'%(results_dir, dataset), 27 | 'GLPHS': '%s/%s/GLPHS/*.npy'%(results_dir, dataset), 28 | 'CNMF': '%s/%s/CNMF/*.npy'%(results_dir, dataset), 29 | # 'ICCV15': '%s/%s/ICCV15/*.npy'%(results_dir, dataset), 30 | 'FUSE': '%s/%s/FUSE/*.npy'%(results_dir, dataset), 31 | 'HySure': '%s/%s/HySure/*.npy'%(results_dir, dataset), 32 | 'MHFnet': '%s/%s/MHFnet/*.npy'%(results_dir, dataset), 33 | 'HSRnet':'%s/%s/HSRnet/*.npy'%(results_dir, dataset), 34 | # 'MoGCDN':'%s/%s/MoGCDN/*.mat'%(results_dir, dataset), 35 | 'Ours': '%s/%s/Ours/*.npy'%(results_dir, dataset), 36 | } 37 | 38 | GT_dir = '%s/%s/GT/*.npy'%(results_dir, dataset) 39 | # MSHR = './' 40 | GT_file = glob.glob(GT_dir) 41 | GT_file.sort() 42 | savedir = '%s/%s/visualization_final/'%(results_dir, dataset) 43 | if not os.path.isdir(savedir): 44 | os.makedirs(savedir) 45 | rgb_band = (30, 10, 4) 46 | rad_res = 4095 47 | 48 | # dl_method_files = {key:glob.glob(value) for key, value in DL_method.items()} 49 | # for key in dl_method_files.keys(): dl_method_files[key].sort() 50 | 51 | # dl_method_files = { 52 | # 'MHFnet':sio.loadmat(dl_method_files['MHFnet'][0])['chikusei'].astype(np.float), 53 | # 'HSRnet':sio.loadmat(dl_method_files['HSRnet'][0])['output'].astype(np.float)/12.0*rad_res, 54 | # } 55 | 56 | 57 | method_files = {key:glob.glob(value) for key, value in method_dir.items()} 58 | for key in method_files.keys(): method_files[key].sort() 59 | key_list = list(method_files.keys()) 60 | imgdict = dict() 61 | for idx_img in range(len(GT_file)): 62 | # imgdict = {key: np.load(value[idx_img]).astype(np.float) for key, value in method_files.items()} 63 | for idx,(key, value) in enumerate(method_files.items()): 64 | if key in ['MogCDN']: 65 | imgdict[key] = sio.loadmat(value[idx_img])['data'].astype(np.float).squeeze() 66 | else: 67 | imgdict[key] = np.load(value[idx_img]).astype(np.float).squeeze() 68 | # if key=='MHFnet': 69 | # dlimgdict[key] = sio.loadmat(value[idx_img])['HyperALi%d'%(idx_img+1)].astype(np.float).squeeze() 70 | # # dlimgdict[dlkey] = dlvalue[idx_img,...] 71 | # if key=='HSRnet': 72 | # dlimgdict[key] = sio.loadmat(value[idx_img])['data'].astype(np.float).squeeze()#/12*rad_res 73 | # # dlimgdict[dlkey] = dlvalue[idx_img,...] 74 | # if key == 'Ours': 75 | # dlimgdict[key] = sio.loadmat(value[idx_img])['data'].astype(np.float).squeeze() 76 | 77 | GT= np.load(GT_file[idx_img]).astype(np.float) 78 | # GT= sio.loadmat(GT_file[idx_img])['HSHR'].astype(np.float) 79 | # MSHR = sio.loadmat(GT_file[idx_img])['MSHR'].astype(np.float) 80 | # MSHR,_,_ = hist_line_stretchv2(MSHR[:,:,(5,3,1)], rad_res, bound=[0.01, 0.99]) 81 | 82 | res_imgdict = {'%s Res.'%key:np.abs(value-GT).mean(axis=2)/rad_res for key, value in imgdict.items()} 83 | 84 | imgdict = {key: value[:,:, rgb_band] for key, value in imgdict.items()} 85 | GT, lowThe, highThe = hist_line_stretch(GT[:,:,rgb_band], rad_res, bound=[0.01, 0.995]) 86 | # print('%s'%str(tmpMetric)) 87 | for key, value in imgdict.items(): 88 | imgdict[key] = (value-lowThe)/(highThe-lowThe) 89 | # for idx, (low, high) in enumerate(zip(lowThe, highThe)): 90 | # tmp = value[:,:,idx] 91 | # tmp[np.where(tmp>high)]=high 92 | # tmp[np.where(tmpgF[i,j]): 309 | GAF[i,j] = gF[i,j]/gA[i,j] 310 | elif(gA[i,j]==gF[i,j]): 311 | GAF[i, j] = gF[i, j] 312 | else: 313 | GAF[i, j] = gA[i,j]/gF[i, j] 314 | AAF[i,j] = 1-np.abs(aA[i,j]-aF[i,j])/(math.pi/2) 315 | 316 | QgAF[i,j] = Tg/(1+math.exp(kg*(GAF[i,j]-Dg))) 317 | QaAF[i,j] = Ta/(1+math.exp(ka*(AAF[i,j]-Da))) 318 | 319 | QAF[i,j] = QgAF[i,j]*QaAF[i,j] 320 | 321 | return QAF 322 | 323 | # QAF = getQabf(aA,gA,aF,gF) 324 | # QBF = getQabf(aB,gB,aF,gF) 325 | # 326 | # 327 | # #计算QABF 328 | # deno = np.sum(gA+gB) 329 | # nume = np.sum(np.multiply(QAF,gA)+np.multiply(QBF,gB)) 330 | # output = nume/deno 331 | 332 | def qabfMetric(imgFu, imgA, imgB): 333 | gA, aA = getArray(imgA) 334 | gB, aB = getArray(imgB) 335 | gF, aF = getArray(imgFu) 336 | QAF = getQabf(aA, gA, aF, gF) 337 | QBF = getQabf(aB, gB, aF, gF) 338 | deno = np.sum(gA + gB) 339 | nume = np.sum(np.multiply(QAF, gA) + np.multiply(QBF, gB)) 340 | output = nume / deno 341 | return output 342 | 343 | -------------------------------------------------------------------------------- /visualization/MsVisualizer.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import spectral as spy 3 | import numpy as np 4 | 5 | 6 | def visualize_dump(ms, save_path, bands=None): 7 | spy.save_rgb(save_path, ms, bands=bands) 8 | 9 | 10 | if __name__ == '__main__': 11 | # 获取mat格式的数据,loadmat输出的是dict,所以需要进行定位 12 | num = 500 13 | hr_ms = np.load(f"../WV2/Augment/train_HR_aug_npy/{num}_HR.npy") 14 | lr_ms = np.load(f"../WV2/Augment/train_LR_aug_npy/{num}_LR.npy") 15 | input_image_PAN = np.load(f"../WV2/Augment/train_PAN_aug_npy/{num}_PAN.npy") 16 | 17 | View2_lr = spy.imshow(data=lr_ms, bands=[2, 1, 0], title="img_LR") # 图像显示 18 | view_pan = spy.imshow(data=input_image_PAN, title="img_pan") 19 | 20 | spy.save_rgb("1.bmp", hr_ms, bands=[0]) 21 | 22 | plt.pause(60) 23 | k = "F:\BaiduNetdiskDownload\GF-2" 24 | -------------------------------------------------------------------------------- /visualization/__init__.py: -------------------------------------------------------------------------------- 1 | from .visualizer import get_local 2 | 3 | -------------------------------------------------------------------------------- /visualization/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/visualization/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /visualization/__pycache__/visualizer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songlei-xiong/ATFuse/1fd64810843ae872ab4998105e2fd3b2b1deafb4/visualization/__pycache__/visualizer.cpython-39.pyc -------------------------------------------------------------------------------- /visualization/attaVisualizer.py: -------------------------------------------------------------------------------- 1 | from PIL import Image, ImageDraw 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | 5 | 6 | def grid_show(to_shows, cols): 7 | rows = (len(to_shows) - 1) // cols + 1 8 | it = iter(to_shows) 9 | fig, axs = plt.subplots(rows, cols, figsize=(rows * 8.5, cols * 2)) 10 | for i in range(rows): 11 | for j in range(cols): 12 | try: 13 | image, title = next(it) 14 | except StopIteration: 15 | image = np.zeros_like(to_shows[0][0]) 16 | title = 'pad' 17 | axs[i, j].imshow(image) 18 | axs[i, j].set_title(title) 19 | axs[i, j].set_yticks([]) 20 | axs[i, j].set_xticks([]) 21 | plt.show() 22 | 23 | 24 | def visualize_head(att_map): 25 | ax = plt.gca() 26 | # Plot the heatmap 27 | im = ax.imshow(att_map) 28 | # Create colorbar 29 | cbar = ax.figure.colorbar(im, ax=ax) 30 | plt.show() 31 | 32 | 33 | def visualize_heads(att_map, cols): 34 | to_shows = [] 35 | att_map = att_map.squeeze() 36 | for i in range(att_map.shape[0]): 37 | to_shows.append((att_map[i], f'Head {i}')) 38 | average_att_map = att_map.mean(axis=0) 39 | to_shows.append((average_att_map, 'Head Average')) 40 | grid_show(to_shows, cols=cols) 41 | 42 | 43 | def gray2rgb(image): 44 | return np.repeat(image[..., np.newaxis], 3, 2) 45 | 46 | 47 | def cls_padding(image, mask, cls_weight, grid_size): 48 | if not isinstance(grid_size, tuple): 49 | grid_size = (grid_size, grid_size) 50 | 51 | image = np.array(image) 52 | 53 | H, W = image.shape[:2] 54 | delta_H = int(H / grid_size[0]) 55 | delta_W = int(W / grid_size[1]) 56 | 57 | padding_w = delta_W 58 | padding_h = H 59 | padding = np.ones_like(image) * 255 60 | padding = padding[:padding_h, :padding_w] 61 | 62 | padded_image = np.hstack((padding, image)) 63 | padded_image = Image.fromarray(padded_image) 64 | draw = ImageDraw.Draw(padded_image) 65 | draw.text((int(delta_W / 4), int(delta_H / 4)), 'CLS', fill=(0, 0, 0)) # PIL.Image.size = (W,H) not (H,W) 66 | 67 | mask = mask / max(np.max(mask), cls_weight) 68 | cls_weight = cls_weight / max(np.max(mask), cls_weight) 69 | 70 | if len(padding.shape) == 3: 71 | padding = padding[:, :, 0] 72 | padding[:, :] = np.min(mask) 73 | mask_to_pad = np.ones((1, 1)) * cls_weight 74 | mask_to_pad = Image.fromarray(mask_to_pad) 75 | mask_to_pad = mask_to_pad.resize((delta_W, delta_H)) 76 | mask_to_pad = np.array(mask_to_pad) 77 | 78 | padding[:delta_H, :delta_W] = mask_to_pad 79 | padded_mask = np.hstack((padding, mask)) 80 | padded_mask = padded_mask 81 | 82 | meta_mask = np.zeros((padded_mask.shape[0], padded_mask.shape[1], 4)) 83 | meta_mask[delta_H:, 0: delta_W, :] = 1 84 | 85 | return padded_image, padded_mask, meta_mask 86 | 87 | 88 | def visualize_grid_to_grid_with_cls(att_map, grid_index, image, grid_size=14, alpha=0.6): 89 | if not isinstance(grid_size, tuple): 90 | grid_size = (grid_size, grid_size) 91 | 92 | attention_map = att_map[grid_index] 93 | cls_weight = attention_map[0] 94 | 95 | mask = attention_map[1:].reshape(grid_size[0], grid_size[1]) 96 | mask = Image.fromarray(mask).resize((image.size)) 97 | 98 | padded_image, padded_mask, meta_mask = cls_padding(image, mask, cls_weight, grid_size) 99 | 100 | if grid_index != 0: # adjust grid_index since we pad our image 101 | grid_index = grid_index + (grid_index - 1) // grid_size[1] 102 | 103 | grid_image = highlight_grid(padded_image, [grid_index], (grid_size[0], grid_size[1] + 1)) 104 | 105 | fig, ax = plt.subplots(1, 2, figsize=(10, 7)) 106 | fig.tight_layout() 107 | 108 | ax[0].imshow(grid_image) 109 | ax[0].axis('off') 110 | 111 | ax[1].imshow(grid_image) 112 | ax[1].imshow(padded_mask, alpha=alpha, cmap='rainbow') 113 | ax[1].imshow(meta_mask) 114 | ax[1].axis('off') 115 | 116 | 117 | def visualize_grid_to_grid(att_map, grid_index, image, grid_size=14, alpha=0.6): 118 | if not isinstance(grid_size, tuple): 119 | grid_size = (grid_size, grid_size) 120 | 121 | H, W = att_map.shape 122 | with_cls_token = False 123 | 124 | grid_image = highlight_grid(image, [grid_index], grid_size) 125 | 126 | mask = att_map[grid_index].reshape(grid_size[0], grid_size[1]) 127 | mask = Image.fromarray(mask).resize((image.size)) 128 | 129 | fig, ax = plt.subplots(1, 2, figsize=(10, 7)) 130 | fig.tight_layout() 131 | 132 | ax[0].imshow(grid_image) 133 | ax[0].axis('off') 134 | 135 | ax[1].imshow(grid_image) 136 | ax[1].imshow(mask / np.max(mask), alpha=alpha, cmap='rainbow') 137 | ax[1].axis('off') 138 | plt.show() 139 | 140 | 141 | def highlight_grid(image, grid_indexes, grid_size=14): 142 | if not isinstance(grid_size, tuple): 143 | grid_size = (grid_size, grid_size) 144 | 145 | W, H = image.size 146 | h = H / grid_size[0] 147 | w = W / grid_size[1] 148 | image = image.copy() 149 | for grid_index in grid_indexes: 150 | x, y = np.unravel_index(grid_index, (grid_size[0], grid_size[1])) 151 | a = ImageDraw.ImageDraw(image) 152 | a.rectangle([(y * w, x * h), (y * w + w, x * h + h)], fill=None, outline='red', width=2) 153 | return image -------------------------------------------------------------------------------- /visualization/visualizer.py: -------------------------------------------------------------------------------- 1 | from bytecode import Bytecode, Instr 2 | 3 | 4 | class get_local(object): 5 | cache = {} 6 | is_activate = False 7 | 8 | def __init__(self, varname): 9 | self.varname = varname 10 | 11 | def __call__(self, func): 12 | if not type(self).is_activate: 13 | return func 14 | 15 | type(self).cache[func.__qualname__] = [] 16 | c = Bytecode.from_code(func.__code__) 17 | extra_code = [ 18 | Instr('STORE_FAST', '_res'), 19 | Instr('LOAD_FAST', self.varname), 20 | Instr('STORE_FAST', '_value'), 21 | Instr('LOAD_FAST', '_res'), 22 | Instr('LOAD_FAST', '_value'), 23 | Instr('BUILD_TUPLE', 2), 24 | Instr('STORE_FAST', '_result_tuple'), 25 | Instr('LOAD_FAST', '_result_tuple'), 26 | ] 27 | c[-1:-1] = extra_code 28 | func.__code__ = c.to_code() 29 | 30 | def wrapper(*args, **kwargs): 31 | res, values = func(*args, **kwargs) 32 | type(self).cache[func.__qualname__].append(values.detach().cpu().numpy()) 33 | return res 34 | 35 | return wrapper 36 | 37 | @classmethod 38 | def clear(cls): 39 | for key in cls.cache.keys(): 40 | cls.cache[key] = [] 41 | 42 | @classmethod 43 | def activate(cls): 44 | cls.is_activate = True 45 | --------------------------------------------------------------------------------