├── LICENSE.md ├── README.md ├── data ├── Readme.md ├── multi_exposure_dataset.py ├── multi_focus_dataset.py ├── self_mixpretrain_dataset.py └── visir_fusion_dataset.py ├── loss ├── mix_fp_loss.py └── readme.md ├── models ├── Readme.md ├── UCSharedModelPro.py ├── UCSharedModelProCommon.py ├── UCTestShareModelProCommon.py ├── UCTestSharedModelPro.py ├── resnest.py ├── resnet.py └── splat.py ├── option ├── options.py ├── test │ ├── EMFF_Test_Dataset.yaml │ ├── IVF_Test_Dataset.yaml │ ├── MEF_Test_Dataset.yaml │ ├── MFF_Test_Dadaset.yaml │ ├── SMEF_Test_Dataset.yaml │ └── TIVF_Test_Dataset.yaml └── train │ └── SelfTrained_SDataset.yaml ├── selftrain.py ├── test.py └── utils ├── Readme.md ├── build_code_arch.py └── util.py /LICENSE.md: -------------------------------------------------------------------------------- 1 | Copyright (c) 2012-2021 Scott Chacon and others 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining 4 | a copy of this software and associated documentation files (the 5 | "Software"), to deal in the Software without restriction, including 6 | without limitation the rights to use, copy, modify, merge, publish, 7 | distribute, sublicense, and/or sell copies of the Software, and to 8 | permit persons to whom the Software is furnished to do so, subject to 9 | the following conditions: 10 | 11 | The above copyright notice and this permission notice shall be 12 | included in all copies or substantial portions of the Software. 13 | 14 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 15 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 16 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 17 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 18 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 19 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 20 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 21 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Fusion from Decomposition 2 | ---------- 3 | This repository is an official implementation of **Fusion from Decomposition: A Self-Supervised Decomposition Approach for Image Fusion** (ECCV 2022). 4 | 5 | ## Prerequisites 6 | ---------- 7 | - Linux 8 | - Python 3 9 | - NVIDIA GPU + CUDA cuDNN 10 | - PyTorch 1.9 11 | - torchvision 0.8 12 | - Pillow 8.1 13 | - Opencv 4.4 14 | 15 | 16 | ## Getting Started 17 | ---------- 18 | ### Installation 19 | ---------- 20 | - Install python libraries and requests. 21 | - Clone this repo: 22 | ```bash 23 | git clone https://github.com/erfect2020/DecompositionForFusion.git 24 | cd DecompositionForFusion 25 | ``` 26 | 27 | 28 | 29 | 30 | ### Start run 31 | ---------- 32 | 1. Download [COCO](https://github.com/cocodataset/cocoapi): https://cocodataset.org/ 33 | 2. Put your training images into any floder and modify the `option/train/SelfTrained_SDataset.yaml' to retarget the path. 34 | 3. Train DeFusion 35 | ```bash 36 | python selftrain.py --opt options/train/SelfTrained_SDataset.yaml 37 | ``` 38 | 39 | ### Start evaluation 40 | ---------- 41 | 42 | 1. Download test dataset: 43 | 1. Multi-exposure image fusion: [MEFB](https://github.com/xingchenzhang/MEFB):https://github.com/xingchenzhang/MEFB, [SICE](https://github.com/csjcai/SICE):https://github.com/csjcai/SICE. 44 | 2. Multi-focus image fusion: [Real-MFF](https://githubmemory.com/repo/Zancelot/Real-MFF):https://githubmemory.com/repo/Zancelot/Real-MFF, [Dataset](https://github.com/xingchenzhang/MFIFB):https://github.com/xingchenzhang/MFIFB. 45 | 3. Visible-infrared image fusion: [RoadScene](https://github.com/jiayi-ma/RoadScene):https://github.com/jiayi-ma/RoadScene, [TNO](https://figshare.com/articles/dataset/TNO_Image_Fusion_Dataset/1008029):https://figshare.com/articles/dataset/TNO_Image_Fusion_Dataset/1008029. 46 | 3. Modify [test.py](test.py) to select the data preprocess files for different tasks: 47 | 4. (Option) Our pretrained model is avaliable at [Google Drive](https://drive.google.com/file/d/1CUoFLiV3mugvbfBcMcwgXbDF6bWPhdd9/view?usp=sharing) 48 | 5. Test DeFusion 49 | 1. Test multi-exposure image fusion task on [MEFB](https://github.com/xingchenzhang/MEFB) or [SICE](https://github.com/csjcai/SICE) 50 | ```bash 51 | python test.py --opt options/test/MEF_Test_Dataset.yaml or 52 | python test.py --opt options/test/SMEF_Test_Dataset.yaml 53 | ``` 54 | 2. Test multi-focus image fusion task on [Real-MFF](https://githubmemory.com/repo/Zancelot/Real-MFF) or [Dataset](https://github.com/xingchenzhang/MFIFB) 55 | ```bash 56 | python test.py --opt options/test/MFF_Test_Dataset.yaml or 57 | python test.py --opt options/test/EMFF_Test_Dataset.yaml 58 | ``` 59 | 3. Test visible infrared image fusion task on [RoadScene](https://github.com/jiayi-ma/RoadScene) or [TNO](https://figshare.com/articles/dataset/TNO_Image_Fusion_Dataset/1008029) 60 | ```bash 61 | python test.py --opt options/test/IVF_Test_Dataset.yaml or 62 | python test.py --opt options/test/TIVF_Test_Dataset.yaml 63 | ``` 64 | 65 | 66 | 67 | ## License 68 | ---------- 69 | Distributed under the MIT License. See ```LICENSE.md``` for more information. 70 | 71 | ## Citations 72 | ---------- 73 | If DeFusion helps your research or work, please consider citing DeFusion. 74 | 75 | ``` 76 | @InProceedings{Liang2022ECCV, 77 | author = {Liang, Pengwei and Jiang, Junjun and Liu, Xianming and Ma, Jiayi}, 78 | title = {Fusion from Decomposition: A Self-Supervised Decomposition Approach for Image Fusion}, 79 | booktitle = {European Conference on Computer Vision (ECCV)}, 80 | year = {2022}, 81 | } 82 | ``` 83 | -------------------------------------------------------------------------------- /data/Readme.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /data/multi_exposure_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import os 3 | import json 4 | import torch 5 | import torchvision.transforms.functional as TF 6 | from torchvision.transforms import Compose, RandomResizedCrop, ToTensor 7 | import cv2 8 | from tqdm import tqdm 9 | from PIL import Image 10 | import random 11 | 12 | class TestDataset(Dataset): 13 | def __init__(self, valopt): 14 | super(TestDataset, self).__init__() 15 | path = valopt['dataroot'] 16 | self.img_transform = ToTensor() 17 | part_imgs = os.path.join(os.path.expanduser(path), valopt['part_name']) 18 | 19 | def not_dir(x): 20 | return '_MDF' not in x and '.DS_Store' not in x and '.txt' not in x 21 | 22 | part_imgs = [os.path.join(part_imgs, os_dir) for os_dir in filter(not_dir, os.listdir(part_imgs))] 23 | 24 | over_imgs = [] 25 | under_imgs = [] 26 | for img_seq in tqdm(part_imgs): 27 | ov_seq = {} 28 | for img_name in filter(not_dir, os.listdir(img_seq)): 29 | img = cv2.imread(os.path.join(img_seq, img_name), 1) 30 | ov_seq[img_name] = img.mean() 31 | over_imgs.append(os.path.join(img_seq, max(ov_seq, key=ov_seq.get))) 32 | under_imgs.append(os.path.join(img_seq, min(ov_seq, key=ov_seq.get))) 33 | 34 | over_imgs.sort() 35 | under_imgs.sort() 36 | 37 | self.iget_imgs = {} 38 | for o_img, u_img in zip(over_imgs, under_imgs): 39 | self.iget_imgs[o_img] = [o_img, u_img] 40 | 41 | self.iget_imgs = [(key, values) for key, values in self.iget_imgs.items()] 42 | self.iget_imgs = sorted(self.iget_imgs, key=lambda x: x[0]) 43 | 44 | def __len__(self): 45 | return len(self.iget_imgs) 46 | 47 | def __getitem__(self, index): 48 | c_img, (o_img, u_img) = self.iget_imgs[index] 49 | 50 | o_img = Image.open(o_img) 51 | u_img = Image.open(u_img) 52 | 53 | o_img = self.img_transform(o_img) 54 | u_img = self.img_transform(u_img) 55 | 56 | 57 | c_img = os.path.split(os.path.split(c_img)[-2])[-1] 58 | return o_img, u_img, c_img 59 | 60 | -------------------------------------------------------------------------------- /data/multi_focus_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import os 3 | import json 4 | import torch 5 | import torchvision.transforms.functional as TF 6 | from torchvision.transforms import Compose, RandomResizedCrop, ToTensor 7 | import torch.nn.functional as F 8 | import cv2 9 | from tqdm import tqdm 10 | from PIL import Image 11 | 12 | 13 | class TestDataset(Dataset): 14 | def __init__(self, valopt): 15 | super(TestDataset, self).__init__() 16 | path = valopt['dataroot'] 17 | self.img_transform = ToTensor() 18 | part_imgs = os.path.join(os.path.expanduser(path), valopt['part_name']) 19 | 20 | def not_dir(x): 21 | return '_MDF' not in x and '.DS_Store' not in x and '.txt' not in x 22 | 23 | part_imgs = [os.path.join(part_imgs, os_dir) for os_dir in filter(not_dir, os.listdir(part_imgs))] 24 | 25 | up_imgs = [] 26 | low_imgs = [] 27 | gt_imgs = [] 28 | for img_seq in tqdm(part_imgs): 29 | gt, up, low = sorted(os.listdir(img_seq)) 30 | # print('gt: ', gt, 'up: ', up, 'low: ', low) 31 | up_imgs.append(os.path.join(img_seq, up)) 32 | low_imgs.append(os.path.join(img_seq, low)) 33 | gt_imgs.append(os.path.join(img_seq, gt)) 34 | 35 | up_imgs.sort() 36 | low_imgs.sort() 37 | gt_imgs.sort() 38 | 39 | self.iget_imgs = {} 40 | for o_img, u_img, g_img in zip(up_imgs, low_imgs, gt_imgs): 41 | self.iget_imgs[o_img] = [o_img, u_img, g_img] 42 | 43 | self.iget_imgs = [(key, values) for key, values in self.iget_imgs.items()] 44 | self.iget_imgs = sorted(self.iget_imgs, key=lambda x: x[0]) 45 | 46 | def __len__(self): 47 | return len(self.iget_imgs) 48 | 49 | def __getitem__(self, index): 50 | c_img, (up_img, low_img, gt_img) = self.iget_imgs[index] 51 | 52 | up_img = Image.open(up_img) 53 | low_img = Image.open(low_img) 54 | 55 | up_img = self.img_transform(up_img) 56 | low_img = self.img_transform(low_img) 57 | 58 | c_img = os.path.split(c_img)[-1].split('.')[0] 59 | return up_img, low_img, c_img 60 | 61 | 62 | 63 | class TestMFFDataset(Dataset): 64 | def __init__(self, valopt): 65 | super(TestMFFDataset, self).__init__() 66 | path = valopt['dataroot'] 67 | self.img_transform = ToTensor() 68 | part_imgs = os.path.join(os.path.expanduser(path), valopt['part_name']) 69 | 70 | def not_dir(x): 71 | return '_MDF' not in x and '.DS_Store' not in x and '.txt' not in x 72 | 73 | part_imgs = [os.path.join(part_imgs, os_dir) for os_dir in filter(not_dir, os.listdir(part_imgs))] 74 | 75 | up_imgs = [] 76 | low_imgs = [] 77 | gt_imgs = [] 78 | for img_seq in tqdm(part_imgs): 79 | up, low = sorted(os.listdir(img_seq)) 80 | up_imgs.append(os.path.join(img_seq, up)) 81 | low_imgs.append(os.path.join(img_seq, low)) 82 | 83 | up_imgs.sort() 84 | low_imgs.sort() 85 | 86 | self.iget_imgs = {} 87 | for o_img, u_img in zip(up_imgs, low_imgs): 88 | self.iget_imgs[o_img] = [o_img, u_img] 89 | 90 | self.iget_imgs = [(key, values) for key, values in self.iget_imgs.items()] 91 | self.iget_imgs = sorted(self.iget_imgs, key=lambda x: x[0]) 92 | 93 | def __len__(self): 94 | return len(self.iget_imgs) 95 | 96 | def __getitem__(self, index): 97 | c_img, (up_img, low_img) = self.iget_imgs[index] 98 | 99 | up_img = Image.open(up_img) 100 | low_img = Image.open(low_img) 101 | 102 | up_img = self.img_transform(up_img) 103 | low_img = self.img_transform(low_img) 104 | 105 | 106 | c_img = os.path.split(c_img)[-1].split('.')[0] 107 | return up_img, low_img, c_img 108 | 109 | -------------------------------------------------------------------------------- /data/self_mixpretrain_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import os 3 | import json 4 | import torch 5 | from torchvision.transforms import Compose, RandomResizedCrop, ToTensor, RandomCrop, ColorJitter 6 | from torch.distributions.bernoulli import Bernoulli 7 | import torch.nn.functional as F 8 | import cv2 9 | from tqdm import tqdm 10 | from PIL import Image 11 | from kornia.filters import GaussianBlur2d, gaussian_blur2d 12 | import random 13 | 14 | 15 | class TrainDataset(Dataset): 16 | def __init__(self, trainopt): 17 | super(TrainDataset, self).__init__() 18 | path = trainopt['dataroot'] 19 | self.img_transform = Compose([ColorJitter(brightness=0.8, contrast=0.8, saturation=0.8, hue=0.2), ToTensor()]) 20 | self.piltotensor = ToTensor() 21 | self.colorjitter = ColorJitter(brightness=(0.1, 1.0), contrast=(0.1, 1.0), saturation=(0.1, 1.0), hue=0) 22 | # self.colorjitter = ColorJitter() 23 | # self.brightness = ColorJitter(brightness=(0.4, 1)) 24 | self.brightness = ColorJitter() 25 | self.img_resize = RandomResizedCrop(trainopt['image_size']) 26 | self.image_size = trainopt['image_size'] 27 | self.batch_size = trainopt['iter_size'] 28 | self.max_iter = trainopt["max_iter"] 29 | self.epoch_num = 0 30 | self.same_radio = 0.2 31 | self.noise_radio = 0.2 32 | self.grid_num1 = 14 #14 33 | self.rand_dist1 = Bernoulli(probs=torch.ones(self.grid_num1 ** 2) * 0.99) 34 | self.grid_num2 = 28 #56 # 28 35 | self.rand_dist2 = Bernoulli(probs=torch.ones(self.grid_num2 ** 2) * 0.95) 36 | self.grid_num3 = 28 # 224 # 28 37 | self.rand_dist3 = Bernoulli(probs=torch.ones(self.grid_num3 ** 2) * 0.5) 38 | self.train_type = ['self_supervised_mixup', 'self_supervised_common_easy', 39 | 'self_supervised_upper', 'self_supervised_lower'] 40 | 41 | trainpairs_forreading = str(trainopt['trainpairs']) 42 | if not os.path.exists(trainpairs_forreading): 43 | self.iget_imgs = {} 44 | extra_imgs = [] 45 | extra_imgs = os.path.join(os.path.expanduser(path), trainopt["train_name"]) 46 | extra_imgs = [os.path.join(extra_imgs, os_dir) for os_dir in os.listdir(extra_imgs)] 47 | extra_imgs = random.sample(extra_imgs, k=50000) 48 | 49 | lowlight_imgs = [] 50 | 51 | def not_dir(x): 52 | return '_MDF' not in x and '.DS_Store' not in x and '.txt' not in x 53 | 54 | lowlight_imgs = list(filter(not_dir, lowlight_imgs)) 55 | extra_imgs = extra_imgs + lowlight_imgs 56 | 57 | for self_img in extra_imgs: 58 | rand_state= torch.rand(1) 59 | if rand_state < 1.0: 60 | self.iget_imgs[str(self_img) + ':' + self.train_type[0]] = [self_img, self_img, self.train_type[0]] 61 | elif rand_state < 0.8: 62 | self.iget_imgs[str(self_img) + ':' + self.train_type[1]] = [self_img, self_img, self.train_type[1]] 63 | elif rand_state < 0.9: 64 | self.iget_imgs[str(self_img) + ':' + self.train_type[2]] = [self_img, self_img, self.train_type[2]] 65 | else: 66 | self.iget_imgs[str(self_img) + ':' + self.train_type[3]] = [self_img, self_img, self.train_type[3]] 67 | 68 | with open(trainpairs_forreading, 'w') as f: 69 | json.dump(self.iget_imgs, f) 70 | else: 71 | with open(trainpairs_forreading, 'r') as f: 72 | self.iget_imgs = json.load(f) 73 | self.iget_imgs = [(key, values) for key, values in self.iget_imgs.items()] 74 | self.iget_imgs = sorted(self.iget_imgs, key=lambda x: x[0]) 75 | 76 | def __len__(self): 77 | return len(self.iget_imgs) 78 | 79 | def reassign_mask(self, radio): 80 | self.rand_dist1 = Bernoulli(probs=torch.ones(self.grid_num1 ** 2) * radio) 81 | self.rand_dist2 = Bernoulli(probs=torch.ones(self.grid_num2 ** 2) * radio) 82 | self.rand_dist3 = Bernoulli(probs=torch.ones(self.grid_num3 ** 2) * radio) 83 | 84 | def __getitem__(self, index): 85 | c_img, (o_img, u_img, train_type) = self.iget_imgs[index] 86 | 87 | 88 | o_img = Image.open(o_img).convert("RGB") 89 | # o_img = self.img_transform(o_img) 90 | o_img = self.piltotensor(o_img) 91 | o_img = self.colorjitter(o_img.unsqueeze(0)).squeeze() 92 | 93 | u_img = o_img.clone() 94 | # print("size", o_img.shape, u_img.shape) 95 | combime_img = self.img_resize(torch.cat((o_img, u_img), dim=0)) 96 | o_img = combime_img[:3, :, :] 97 | u_img = combime_img[3:, :, :] 98 | gt_img = torch.cat((torch.zeros_like(o_img[:2, :, :]), combime_img), dim=0) 99 | 100 | if train_type == self.train_type[0]: 101 | self.grid_num1, self.grid_num2, self.grid_num3 = torch.randint(13, 14, [1]).item(), \ 102 | torch.randint(26, 28, [1]).item(), \ 103 | torch.randint(26, 28, [1]).item() ## 222 226 # 27 28 104 | self.reassign_mask(0.1 + torch.rand(1).item()/1.12) 105 | grid1 = F.interpolate(self.rand_dist1.sample().reshape(1, 1, self.grid_num1, self.grid_num1), 106 | size=self.image_size, mode='nearest').squeeze() 107 | grid1 *= F.interpolate(self.rand_dist2.sample().reshape(1, 1, self.grid_num2, self.grid_num2), 108 | size=self.image_size, mode='nearest').squeeze() 109 | # grid1 *= F.interpolate(self.rand_dist3.sample().reshape(1, 1, self.grid_num3, self.grid_num3), 110 | # size=self.image_size, mode='nearest').squeeze() 111 | grid2 = F.interpolate(self.rand_dist1.sample().reshape(1, 1, self.grid_num1, self.grid_num1), 112 | size=self.image_size, mode='nearest').squeeze() 113 | grid2 *= F.interpolate(self.rand_dist2.sample().reshape(1, 1, self.grid_num2, self.grid_num2), 114 | size=self.image_size, mode='nearest').squeeze() 115 | # grid2 *= F.interpolate(self.rand_dist3.sample().reshape(1, 1, self.grid_num3, self.grid_num3), 116 | # size=self.image_size, mode='nearest').squeeze() 117 | 118 | sample_rand = torch.rand(1).item() 119 | if sample_rand < 0.33: 120 | grid3 = F.interpolate(self.rand_dist1.sample().reshape(1, 1, self.grid_num1, self.grid_num1), 121 | size=self.image_size, mode='nearest').squeeze() 122 | elif sample_rand < 0.66: 123 | grid3 = F.interpolate(self.rand_dist2.sample().reshape(1, 1, self.grid_num2, self.grid_num2), 124 | size=self.image_size, mode='nearest').squeeze() 125 | else: 126 | grid3 = F.interpolate(self.rand_dist3.sample().reshape(1, 1, self.grid_num3, self.grid_num3), 127 | size=self.image_size, mode='nearest').squeeze() 128 | 129 | none_grid = ((grid1 == 0.0) & (grid2 == 0.0)).float() 130 | none_grid1 = grid3 * none_grid 131 | none_grid2 = (1 - grid3) * none_grid 132 | grid1 += none_grid1 133 | grid2 += none_grid2 134 | 135 | if torch.rand(1).item() < 0.5: 136 | mask1, mask2 = grid1, grid2 137 | else: 138 | mask1, mask2 = grid2, grid1 139 | 140 | o_rand = torch.randn_like(o_img).abs().clamp(0,1) 141 | u_rand = torch.randn_like(u_img).abs().clamp(0,1) 142 | o_img = o_img * mask1 + o_rand * (1.0 - mask1) #* torch.rand(1).item() 143 | u_img = u_img * mask2 + u_rand * (1.0 - mask2) #* torch.rand(1).item() 144 | gt_img = torch.cat((mask1.unsqueeze(0), mask2.unsqueeze(0), combime_img), dim=0) 145 | # print("gt_img shape", gt_img.shape) 146 | 147 | 148 | return o_img, u_img, gt_img, train_type 149 | -------------------------------------------------------------------------------- /data/visir_fusion_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import os 3 | import json 4 | import torch 5 | import torchvision.transforms.functional as TF 6 | from torchvision.transforms import Compose, RandomResizedCrop, ToTensor 7 | import torch.nn.functional as F 8 | import cv2 9 | from tqdm import tqdm 10 | from PIL import Image 11 | 12 | 13 | class TestDataset(Dataset): 14 | def __init__(self, valopt): 15 | super(TestDataset, self).__init__() 16 | path = valopt['dataroot'] 17 | self.img_transform = ToTensor() 18 | inf_imgs = os.path.join(os.path.expanduser(path), valopt['infrare_name']) 19 | vis_imgs = os.path.join(os.path.expanduser(path), valopt['visible_name']) 20 | 21 | def not_dir(x): 22 | return '_MDF' not in x and '.DS_Store' not in x and '.txt' not in x 23 | 24 | inf_imgs = [os.path.join(inf_imgs, os_dir) for os_dir in filter(not_dir, os.listdir(inf_imgs))] 25 | vis_imgs = [os.path.join(vis_imgs, os_dir) for os_dir in filter(not_dir, os.listdir(vis_imgs))] 26 | 27 | inf_imgs.sort() 28 | vis_imgs.sort() 29 | 30 | self.iget_imgs = {} 31 | for i_img, v_img in zip(inf_imgs, vis_imgs): 32 | self.iget_imgs[v_img] = [i_img, v_img] 33 | 34 | self.iget_imgs = [(key, values) for key, values in self.iget_imgs.items()] 35 | self.iget_imgs = sorted(self.iget_imgs, key=lambda x: x[0]) 36 | 37 | def __len__(self): 38 | return len(self.iget_imgs) 39 | 40 | def __getitem__(self, index): 41 | c_img, (up_img, low_img) = self.iget_imgs[index] 42 | 43 | up_img = Image.open(up_img).convert('L').convert('RGB') 44 | low_img = Image.open(low_img).convert('L').convert('RGB') 45 | 46 | up_img = self.img_transform(up_img) 47 | low_img = self.img_transform(low_img) 48 | 49 | c_img = os.path.split(c_img)[-1].split('.')[0] 50 | return up_img, low_img, c_img 51 | 52 | 53 | class TestTNODataset(Dataset): 54 | def __init__(self, valopt): 55 | super(TestTNODataset, self).__init__() 56 | path = valopt['dataroot'] 57 | self.img_transform = ToTensor() 58 | part_imgs = os.path.expanduser(path) 59 | 60 | def not_dir(x): 61 | return '_MDF' not in x and '.DS_Store' not in x and '.txt' not in x 62 | 63 | part_imgs = [os.path.join(part_imgs, os_dir) for os_dir in filter(not_dir, os.listdir(part_imgs))] 64 | 65 | ir_imgs = [] 66 | vi_imgs = [] 67 | gt_imgs = [] 68 | for img_seq in tqdm(part_imgs): 69 | up, low = sorted(os.listdir(img_seq)) 70 | if 'ir' in up.lower(): 71 | ir_imgs.append(os.path.join(img_seq, up)) 72 | vi_imgs.append(os.path.join(img_seq, low)) 73 | else: 74 | vi_imgs.append(os.path.join(img_seq, up)) 75 | ir_imgs.append(os.path.join(img_seq, low)) 76 | 77 | ir_imgs.sort() 78 | vi_imgs.sort() 79 | 80 | self.iget_imgs = {} 81 | for o_img, u_img in zip(ir_imgs, vi_imgs): 82 | self.iget_imgs[o_img] = [o_img, u_img] 83 | 84 | self.iget_imgs = [(key, values) for key, values in self.iget_imgs.items()] 85 | self.iget_imgs = sorted(self.iget_imgs, key=lambda x: x[0]) 86 | 87 | def __len__(self): 88 | return len(self.iget_imgs) 89 | 90 | def __getitem__(self, index): 91 | c_img, (up_img, low_img) = self.iget_imgs[index] 92 | 93 | up_img = Image.open(up_img).convert('RGB') 94 | low_img = Image.open(low_img).convert("RGB") 95 | 96 | up_img = self.img_transform(up_img) 97 | low_img = self.img_transform(low_img) 98 | 99 | c_img = os.path.split(os.path.split(c_img)[-2])[-1] 100 | return up_img, low_img, c_img 101 | 102 | -------------------------------------------------------------------------------- /loss/mix_fp_loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | class SelfTrainLoss(nn.Module): 5 | def __init__(self): 6 | super(SelfTrainLoss, self).__init__() 7 | self.l1_loss = nn.L1Loss() 8 | self.mse_loss = nn.MSELoss() 9 | self.is_train = False 10 | self.iteres = { 11 | 'self_supervised_common_mix': 0, 12 | 'self_supervised_upper_mix': 0, 13 | 'self_supervised_lower_mix': 0, 14 | 'self_supervised_fusion_mix': 0, 15 | 'total_loss': 0 16 | } 17 | 18 | def inital_losses(self, b_input_type, losses, compute_num): 19 | if 'self_supervised_mixup' in b_input_type: 20 | tmp = b_input_type.count('self_supervised_mixup') 21 | losses['self_supervised_common_mix'] = 0 22 | compute_num['self_supervised_common_mix'] = tmp 23 | losses['self_supervised_upper_mix'] = 0 24 | compute_num['self_supervised_upper_mix'] = tmp 25 | losses['self_supervised_lower_mix'] = 0 26 | compute_num['self_supervised_lower_mix'] = tmp 27 | losses['self_supervised_fusion_mix'] = 0 28 | compute_num['self_supervised_fusion_mix'] = tmp 29 | 30 | 31 | def forward(self, img1, img2, gt_img, common_part, upper_part, lower_part, fusion_part, b_input_type): 32 | losses = {} 33 | compute_num = {} 34 | losses['total_loss'] = 0 35 | 36 | self.inital_losses(b_input_type, losses, compute_num) 37 | for index, input_type in enumerate(b_input_type): 38 | common_part_i = common_part[index].unsqueeze(0) 39 | upper_part_i = upper_part[index].unsqueeze(0) 40 | lower_part_i = lower_part[index].unsqueeze(0) 41 | fusion_part_i = fusion_part[index].unsqueeze(0) 42 | img1_i = img1[index].unsqueeze(0) 43 | img2_i = img2[index].unsqueeze(0) 44 | gt_i = gt_img[index].unsqueeze(0) 45 | if input_type == 'self_supervised_mixup': 46 | mask1 = gt_i[:, 0:1, :, :] 47 | mask2 = gt_i[:, 1:2, :, :] 48 | gt_img1_i = gt_i[:, 2:5, :, :] 49 | gt_img2_i = gt_i[:, 5:8, :, :] 50 | common_mask = ((mask1 == 1.) & (mask2 == 1.)).float() 51 | gt_common_part = common_mask * gt_img1_i 52 | gt_upper_part = (mask1 - common_mask).abs() * gt_img1_i 53 | gt_lower_part = (mask2 - common_mask).abs() * gt_img2_i 54 | 55 | if self.iteres['total_loss'] < 3000: 56 | common_part_pre = common_part_i * common_mask 57 | upper_part_pre = upper_part_i * (mask1 - common_mask).abs() 58 | lower_part_pre = lower_part_i * (mask2 - common_mask).abs() 59 | common_part_post = 0 60 | upper_part_post = 0 61 | lower_part_post = 0 62 | else: 63 | annel_alpha = min(self.iteres['total_loss'], 7000) / 7000 64 | annel_alpha = annel_alpha ** 2 65 | annel_alpha = annel_alpha * 0.15 66 | lower_annel_beta = 1 67 | if self.iteres['total_loss'] > 40000: 68 | annel_alpha *= 0.1 69 | common_part_pre = common_part_i * annel_alpha + common_part_i * common_mask * (1 - annel_alpha) 70 | upper_part_pre = upper_part_i * annel_alpha + upper_part_i * (mask1 - common_mask).abs() * (1 - annel_alpha) 71 | lower_part_pre = lower_part_i * annel_alpha * lower_annel_beta + lower_part_i * (mask2 - common_mask).abs() * (1 - annel_alpha * lower_annel_beta) 72 | 73 | self_supervised_common_mix_loss = self.l1_loss(common_part_pre, gt_common_part) #\ 74 | losses['self_supervised_common_mix'] += self_supervised_common_mix_loss #+ self_supervised_common_mix_loss_a_channel 75 | self_supervised_upper_mix_loss = self.l1_loss(upper_part_pre, gt_upper_part) #+ 5 * \ 76 | losses['self_supervised_upper_mix'] += self_supervised_upper_mix_loss #+ self_supervised_upper_mix_loss_a_channel 77 | self_supervised_lower_mix_loss = self.l1_loss(lower_part_pre, gt_lower_part) #+ 5 * \ 78 | losses['self_supervised_lower_mix'] += self_supervised_lower_mix_loss #+ self_supervised_lower_mix_loss_a_channel 79 | 80 | if self.iteres['total_loss'] >= 17000: 81 | annel_beta = min(self.iteres['total_loss'] - 10000, 14000) / 14000 82 | annel_beta = annel_beta ** 2 83 | self_supervised_fusion_mix_loss = 1 * self.l1_loss(gt_img1_i, fusion_part_i) * annel_beta 84 | #+ 0 * self.ssim_loss(gt_img1_i, fusion_part_i)) 85 | else: 86 | self_supervised_fusion_mix_loss = torch.tensor(0.0).cuda() 87 | losses['self_supervised_fusion_mix'] += self_supervised_fusion_mix_loss 88 | losses['total_loss'] += self_supervised_common_mix_loss + self_supervised_upper_mix_loss \ 89 | + self_supervised_lower_mix_loss + self_supervised_fusion_mix_loss #\ 90 | 91 | 92 | for k, v in losses.items(): 93 | if k in self.iteres.keys(): 94 | self.iteres[k] += 1 95 | if k != 'total_loss': 96 | losses[k] = v / compute_num[k] 97 | 98 | return losses, self.iteres.copy() 99 | -------------------------------------------------------------------------------- /loss/readme.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /models/Readme.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /models/UCSharedModelPro.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.init as init 4 | from .resnet import ResNestLayer, Bottleneck 5 | import math 6 | 7 | 8 | class UCSharedNetPro(nn.Module): 9 | def __init__(self): 10 | super(UCSharedNetPro, self).__init__() 11 | encoder_upper = [nn.Conv2d(3, 16, 3, 1, 1, bias=True), 12 | nn.ReLU(inplace=True), 13 | ResNestLayer(Bottleneck, 8, 6, stem_width=8, norm_layer=None), 14 | ] 15 | self.encoder_upper = nn.Sequential(*encoder_upper) 16 | # self.encoder_upper_in = nn.InstanceNorm2d(64,affine=True) 17 | self.maxpool_upper = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 18 | self.upper_encoder_layer1 = ResNestLayer(Bottleneck, 16, 6, stem_width=16, norm_layer=None, is_first=False) 19 | self.upper_encoder_layer2 = ResNestLayer(Bottleneck, 32, 4, stem_width=32, stride=2, norm_layer=None) 20 | self.upper_encoder_layer3 = ResNestLayer(Bottleneck, 64, 4, stem_width=64, stride=2, norm_layer=None) 21 | 22 | self.encoder_lower = self.encoder_upper 23 | self.maxpool_lower = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 24 | self.lower_encoder_layer1 = self.upper_encoder_layer1 25 | self.lower_encoder_layer2 = self.upper_encoder_layer2 26 | self.lower_encoder_layer3 = self.upper_encoder_layer3 27 | 28 | encoder_body_fusion = [ 29 | ResNestLayer(Bottleneck, 256, 4, stem_width=256, norm_layer=None, is_first=False) 30 | ] 31 | self.common_encoder = nn.Sequential(*encoder_body_fusion) 32 | 33 | self.decoder_common_layer1 = ResNestLayer(Bottleneck, 64, 2, stem_width=512, avg_down=False, avd=False, stride=1, norm_layer=None) 34 | self.decoder_common_up1 = nn.Upsample(scale_factor=2, mode='bilinear') # nn.PixelShuffle(2) 35 | self.decoder_common_layer2 = ResNestLayer(Bottleneck, 16, 2, stem_width=128, avg_down=False, avd=False, stride=1, norm_layer=None) 36 | self.decoder_common_up2 = nn.Upsample(scale_factor=2, mode='bilinear') # nn.PixelShuffle(2) 37 | self.decoder_common_layer3 = ResNestLayer(Bottleneck, 4, 2, stem_width=32, avg_down=False, avd=False, stride=1, norm_layer=None) 38 | self.decoder_common_up3 = nn.Upsample(scale_factor=2, mode='bilinear') # nn.PixelShuffle(2) 39 | decoder_common_layer4 = [ 40 | ResNestLayer(Bottleneck, 4, 2, stem_width=8, avg_down=False, avd=False, stride=1, norm_layer=None), 41 | 42 | ] 43 | self.decoder_common_layer4 = nn.Sequential(*decoder_common_layer4) 44 | decoder_projection_layer = [nn.Conv2d(16, 3, 3, 1, 1, bias=True), 45 | nn.ReLU(inplace=True)] 46 | self.decoder_common_projection_layer = nn.Sequential(*decoder_projection_layer) 47 | 48 | self.decoder_upper_layer1 = ResNestLayer(Bottleneck, 96, 4, stem_width=640, avg_down=False, avd=False, stride=1, norm_layer=None) 49 | self.decoder_upper_up1 = nn.Upsample(scale_factor=2, mode='bilinear') # nn.PixelShuffle(2) 50 | # 192,64 -> 128, 32 51 | self.decoder_upper_layer2 = ResNestLayer(Bottleneck, 32, 4, stem_width=256, avg_down=False, avd=False, stride=1, norm_layer=None) 52 | self.decoder_upper_up2 = nn.Upsample(scale_factor=2, mode='bilinear') # nn.PixelShuffle(2) 53 | self.decoder_upper_layer3 = ResNestLayer(Bottleneck, 16, 6, stem_width=96, avg_down=False, avd=False, stride=1, norm_layer=None) 54 | self.decoder_upper_up3 = nn.Upsample(scale_factor=2, mode='bilinear') # nn.PixelShuffle(2) 55 | decoder_upper_layer4 = [ 56 | ResNestLayer(Bottleneck, 4, 6, stem_width=32, avg_down=False, avd=False, stride=1, norm_layer=None), 57 | ] 58 | self.decoder_upper_layer4 = nn.Sequential(*decoder_upper_layer4) 59 | upper_decoder_projection_layer = [nn.Conv2d(16, 3, 3, 1, 1, bias=True), 60 | nn.ReLU(inplace=True)] 61 | self.decoder_upper_projection_layer = nn.Sequential(*upper_decoder_projection_layer ) 62 | 63 | self.decoder_lower_layer1 = self.decoder_upper_layer1 64 | self.decoder_lower_up1 = nn.Upsample(scale_factor=2, mode='bilinear') # nn.PixelShuffle(2) 65 | self.decoder_lower_layer2 = self.decoder_upper_layer2 66 | self.decoder_lower_up2 = nn.Upsample(scale_factor=2, mode='bilinear') # nn.PixelShuffle(2) 67 | self.decoder_lower_layer3 = self.decoder_upper_layer3 68 | self.decoder_lower_up3 = nn.Upsample(scale_factor=2, mode='bilinear') # nn.PixelShuffle(2) 69 | self.decoder_lower_layer4 = self.decoder_upper_layer4 70 | self.decoder_lower_projection_layer = self.decoder_upper_projection_layer 71 | 72 | self.fusion_rule = nn.Sequential(*[ 73 | nn.Conv2d(16, 3, 3, 1, 1, bias=True), 74 | nn.ReLU(inplace=True) 75 | ]) 76 | 77 | for m in self.modules(): 78 | if isinstance(m, nn.Conv2d): 79 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 80 | m.weight.data.normal_(0, math.sqrt(2. / n)) 81 | elif isinstance(m, nn.InstanceNorm2d): 82 | m.weight.data.fill_(1) 83 | m.bias.data.zero_() 84 | 85 | 86 | def forward(self, img1, img2): 87 | 88 | feature_upper = self.encoder_upper(img1) 89 | feature_upper0 = self.maxpool_upper(feature_upper) 90 | feature_upper1 = self.upper_encoder_layer1(feature_upper0) 91 | feature_upper2 = self.upper_encoder_layer2(feature_upper1) 92 | feature_upper3 = self.upper_encoder_layer3(feature_upper2) 93 | 94 | feature_lower = self.encoder_lower(img2) 95 | feature_lower0 = self.maxpool_lower(feature_lower) 96 | feature_lower1 = self.lower_encoder_layer1(feature_lower0) 97 | feature_lower2 = self.lower_encoder_layer2(feature_lower1) 98 | feature_lower3 = self.lower_encoder_layer3(feature_lower2) 99 | 100 | feature_concat = torch.cat((feature_upper3, feature_lower3), dim=1) 101 | feature_common = self.common_encoder(feature_concat) 102 | 103 | common_part = self.decoder_common_layer1(feature_common) 104 | common_part = self.decoder_common_up1(common_part) 105 | common_part = self.decoder_common_layer2(common_part) 106 | common_part = self.decoder_common_up2(common_part) 107 | common_part = self.decoder_common_layer3(common_part) 108 | common_part = self.decoder_common_up3(common_part) 109 | common_part = self.decoder_common_layer4(common_part) 110 | common_part_embedding = common_part 111 | common_part = self.decoder_upper_projection_layer(common_part) 112 | 113 | feature_de_upper = torch.cat((feature_common, feature_upper3), dim=1) 114 | upper_part = self.decoder_upper_layer1(feature_de_upper) 115 | upper_part = self.decoder_upper_up1(upper_part) 116 | upper_part = torch.cat((upper_part, feature_upper2), dim=1) 117 | upper_part = self.decoder_upper_layer2(upper_part) 118 | upper_part = self.decoder_upper_up2(upper_part) 119 | upper_part = torch.cat((upper_part, feature_upper1), dim=1) 120 | upper_part = self.decoder_upper_layer3(upper_part) 121 | upper_part = self.decoder_upper_up3(upper_part) 122 | upper_part = self.decoder_upper_layer4(upper_part) 123 | upper_part_embeding = upper_part 124 | upper_part = self.decoder_upper_projection_layer(upper_part) 125 | 126 | feature_de_lower = torch.cat((feature_common, feature_lower3), dim=1) 127 | lower_part = self.decoder_lower_layer1(feature_de_lower) 128 | lower_part = self.decoder_lower_up1(lower_part) 129 | lower_part = torch.cat((lower_part, feature_lower2), dim=1) 130 | lower_part = self.decoder_lower_layer2(lower_part) 131 | lower_part = self.decoder_lower_up2(lower_part) 132 | lower_part = torch.cat((lower_part, feature_lower1), dim=1) 133 | lower_part = self.decoder_lower_layer3(lower_part) 134 | lower_part = self.decoder_lower_up3(lower_part) 135 | lower_part = self.decoder_lower_layer4(lower_part) 136 | lower_part_embeddding = lower_part 137 | lower_part = self.decoder_lower_projection_layer(lower_part) 138 | 139 | fusion_part = self.fusion_rule(upper_part_embeding+ lower_part_embeddding+ common_part_embedding) 140 | 141 | return common_part, upper_part, lower_part, fusion_part 142 | -------------------------------------------------------------------------------- /models/UCSharedModelProCommon.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.init as init 4 | from .resnet import ResNestLayer, Bottleneck 5 | import math 6 | 7 | 8 | class UCSharedNetPro(nn.Module): 9 | def __init__(self): 10 | super(UCSharedNetPro, self).__init__() 11 | encoder_upper = [nn.Conv2d(3, 16, 3, 1, 1, bias=True), 12 | nn.ReLU(inplace=True), 13 | ResNestLayer(Bottleneck, 8, 6, stem_width=8, norm_layer=None), 14 | ] 15 | self.encoder_upper = nn.Sequential(*encoder_upper) 16 | self.maxpool_upper = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 17 | self.upper_encoder_layer1 = ResNestLayer(Bottleneck, 16, 6, stem_width=16, norm_layer=None, is_first=False) 18 | self.upper_encoder_layer2 = ResNestLayer(Bottleneck, 32, 4, stem_width=32, stride=2, norm_layer=None) 19 | self.upper_encoder_layer3 = ResNestLayer(Bottleneck, 64, 4, stem_width=64, stride=2, norm_layer=None) 20 | 21 | self.encoder_lower = self.encoder_upper 22 | self.maxpool_lower = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 23 | self.lower_encoder_layer1 = self.upper_encoder_layer1 24 | self.lower_encoder_layer2 = self.upper_encoder_layer2 25 | self.lower_encoder_layer3 = self.upper_encoder_layer3 26 | 27 | encoder_body_fusion = [ 28 | ResNestLayer(Bottleneck, 256, 4, stem_width=256, norm_layer=None, is_first=False) 29 | ] 30 | self.common_encoder = nn.Sequential(*encoder_body_fusion) 31 | 32 | self.decoder_common_layer1 = ResNestLayer(Bottleneck, 64, 4, stem_width=768, avg_down=False, avd=False, stride=1, norm_layer=None) 33 | self.decoder_common_up1 = nn.Upsample(scale_factor=2, mode='bilinear') # nn.PixelShuffle(2) 34 | self.decoder_common_layer2 = ResNestLayer(Bottleneck, 16, 4, stem_width=256, avg_down=False, avd=False, stride=1, norm_layer=None) 35 | self.decoder_common_up2 = nn.Upsample(scale_factor=2, mode='bilinear') # nn.PixelShuffle(2) 36 | self.decoder_common_layer3 = ResNestLayer(Bottleneck, 4, 6, stem_width=96, avg_down=False, avd=False, stride=1, norm_layer=None) 37 | self.decoder_common_up3 = nn.Upsample(scale_factor=2, mode='bilinear') # nn.PixelShuffle(2) 38 | decoder_common_layer4 = [ 39 | ResNestLayer(Bottleneck, 4, 6, stem_width=8, avg_down=False, avd=False, stride=1, norm_layer=None), 40 | ] 41 | self.decoder_common_layer4 = nn.Sequential(*decoder_common_layer4) 42 | decoder_projection_layer = [nn.Conv2d(16, 3, 3, 1, 1, bias=True), 43 | nn.ReLU(inplace=True)] 44 | self.decoder_common_projection_layer = nn.Sequential(*decoder_projection_layer) 45 | 46 | self.decoder_upper_layer1 = ResNestLayer(Bottleneck, 96, 4, stem_width=640, avg_down=False, avd=False, stride=1, norm_layer=None) 47 | self.decoder_upper_up1 = nn.Upsample(scale_factor=2, mode='bilinear') # nn.PixelShuffle(2) 48 | # 192,64 -> 128, 32 49 | self.decoder_upper_layer2 = ResNestLayer(Bottleneck, 32, 4, stem_width=256, avg_down=False, avd=False, stride=1, norm_layer=None) 50 | self.decoder_upper_up2 = nn.Upsample(scale_factor=2, mode='bilinear') # nn.PixelShuffle(2) 51 | self.decoder_upper_layer3 = ResNestLayer(Bottleneck, 16, 6, stem_width=96, avg_down=False, avd=False, stride=1, norm_layer=None) 52 | self.decoder_upper_up3 = nn.Upsample(scale_factor=2, mode='bilinear') # nn.PixelShuffle(2) 53 | decoder_upper_layer4 = [ 54 | ResNestLayer(Bottleneck, 4, 6, stem_width=32, avg_down=False, avd=False, stride=1, norm_layer=None), 55 | ] 56 | self.decoder_upper_layer4 = nn.Sequential(*decoder_upper_layer4) 57 | upper_decoder_projection_layer = [nn.Conv2d(16, 3, 3, 1, 1, bias=True), 58 | nn.ReLU(inplace=True)] 59 | self.decoder_upper_projection_layer = nn.Sequential(*upper_decoder_projection_layer ) 60 | 61 | self.decoder_lower_layer1 = self.decoder_upper_layer1 62 | self.decoder_lower_up1 = nn.Upsample(scale_factor=2, mode='bilinear') # nn.PixelShuffle(2) 63 | self.decoder_lower_layer2 = self.decoder_upper_layer2 64 | self.decoder_lower_up2 = nn.Upsample(scale_factor=2, mode='bilinear') # nn.PixelShuffle(2) 65 | self.decoder_lower_layer3 = self.decoder_upper_layer3 66 | self.decoder_lower_up3 = nn.Upsample(scale_factor=2, mode='bilinear') # nn.PixelShuffle(2) 67 | self.decoder_lower_layer4 = self.decoder_upper_layer4 68 | self.decoder_lower_projection_layer = self.decoder_upper_projection_layer 69 | 70 | self.fusion_rule = nn.Sequential(*[ 71 | nn.Conv2d(16, 3, 3, 1, 1, bias=True), 72 | nn.ReLU(inplace=True) 73 | ]) 74 | 75 | for m in self.modules(): 76 | if isinstance(m, nn.Conv2d): 77 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 78 | m.weight.data.normal_(0, math.sqrt(2. / n)) 79 | elif isinstance(m, nn.InstanceNorm2d): 80 | m.weight.data.fill_(1) 81 | m.bias.data.zero_() 82 | 83 | def forward(self, img1, img2): 84 | 85 | feature_upper = self.encoder_upper(img1) 86 | feature_upper0 = self.maxpool_upper(feature_upper) 87 | feature_upper1 = self.upper_encoder_layer1(feature_upper0) 88 | feature_upper2 = self.upper_encoder_layer2(feature_upper1) 89 | feature_upper3 = self.upper_encoder_layer3(feature_upper2) 90 | 91 | feature_lower = self.encoder_lower(img2) 92 | feature_lower0 = self.maxpool_lower(feature_lower) 93 | feature_lower1 = self.lower_encoder_layer1(feature_lower0) 94 | feature_lower2 = self.lower_encoder_layer2(feature_lower1) 95 | feature_lower3 = self.lower_encoder_layer3(feature_lower2) 96 | 97 | feature_concat = torch.cat((feature_upper3, feature_lower3), dim=1) 98 | feature_common = self.common_encoder(feature_concat) 99 | 100 | common_part = torch.cat((feature_common, feature_upper3, feature_lower3), dim=1) 101 | common_part = self.decoder_common_layer1(feature_common) 102 | common_part = self.decoder_common_up1(common_part) 103 | common_part = torch.cat((common_part, feature_upper2, feature_lower2), dim=1) 104 | common_part = self.decoder_common_layer2(common_part) 105 | common_part = self.decoder_common_up2(common_part) 106 | common_part = torch.cat((common_part, feature_upper1, feature_lower1), dim=1) 107 | common_part = self.decoder_common_layer3(common_part) 108 | common_part = self.decoder_common_up3(common_part) 109 | common_part = self.decoder_common_layer4(common_part) 110 | common_part_embedding = common_part 111 | common_part = self.decoder_common_projection_layer(common_part) 112 | 113 | feature_de_upper = torch.cat((feature_common, feature_upper3), dim=1) 114 | upper_part = self.decoder_upper_layer1(feature_de_upper) 115 | upper_part = self.decoder_upper_up1(upper_part) 116 | upper_part = torch.cat((upper_part, feature_upper2), dim=1) 117 | upper_part = self.decoder_upper_layer2(upper_part) 118 | upper_part = self.decoder_upper_up2(upper_part) 119 | upper_part = torch.cat((upper_part, feature_upper1), dim=1) 120 | upper_part = self.decoder_upper_layer3(upper_part) 121 | upper_part = self.decoder_upper_up3(upper_part) 122 | upper_part = self.decoder_upper_layer4(upper_part) 123 | upper_part_embeding = upper_part 124 | upper_part = self.decoder_upper_projection_layer(upper_part) 125 | 126 | feature_de_lower = torch.cat((feature_common, feature_lower3), dim=1) 127 | lower_part = self.decoder_lower_layer1(feature_de_lower) 128 | lower_part = self.decoder_lower_up1(lower_part) 129 | lower_part = torch.cat((lower_part, feature_lower2), dim=1) 130 | lower_part = self.decoder_lower_layer2(lower_part) 131 | lower_part = self.decoder_lower_up2(lower_part) 132 | lower_part = torch.cat((lower_part, feature_lower1), dim=1) 133 | lower_part = self.decoder_lower_layer3(lower_part) 134 | lower_part = self.decoder_lower_up3(lower_part) 135 | lower_part = self.decoder_lower_layer4(lower_part) 136 | lower_part_embeddding = lower_part 137 | lower_part = self.decoder_lower_projection_layer(lower_part) 138 | 139 | fusion_part = self.fusion_rule(upper_part_embeding+ lower_part_embeddding+ common_part_embedding) 140 | 141 | return common_part, upper_part, lower_part, fusion_part 142 | -------------------------------------------------------------------------------- /models/UCTestShareModelProCommon.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.init as init 4 | from .resnet import ResNestLayer, Bottleneck 5 | import math 6 | import torch.nn.functional as F 7 | 8 | 9 | class UCTestSharedNetPro(nn.Module): 10 | def __init__(self): 11 | super(UCTestSharedNetPro, self).__init__() 12 | encoder_upper = [nn.Conv2d(3, 16, 3, 1, 1, bias=True), 13 | nn.ReLU(inplace=True), 14 | ResNestLayer(Bottleneck, 8, 6, stem_width=8, norm_layer=None), 15 | ] 16 | self.encoder_upper = nn.Sequential(*encoder_upper) 17 | # self.encoder_upper_in = nn.InstanceNorm2d(64,affine=True) 18 | self.maxpool_upper = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 19 | self.upper_encoder_layer1 = ResNestLayer(Bottleneck, 16, 6, stem_width=16, norm_layer=None, is_first=False) 20 | self.upper_encoder_layer2 = ResNestLayer(Bottleneck, 32, 4, stem_width=32, stride=2, norm_layer=None) 21 | self.upper_encoder_layer3 = ResNestLayer(Bottleneck, 64, 4, stem_width=64, stride=2, norm_layer=None) 22 | 23 | self.encoder_lower = self.encoder_upper 24 | self.maxpool_lower = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 25 | self.lower_encoder_layer1 = self.upper_encoder_layer1 26 | self.lower_encoder_layer2 = self.upper_encoder_layer2 27 | self.lower_encoder_layer3 = self.upper_encoder_layer3 28 | 29 | encoder_body_fusion = [ 30 | ResNestLayer(Bottleneck, 256, 4, stem_width=256, norm_layer=None, is_first=False) 31 | ] 32 | self.common_encoder = nn.Sequential(*encoder_body_fusion) 33 | 34 | self.decoder_common_layer1 = ResNestLayer(Bottleneck, 64, 2, stem_width=768, avg_down=False, avd=False, stride=1, norm_layer=None) 35 | self.decoder_common_up1 = nn.Upsample(scale_factor=2, mode='bilinear') # nn.PixelShuffle(2) 36 | self.decoder_common_layer2 = ResNestLayer(Bottleneck, 16, 2, stem_width=256, avg_down=False, avd=False, stride=1, norm_layer=None) 37 | self.decoder_common_up2 = nn.Upsample(scale_factor=2, mode='bilinear') # nn.PixelShuffle(2) 38 | self.decoder_common_layer3 = ResNestLayer(Bottleneck, 4, 2, stem_width=96, avg_down=False, avd=False, stride=1, norm_layer=None) 39 | self.decoder_common_up3 = nn.Upsample(scale_factor=2, mode='bilinear') # nn.PixelShuffle(2) 40 | decoder_common_layer4 = [ 41 | ResNestLayer(Bottleneck, 4, 2, stem_width=8, avg_down=False, avd=False, stride=1, norm_layer=None), 42 | ] 43 | self.decoder_common_layer4 = nn.Sequential(*decoder_common_layer4) 44 | decoder_projection_layer = [nn.Conv2d(16, 3, 3, 1, 1, bias=True), 45 | nn.ReLU(inplace=True)] 46 | self.decoder_common_projection_layer = nn.Sequential(*decoder_projection_layer) 47 | 48 | self.decoder_upper_layer1 = ResNestLayer(Bottleneck, 96, 4, stem_width=640, avg_down=False, avd=False, stride=1, norm_layer=None) 49 | self.decoder_upper_up1 = nn.Upsample(scale_factor=2, mode='bilinear') # nn.PixelShuffle(2) 50 | self.decoder_upper_layer2 = ResNestLayer(Bottleneck, 32, 4, stem_width=256, avg_down=False, avd=False, stride=1, norm_layer=None) 51 | self.decoder_upper_up2 = nn.Upsample(scale_factor=2, mode='bilinear') # nn.PixelShuffle(2) 52 | self.decoder_upper_layer3 = ResNestLayer(Bottleneck, 16, 6, stem_width=96, avg_down=False, avd=False, stride=1, norm_layer=None) 53 | self.decoder_upper_up3 = nn.Upsample(scale_factor=2, mode='bilinear') # nn.PixelShuffle(2) 54 | decoder_upper_layer4 = [ 55 | ResNestLayer(Bottleneck, 4, 6, stem_width=32, avg_down=False, avd=False, stride=1, norm_layer=None), 56 | ] 57 | self.decoder_upper_layer4 = nn.Sequential(*decoder_upper_layer4) 58 | upper_decoder_projection_layer = [nn.Conv2d(16, 3, 3, 1, 1, bias=True), 59 | nn.ReLU(inplace=True)] 60 | self.decoder_upper_projection_layer = nn.Sequential(*upper_decoder_projection_layer) 61 | 62 | self.decoder_lower_layer1 = self.decoder_upper_layer1 63 | self.decoder_lower_up1 = nn.Upsample(scale_factor=2, mode='bilinear') # nn.PixelShuffle(2) 64 | self.decoder_lower_layer2 = self.decoder_upper_layer2 65 | self.decoder_lower_up2 = nn.Upsample(scale_factor=2, mode='bilinear') # nn.PixelShuffle(2) 66 | self.decoder_lower_layer3 = self.decoder_upper_layer3 67 | self.decoder_lower_up3 = nn.Upsample(scale_factor=2, mode='bilinear') # nn.PixelShuffle(2) 68 | self.decoder_lower_layer4 = self.decoder_upper_layer4 69 | self.decoder_lower_projection_layer = self.decoder_upper_projection_layer 70 | 71 | self.fusion_rule = nn.Sequential(*[ 72 | nn.Conv2d(16, 3, 3, 1, 1, bias=True), 73 | nn.ReLU(inplace=True) 74 | ]) 75 | 76 | for m in self.modules(): 77 | if isinstance(m, nn.Conv2d): 78 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 79 | m.weight.data.normal_(0, math.sqrt(2. / n)) 80 | elif isinstance(m, nn.InstanceNorm2d): 81 | m.weight.data.fill_(1) 82 | m.bias.data.zero_() 83 | 84 | def forward(self, img1, img2): 85 | 86 | feature_upper = self.encoder_upper(img1) 87 | feature_upper0 = self.maxpool_upper(feature_upper) 88 | feature_upper1 = self.upper_encoder_layer1(feature_upper0) 89 | feature_upper2 = self.upper_encoder_layer2(feature_upper1) 90 | feature_upper3 = self.upper_encoder_layer3(feature_upper2) 91 | 92 | feature_lower = self.encoder_lower(img2) 93 | feature_lower0 = self.maxpool_lower(feature_lower) 94 | feature_lower1 = self.lower_encoder_layer1(feature_lower0) 95 | feature_lower2 = self.lower_encoder_layer2(feature_lower1) 96 | feature_lower3 = self.lower_encoder_layer3(feature_lower2) 97 | 98 | feature_concat = torch.cat((feature_upper3, feature_lower3), dim=1) 99 | feature_common = self.common_encoder(feature_concat) 100 | 101 | common_part = torch.cat((feature_common, feature_upper3, feature_lower3), dim=1) 102 | common_part = self.decoder_common_layer1(common_part) 103 | common_part = self.decoder_common_up1(common_part) 104 | common_part = F.interpolate(common_part, size=feature_upper2.shape[2:]) 105 | common_part = torch.cat((common_part, feature_upper2, feature_lower2), dim=1) 106 | common_part = self.decoder_common_layer2(common_part) 107 | common_part = self.decoder_common_up2(common_part) 108 | common_part = F.interpolate(common_part, size=feature_upper1.shape[2:]) 109 | common_part = torch.cat((common_part, feature_upper1, feature_lower1), dim=1) 110 | common_part = self.decoder_common_layer3(common_part) 111 | common_part = self.decoder_common_up3(common_part) 112 | common_part = F.interpolate(common_part, size=feature_upper.shape[2:]) 113 | common_part = self.decoder_common_layer4(common_part) 114 | common_part_embedding = common_part 115 | 116 | feature_de_upper = torch.cat((feature_common, feature_upper3), dim=1) 117 | upper_part = self.decoder_upper_layer1(feature_de_upper) 118 | upper_part = self.decoder_upper_up1(upper_part) 119 | upper_part = F.interpolate(upper_part, size=feature_upper2.shape[2:]) 120 | upper_part = torch.cat((upper_part, feature_upper2), dim=1) 121 | upper_part = self.decoder_upper_layer2(upper_part) 122 | upper_part = self.decoder_upper_up2(upper_part) 123 | upper_part = F.interpolate(upper_part, size=feature_upper1.shape[2:]) 124 | upper_part = torch.cat((upper_part, feature_upper1), dim=1) 125 | upper_part = self.decoder_upper_layer3(upper_part) 126 | upper_part = self.decoder_upper_up3(upper_part) 127 | upper_part = F.interpolate(upper_part, size=feature_upper.shape[2:]) 128 | upper_part = self.decoder_upper_layer4(upper_part) 129 | upper_part_embeding = upper_part 130 | 131 | feature_de_lower = torch.cat((feature_common, feature_lower3), dim=1) 132 | lower_part = self.decoder_lower_layer1(feature_de_lower) 133 | lower_part = self.decoder_lower_up1(lower_part) 134 | lower_part = F.interpolate(lower_part, size=feature_upper2.shape[2:]) 135 | lower_part = torch.cat((lower_part, feature_lower2), dim=1) 136 | lower_part = self.decoder_lower_layer2(lower_part) 137 | lower_part = self.decoder_lower_up2(lower_part) 138 | lower_part = F.interpolate(lower_part, size=feature_upper1.shape[2:]) 139 | lower_part = torch.cat((lower_part, feature_lower1), dim=1) 140 | lower_part = self.decoder_lower_layer3(lower_part) 141 | lower_part = self.decoder_lower_up3(lower_part) 142 | lower_part = F.interpolate(lower_part, size=feature_upper.shape[2:]) 143 | lower_part = self.decoder_lower_layer4(lower_part) 144 | lower_part_embeddding = lower_part 145 | 146 | fusion_part = self.fusion_rule(upper_part_embeding+lower_part_embeddding+ common_part_embedding) 147 | 148 | 149 | common_part = self.decoder_upper_projection_layer(common_part_embedding) 150 | upper_part = self.decoder_upper_projection_layer(upper_part_embeding) 151 | lower_part = self.decoder_lower_projection_layer(lower_part_embeddding) 152 | 153 | return common_part, upper_part, lower_part, fusion_part 154 | # return common_part_embedding, common_part, upper_part_embeding, upper_part, lower_part_embeddding, lower_part, fusion_part 155 | 156 | 157 | class UCTestSharedDeepNetPro(nn.Module): 158 | def __init__(self): 159 | super(UCTestSharedDeepNetPro, self).__init__() 160 | encoder_upper = [nn.Conv2d(3, 16, 3, 1, 1, bias=True), 161 | nn.ReLU(inplace=True), 162 | ResNestLayer(Bottleneck, 8, 6, stem_width=8, norm_layer=None), 163 | ] 164 | self.encoder_upper = nn.Sequential(*encoder_upper) 165 | self.maxpool_upper = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 166 | self.upper_encoder_layer1 = ResNestLayer(Bottleneck, 16, 6, stem_width=16, norm_layer=None, is_first=False) 167 | self.upper_encoder_layer2 = ResNestLayer(Bottleneck, 32, 4, stem_width=32, stride=2, norm_layer=None) 168 | self.upper_encoder_layer3 = ResNestLayer(Bottleneck, 64, 4, stem_width=64, stride=2, norm_layer=None) 169 | 170 | self.encoder_lower = self.encoder_upper 171 | self.maxpool_lower = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 172 | self.lower_encoder_layer1 = self.upper_encoder_layer1 173 | self.lower_encoder_layer2 = self.upper_encoder_layer2 174 | self.lower_encoder_layer3 = self.upper_encoder_layer3 175 | 176 | encoder_body_fusion = [ 177 | ResNestLayer(Bottleneck, 256, 4, stem_width=256, norm_layer=None, is_first=False) 178 | ] 179 | self.common_encoder = nn.Sequential(*encoder_body_fusion) 180 | 181 | self.decoder_common_layer1 = ResNestLayer(Bottleneck, 64, 4, stem_width=768, avg_down=False, avd=False, stride=1, norm_layer=None) 182 | self.decoder_common_up1 = nn.Upsample(scale_factor=2, mode='bilinear') # nn.PixelShuffle(2) 183 | self.decoder_common_layer2 = ResNestLayer(Bottleneck, 16, 4, stem_width=256, avg_down=False, avd=False, stride=1, norm_layer=None) 184 | self.decoder_common_up2 = nn.Upsample(scale_factor=2, mode='bilinear') # nn.PixelShuffle(2) 185 | self.decoder_common_layer3 = ResNestLayer(Bottleneck, 4, 6, stem_width=96, avg_down=False, avd=False, stride=1, norm_layer=None) 186 | self.decoder_common_up3 = nn.Upsample(scale_factor=2, mode='bilinear') # nn.PixelShuffle(2) 187 | decoder_common_layer4 = [ 188 | ResNestLayer(Bottleneck, 4, 6, stem_width=8, avg_down=False, avd=False, stride=1, norm_layer=None), 189 | ] 190 | self.decoder_common_layer4 = nn.Sequential(*decoder_common_layer4) 191 | decoder_projection_layer = [nn.Conv2d(16, 3, 3, 1, 1, bias=True), 192 | nn.ReLU(inplace=True)] 193 | self.decoder_common_projection_layer = nn.Sequential(*decoder_projection_layer) 194 | 195 | self.decoder_upper_layer1 = ResNestLayer(Bottleneck, 96, 4, stem_width=640, avg_down=False, avd=False, stride=1, norm_layer=None) 196 | self.decoder_upper_up1 = nn.Upsample(scale_factor=2, mode='bilinear') # nn.PixelShuffle(2) 197 | self.decoder_upper_layer2 = ResNestLayer(Bottleneck, 32, 4, stem_width=256, avg_down=False, avd=False, stride=1, norm_layer=None) 198 | self.decoder_upper_up2 = nn.Upsample(scale_factor=2, mode='bilinear') # nn.PixelShuffle(2) 199 | self.decoder_upper_layer3 = ResNestLayer(Bottleneck, 16, 6, stem_width=96, avg_down=False, avd=False, stride=1, norm_layer=None) 200 | self.decoder_upper_up3 = nn.Upsample(scale_factor=2, mode='bilinear') # nn.PixelShuffle(2) 201 | decoder_upper_layer4 = [ 202 | ResNestLayer(Bottleneck, 4, 6, stem_width=32, avg_down=False, avd=False, stride=1, norm_layer=None), 203 | ] 204 | self.decoder_upper_layer4 = nn.Sequential(*decoder_upper_layer4) 205 | upper_decoder_projection_layer = [nn.Conv2d(16, 3, 3, 1, 1, bias=True), 206 | nn.ReLU(inplace=True)] 207 | self.decoder_upper_projection_layer = nn.Sequential(*upper_decoder_projection_layer) 208 | 209 | self.decoder_lower_layer1 = self.decoder_upper_layer1 210 | self.decoder_lower_up1 = nn.Upsample(scale_factor=2, mode='bilinear') # nn.PixelShuffle(2) 211 | self.decoder_lower_layer2 = self.decoder_upper_layer2 212 | self.decoder_lower_up2 = nn.Upsample(scale_factor=2, mode='bilinear') # nn.PixelShuffle(2) 213 | self.decoder_lower_layer3 = self.decoder_upper_layer3 214 | self.decoder_lower_up3 = nn.Upsample(scale_factor=2, mode='bilinear') # nn.PixelShuffle(2) 215 | self.decoder_lower_layer4 = self.decoder_upper_layer4 216 | self.decoder_lower_projection_layer = self.decoder_upper_projection_layer 217 | 218 | self.fusion_rule = nn.Sequential(*[ 219 | nn.Conv2d(16, 3, 3, 1, 1, bias=True), 220 | nn.ReLU(inplace=True) 221 | ]) 222 | 223 | for m in self.modules(): 224 | if isinstance(m, nn.Conv2d): 225 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 226 | m.weight.data.normal_(0, math.sqrt(2. / n)) 227 | elif isinstance(m, nn.InstanceNorm2d): 228 | m.weight.data.fill_(1) 229 | m.bias.data.zero_() 230 | 231 | def forward(self, img1, img2): 232 | 233 | feature_upper = self.encoder_upper(img1) 234 | feature_upper0 = self.maxpool_upper(feature_upper) 235 | feature_upper1 = self.upper_encoder_layer1(feature_upper0) 236 | # print("feature upper1", feature_upper1.shape) 237 | feature_upper2 = self.upper_encoder_layer2(feature_upper1) 238 | # print("feature upper2", feature_upper2.shape) 239 | feature_upper3 = self.upper_encoder_layer3(feature_upper2) 240 | # print("feature upper3", feature_upper3.shape) 241 | 242 | feature_lower = self.encoder_lower(img2) 243 | feature_lower0 = self.maxpool_lower(feature_lower) 244 | feature_lower1 = self.lower_encoder_layer1(feature_lower0) 245 | feature_lower2 = self.lower_encoder_layer2(feature_lower1) 246 | feature_lower3 = self.lower_encoder_layer3(feature_lower2) 247 | 248 | feature_concat = torch.cat((feature_upper3, feature_lower3), dim=1) 249 | feature_common = self.common_encoder(feature_concat) 250 | 251 | common_part = torch.cat((feature_common, feature_upper3, feature_lower3), dim=1) 252 | common_part = self.decoder_common_layer1(common_part) 253 | common_part = self.decoder_common_up1(common_part) 254 | common_part = F.interpolate(common_part, size=feature_upper2.shape[2:]) 255 | common_part = torch.cat((common_part, feature_upper2, feature_lower2), dim=1) 256 | common_part = self.decoder_common_layer2(common_part) 257 | common_part = self.decoder_common_up2(common_part) 258 | common_part = F.interpolate(common_part, size=feature_upper1.shape[2:]) 259 | common_part = torch.cat((common_part, feature_upper1, feature_lower1), dim=1) 260 | common_part = self.decoder_common_layer3(common_part) 261 | common_part = self.decoder_common_up3(common_part) 262 | common_part = F.interpolate(common_part, size=feature_upper.shape[2:]) 263 | common_part = self.decoder_common_layer4(common_part) 264 | common_part_embedding = common_part 265 | 266 | 267 | feature_de_upper = torch.cat((feature_common, feature_upper3), dim=1) 268 | upper_part = self.decoder_upper_layer1(feature_de_upper) 269 | upper_part = self.decoder_upper_up1(upper_part) 270 | upper_part = F.interpolate(upper_part, size=feature_upper2.shape[2:]) 271 | upper_part = torch.cat((upper_part, feature_upper2), dim=1) 272 | upper_part = self.decoder_upper_layer2(upper_part) 273 | upper_part = self.decoder_upper_up2(upper_part) 274 | upper_part = F.interpolate(upper_part, size=feature_upper1.shape[2:]) 275 | upper_part = torch.cat((upper_part, feature_upper1), dim=1) 276 | upper_part = self.decoder_upper_layer3(upper_part) 277 | upper_part = self.decoder_upper_up3(upper_part) 278 | upper_part = F.interpolate(upper_part, size=feature_upper.shape[2:]) 279 | upper_part = self.decoder_upper_layer4(upper_part) 280 | upper_part_embeding = upper_part 281 | 282 | 283 | 284 | feature_de_lower = torch.cat((feature_common, feature_lower3), dim=1) 285 | lower_part = self.decoder_lower_layer1(feature_de_lower) 286 | lower_part = self.decoder_lower_up1(lower_part) 287 | lower_part = F.interpolate(lower_part, size=feature_upper2.shape[2:]) 288 | lower_part = torch.cat((lower_part, feature_lower2), dim=1) 289 | lower_part = self.decoder_lower_layer2(lower_part) 290 | lower_part = self.decoder_lower_up2(lower_part) 291 | lower_part = F.interpolate(lower_part, size=feature_upper1.shape[2:]) 292 | lower_part = torch.cat((lower_part, feature_lower1), dim=1) 293 | lower_part = self.decoder_lower_layer3(lower_part) 294 | lower_part = self.decoder_lower_up3(lower_part) 295 | lower_part = F.interpolate(lower_part, size=feature_upper.shape[2:]) 296 | lower_part = self.decoder_lower_layer4(lower_part) 297 | lower_part_embeddding = lower_part 298 | 299 | 300 | fusion_part = self.fusion_rule(upper_part_embeding+lower_part_embeddding+ common_part_embedding) 301 | 302 | 303 | common_part = self.decoder_upper_projection_layer(common_part) 304 | upper_part = self.decoder_upper_projection_layer(upper_part) 305 | lower_part = self.decoder_lower_projection_layer(lower_part) 306 | 307 | return common_part, upper_part, lower_part, fusion_part 308 | 309 | 310 | class UCTestSharedAblaNetPro(nn.Module): 311 | def __init__(self): 312 | super(UCTestSharedAblaNetPro, self).__init__() 313 | encoder_upper = [nn.Conv2d(3, 16, 3, 1, 1, bias=True), 314 | nn.ReLU(inplace=True), 315 | ResNestLayer(Bottleneck, 8, 6, stem_width=8, norm_layer=None), 316 | ] 317 | self.encoder_upper = nn.Sequential(*encoder_upper) 318 | self.maxpool_upper = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 319 | self.upper_encoder_layer1 = ResNestLayer(Bottleneck, 16, 6, stem_width=16, norm_layer=None, is_first=False) 320 | self.upper_encoder_layer2 = ResNestLayer(Bottleneck, 32, 4, stem_width=32, stride=2, norm_layer=None) 321 | self.upper_encoder_layer3 = ResNestLayer(Bottleneck, 64, 4, stem_width=64, stride=2, norm_layer=None) 322 | 323 | self.encoder_lower = self.encoder_upper 324 | self.maxpool_lower = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 325 | self.lower_encoder_layer1 = self.upper_encoder_layer1 326 | self.lower_encoder_layer2 = self.upper_encoder_layer2 327 | self.lower_encoder_layer3 = self.upper_encoder_layer3 328 | 329 | encoder_body_fusion = [ 330 | ResNestLayer(Bottleneck, 256, 4, stem_width=256, norm_layer=None, is_first=False) 331 | ] 332 | self.common_encoder = nn.Sequential(*encoder_body_fusion) 333 | 334 | self.decoder_common_layer1 = ResNestLayer(Bottleneck, 64, 4, stem_width=512, avg_down=False, avd=False, stride=1, norm_layer=None) 335 | self.decoder_common_up1 = nn.Upsample(scale_factor=2, mode='bilinear') # nn.PixelShuffle(2) 336 | self.decoder_common_layer2 = ResNestLayer(Bottleneck, 16, 4, stem_width=256, avg_down=False, avd=False, stride=1, norm_layer=None) 337 | self.decoder_common_up2 = nn.Upsample(scale_factor=2, mode='bilinear') # nn.PixelShuffle(2) 338 | self.decoder_common_layer3 = ResNestLayer(Bottleneck, 4, 6, stem_width=96, avg_down=False, avd=False, stride=1, norm_layer=None) 339 | self.decoder_common_up3 = nn.Upsample(scale_factor=2, mode='bilinear') # nn.PixelShuffle(2) 340 | decoder_common_layer4 = [ 341 | ResNestLayer(Bottleneck, 4, 6, stem_width=8, avg_down=False, avd=False, stride=1, norm_layer=None), 342 | ] 343 | self.decoder_common_layer4 = nn.Sequential(*decoder_common_layer4) 344 | decoder_projection_layer = [nn.Conv2d(16, 3, 3, 1, 1, bias=True), 345 | nn.ReLU(inplace=True)] 346 | self.decoder_common_projection_layer = nn.Sequential(*decoder_projection_layer) 347 | 348 | self.decoder_upper_layer1 = ResNestLayer(Bottleneck, 96, 4, stem_width=640, avg_down=False, avd=False, stride=1, norm_layer=None) 349 | self.decoder_upper_up1 = nn.Upsample(scale_factor=2, mode='bilinear') # nn.PixelShuffle(2) 350 | self.decoder_upper_layer2 = ResNestLayer(Bottleneck, 32, 4, stem_width=256, avg_down=False, avd=False, stride=1, norm_layer=None) 351 | self.decoder_upper_up2 = nn.Upsample(scale_factor=2, mode='bilinear') # nn.PixelShuffle(2) 352 | self.decoder_upper_layer3 = ResNestLayer(Bottleneck, 16, 6, stem_width=96, avg_down=False, avd=False, stride=1, norm_layer=None) 353 | self.decoder_upper_up3 = nn.Upsample(scale_factor=2, mode='bilinear') # nn.PixelShuffle(2) 354 | decoder_upper_layer4 = [ 355 | ResNestLayer(Bottleneck, 4, 6, stem_width=32, avg_down=False, avd=False, stride=1, norm_layer=None), 356 | ] 357 | self.decoder_upper_layer4 = nn.Sequential(*decoder_upper_layer4) 358 | upper_decoder_projection_layer = [nn.Conv2d(16, 3, 3, 1, 1, bias=True), 359 | nn.ReLU(inplace=True)] 360 | self.decoder_upper_projection_layer = nn.Sequential(*upper_decoder_projection_layer) 361 | 362 | self.decoder_lower_layer1 = self.decoder_upper_layer1 363 | self.decoder_lower_up1 = nn.Upsample(scale_factor=2, mode='bilinear') # nn.PixelShuffle(2) 364 | self.decoder_lower_layer2 = self.decoder_upper_layer2 365 | self.decoder_lower_up2 = nn.Upsample(scale_factor=2, mode='bilinear') # nn.PixelShuffle(2) 366 | self.decoder_lower_layer3 = self.decoder_upper_layer3 367 | self.decoder_lower_up3 = nn.Upsample(scale_factor=2, mode='bilinear') # nn.PixelShuffle(2) 368 | self.decoder_lower_layer4 = self.decoder_upper_layer4 369 | self.decoder_lower_projection_layer = self.decoder_upper_projection_layer 370 | 371 | self.fusion_rule = nn.Sequential(*[ 372 | nn.Conv2d(16, 3, 3, 1, 1, bias=True), 373 | nn.ReLU(inplace=True) 374 | ]) 375 | 376 | for m in self.modules(): 377 | if isinstance(m, nn.Conv2d): 378 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 379 | m.weight.data.normal_(0, math.sqrt(2. / n)) 380 | elif isinstance(m, nn.InstanceNorm2d): 381 | m.weight.data.fill_(1) 382 | m.bias.data.zero_() 383 | 384 | def forward(self, img1, img2): 385 | 386 | feature_upper = self.encoder_upper(img1) 387 | feature_upper0 = self.maxpool_upper(feature_upper) 388 | feature_upper1 = self.upper_encoder_layer1(feature_upper0) 389 | feature_upper2 = self.upper_encoder_layer2(feature_upper1) 390 | feature_upper3 = self.upper_encoder_layer3(feature_upper2) 391 | 392 | feature_lower = self.encoder_lower(img2) 393 | feature_lower0 = self.maxpool_lower(feature_lower) 394 | feature_lower1 = self.lower_encoder_layer1(feature_lower0) 395 | feature_lower2 = self.lower_encoder_layer2(feature_lower1) 396 | feature_lower3 = self.lower_encoder_layer3(feature_lower2) 397 | 398 | feature_concat = torch.cat((feature_upper3, feature_lower3), dim=1) 399 | feature_common = self.common_encoder(feature_concat) 400 | 401 | common_part = feature_common 402 | common_part = self.decoder_common_layer1(common_part) 403 | common_part = self.decoder_common_up1(common_part) 404 | common_part = F.interpolate(common_part, size=feature_upper2.shape[2:]) 405 | common_part = torch.cat((common_part, feature_upper2, feature_lower2), dim=1) 406 | common_part = self.decoder_common_layer2(common_part) 407 | common_part = self.decoder_common_up2(common_part) 408 | common_part = F.interpolate(common_part, size=feature_upper1.shape[2:]) 409 | common_part = torch.cat((common_part, feature_upper1, feature_lower1), dim=1) 410 | common_part = self.decoder_common_layer3(common_part) 411 | common_part = self.decoder_common_up3(common_part) 412 | common_part = F.interpolate(common_part, size=feature_upper.shape[2:]) 413 | common_part = self.decoder_common_layer4(common_part) 414 | common_part_embedding = common_part 415 | 416 | 417 | feature_de_upper = torch.cat((feature_common, feature_upper3), dim=1) 418 | upper_part = self.decoder_upper_layer1(feature_de_upper) 419 | upper_part = self.decoder_upper_up1(upper_part) 420 | upper_part = F.interpolate(upper_part, size=feature_upper2.shape[2:]) 421 | upper_part = torch.cat((upper_part, feature_upper2), dim=1) 422 | upper_part = self.decoder_upper_layer2(upper_part) 423 | upper_part = self.decoder_upper_up2(upper_part) 424 | upper_part = F.interpolate(upper_part, size=feature_upper1.shape[2:]) 425 | upper_part = torch.cat((upper_part, feature_upper1), dim=1) 426 | upper_part = self.decoder_upper_layer3(upper_part) 427 | upper_part = self.decoder_upper_up3(upper_part) 428 | upper_part = F.interpolate(upper_part, size=feature_upper.shape[2:]) 429 | upper_part = self.decoder_upper_layer4(upper_part) 430 | upper_part_embeding = upper_part 431 | 432 | 433 | 434 | feature_de_lower = torch.cat((feature_common, feature_lower3), dim=1) 435 | lower_part = self.decoder_lower_layer1(feature_de_lower) 436 | lower_part = self.decoder_lower_up1(lower_part) 437 | lower_part = F.interpolate(lower_part, size=feature_upper2.shape[2:]) 438 | lower_part = torch.cat((lower_part, feature_lower2), dim=1) 439 | lower_part = self.decoder_lower_layer2(lower_part) 440 | lower_part = self.decoder_lower_up2(lower_part) 441 | lower_part = F.interpolate(lower_part, size=feature_upper1.shape[2:]) 442 | lower_part = torch.cat((lower_part, feature_lower1), dim=1) 443 | lower_part = self.decoder_lower_layer3(lower_part) 444 | lower_part = self.decoder_lower_up3(lower_part) 445 | lower_part = F.interpolate(lower_part, size=feature_upper.shape[2:]) 446 | lower_part = self.decoder_lower_layer4(lower_part) 447 | lower_part_embeddding = lower_part 448 | 449 | fusion_part = self.fusion_rule(upper_part_embeding+lower_part_embeddding+ common_part_embedding) 450 | 451 | common_part = self.decoder_common_projection_layer(common_part) 452 | upper_part = self.decoder_upper_projection_layer(upper_part) 453 | lower_part = self.decoder_lower_projection_layer(lower_part) 454 | 455 | return common_part, upper_part, lower_part, fusion_part 456 | -------------------------------------------------------------------------------- /models/UCTestSharedModelPro.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.init as init 4 | from .resnet import ResNestLayer, Bottleneck 5 | import math 6 | import torch.nn.functional as F 7 | 8 | 9 | class UCTestSharedNetPro(nn.Module): 10 | def __init__(self): 11 | super(UCTestSharedNetPro, self).__init__() 12 | encoder_upper = [nn.Conv2d(3, 16, 3, 1, 1, bias=True), 13 | nn.ReLU(inplace=True), 14 | ResNestLayer(Bottleneck, 8, 6, stem_width=8, norm_layer=None), 15 | ] 16 | self.encoder_upper = nn.Sequential(*encoder_upper) 17 | # self.encoder_upper_in = nn.InstanceNorm2d(64,affine=True) 18 | self.maxpool_upper = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 19 | self.upper_encoder_layer1 = ResNestLayer(Bottleneck, 16, 6, stem_width=16, norm_layer=None, is_first=False) 20 | self.upper_encoder_layer2 = ResNestLayer(Bottleneck, 32, 4, stem_width=32, stride=2, norm_layer=None) 21 | self.upper_encoder_layer3 = ResNestLayer(Bottleneck, 64, 4, stem_width=64, stride=2, norm_layer=None) 22 | 23 | self.encoder_lower = self.encoder_upper 24 | self.maxpool_lower = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 25 | self.lower_encoder_layer1 = self.upper_encoder_layer1 26 | self.lower_encoder_layer2 = self.upper_encoder_layer2 27 | self.lower_encoder_layer3 = self.upper_encoder_layer3 28 | 29 | encoder_body_fusion = [ 30 | ResNestLayer(Bottleneck, 256, 4, stem_width=256, norm_layer=None, is_first=False) 31 | ] 32 | self.common_encoder = nn.Sequential(*encoder_body_fusion) 33 | 34 | self.decoder_common_layer1 = ResNestLayer(Bottleneck, 64, 2, stem_width=512, avg_down=False, avd=False, stride=1, norm_layer=None) 35 | self.decoder_common_up1 = nn.Upsample(scale_factor=2, mode='bilinear') # nn.PixelShuffle(2) 36 | self.decoder_common_layer2 = ResNestLayer(Bottleneck, 16, 2, stem_width=128, avg_down=False, avd=False, stride=1, norm_layer=None) 37 | self.decoder_common_up2 = nn.Upsample(scale_factor=2, mode='bilinear') # nn.PixelShuffle(2) 38 | self.decoder_common_layer3 = ResNestLayer(Bottleneck, 4, 2, stem_width=32, avg_down=False, avd=False, stride=1, norm_layer=None) 39 | self.decoder_common_up3 = nn.Upsample(scale_factor=2, mode='bilinear') # nn.PixelShuffle(2) 40 | decoder_common_layer4 = [ 41 | ResNestLayer(Bottleneck, 4, 2, stem_width=8, avg_down=False, avd=False, stride=1, norm_layer=None), 42 | ] 43 | self.decoder_common_layer4 = nn.Sequential(*decoder_common_layer4) 44 | decoder_projection_layer = [nn.Conv2d(16, 3, 3, 1, 1, bias=True), 45 | nn.ReLU(inplace=True)] 46 | self.decoder_common_projection_layer = nn.Sequential(*decoder_projection_layer) 47 | 48 | self.decoder_upper_layer1 = ResNestLayer(Bottleneck, 96, 4, stem_width=640, avg_down=False, avd=False, stride=1, norm_layer=None) 49 | self.decoder_upper_up1 = nn.Upsample(scale_factor=2, mode='bilinear') # nn.PixelShuffle(2) 50 | self.decoder_upper_layer2 = ResNestLayer(Bottleneck, 32, 4, stem_width=256, avg_down=False, avd=False, stride=1, norm_layer=None) 51 | self.decoder_upper_up2 = nn.Upsample(scale_factor=2, mode='bilinear') # nn.PixelShuffle(2) 52 | self.decoder_upper_layer3 = ResNestLayer(Bottleneck, 16, 6, stem_width=96, avg_down=False, avd=False, stride=1, norm_layer=None) 53 | self.decoder_upper_up3 = nn.Upsample(scale_factor=2, mode='bilinear') # nn.PixelShuffle(2) 54 | decoder_upper_layer4 = [ 55 | ResNestLayer(Bottleneck, 4, 6, stem_width=32, avg_down=False, avd=False, stride=1, norm_layer=None), 56 | ] 57 | self.decoder_upper_layer4 = nn.Sequential(*decoder_upper_layer4) 58 | upper_decoder_projection_layer = [nn.Conv2d(16, 3, 3, 1, 1, bias=True), 59 | nn.ReLU(inplace=True)] 60 | self.decoder_upper_projection_layer = nn.Sequential(*upper_decoder_projection_layer) 61 | 62 | self.decoder_lower_layer1 = self.decoder_upper_layer1 63 | self.decoder_lower_up1 = nn.Upsample(scale_factor=2, mode='bilinear') # nn.PixelShuffle(2) 64 | self.decoder_lower_layer2 = self.decoder_upper_layer2 65 | self.decoder_lower_up2 = nn.Upsample(scale_factor=2, mode='bilinear') # nn.PixelShuffle(2) 66 | self.decoder_lower_layer3 = self.decoder_upper_layer3 67 | self.decoder_lower_up3 = nn.Upsample(scale_factor=2, mode='bilinear') # nn.PixelShuffle(2) 68 | self.decoder_lower_layer4 = self.decoder_upper_layer4 69 | self.decoder_lower_projection_layer = self.decoder_upper_projection_layer 70 | 71 | self.fusion_rule = nn.Sequential(*[ 72 | nn.Conv2d(16, 3, 3, 1, 1, bias=True), 73 | nn.ReLU(inplace=True) 74 | ]) 75 | 76 | for m in self.modules(): 77 | if isinstance(m, nn.Conv2d): 78 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 79 | m.weight.data.normal_(0, math.sqrt(2. / n)) 80 | elif isinstance(m, nn.InstanceNorm2d): 81 | m.weight.data.fill_(1) 82 | m.bias.data.zero_() 83 | 84 | def forward(self, img1, img2): 85 | 86 | feature_upper = self.encoder_upper(img1) 87 | feature_upper0 = self.maxpool_upper(feature_upper) 88 | feature_upper1 = self.upper_encoder_layer1(feature_upper0) 89 | feature_upper2 = self.upper_encoder_layer2(feature_upper1) 90 | feature_upper3 = self.upper_encoder_layer3(feature_upper2) 91 | 92 | feature_lower = self.encoder_lower(img2) 93 | feature_lower0 = self.maxpool_lower(feature_lower) 94 | feature_lower1 = self.lower_encoder_layer1(feature_lower0) 95 | feature_lower2 = self.lower_encoder_layer2(feature_lower1) 96 | feature_lower3 = self.lower_encoder_layer3(feature_lower2) 97 | 98 | feature_concat = torch.cat((feature_upper3, feature_lower3), dim=1) 99 | feature_common = self.common_encoder(feature_concat) 100 | 101 | common_part = self.decoder_common_layer1(feature_common) 102 | common_part = self.decoder_common_up1(common_part) 103 | common_part = F.interpolate(common_part, size=feature_upper2.shape[2:]) 104 | common_part = self.decoder_common_layer2(common_part) 105 | common_part = self.decoder_common_up2(common_part) 106 | common_part = F.interpolate(common_part, size=feature_upper1.shape[2:]) 107 | common_part = self.decoder_common_layer3(common_part) 108 | common_part = self.decoder_common_up3(common_part) 109 | common_part = F.interpolate(common_part, size=feature_upper.shape[2:]) 110 | common_part = self.decoder_common_layer4(common_part) 111 | common_part_embedding = common_part 112 | common_part = self.decoder_upper_projection_layer(common_part) 113 | 114 | feature_de_upper = torch.cat((feature_common, feature_upper3), dim=1) 115 | upper_part = self.decoder_upper_layer1(feature_de_upper) 116 | upper_part = self.decoder_upper_up1(upper_part) 117 | upper_part = F.interpolate(upper_part, size=feature_upper2.shape[2:]) 118 | upper_part = torch.cat((upper_part, feature_upper2), dim=1) 119 | upper_part = self.decoder_upper_layer2(upper_part) 120 | upper_part = self.decoder_upper_up2(upper_part) 121 | upper_part = F.interpolate(upper_part, size=feature_upper1.shape[2:]) 122 | upper_part = torch.cat((upper_part, feature_upper1), dim=1) 123 | upper_part = self.decoder_upper_layer3(upper_part) 124 | upper_part = self.decoder_upper_up3(upper_part) 125 | upper_part = F.interpolate(upper_part, size=feature_upper.shape[2:]) 126 | upper_part = self.decoder_upper_layer4(upper_part) 127 | upper_part_embeding = upper_part 128 | upper_part = self.decoder_upper_projection_layer(upper_part) 129 | 130 | feature_de_lower = torch.cat((feature_common, feature_lower3), dim=1) 131 | lower_part = self.decoder_lower_layer1(feature_de_lower) 132 | lower_part = self.decoder_lower_up1(lower_part) 133 | lower_part = F.interpolate(lower_part, size=feature_upper2.shape[2:]) 134 | lower_part = torch.cat((lower_part, feature_lower2), dim=1) 135 | lower_part = self.decoder_lower_layer2(lower_part) 136 | lower_part = self.decoder_lower_up2(lower_part) 137 | lower_part = F.interpolate(lower_part, size=feature_upper1.shape[2:]) 138 | lower_part = torch.cat((lower_part, feature_lower1), dim=1) 139 | lower_part = self.decoder_lower_layer3(lower_part) 140 | lower_part = self.decoder_lower_up3(lower_part) 141 | lower_part = F.interpolate(lower_part, size=feature_upper.shape[2:]) 142 | lower_part = self.decoder_lower_layer4(lower_part) 143 | lower_part_embeddding = lower_part 144 | lower_part = self.decoder_lower_projection_layer(lower_part) 145 | 146 | fusion_part = self.fusion_rule(upper_part_embeding+lower_part_embeddding+ common_part_embedding) 147 | 148 | return common_part_embedding, upper_part_embeding, lower_part_embeddding, fusion_part 149 | -------------------------------------------------------------------------------- /models/resnest.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Hang Zhang 3 | ## Email: zhanghang0704@gmail.com 4 | ## Copyright (c) 2020 5 | ## 6 | ## LICENSE file in the root directory of this source tree 7 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 8 | """ResNeSt models""" 9 | 10 | import torch 11 | from .resnet import ResNet, Bottleneck 12 | 13 | 14 | 15 | 16 | def resnest50(pretrained=False, root='~/.encoding/models', **kwargs): 17 | model = ResNet(Bottleneck, [3, 4, 6, 3], 18 | radix=2, groups=1, bottleneck_width=64, 19 | deep_stem=True, stem_width=32, avg_down=True, 20 | avd=True, avd_first=False, **kwargs) 21 | 22 | return model 23 | 24 | 25 | 26 | 27 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Hang Zhang 3 | ## Email: zhanghang0704@gmail.com 4 | ## Copyright (c) 2020 5 | ## 6 | ## LICENSE file in the root directory of this source tree 7 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 8 | """ResNet variants""" 9 | import math 10 | import torch 11 | import torch.nn as nn 12 | 13 | from .splat import SplAtConv2d 14 | 15 | __all__ = ['ResNet', 'Bottleneck'] 16 | 17 | 18 | class GlobalAvgPool2d(nn.Module): 19 | def __init__(self): 20 | """Global average pooling over the input's spatial dimensions""" 21 | super(GlobalAvgPool2d, self).__init__() 22 | 23 | def forward(self, inputs): 24 | return nn.functional.adaptive_avg_pool2d(inputs, 1).view(inputs.size(0), -1) 25 | 26 | class Bottleneck(nn.Module): 27 | """ResNet Bottleneck 28 | """ 29 | # pylint: disable=unused-argument 30 | expansion = 4 31 | def __init__(self, inplanes, planes, stride=1, downsample=None, 32 | radix=1, cardinality=1, bottleneck_width=64, 33 | avd=False, avd_first=False, dilation=1, is_first=False, 34 | norm_layer=None): 35 | super(Bottleneck, self).__init__() 36 | group_width = int(planes * (bottleneck_width / 64.)) * cardinality 37 | self.conv1 = nn.Conv2d(inplanes, group_width, kernel_size=1, bias=False) 38 | # self.bn1 = norm_layer(group_width, affine=True) 39 | self.radix = radix 40 | self.avd = avd and (stride > 1 or is_first) 41 | self.avd_first = avd_first 42 | 43 | if self.avd: 44 | self.avd_layer = nn.AvgPool2d(3, stride, padding=1) 45 | stride = 1 46 | 47 | self.conv2 = SplAtConv2d( 48 | group_width, group_width, kernel_size=3, 49 | stride=stride, padding=dilation, 50 | dilation=dilation, groups=cardinality, bias=False, 51 | radix=radix, 52 | norm_layer=norm_layer) 53 | 54 | self.conv3 = nn.Conv2d( 55 | group_width, planes * 4, kernel_size=1, bias=False) 56 | # self.bn3 = norm_layer(planes*4, affine=True) 57 | 58 | self.relu = nn.ReLU(inplace=True) 59 | self.downsample = downsample 60 | self.dilation = dilation 61 | self.stride = stride 62 | 63 | def forward(self, x): 64 | residual = x 65 | 66 | out = self.conv1(x) 67 | # out = self.bn1(out) 68 | 69 | out = self.relu(out) 70 | 71 | if self.avd and self.avd_first: 72 | out = self.avd_layer(out) 73 | 74 | out = self.conv2(out) 75 | 76 | if self.avd and not self.avd_first: 77 | out = self.avd_layer(out) 78 | 79 | out = self.conv3(out) 80 | # out = self.bn3(out) 81 | 82 | if self.downsample is not None: 83 | residual = self.downsample(x) 84 | 85 | out += residual 86 | out = self.relu(out) 87 | 88 | return out 89 | 90 | 91 | class ResNestLayer(nn.Module): 92 | def __init__(self, block, planes, blocks, 93 | stride=1, dilation=1, norm_layer=None, is_first=True 94 | ,radix=2, deep_stem=True, stem_width=32, avg_down=True, groups=1 95 | ,bottleneck_width=64, avd=True, avd_first=False): 96 | super(ResNestLayer, self).__init__() 97 | self.inplanes = stem_width * 2 if deep_stem else 64 98 | self.radix = radix 99 | self.avg_down = avg_down 100 | self.cardinality = groups 101 | self.bottleneck_width = bottleneck_width 102 | self.avd = avd 103 | self.avd_first = avd_first 104 | self.rectified_conv = False 105 | self.rectify_avg = False 106 | self.layer = self._make_layer(block=block, planes=planes, blocks=blocks, 107 | stride=stride, dilation=dilation, norm_layer=norm_layer, 108 | is_first=is_first) 109 | for m in self.modules(): 110 | if isinstance(m, nn.Conv2d): 111 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 112 | m.weight.data.normal_(0, math.sqrt(2. / n)) 113 | # elif isinstance(m, norm_layer): 114 | # m.weight.data.fill_(1) 115 | # m.bias.data.zero_() 116 | 117 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, norm_layer=None, is_first=True): 118 | downsample = None 119 | if stride != 1 or self.inplanes != planes * block.expansion: 120 | down_layers = [] 121 | if self.avg_down: 122 | if dilation == 1: 123 | down_layers.append(nn.AvgPool2d(kernel_size=stride, stride=stride, 124 | ceil_mode=True, count_include_pad=False)) 125 | else: 126 | down_layers.append(nn.AvgPool2d(kernel_size=1, stride=1, 127 | ceil_mode=True, count_include_pad=False)) 128 | down_layers.append(nn.Conv2d(self.inplanes, planes * block.expansion, 129 | kernel_size=1, stride=1, bias=False)) 130 | else: 131 | down_layers.append(nn.Conv2d(self.inplanes, planes * block.expansion, 132 | kernel_size=1, stride=stride, bias=False)) 133 | # down_layers.append(norm_layer(planes * block.expansion, affine=True)) 134 | downsample = nn.Sequential(*down_layers) 135 | 136 | layers = [] 137 | if dilation == 1 or dilation == 2: 138 | layers.append(block(self.inplanes, planes, stride, downsample=downsample, 139 | radix=self.radix, cardinality=self.cardinality, 140 | bottleneck_width=self.bottleneck_width, 141 | avd=self.avd, avd_first=self.avd_first, 142 | dilation=1, is_first=is_first, 143 | norm_layer=norm_layer)) 144 | elif dilation == 4: 145 | layers.append(block(self.inplanes, planes, stride, downsample=downsample, 146 | radix=self.radix, cardinality=self.cardinality, 147 | bottleneck_width=self.bottleneck_width, 148 | avd=self.avd, avd_first=self.avd_first, 149 | dilation=2, is_first=is_first, 150 | norm_layer=norm_layer)) 151 | else: 152 | raise RuntimeError("=> unknown dilation size: {}".format(dilation)) 153 | 154 | self.inplanes = planes * block.expansion 155 | for i in range(1, blocks): 156 | layers.append(block(self.inplanes, planes, 157 | radix=self.radix, cardinality=self.cardinality, 158 | bottleneck_width=self.bottleneck_width, 159 | avd=self.avd, avd_first=self.avd_first, 160 | dilation=dilation, 161 | norm_layer=norm_layer)) 162 | 163 | return nn.Sequential(*layers) 164 | 165 | def forward(self, x): 166 | out = self.layer(x) 167 | return out 168 | 169 | class ResNet(nn.Module): 170 | 171 | # def resnest50(pretrained=False, root='~/.encoding/models', **kwargs): 172 | # model = ResNet(Bottleneck, [3, 4, 6, 3], 173 | # radix=2, groups=1, bottleneck_width=64, 174 | # deep_stem=True, stem_width=32, avg_down=True, 175 | # avd=True, avd_first=False, **kwargs) 176 | # 177 | # return model 178 | # pylint: disable=unused-variable 179 | def __init__(self, block, layers, radix=1, groups=1, bottleneck_width=64, 180 | num_classes=1000, dilated=False, dilation=1, 181 | deep_stem=False, stem_width=64, avg_down=False, 182 | rectified_conv=False, rectify_avg=False, 183 | avd=False, avd_first=False, 184 | final_drop=0.0, dropblock_prob=0, 185 | last_gamma=False, norm_layer=nn.BatchNorm2d): 186 | self.cardinality = groups 187 | self.bottleneck_width = bottleneck_width 188 | # ResNet-D params 189 | self.inplanes = stem_width*2 if deep_stem else 64 190 | self.avg_down = avg_down 191 | self.last_gamma = last_gamma 192 | # ResNeSt params 193 | self.radix = radix 194 | self.avd = avd 195 | self.avd_first = avd_first 196 | 197 | super(ResNet, self).__init__() 198 | self.rectified_conv = rectified_conv 199 | self.rectify_avg = rectify_avg 200 | 201 | conv_layer = nn.Conv2d 202 | conv_kwargs = {'average_mode': rectify_avg} if rectified_conv else {} 203 | if deep_stem: 204 | self.conv1 = nn.Sequential( 205 | conv_layer(3, stem_width, kernel_size=3, stride=2, padding=1, bias=False, **conv_kwargs), 206 | norm_layer(stem_width), 207 | nn.ReLU(inplace=True), 208 | conv_layer(stem_width, stem_width, kernel_size=3, stride=1, padding=1, bias=False, **conv_kwargs), 209 | norm_layer(stem_width), 210 | nn.ReLU(inplace=True), 211 | conv_layer(stem_width, stem_width*2, kernel_size=3, stride=1, padding=1, bias=False, **conv_kwargs), 212 | ) 213 | else: 214 | self.conv1 = conv_layer(3, 64, kernel_size=7, stride=2, padding=3, 215 | bias=False, **conv_kwargs) 216 | self.bn1 = norm_layer(self.inplanes) 217 | self.relu = nn.ReLU(inplace=True) 218 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 219 | self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer, is_first=False) 220 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, norm_layer=norm_layer) 221 | if dilated or dilation == 4: 222 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1, 223 | dilation=2, norm_layer=norm_layer, 224 | dropblock_prob=dropblock_prob) 225 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, 226 | dilation=4, norm_layer=norm_layer, 227 | dropblock_prob=dropblock_prob) 228 | elif dilation==2: 229 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 230 | dilation=1, norm_layer=norm_layer, 231 | dropblock_prob=dropblock_prob) 232 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, 233 | dilation=2, norm_layer=norm_layer, 234 | dropblock_prob=dropblock_prob) 235 | else: 236 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 237 | norm_layer=norm_layer, 238 | dropblock_prob=dropblock_prob) 239 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 240 | norm_layer=norm_layer, 241 | dropblock_prob=dropblock_prob) 242 | self.avgpool = GlobalAvgPool2d() 243 | self.drop = nn.Dropout(final_drop) if final_drop > 0.0 else None 244 | self.fc = nn.Linear(512 * block.expansion, num_classes) 245 | 246 | for m in self.modules(): 247 | if isinstance(m, nn.Conv2d): 248 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 249 | m.weight.data.normal_(0, math.sqrt(2. / n)) 250 | elif isinstance(m, norm_layer): 251 | m.weight.data.fill_(1) 252 | m.bias.data.zero_() 253 | 254 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, norm_layer=None, 255 | dropblock_prob=0.0, is_first=True): 256 | downsample = None 257 | if stride != 1 or self.inplanes != planes * block.expansion: 258 | down_layers = [] 259 | if self.avg_down: 260 | if dilation == 1: 261 | down_layers.append(nn.AvgPool2d(kernel_size=stride, stride=stride, 262 | ceil_mode=True, count_include_pad=False)) 263 | else: 264 | down_layers.append(nn.AvgPool2d(kernel_size=1, stride=1, 265 | ceil_mode=True, count_include_pad=False)) 266 | down_layers.append(nn.Conv2d(self.inplanes, planes * block.expansion, 267 | kernel_size=1, stride=1, bias=False)) 268 | else: 269 | down_layers.append(nn.Conv2d(self.inplanes, planes * block.expansion, 270 | kernel_size=1, stride=stride, bias=False)) 271 | down_layers.append(norm_layer(planes * block.expansion)) 272 | downsample = nn.Sequential(*down_layers) 273 | 274 | layers = [] 275 | if dilation == 1 or dilation == 2: 276 | layers.append(block(self.inplanes, planes, stride, downsample=downsample, 277 | radix=self.radix, cardinality=self.cardinality, 278 | bottleneck_width=self.bottleneck_width, 279 | avd=self.avd, avd_first=self.avd_first, 280 | dilation=1, is_first=is_first, rectified_conv=self.rectified_conv, 281 | rectify_avg=self.rectify_avg, 282 | norm_layer=norm_layer, dropblock_prob=dropblock_prob, 283 | last_gamma=self.last_gamma)) 284 | elif dilation == 4: 285 | layers.append(block(self.inplanes, planes, stride, downsample=downsample, 286 | radix=self.radix, cardinality=self.cardinality, 287 | bottleneck_width=self.bottleneck_width, 288 | avd=self.avd, avd_first=self.avd_first, 289 | dilation=2, is_first=is_first, rectified_conv=self.rectified_conv, 290 | rectify_avg=self.rectify_avg, 291 | norm_layer=norm_layer, dropblock_prob=dropblock_prob, 292 | last_gamma=self.last_gamma)) 293 | else: 294 | raise RuntimeError("=> unknown dilation size: {}".format(dilation)) 295 | 296 | self.inplanes = planes * block.expansion 297 | for i in range(1, blocks): 298 | layers.append(block(self.inplanes, planes, 299 | radix=self.radix, cardinality=self.cardinality, 300 | bottleneck_width=self.bottleneck_width, 301 | avd=self.avd, avd_first=self.avd_first, 302 | dilation=dilation, rectified_conv=self.rectified_conv, 303 | rectify_avg=self.rectify_avg, 304 | norm_layer=norm_layer, dropblock_prob=dropblock_prob, 305 | last_gamma=self.last_gamma)) 306 | 307 | return nn.Sequential(*layers) 308 | 309 | def forward(self, x): 310 | x = self.conv1(x) 311 | x = self.bn1(x) 312 | x = self.relu(x) 313 | x = self.maxpool(x) 314 | 315 | x = self.layer1(x) 316 | x = self.layer2(x) 317 | x = self.layer3(x) 318 | x = self.layer4(x) 319 | 320 | x = self.avgpool(x) 321 | x = torch.flatten(x, 1) 322 | if self.drop: 323 | x = self.drop(x) 324 | x = self.fc(x) 325 | 326 | return x 327 | 328 | 329 | 330 | 331 | -------------------------------------------------------------------------------- /models/splat.py: -------------------------------------------------------------------------------- 1 | """Split-Attention""" 2 | 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | from torch.nn import Conv2d, Module, ReLU 7 | from torch.nn.modules.utils import _pair 8 | 9 | __all__ = ['SplAtConv2d'] 10 | 11 | 12 | 13 | class SplAtConv2d(Module): 14 | """Split-Attention Conv2d 15 | """ 16 | def __init__(self, in_channels, channels, kernel_size, stride=(1, 1), padding=(0, 0), 17 | dilation=(1, 1), groups=1, bias=True, 18 | radix=2, reduction_factor=4, norm_layer=None, **kwargs): 19 | super(SplAtConv2d, self).__init__() 20 | padding = _pair(padding) 21 | inter_channels = max(in_channels*radix//reduction_factor, 32) 22 | self.radix = radix 23 | self.cardinality = groups 24 | self.channels = channels 25 | 26 | self.conv = Conv2d(in_channels=in_channels, out_channels=channels*radix, 27 | kernel_size=kernel_size, stride=stride, 28 | padding=padding, dilation=dilation, 29 | groups=groups*radix, bias=bias, **kwargs) 30 | self.use_bn = norm_layer is not None 31 | if self.use_bn: 32 | self.bn0 = norm_layer(channels*radix, affine=True) 33 | self.relu = ReLU(inplace=True) 34 | self.fc1 = Conv2d(channels, inter_channels, 1, groups=self.cardinality) 35 | if self.use_bn: 36 | self.bn1 = norm_layer(inter_channels, affine=True) 37 | self.fc2 = Conv2d(inter_channels, channels*radix, 1, groups=self.cardinality) 38 | self.rsoftmax = rSoftMax(radix, groups) 39 | 40 | def forward(self, x): 41 | x = self.conv(x) 42 | if self.use_bn: 43 | x = self.bn0(x) 44 | x = self.relu(x) 45 | 46 | batch, rchannel = x.shape[:2] 47 | if self.radix > 1: 48 | if torch.__version__ < '1.5': 49 | splited = torch.split(x, int(rchannel//self.radix), dim=1) 50 | else: 51 | splited = torch.split(x, rchannel//self.radix, dim=1) 52 | gap = sum(splited) 53 | else: 54 | gap = x 55 | gap = F.adaptive_avg_pool2d(gap, 1) 56 | gap = self.fc1(gap) 57 | 58 | if self.use_bn: 59 | gap = self.bn1(gap) 60 | gap = self.relu(gap) 61 | 62 | atten = self.fc2(gap) 63 | atten = self.rsoftmax(atten).view(batch, -1, 1, 1) 64 | 65 | if self.radix > 1: 66 | if torch.__version__ < '1.5': 67 | attens = torch.split(atten, int(rchannel//self.radix), dim=1) 68 | else: 69 | attens = torch.split(atten, rchannel//self.radix, dim=1) 70 | out = sum([att*split for (att, split) in zip(attens, splited)]) 71 | else: 72 | out = atten * x 73 | return out.contiguous() 74 | 75 | class rSoftMax(nn.Module): 76 | def __init__(self, radix, cardinality): 77 | super().__init__() 78 | self.radix = radix 79 | self.cardinality = cardinality 80 | 81 | def forward(self, x): 82 | batch = x.size(0) 83 | if self.radix > 1: 84 | x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2) 85 | x = F.softmax(x, dim=1) 86 | x = x.reshape(batch, -1) 87 | else: 88 | x = torch.sigmoid(x) 89 | return x 90 | -------------------------------------------------------------------------------- /option/options.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import os 3 | import os.path 4 | import logging 5 | 6 | 7 | def parse(opt_path, is_train=True): 8 | with open(opt_path, mode='r') as f: 9 | opt = yaml.load(f, Loader=yaml.FullLoader) 10 | opt['is_train'] = is_train 11 | 12 | # datasets 13 | for phase, dataset in opt['dataset'].items(): 14 | phase = phase.split('_')[0] 15 | dataset['phase'] = phase 16 | if dataset.get('dataroot', None) is not None: 17 | dataset['dataroot'] = os.path.expanduser(dataset['dataroot']) 18 | # path 19 | for key, path in opt['path'].items(): 20 | if path and key in opt['path'] and key != 'strict_load': 21 | opt['path'][key] = os.path.expanduser(path) 22 | opt['path']['root'] = os.path.abspath(os.path.join(__file__, os.path.pardir, os.path.pardir)) 23 | 24 | if is_train: 25 | experiments_root = os.path.join(opt['path']['root'], 'experiments', opt['name']) 26 | opt['path']['experiments_root'] = experiments_root 27 | opt['path']['models'] = os.path.join(experiments_root, 'models') 28 | opt['path']['training_state'] = os.path.join(experiments_root, 'training_state') 29 | opt['path']['log'] = experiments_root 30 | opt['path']['val_images'] = os.path.join(experiments_root, 'val_images') 31 | 32 | # change some options for debug mode 33 | if 'debug' in opt['name']: 34 | opt['train']['val_freq'] = 8 35 | opt['logger']['print_freq'] = 1 36 | opt['logger']['save_checkpoint_freq'] = 8 37 | else: # test 38 | results_root = os.path.join(opt['path']['root'], 'results', opt['name']) 39 | opt['path']['results_root'] = results_root 40 | opt['path']['log'] = results_root 41 | opt['path']['test_images'] = os.path.join(results_root,'test_images') 42 | 43 | return opt 44 | 45 | def check_resume(opt, resume_iter): 46 | '''Check resume states and pretrain_model paths''' 47 | logger = logging.getLogger('base') 48 | if opt['path']['resume_state']: 49 | if opt['path'].get('pretrain_model_G', None) is not None: 50 | logger.warning('pretrain_model path will be ignored when resuming training.') 51 | opt['path']['pretrain_model_G'] = os.path.join(opt['path']['models'], 52 | '{}_G.pth'.format(resume_iter)) 53 | logger.info('Set [pretrain_model_G] to ' + opt['path']['pretrain_model_G']) 54 | 55 | def dict2str(opt, indent_l=1): 56 | '''dict to string for logger''' 57 | msg = '' 58 | for k, v in opt.items(): 59 | if isinstance(v, dict): 60 | msg += ' ' * (indent_l * 2) + k + ':[\n' 61 | msg += dict2str(v, indent_l + 1) 62 | msg += ' ' * (indent_l * 2) + ']\n' 63 | else: 64 | msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n' 65 | return msg 66 | 67 | class NoneDict(dict): 68 | def __missing__(self, key): 69 | return None 70 | 71 | 72 | # convert to NoneDict, which return None for missing key. 73 | def dict_to_nonedict(opt): 74 | if isinstance(opt, dict): 75 | new_opt = dict() 76 | for key, sub_opt in opt.items(): 77 | new_opt[key] = dict_to_nonedict(sub_opt) 78 | return NoneDict(**new_opt) 79 | elif isinstance(opt, list): 80 | return [dict_to_nonedict(sub_opt) for sub_opt in opt] 81 | else: 82 | return opt 83 | -------------------------------------------------------------------------------- /option/test/EMFF_Test_Dataset.yaml: -------------------------------------------------------------------------------- 1 | name: MultiFocusFusion_TEST 2 | use_tb_logger: false 3 | model: FusionModel # reconstruct Defocus network. 4 | 5 | dataset: 6 | test: 7 | name: test 8 | dataroot: ~/Documents/fusion-dataset/MFIF/input 9 | batch_size: 1 10 | workers: 1 11 | part_name: . 12 | datashell: MFF 13 | 14 | network_G: 15 | in_nc: 3 16 | block_num: 2 17 | init: xavier 18 | hidden_channels: 128 19 | K: 4 20 | 21 | path: 22 | pretrain_model_G: ~ 23 | strict_load: true 24 | resume_state: /home/lpw/Documents/multi-exposure/SelfSupervisedFusion/experiments/Selftrained/models/models.pth 25 | 26 | logger: 27 | print_freq: 10 28 | save_checkpoint_freq: 20 #!!float 5e3 29 | -------------------------------------------------------------------------------- /option/test/IVF_Test_Dataset.yaml: -------------------------------------------------------------------------------- 1 | name: VisibleInfrareFusion_TEST 2 | use_tb_logger: false 3 | model: FusionModel # reconstruct Defocus network. 4 | 5 | dataset: 6 | test: 7 | name: test 8 | dataroot: ~/Documents/fusion-dataset/RoadScene 9 | batch_size: 1 10 | workers: 1 11 | infrare_name: cropinfrared 12 | visible_name: crop_LR_visible 13 | datashell: IVF 14 | 15 | network_G: 16 | in_nc: 3 17 | block_num: 2 18 | init: xavier 19 | hidden_channels: 128 20 | K: 4 21 | 22 | path: 23 | pretrain_model_G: ~ 24 | strict_load: true 25 | resume_state: /home/lpw/Documents/multi-exposure/SelfSupervisedFusion/experiments/Selftrained/models/models.pth 26 | 27 | logger: 28 | print_freq: 10 29 | save_checkpoint_freq: 20 #!!float 5e3 30 | -------------------------------------------------------------------------------- /option/test/MEF_Test_Dataset.yaml: -------------------------------------------------------------------------------- 1 | name: MultiExposureFusion_TEST 2 | use_tb_logger: false 3 | model: FusionModel # reconstruct Defocus network. 4 | 5 | dataset: 6 | test: 7 | name: test 8 | dataroot: ~/Documents/fusion-dataset/MEFB-dataset 9 | batch_size: 1 10 | img_size: 256 11 | workers: 1 12 | part_name: input 13 | 14 | 15 | network_G: 16 | in_nc: 3 17 | block_num: 2 18 | init: xavier 19 | hidden_channels: 128 20 | K: 4 21 | 22 | path: 23 | pretrain_model_G: ~ 24 | strict_load: true 25 | resume_state: /home/lpw/Documents/multi-exposure/SelfSupervisedFusion/experiments/Selftrained/models/models.pth 26 | 27 | logger: 28 | print_freq: 10 29 | save_checkpoint_freq: 20 #!!float 5e3 30 | -------------------------------------------------------------------------------- /option/test/MFF_Test_Dadaset.yaml: -------------------------------------------------------------------------------- 1 | name: MultiFocusFusion_TEST 2 | use_tb_logger: false 3 | model: FusionModel # reconstruct Defocus network. 4 | 5 | dataset: 6 | test: 7 | name: test 8 | dataroot: ~/Documents/fusion-dataset/ResultsMFIF 9 | batch_size: 1 10 | workers: 1 11 | part_name: . 12 | datashell: MFF 13 | 14 | network_G: 15 | in_nc: 3 16 | block_num: 2 17 | init: xavier 18 | hidden_channels: 128 19 | K: 4 20 | 21 | path: 22 | pretrain_model_G: ~ 23 | strict_load: true 24 | resume_state: /home/lpw/Documents/multi-exposure/SelfSupervisedFusion/experiments/Selftrained/models/models.pth 25 | 26 | logger: 27 | print_freq: 10 28 | save_checkpoint_freq: 20 #!!float 5e3 29 | -------------------------------------------------------------------------------- /option/test/SMEF_Test_Dataset.yaml: -------------------------------------------------------------------------------- 1 | name: MultiExposureFusionS_TEST 2 | use_tb_logger: false 3 | model: FusionModel # reconstruct Defocus network. 4 | 5 | dataset: 6 | test: 7 | name: test 8 | dataroot: ~/Documents/fusion-dataset/OVDataset 9 | batch_size: 1 10 | img_size: 256 11 | workers: 1 12 | part_name: Dataset_Part1 13 | 14 | 15 | network_G: 16 | in_nc: 3 17 | block_num: 2 18 | init: xavier 19 | hidden_channels: 128 20 | K: 4 21 | 22 | path: 23 | pretrain_model_G: ~ 24 | strict_load: true 25 | resume_state: /home/lpw/Documents/multi-exposure/SelfSupervisedFusion/experiments/Selftrained/models/models.pth 26 | 27 | logger: 28 | print_freq: 10 29 | save_checkpoint_freq: 20 #!!float 5e3 30 | -------------------------------------------------------------------------------- /option/test/TIVF_Test_Dataset.yaml: -------------------------------------------------------------------------------- 1 | name: VisibleInfrareFusionTNO_TEST 2 | use_tb_logger: false 3 | model: FusionModel # reconstruct Defocus network. 4 | 5 | dataset: 6 | test: 7 | name: test 8 | dataroot: ~/Documents/fusion-dataset/TNO_dataset 9 | batch_size: 1 10 | workers: 1 11 | datashell: IVF 12 | 13 | network_G: 14 | in_nc: 3 15 | block_num: 2 16 | init: xavier 17 | hidden_channels: 128 18 | K: 4 19 | 20 | path: 21 | pretrain_model_G: ~ 22 | strict_load: true 23 | resume_state: /home/lpw/Documents/multi-exposure/SelfSupervisedFusion/experiments/Selftrained/models/models.pth 24 | 25 | logger: 26 | print_freq: 10 27 | save_checkpoint_freq: 20 #!!float 5e3 28 | -------------------------------------------------------------------------------- /option/train/SelfTrained_SDataset.yaml: -------------------------------------------------------------------------------- 1 | name: Selftrained 2 | use_tb_logger: true 3 | model: UCMConv # reconstruct Camera response function network 4 | 5 | dataset: 6 | train: 7 | name: train 8 | dataroot: ~/Documents/fusion-dataset/coco-dataset/ 9 | filter: 10 | trainpairs: data/self_train_grid_data_ablation.txt 11 | extra_data: ~/Documents/fusion-dataset/DNIM/Image/ 12 | batch_size: 4 13 | image_size: 256 14 | max_iter: 40 15 | iter_size: 1 16 | workers: 4 17 | train_name: train2017 18 | 19 | val: 20 | name: val 21 | dataroot: ~/Documents/fusion-dataset/MEFB-dataset 22 | batch_size: 1 23 | workers: 1 24 | input_name: input 25 | 26 | 27 | network_G: 28 | in_nc: 3 29 | block_num: 2 30 | init: xavier 31 | hidden_channels: 128 32 | K: 4 33 | 34 | path: 35 | pretrain_model_G: ~ 36 | strict_load: true 37 | resume_state: ~ 38 | 39 | logger: 40 | print_freq: 10 41 | save_checkpoint_freq: 1 #!!float 5e3 42 | 43 | train: 44 | lr: !!float 1e-3 45 | beta1: 0.9 46 | beta2: 0.999 47 | max_grad_norm: 20 48 | max_grad_clip: 20 49 | niter: 500000 50 | epoch: 50 51 | 52 | lr_steps: [10, 20, 30, 40] 53 | lr_gamma: 0.5 54 | 55 | val_freq: 3 #!!float 5e3 56 | kernel_freq: 1 #!!float 5e3 57 | manual_seed: 1 58 | -------------------------------------------------------------------------------- /selftrain.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from utils import build_code_arch, util 3 | from data.self_mixpretrain_dataset import TrainDataset 4 | from torch.utils.data import DataLoader 5 | from models.UCSharedModelProCommon import UCSharedNetPro 6 | from torch.optim import Adam, lr_scheduler, AdamW, Adamax 7 | from torch.utils.tensorboard import SummaryWriter 8 | from tqdm import tqdm 9 | from loss.mix_fp_loss import SelfTrainLoss 10 | import os 11 | import torch 12 | 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('-opt', type=str, required=True, help='Multi Data Fusion: Path to option ymal file.') 16 | train_args = parser.parse_args() 17 | 18 | opt, resume_state = build_code_arch.build_resume_state(train_args) 19 | opt, logger, tb_logger = build_code_arch.build_logger(opt) 20 | 21 | for phase, dataset_opt in opt['dataset'].items(): 22 | if phase == 'train': 23 | train_dataset = TrainDataset(dataset_opt) 24 | train_loader = DataLoader( 25 | train_dataset, batch_size=dataset_opt['batch_size'], shuffle=True, 26 | num_workers=dataset_opt['workers'], pin_memory=True) 27 | logger.info('Number of train images: {:,d}'.format(len(train_dataset))) 28 | assert train_loader is not None 29 | 30 | 31 | model = UCSharedNetPro() 32 | 33 | 34 | optimizer = Adam(model.parameters(), betas=(opt['train']['beta1'], opt['train']['beta2']), 35 | lr=opt['train']['lr']) 36 | 37 | scheduler = lr_scheduler.MultiStepLR(optimizer=optimizer, 38 | milestones=opt['train']['lr_steps'], 39 | gamma=opt['train']['lr_gamma']) 40 | writer = SummaryWriter() 41 | model = model.cuda() 42 | model.train() 43 | 44 | # resume training 45 | if resume_state: 46 | logger.info('Resuming training from epoch: {}.'.format( 47 | resume_state['epoch'])) 48 | start_epoch = resume_state['epoch'] + 1 49 | optimizer.load_state_dict(resume_state['optimizers']) 50 | # scheduler.load_state_dict(resume_state['schedulers']) 51 | model.load_state_dict(resume_state['state_dict']) 52 | else: 53 | start_epoch = 0 54 | 55 | criterion = SelfTrainLoss() 56 | max_steps = len(train_loader) 57 | 58 | 59 | logger.info('Start training from epoch: {:d}'.format(start_epoch)) 60 | logger.info('# network parameters: {}'.format(sum(param.numel() for param in model.parameters()))) 61 | total_epochs = opt['train']['epoch'] 62 | 63 | 64 | for epoch in range(start_epoch, total_epochs + 1): 65 | criterion.is_train = True 66 | for index, train_data in tqdm(enumerate(train_loader)): 67 | # training 68 | # continue 69 | o_img, v_img, gt_img, train_type = train_data 70 | o_img = o_img.cuda() 71 | v_img = v_img.cuda() 72 | gt_img = gt_img.cuda() 73 | common_part, upper_part, lower_part, fusion_part = model(o_img, v_img) 74 | losses, iteres = criterion(img1 = o_img, img2 = v_img, 75 | gt_img=gt_img, common_part= common_part, 76 | upper_part=upper_part, lower_part = lower_part, 77 | fusion_part=fusion_part, b_input_type = train_type) 78 | grad_loss = losses["total_loss"] 79 | optimizer.zero_grad() 80 | grad_loss.backward() 81 | optimizer.step() 82 | current_step = epoch * max_steps + index 83 | # log 84 | message = ' '.format( 85 | epoch, current_step, scheduler.get_last_lr()[0]) 86 | for k, v in losses.items(): 87 | v = v.cpu().item() 88 | message += '{:s}: {:.4e} '.format(k, v) 89 | # tensorboard logger 90 | if opt['use_tb_logger'] and 'debug' not in opt['name']: 91 | tb_logger.add_scalar(k, v, iteres[k]) 92 | logger.info(message) 93 | 94 | # update learning rate 95 | scheduler.step() 96 | 97 | # save models and training states 98 | if epoch % opt['logger']['save_checkpoint_freq'] == 0: 99 | logger.info('Saving models and training states.') 100 | save_filename = '{}_{}.pth'.format(epoch, 'models') 101 | save_path = os.path.join(opt['path']['models'], save_filename) 102 | state_dict = model.state_dict() 103 | save_checkpoint = {'state_dict': state_dict, 104 | 'optimizers': optimizer.state_dict(), 105 | 'schedulers': scheduler.state_dict(), 106 | 'epoch': epoch} 107 | torch.save(save_checkpoint, save_path) 108 | torch.cuda.empty_cache() 109 | 110 | logger.info('Saving the final model.') 111 | save_filename = 'latest.pth' 112 | save_path = os.path.join(opt['path']['models'], save_filename) 113 | save_checkpoint = {"state_dict": model.state_dict(), 114 | 'optimizers': optimizer.state_dict(), 115 | 'schedulers': scheduler.state_dict(), 116 | "epoch": opt['train']['epoch']} 117 | torch.save(save_checkpoint, save_path) 118 | logger.info('End of training.') 119 | tb_logger.close() 120 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from utils import util 3 | from data.multi_exposure_dataset import TestDataset 4 | # from data.multi_focus_dataset import TestDataset 5 | # from data.multi_focus_dataset import TestMFFDataset as TestDataset 6 | # from data.visir_fusion_dataset import TestDataset 7 | # from data.visir_fusion_dataset import TestTNODataset as TestDataset 8 | from torch.utils.data import DataLoader 9 | from models.UCTestShareModelProCommon import UCTestSharedNetPro 10 | from tqdm import tqdm 11 | from torchvision.transforms import ToPILImage 12 | import os 13 | import torch 14 | import option.options as option 15 | import logging 16 | import torch.nn.functional as F 17 | 18 | 19 | 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('-opt', type=str, required=True, help='Multi Data Fusion: Path to option ymal file.') 22 | test_args = parser.parse_args() 23 | 24 | opt = option.parse(test_args.opt, is_train=False) 25 | util.mkdir_and_rename(opt['path']['results_root']) # rename results folder if exists 26 | util.mkdirs((path for key, path in opt['path'].items() if not key == 'results_root' 27 | and 'pretrain_model' not in key and 'resume' not in key)) 28 | util.setup_logger('base', opt['path']['log'], 'test_' + opt['name'], level=logging.INFO, 29 | screen=True, tofile=True) 30 | 31 | logger = logging.getLogger('base') 32 | logger.info(option.dict2str(opt)) 33 | 34 | torch.backends.cudnn.deterministic = True 35 | # convert to NoneDict, which returns None for missing keys 36 | opt = option.dict_to_nonedict(opt) 37 | 38 | 39 | dataset_opt = opt['dataset']['test'] 40 | test_dataset = TestDataset(dataset_opt) 41 | test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, 42 | num_workers=dataset_opt['workers'], pin_memory=True) 43 | logger.info('Number of test images in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_dataset))) 44 | 45 | 46 | model = UCTestSharedNetPro() 47 | device_id = torch.cuda.current_device() 48 | resume_state = torch.load(opt['path']['resume_state'], 49 | map_location=lambda storage, loc: storage.cuda(device_id)) 50 | 51 | model.load_state_dict(resume_state['state_dict']) 52 | model = model.cuda() 53 | model.eval() 54 | torch.cuda.empty_cache() 55 | 56 | 57 | avg_psnr = 0.0 58 | avg_ssim = 0.0 59 | avg_mae = 0.0 60 | avg_lpips = 0.0 61 | idx = 0 62 | model.eval() 63 | for test_data in tqdm(test_loader): 64 | with torch.no_grad(): 65 | o_img, u_img, root_name = test_data 66 | 67 | padding_number = 16 68 | 69 | o_img = F.pad(o_img, (padding_number, padding_number, padding_number, padding_number), mode='reflect') 70 | u_img = F.pad(u_img, (padding_number, padding_number, padding_number, padding_number), mode='reflect') 71 | o_img = o_img.cuda() 72 | u_img = u_img.cuda() 73 | 74 | common_part, upper_part, lower_part, fusion_part = model(o_img, u_img) 75 | 76 | o_img = o_img[:, :, padding_number:-padding_number, padding_number:-padding_number] 77 | u_img = u_img[:, :, padding_number:-padding_number, padding_number:-padding_number] 78 | common_part = common_part[:, :, padding_number:-padding_number, padding_number:-padding_number] 79 | upper_part = upper_part[:, :, padding_number:-padding_number, padding_number:-padding_number] 80 | lower_part = lower_part[:, :, padding_number:-padding_number, padding_number:-padding_number] 81 | fusion_part = fusion_part[:, :, padding_number:-padding_number, padding_number:-padding_number] 82 | print("ou img", o_img.shape, u_img.shape, fusion_part.shape, root_name) 83 | 84 | recover = fusion_part 85 | # Save ground truth 86 | img_dir = opt['path']['test_images'] 87 | 88 | common_img = ToPILImage()(common_part.clamp(0,1)[0]) 89 | c_img_path = os.path.join(img_dir, "{:s}_common.png".format(root_name[0])) 90 | common_img.save(c_img_path) 91 | 92 | upper_img = ToPILImage()(upper_part.clamp(0,1)[0]) 93 | upper_img_path = os.path.join(img_dir, "{:s}_upper.png".format(root_name[0])) 94 | upper_img.save(upper_img_path) 95 | 96 | lower_img = ToPILImage()(lower_part.clamp(0,1)[0]) 97 | lower_img_path = os.path.join(img_dir, "{:s}_lower.png".format(root_name[0])) 98 | lower_img.save(lower_img_path) 99 | 100 | over_img = ToPILImage()(o_img[0])#.convert('L') 101 | o_img_path = os.path.join(img_dir, "{:s}_over.png".format(root_name[0])) 102 | over_img.save(o_img_path) 103 | 104 | under_img = ToPILImage()(u_img[0])#.convert('L') 105 | u_img_path = os.path.join(img_dir, "{:s}_under.png".format(root_name[0])) 106 | under_img.save(u_img_path) 107 | 108 | recover_img = ToPILImage()(recover.clamp(0,1)[0])#.convert('L') 109 | save_img_path = os.path.join(img_dir, "{:s}_recover.png".format(root_name[0])) 110 | recover_img.save(save_img_path) 111 | # calculate psnr 112 | idx += 1 113 | 114 | avg_ssim += util.calculate_ssim(o_img, recover) + util.calculate_ssim(u_img, recover) 115 | logger.info("current {} over ssim is {:.4e} under ssim is {: .4e}".format(root_name[0] , 116 | util.calculate_ssim(o_img, recover), 117 | util.calculate_ssim(u_img, recover) 118 | )) 119 | 120 | 121 | avg_ssim = avg_ssim / idx 122 | # log 123 | logger.info('# Test #ssim: {:e}.'.format(avg_ssim)) 124 | logger_test = logging.getLogger('test') # validation logger 125 | logger_test.info('Test ssim: {:e}.'.format(avg_ssim)) 126 | logger.info('End of testing.') 127 | -------------------------------------------------------------------------------- /utils/Readme.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /utils/build_code_arch.py: -------------------------------------------------------------------------------- 1 | import option.options as option 2 | from torch.utils.tensorboard import SummaryWriter 3 | import socket 4 | from datetime import datetime 5 | import torch 6 | from utils import util 7 | import logging 8 | 9 | 10 | 11 | def build_resume_state(train_args): 12 | opt = option.parse(train_args.opt, is_train=True) 13 | 14 | # loading resume state if exists 15 | if opt['path'].get('resume_state', None): 16 | # distributed resuming: all load into default GPU 17 | device_id = torch.cuda.current_device() 18 | resume_state = torch.load(opt['path']['resume_state'], 19 | map_location=lambda storage, loc: storage.cuda(device_id)) 20 | option.check_resume(opt, resume_state['epoch']) # check resume options 21 | else: 22 | resume_state = None 23 | 24 | 25 | if resume_state is None: 26 | util.mkdir_and_rename(opt['path']['experiments_root']) # rename experiment folder if exists 27 | util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root' 28 | and 'pretrain_model' not in key and 'resume' not in key)) 29 | 30 | return opt, resume_state 31 | 32 | def build_logger(opt): 33 | util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO, 34 | screen=True, tofile=True) 35 | util.setup_logger('val', opt['path']['log'], 'val_' + opt['name'], level=logging.INFO, 36 | screen=True, tofile=True) 37 | 38 | logger = logging.getLogger('base') 39 | logger.info(option.dict2str(opt)) 40 | 41 | # tensorboard logger 42 | if opt['use_tb_logger'] and 'debug' not in opt['name']: 43 | current_time = datetime.now().strftime('%b%d_%H-%M-%S') 44 | CURRENT_DATETIME_HOSTNAME = '/' + current_time + '_' + socket.gethostname() 45 | tb_logger = SummaryWriter(log_dir='./tb_logger/' + opt['name'] + CURRENT_DATETIME_HOSTNAME) 46 | 47 | torch.backends.cudnn.deterministic = True 48 | # convert to NoneDict, which returns None for missing keys 49 | opt = option.dict_to_nonedict(opt) 50 | 51 | seed = opt['train']['manual_seed'] 52 | util.set_random_seed(seed) 53 | torch.backends.cudnn.benchmark = True 54 | 55 | return opt, logger, tb_logger 56 | 57 | -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from datetime import datetime 4 | import numpy as np 5 | import random 6 | import torch 7 | import math 8 | # from kornia.losses import ssim 9 | from torchvision.transforms import ToTensor 10 | from skimage.metrics import structural_similarity as compare_ssim 11 | 12 | 13 | def mkdir(path): 14 | if not os.path.exists(path): 15 | os.makedirs(path) 16 | 17 | 18 | def mkdirs(paths): 19 | if isinstance(paths, str): 20 | mkdir(paths) 21 | else: 22 | for path in paths: 23 | mkdir(path) 24 | 25 | 26 | def get_timestamp(): 27 | return datetime.now().strftime('%y%m%d-%H%M%S') 28 | 29 | 30 | def mkdir_and_rename(path): 31 | if os.path.exists(path): 32 | new_name = path + '_archived_' + get_timestamp() 33 | print('Path already exists. Rename it to [{:s}]'.format(new_name)) 34 | logger = logging.getLogger('base') 35 | logger.info('Path already exists. Rename it to [{:s}]'.format(new_name)) 36 | os.rename(path, new_name) 37 | os.makedirs(path) 38 | 39 | 40 | def calculate_psnr(img1, img2): 41 | # img1 and img2 have range [0, 255] 42 | img1 = ToTensor()(img1) 43 | img2 = ToTensor()(img2) 44 | img1 = img1.squeeze().permute(1, 2, 0).cpu().numpy() 45 | img2 = img2.squeeze().permute(1, 2, 0).cpu().numpy() 46 | img1 = img1.astype(np.float64) 47 | img2 = img2.astype(np.float64) 48 | mse = np.mean((img1 - img2)**2) 49 | if mse == 0: 50 | return float('inf') 51 | return 20 * math.log10(1.0 / math.sqrt(mse)) 52 | 53 | 54 | def calculate_ssim(img1, img2): 55 | # ssim_value = ssim(img1, img2, 11, 'mean') 56 | # return 1 - ssim_value.item() 57 | img1 = img1.squeeze().permute(1, 2, 0).cpu().numpy() 58 | img2 = img2.squeeze().permute(1, 2, 0).cpu().numpy() 59 | ssim_value = compare_ssim(img1, img2, data_range=1, multichannel=True) 60 | return ssim_value 61 | 62 | 63 | 64 | 65 | def calculate_mae(img1, img2): 66 | mae = torch.mean((img1 - img2).abs(), dim=[2, 3, 1]) 67 | return mae.squeeze().item() 68 | 69 | 70 | def setup_logger(logger_name, root, phase, level=logging.INFO, screen=False, tofile=False): 71 | '''set up logger''' 72 | lg = logging.getLogger(logger_name) 73 | formatter = logging.Formatter('%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s', 74 | datefmt='%y-%m-%d %H:%M:%S') 75 | lg.setLevel(level) 76 | if tofile: 77 | log_file = os.path.join(root, phase + '_{}.log'.format(get_timestamp())) 78 | fh = logging.FileHandler(log_file, mode='w') 79 | fh.setFormatter(formatter) 80 | lg.addHandler(fh) 81 | if screen: 82 | sh = logging.StreamHandler() 83 | sh.setFormatter(formatter) 84 | lg.addHandler(sh) 85 | 86 | 87 | def set_random_seed(seed): 88 | random.seed(seed) 89 | np.random.seed(seed) 90 | torch.manual_seed(seed) 91 | torch.cuda.manual_seed(seed) 92 | torch.cuda.manual_seed_all(seed) 93 | torch.random.manual_seed(seed) 94 | 95 | 96 | def squeeze2d(input, factor): 97 | if factor == 1: 98 | return input 99 | 100 | B, C, H, W = input.size() 101 | 102 | assert H % factor == 0 and W % factor == 0, "H or W modulo factor is not 0" 103 | 104 | x = input.view(B, C, H // factor, factor, W // factor, factor) 105 | x = x.permute(0, 1, 3, 5, 2, 4).contiguous() 106 | x = x.view(B, C * factor * factor, H // factor, W // factor) 107 | 108 | return x --------------------------------------------------------------------------------