├── .gitignore ├── README.md ├── config.py ├── dataloader_3channel.py ├── dataloader_4channel.py ├── demosaic └── demosaic.py ├── folder_ensemble.py ├── images ├── qualitative.png └── qualitative2.png ├── loss.py ├── models ├── __init__.py ├── model_3channel.py ├── model_4channel.py ├── modules_3channel.py ├── modules_4channel.py └── utils.py ├── proc_img.py ├── pytorch_ssim ├── __init__.py └── __pycache__ │ └── __init__.cpython-37.pyc ├── train_3channel.py ├── train_4channel.py ├── training_log └── multi_loss_log.txt ├── utils.py ├── validation_3channel.py ├── validation_4channel.py ├── validation_final.py └── validation_final_fullres.py /.gitignore: -------------------------------------------------------------------------------- 1 | runs/ 2 | weights/ 3 | saved_checkpoints/ 4 | best_weight/ 5 | __pycache__/ 6 | results/ 7 | train_tl.py 8 | train_tl_gan.py 9 | .DS_Store 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AWNet 2 | 3 | This is the official PyTorch implement of AWNet. Our Team, MacAI, recieved a Runner-Up award in the AIM 2020 Learned Image ISP challenge (ECCVW 2020). The proposed solution can achieve excellent MOS while remaining competetive in numerical result. See more details from our [paper](https://arxiv.org/abs/2008.09228). 4 | 5 | ## Abstract 6 | As the revolutionary improvement being made on the performance of smartphones over the last decade, mobile photography becomes one of the most common practices among the majority of smartphone users. However, due to the limited size of camera sensors on phone, the photographed image is still visually distinct to the one taken by the digital single-lens reflex (DSLR) camera. To narrow this performance gap, one is to redesign the camera image signal processor (ISP) to improve the image quality. Owing to the rapid rise of deep learning, recent works resort to the deep convolutional neural network (CNN) to develop a sophisticated data-driven ISP that directly maps the phone-captured image to the DSLR-captured one. In this paper, we introduce a novel network that utilizes the attention mechanism and wavelet transform, dubbed AWNet, to tackle this learnable image ISP problem. By adding the wavelet transform, our proposed method enables us to restore favorable image details from RAW information and achieve a larger receptive field while remaining high efficiency in terms of computational cost. The global context block is adopted in our method to learn the non-local color mapping for the generation of appealing RGB images. More importantly, this block alleviates the influence of image misalignment occurred on the provided dataset. Experimental results indicate the advances of our design in both qualitative and quantitative measurements. 7 | 8 | ## Presentation Video 9 | [![Watch the video](https://img.youtube.com/vi/HlrzVFMUwCQ/0.jpg)](https://youtu.be/HlrzVFMUwCQ) 10 | 11 | ## Pretrained Models & Dataset 12 | 1. Download [demosaiced_model](https://drive.google.com/file/d/1uhohG6cYkM_-W4dGLl8yGlo85UMF6KEK/view?usp=sharing) and [Raw_model](https://drive.google.com/file/d/1jBwEm7_zbU55qOlGAVuOAQ8BIwx2g7Fw/view?usp=sharing) and place into the folder ```./best_weight```. 13 | 2. Download the ZRR dataset from [here](https://competitions.codalab.org/competitions/24718) 14 | 15 | 16 | If you want to reproduce our result for the AIM 2020 challenge, please follow the steps in ```old_version``` branch. The model in ```master``` has been modified to match the architecture described in our paper. 17 | 18 | ## Training 19 | 1. Generate pseudo-demosicing images for 3-channel-input model. 20 | ``` 21 | cd demosaic 22 | python demosaic.py -data -save 23 | ``` 24 | Then, move the resulting folder of demosaicing images under your root dataset directory. Please make sure your dataset structure is the same as what we show in the Training/Validation Dataset Strcuture section. 25 | 2. Change configuration in ```config.py``` accordingly and run 26 | ```python train_3channel.py``` or ```python train_4channel.py``` 27 | 28 | ## Testing 29 | 1. Generate pseudo-demosicing images for 3-channel-input model. 30 | ``` 31 | cd demosaic 32 | python demosaic.py -data -save 33 | ``` 34 | Then, move the resulting folder of demosaicing images under your root dataset directory. 35 | Please make sure your dataset structure is the same as what we show in the Training/Validation Dataset Strcuture section. 36 | 2. To reproduce our final results from testing board, run ```python validation_final.py```. 37 | 3. To reporduce our full resolution result, run ```python validation_final_fullres.py```. 38 | 39 | ## Qualitative Results 40 | Full resolution: 41 | 42 | 43 | Compare with other state-of-the-arts: 44 |
45 | 46 |
47 | 48 | ## Acknowledgement 49 | We thank the authors of [MWCNN](https://github.com/lpj0/MWCNN.git), [GCNet](https://github.com/xvjiarui/GCNet.git), and [Pytorch_SSIM](https://github.com/Po-Hsun-Su/pytorch-ssim). Part of our code is built upon their modules. 50 | 51 | ## Citation 52 | If our work helps your research, please consider to cite our paper: 53 | ``` 54 | @article{dai2020awnet, 55 | title={AWNet: Attentive Wavelet Network for Image ISP}, 56 | author={Dai, Linhui and Liu, Xiaohong and Li, Chengqi and Chen, Jun}, 57 | journal={arXiv preprint arXiv:2008.09228}, 58 | year={2020} 59 | } 60 | ``` -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | class trainConfig: 5 | learning_rate = [1e-4, 5e-5, 1e-5, 5e-6, 1e-6] 6 | print_loss = False 7 | 8 | batch_size = 3 9 | epoch = 50 10 | pretrain = True 11 | 12 | data_dir = '/home/charliedai/aim2020/Dataset' 13 | 14 | checkpoints = './saved_checkpoints' 15 | if not os.path.exists(checkpoints): 16 | os.makedirs(checkpoints) 17 | save_best = './best_weight' 18 | if not os.path.exists(save_best): 19 | os.makedirs(save_best) 20 | -------------------------------------------------------------------------------- /dataloader_3channel.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from torchvision import transforms 3 | import numpy as np 4 | import imageio 5 | import PIL.Image as Image 6 | import torch 7 | import os 8 | import random 9 | from utils import fun_ensemble 10 | 11 | to_tensor = transforms.Compose([transforms.ToTensor()]) 12 | 13 | 14 | def extract_bayer_channels(raw): 15 | # Reshape the input bayer image 16 | 17 | ch_B = raw[1::2, 1::2] 18 | ch_Gb = raw[0::2, 1::2] 19 | ch_R = raw[0::2, 0::2] 20 | ch_Gr = raw[1::2, 0::2] 21 | 22 | RAW_combined = np.dstack((ch_B, ch_Gb, ch_R, ch_Gr)) 23 | RAW_norm = RAW_combined.astype(np.float32) / (4 * 255) 24 | 25 | return RAW_norm 26 | 27 | 28 | class LoadData(Dataset): 29 | def __init__(self, 30 | dataset_dir, 31 | dataset_size, 32 | dslr_scale, 33 | test=False, 34 | if_rotate=True, 35 | if_filp=True, 36 | is_ensemble=False, 37 | is_rescale=False): 38 | self.is_ensemble = is_ensemble 39 | self.is_test = test 40 | self.if_rotate = if_rotate 41 | self.if_filp = if_filp 42 | self.is_rescale = is_rescale 43 | if self.is_test: 44 | self.raw_dir = os.path.join(dataset_dir, 'test', 'test_vis') 45 | self.dslr_dir = os.path.join(dataset_dir, 'test', 'canon') 46 | self.dataset_size = dataset_size 47 | else: 48 | self.raw_dir = os.path.join(dataset_dir, 'train', 'train_vis') 49 | self.dslr_dir = os.path.join(dataset_dir, 'train', 'canon') 50 | 51 | self.dataset_size = dataset_size 52 | self.scale = dslr_scale # dslr_scale 53 | 54 | self.tf1 = transforms.Compose([ 55 | transforms.RandomVerticalFlip(p=1), 56 | ]) 57 | self.tf2 = transforms.Compose([ 58 | transforms.RandomHorizontalFlip(p=1), 59 | ]) 60 | 61 | self.rescale = transforms.Compose([ 62 | transforms.Resize((self.scale, self.scale)), 63 | ]) 64 | 65 | self.toTensor = transforms.Compose([transforms.ToTensor()]) 66 | self.rotate = transforms.Compose( 67 | [transforms.RandomRotation(degrees=(-45, 45))]) 68 | 69 | def __len__(self): 70 | return self.dataset_size 71 | 72 | def __getitem__(self, idx): 73 | raw_image = Image.open(os.path.join(self.raw_dir, str(idx) + ".png")) 74 | dslr_image = Image.open(os.path.join(self.dslr_dir, str(idx) + ".jpg")) 75 | 76 | if not self.is_test: 77 | if self.if_rotate: 78 | p = random.randint(0, 2) 79 | if p == 0: 80 | raw_image = self.tf1(raw_image) 81 | dslr_image = self.tf1(dslr_image) 82 | elif p == 1: 83 | raw_image = self.tf2(raw_image) 84 | dslr_image = self.tf2(dslr_image) 85 | 86 | if self.is_rescale: 87 | raw_image = self.rescale(raw_image) 88 | dslr_image = self.rescale(dslr_image) 89 | 90 | if self.is_ensemble: 91 | raw_image = fun_ensemble(raw_image) 92 | raw_image = [self.toTensor(x) for x in raw_image] 93 | else: 94 | raw_image = self.toTensor(raw_image) 95 | 96 | dslr_image = self.toTensor(dslr_image) 97 | 98 | return raw_image, dslr_image, str(idx) 99 | 100 | 101 | class LoadData_real(Dataset): 102 | def __init__(self, dataset_dir, is_ensemble=False): 103 | self.is_ensemble = is_ensemble 104 | 105 | self.raw_dir = dataset_dir 106 | 107 | # self.dataset_size = 670 108 | self.dataset_size = 42 109 | 110 | self.toTensor = transforms.Compose([transforms.ToTensor()]) 111 | 112 | def __len__(self): 113 | return self.dataset_size 114 | 115 | def __getitem__(self, idx): 116 | raw_image = Image.open( 117 | os.path.join(self.raw_dir, 118 | str(idx + 1) + ".png")) 119 | 120 | if self.is_ensemble: 121 | raw_image = fun_ensemble(raw_image) 122 | raw_image = [self.toTensor(x) for x in raw_image] 123 | else: 124 | raw_image = self.toTensor(raw_image) 125 | 126 | return raw_image, str(idx + 1) 127 | -------------------------------------------------------------------------------- /dataloader_4channel.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from torchvision import transforms 3 | import numpy as np 4 | import imageio 5 | import PIL.Image as Image 6 | import torch 7 | import os 8 | 9 | to_tensor = transforms.Compose([transforms.ToTensor()]) 10 | 11 | 12 | def extract_bayer_channels(raw): 13 | # Reshape the input bayer image 14 | 15 | ch_B = raw[1::2, 1::2] 16 | ch_Gb = raw[0::2, 1::2] 17 | ch_R = raw[0::2, 0::2] 18 | ch_Gr = raw[1::2, 0::2] 19 | 20 | RAW_combined = np.dstack((ch_B, ch_Gb, ch_R, ch_Gr)) 21 | RAW_norm = RAW_combined.astype(np.float32) / (4 * 255) 22 | 23 | return RAW_norm 24 | 25 | 26 | class LoadData(Dataset): 27 | def __init__(self, dataset_dir, dataset_size, dslr_scale, test=False): 28 | 29 | if test: 30 | self.raw_dir = os.path.join(dataset_dir, 'test', 'huawei_raw') 31 | self.dslr_dir = os.path.join(dataset_dir, 'test', 'canon') 32 | self.dataset_size = dataset_size 33 | else: 34 | self.raw_dir = os.path.join(dataset_dir, 'train', 'huawei_raw') 35 | self.dslr_dir = os.path.join(dataset_dir, 'train', 'canon') 36 | 37 | self.dataset_size = dataset_size 38 | self.scale = dslr_scale 39 | self.test = test 40 | 41 | def __len__(self): 42 | return self.dataset_size 43 | 44 | def __getitem__(self, idx): 45 | 46 | raw_image = np.asarray( 47 | imageio.imread(os.path.join(self.raw_dir, 48 | str(idx) + '.png'))) 49 | raw_image = extract_bayer_channels(raw_image) 50 | raw_image = torch.from_numpy(raw_image.transpose((2, 0, 1))) 51 | 52 | dslr_image = imageio.imread( 53 | os.path.join(self.dslr_dir, 54 | str(idx) + ".jpg")) 55 | dslr_image = np.asarray(dslr_image) 56 | dslr_img_shape = dslr_image.shape 57 | dslr_image = np.float32( 58 | np.array( 59 | Image.fromarray(dslr_image).resize( 60 | (dslr_img_shape[0] // self.scale, 61 | dslr_img_shape[1] // self.scale)))) / 255.0 62 | dslr_image = torch.from_numpy(dslr_image.transpose((2, 0, 1))) 63 | return raw_image, dslr_image, str(idx) 64 | -------------------------------------------------------------------------------- /demosaic/demosaic.py: -------------------------------------------------------------------------------- 1 | import imageio 2 | import os 3 | import numpy as np 4 | import cv2 5 | import argparse 6 | 7 | 8 | def extract_bayer_channels(raw): 9 | 10 | # Reshape the input bayer image 11 | 12 | ch_B = raw[1::2, 1::2] 13 | ch_Gb = raw[0::2, 1::2] 14 | ch_R = raw[0::2, 0::2] 15 | ch_Gr = raw[1::2, 0::2] 16 | 17 | ch_B = cv2.resize(ch_B, (ch_B.shape[1] * 2, ch_B.shape[0] * 2)) 18 | ch_R = cv2.resize(ch_R, (ch_R.shape[1] * 2, ch_R.shape[0] * 2)) 19 | ch_Gb = cv2.resize(ch_Gb, (ch_Gb.shape[1] * 2, ch_Gb.shape[0] * 2)) 20 | ch_Gr = cv2.resize(ch_Gr, (ch_Gr.shape[1] * 2, ch_Gr.shape[0] * 2)) 21 | 22 | ch_G = ch_Gb / 2 + ch_Gr / 2 23 | RAW_combined = np.dstack((ch_B, ch_G, ch_R)) 24 | RAW_norm = RAW_combined.astype(np.float32) / (3 * 255) 25 | 26 | return RAW_norm 27 | 28 | 29 | def process_image(path, save, file): 30 | raw_image = np.asarray(imageio.imread(path)) 31 | demosaic = extract_bayer_channels(raw_image) 32 | save_path = './{}/{}'.format(save, file) 33 | print(save_path) 34 | print(save_path) 35 | cv2.imwrite(save_path, demosaic * 255) 36 | 37 | 38 | def batch_process(root, save): 39 | if not os.path.exists(save): 40 | os.makedirs(save) 41 | for file in os.listdir(root): 42 | path = os.path.join(root, file) 43 | print(path) 44 | process_image(path, save, file) 45 | 46 | 47 | if __name__ == '__main__': 48 | 49 | parser = argparse.ArgumentParser() 50 | parser.add_argument( 51 | '--data', 52 | type=str, 53 | default='/home/charliedai/aim2020/Dataset/train/huawei_raw/', 54 | help='data directory of your raw images') 55 | parser.add_argument( 56 | '--save', type=str, default='./AIM2020_ISP_fullres_test_raw_pseudo_demosaicing', help='save image folder') 57 | args = parser.parse_args() 58 | batch_process(args.data, args.save) 59 | -------------------------------------------------------------------------------- /folder_ensemble.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | 4 | path1 = './result_fullres_4channel/' 5 | path2 = './result_fullres_3channel/' 6 | save_path = './final_result_fullres' 7 | if not os.path.exists(save_path): 8 | os.makedirs(save_path) 9 | for file in os.listdir(path1): 10 | image_path1 = os.path.join(path1, file) 11 | image_path2 = os.path.join(path2, file) 12 | img1 = cv2.imread(image_path1) 13 | img2 = cv2.imread(image_path2) 14 | out = img1 / 2 + img2 / 2 15 | save_path1 = os.path.join(save_path, file) 16 | cv2.imwrite(save_path1, out) -------------------------------------------------------------------------------- /images/qualitative.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Charlie0215/AWNet-Attentive-Wavelet-Network-for-Image-ISP/a1088bcef3a98fa6ba9e8febb2a40dd9b5bf3e20/images/qualitative.png -------------------------------------------------------------------------------- /images/qualitative2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Charlie0215/AWNet-Attentive-Wavelet-Network-for-Image-ISP/a1088bcef3a98fa6ba9e8febb2a40dd9b5bf3e20/images/qualitative2.png -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torchvision.models.vgg import vgg16 4 | import torch.nn.functional as F 5 | import pytorch_ssim 6 | 7 | 8 | class Loss(nn.Module): 9 | def __init__(self): 10 | super(Loss, self).__init__() 11 | vgg = vgg16(pretrained=True) 12 | self.loss_network = nn.Sequential(*list(vgg.features)[:31]).eval() 13 | for param in self.loss_network.parameters(): 14 | param.requires_grad = False 15 | self.mse_loss = nn.MSELoss() 16 | self.ssim_loss = pytorch_ssim.SSIM() 17 | 18 | def forward(self, y, target, fake_label=None): 19 | total_loss, losses = self.image_restoration(y, target) 20 | if fake_label: 21 | total_loss += fake_label 22 | return total_loss, losses 23 | 24 | def image_restoration(self, pred, target): 25 | perceptual_loss = self.perceptual_loss(pred, target) 26 | l1 = F.l1_loss(pred, target) 27 | ssim_loss = 1 - self.ssim_loss(pred, target) 28 | del pred, target 29 | Loss = perceptual_loss + l1 + ssim_loss 30 | 31 | return Loss, (perceptual_loss, l1, ssim_loss) 32 | 33 | def perceptual_loss(self, out_images, target_images): 34 | 35 | loss = self.mse_loss( 36 | self.loss_network(out_images), 37 | self.loss_network(target_images)) / 3 38 | return loss 39 | 40 | def CharbonnierLoss(self, y, target, eps=1e-6): 41 | diff = y - target 42 | loss = torch.mean(torch.sqrt(diff * diff + eps)) 43 | return loss 44 | 45 | def tv_loss(self, x, TVLoss_weight=1): 46 | def _tensor_size(t): 47 | return t.size()[1] * t.size()[2] * t.size()[3] 48 | 49 | batch_size = x.size()[0] 50 | h_x = x.size()[2] 51 | w_x = x.size()[3] 52 | count_h = _tensor_size(x[:, :, 1:, :]) 53 | count_w = _tensor_size(x[:, :, :, 1:]) 54 | h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum() 55 | w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum() 56 | return TVLoss_weight * 2 * ( 57 | h_tv / count_h + w_tv / count_w) / batch_size 58 | 59 | 60 | class ms_Loss(Loss): 61 | def __init__(self): 62 | super(ms_Loss, self).__init__() 63 | 64 | def forward(self, y, target, texture_img=None): 65 | loss = 0 66 | total_l1 = 0 67 | total_perceptual = 0 68 | total_ssim = 0 69 | # scale 1 70 | if texture_img: 71 | l1 = self.CharbonnierLoss(texture_img, target) * 0.25 72 | total_l1 += l1 73 | loss += l1 74 | for i in range(len(y)): 75 | if i == 0: 76 | perceptual_loss = self.perceptual_loss(y[i], target) 77 | ssim_loss = 1 - self.ssim_loss(y[i], target) 78 | l1 = self.CharbonnierLoss(y[i], target) 79 | loss += 0.05 * ssim_loss + l1 + 0.25 * perceptual_loss 80 | total_l1 += l1 81 | total_perceptual += perceptual_loss 82 | total_ssim += ssim_loss 83 | elif i == 1 or i == 2: 84 | h, w = y[i].size(2), y[i].size(3) 85 | target = F.interpolate(target, size=(h, w)) 86 | perceptual_loss = self.perceptual_loss(y[i], target) 87 | l1 = F.smooth_l1_loss(y[i], target) 88 | l1 = self.CharbonnierLoss(y[i], target) 89 | total_l1 += l1 90 | loss += perceptual_loss + l1 * 0.25 91 | total_perceptual += perceptual_loss 92 | else: 93 | h, w = y[i].size(2), y[i].size(3) 94 | target = F.interpolate(target, size=(h, w)) 95 | l1 = self.CharbonnierLoss(y[i], target) 96 | total_l1 += l1 97 | loss += l1 98 | 99 | return loss, (total_perceptual, total_l1, total_ssim) 100 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Charlie0215/AWNet-Attentive-Wavelet-Network-for-Image-ISP/a1088bcef3a98fa6ba9e8febb2a40dd9b5bf3e20/models/__init__.py -------------------------------------------------------------------------------- /models/model_3channel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from models.utils import DWT, IWT 5 | from models.modules_3channel import shortcutblock, GCIWTResUp, GCWTResDown, GCRDB, ContextBlock2d, SE_net, PSPModule 6 | import functools 7 | 8 | 9 | class AWNet(nn.Module): 10 | def __init__(self, in_channels, out_channels, block=[2,2,2,3,4]): 11 | super().__init__() 12 | 13 | self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=3, stride=1, padding=1) 14 | 15 | #layer1 16 | _layer_1_dw = [] 17 | for i in range(block[0]): 18 | _layer_1_dw.append(GCRDB(64, ContextBlock2d)) 19 | _layer_1_dw.append(GCWTResDown(64, ContextBlock2d, norm_layer=None)) 20 | self.layer1 = nn.Sequential(*_layer_1_dw) 21 | 22 | #layer 2 23 | _layer_2_dw = [] 24 | for i in range(block[1]): 25 | _layer_2_dw.append(GCRDB(128, ContextBlock2d)) 26 | _layer_2_dw.append(GCWTResDown(128, ContextBlock2d, norm_layer=None)) 27 | self.layer2 = nn.Sequential(*_layer_2_dw) 28 | 29 | #layer 3 30 | _layer_3_dw = [] 31 | for i in range(block[2]): 32 | _layer_3_dw.append(GCRDB(256, ContextBlock2d)) 33 | _layer_3_dw.append(GCWTResDown(256, ContextBlock2d, norm_layer=None)) 34 | self.layer3 = nn.Sequential(*_layer_3_dw) 35 | 36 | #layer 4 37 | _layer_4_dw = [] 38 | for i in range(block[3]): 39 | _layer_4_dw.append(GCRDB(512, ContextBlock2d)) 40 | _layer_4_dw.append(GCWTResDown(512, ContextBlock2d, norm_layer=None)) 41 | self.layer4 = nn.Sequential(*_layer_4_dw) 42 | 43 | #layer 5 44 | _layer_5_dw = [] 45 | for i in range(block[4]): 46 | _layer_5_dw.append(GCRDB(1024, ContextBlock2d)) 47 | self.layer5 = nn.Sequential(*_layer_5_dw) 48 | 49 | #upsample4 50 | self.layer4_up = GCIWTResUp(1024, ContextBlock2d) 51 | 52 | #upsample3 53 | self.layer3_up = GCIWTResUp(512, ContextBlock2d) 54 | 55 | #upsample2 56 | self.layer2_up = GCIWTResUp(256, ContextBlock2d) 57 | 58 | #upsample1 59 | self.layer1_up = GCIWTResUp(128, ContextBlock2d) 60 | 61 | self.sc_x1 = shortcutblock(64) 62 | self.sc_x2 = shortcutblock(128) 63 | self.sc_x3 = shortcutblock(256) 64 | self.sc_x4 = shortcutblock(512) 65 | 66 | self.scale_5 = nn.Conv2d(1024, out_channels, kernel_size=3, padding=1) 67 | self.scale_4 = nn.Conv2d(512, out_channels, kernel_size=3, padding=1) 68 | self.scale_3 = nn.Conv2d(256, out_channels, kernel_size=3, padding=1) 69 | self.scale_2 = nn.Conv2d(128, out_channels, kernel_size=3, padding=1) 70 | 71 | self.final_conv = nn.Conv2d(64, out_channels, kernel_size=3, padding=1) 72 | 73 | self.se1 = SE_net(64, 64) 74 | self.se2 = SE_net(128, 128) 75 | self.se3 = SE_net(256, 256) 76 | self.se4 = SE_net(512, 512) 77 | self.se5 = SE_net(1024, 1024) 78 | 79 | self.enhance = PSPModule(features=64, out_features=64, sizes=(1, 2, 3, 6)) 80 | 81 | def forward(self, x, target=None, teacher_latent=None): 82 | 83 | x1 = self.conv1(x) 84 | 85 | x2, x2_dwt = self.layer1(self.se1(x1)) 86 | x3, x3_dwt = self.layer2(self.se2(x2)) 87 | x4, x4_dwt = self.layer3(self.se3(x3)) 88 | x5, x5_dwt = self.layer4(self.se4(x4)) 89 | x5_latent = self.layer5(self.se5(x5)) 90 | 91 | x5_out = self.scale_5(x5_latent) 92 | x5_out = F.sigmoid(x5_out) 93 | x4_up = self.layer4_up(x5_latent, x5_dwt) + self.sc_x4(x4) 94 | x4_out = self.scale_4(x4_up) 95 | x4_out = F.sigmoid(x4_out) 96 | x3_up = self.layer3_up(x4_up, x4_dwt) + self.sc_x3(x3) 97 | x3_out = self.scale_3(x3_up) 98 | x3_out = F.sigmoid(x3_out) 99 | x2_up = self.layer2_up(x3_up, x3_dwt) + self.sc_x2(x2) 100 | x2_out = self.scale_2(x2_up) 101 | x2_out = F.sigmoid(x2_out) 102 | x1_up = self.layer1_up(x2_up, x2_dwt) + self.sc_x1(x1) 103 | x1_up = self.enhance(x1_up) 104 | out = self.final_conv(x1_up) 105 | out = F.sigmoid(out) 106 | 107 | return (out, x2_out, x3_out, x4_out, x5_out) , x5_latent -------------------------------------------------------------------------------- /models/model_4channel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from models.utils import DWT, IWT 5 | from models.modules_4channel import shortcutblock, GCIWTResUp, GCWTResDown, GCRDB, ContextBlock2d, SE_net, PSPModule, \ 6 | last_upsample 7 | import functools 8 | 9 | 10 | class AWNet(nn.Module): 11 | def __init__(self, in_channels, out_channels, block=[2, 2, 2, 3, 4]): 12 | super().__init__() 13 | 14 | self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=3, stride=1, padding=1) 15 | # layer1 16 | _layer_1_dw = [] 17 | for i in range(block[0]): 18 | _layer_1_dw.append(GCRDB(64, ContextBlock2d)) 19 | _layer_1_dw.append(GCWTResDown(64, ContextBlock2d, norm_layer=None)) 20 | self.layer1 = nn.Sequential(*_layer_1_dw) 21 | 22 | # layer 2 23 | _layer_2_dw = [] 24 | for i in range(block[1]): 25 | _layer_2_dw.append(GCRDB(128, ContextBlock2d)) 26 | _layer_2_dw.append(GCWTResDown(128, ContextBlock2d, norm_layer=None)) 27 | self.layer2 = nn.Sequential(*_layer_2_dw) 28 | 29 | # layer 3 30 | _layer_3_dw = [] 31 | for i in range(block[2]): 32 | _layer_3_dw.append(GCRDB(256, ContextBlock2d)) 33 | _layer_3_dw.append(GCWTResDown(256, ContextBlock2d, norm_layer=None)) 34 | self.layer3 = nn.Sequential(*_layer_3_dw) 35 | 36 | # layer 4 37 | _layer_4_dw = [] 38 | for i in range(block[3]): 39 | _layer_4_dw.append(GCRDB(512, ContextBlock2d)) 40 | _layer_4_dw.append(GCWTResDown(512, ContextBlock2d, norm_layer=None)) 41 | self.layer4 = nn.Sequential(*_layer_4_dw) 42 | 43 | # layer 5 44 | _layer_5_dw = [] 45 | for i in range(block[4]): 46 | _layer_5_dw.append(GCRDB(1024, ContextBlock2d)) 47 | self.layer5 = nn.Sequential(*_layer_5_dw) 48 | 49 | # upsample4 50 | self.layer4_up = GCIWTResUp(2048, ContextBlock2d) 51 | 52 | # upsample3 53 | self.layer3_up = GCIWTResUp(1024, ContextBlock2d) 54 | 55 | # upsample2 56 | self.layer2_up = GCIWTResUp(512, ContextBlock2d) 57 | 58 | # upsample1 59 | self.layer1_up = GCIWTResUp(256, ContextBlock2d) 60 | 61 | self.sc_x1 = shortcutblock(64, 64) 62 | self.sc_x2 = shortcutblock(128, 128) 63 | self.sc_x3 = shortcutblock(256, 256) 64 | self.sc_x4 = shortcutblock(512, 512) 65 | 66 | self.scale_5 = nn.Conv2d(1024, out_channels, kernel_size=3, padding=1) 67 | self.scale_4 = nn.Conv2d(512, out_channels, kernel_size=3, padding=1) 68 | self.scale_3 = nn.Conv2d(256, out_channels, kernel_size=3, padding=1) 69 | self.scale_2 = nn.Conv2d(128, out_channels, kernel_size=3, padding=1) 70 | self.scale_1 = nn.Conv2d(64, out_channels, kernel_size=3, padding=1) 71 | 72 | self.se1 = SE_net(64, 64) 73 | self.se2 = SE_net(128, 128) 74 | self.se3 = SE_net(256, 256) 75 | self.se4 = SE_net(512, 512) 76 | self.se5 = SE_net(1024, 1024) 77 | 78 | self.last = last_upsample() 79 | 80 | def forward(self, x, target=None, teacher_latent=None): 81 | 82 | x1 = self.conv1(x) 83 | 84 | x2, x2_dwt = self.layer1(self.se1(x1)) 85 | x3, x3_dwt = self.layer2(self.se2(x2)) 86 | x4, x4_dwt = self.layer3(self.se3(x3)) 87 | x5, x5_dwt = self.layer4(self.se4(x4)) 88 | x5_latent = self.layer5(self.se5(x5)) 89 | 90 | x5_out = self.scale_5(x5_latent) 91 | x5_out = F.sigmoid(x5_out) 92 | x4_up = self.layer4_up(x5_latent, x5_dwt) + self.sc_x4(x4) 93 | x4_out = self.scale_4(x4_up) 94 | x4_out = F.sigmoid(x4_out) 95 | x3_up = self.layer3_up(x4_up, x4_dwt) + self.sc_x3(x3) 96 | x3_out = self.scale_3(x3_up) 97 | x3_out = F.sigmoid(x3_out) 98 | x2_up = self.layer2_up(x3_up, x3_dwt) + self.sc_x2(x2) 99 | x2_out = self.scale_2(x2_up) 100 | x2_out = F.sigmoid(x2_out) 101 | x1_up = self.layer1_up(x2_up, x2_dwt) + self.sc_x1(x1) 102 | x1_out = self.scale_1(x1_up) 103 | x1_out = F.sigmoid(x1_out) 104 | out = self.last(x1_up) 105 | return (out, x1_out, x2_out, x3_out, x4_out, x5_out), x5_latent -------------------------------------------------------------------------------- /models/modules_3channel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from models.utils import DWT, IWT 5 | 6 | 7 | class GCRDB(nn.Module): 8 | def __init__(self, in_channels, att_block, num_dense_layer=6, growth_rate=16): 9 | super(GCRDB, self).__init__() 10 | _in_channels = in_channels 11 | modules = [] 12 | 13 | for i in range(num_dense_layer): 14 | modules.append(MakeDense(_in_channels, growth_rate)) 15 | _in_channels += growth_rate 16 | 17 | self.residual_dense_layers = nn.Sequential(*modules) 18 | self.conv_1x1 = nn.Conv2d(_in_channels, in_channels, kernel_size=1, padding=0) 19 | self.final_att = att_block(inplanes=in_channels, planes=in_channels) 20 | 21 | def forward(self, x): 22 | out_rdb = self.residual_dense_layers(x) 23 | out_rdb = self.conv_1x1(out_rdb) 24 | out_rdb = self.final_att(out_rdb) 25 | out = out_rdb + x 26 | return out 27 | 28 | 29 | class MakeDense(nn.Module): 30 | def __init__(self, in_channels, growth_rate, kernel_size=3): 31 | super(MakeDense, self).__init__() 32 | self.conv = nn.Conv2d(in_channels, growth_rate, kernel_size=kernel_size, padding=(kernel_size - 1) // 2) 33 | self.norm_layer = nn.BatchNorm2d(growth_rate) 34 | 35 | def forward(self, x): 36 | out = F.relu(self.conv(x)) 37 | out = self.norm_layer(out) 38 | out = torch.cat((x, out), 1) 39 | return out 40 | 41 | 42 | class SE_net(nn.Module): 43 | def __init__(self, in_channels, out_channels, reduction=4, attention=True): 44 | super().__init__() 45 | self.attention = attention 46 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 47 | self.conv_in = nn.Conv2d(in_channels, in_channels, kernel_size=1, padding=0) 48 | self.conv_mid = nn.Conv2d(in_channels, in_channels // reduction, kernel_size=1, padding=0) 49 | self.conv_out = nn.Conv2d(in_channels // reduction, out_channels, kernel_size=1, padding=0) 50 | 51 | self.x_red = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0) 52 | 53 | def forward(self, x): 54 | 55 | if self.attention is True: 56 | y = self.avg_pool(x) 57 | y = F.relu(self.conv_in(y)) 58 | y = F.relu(self.conv_mid(y)) 59 | y = torch.sigmoid(self.conv_out(y)) 60 | x = self.x_red(x) 61 | return x * y 62 | else: 63 | return x 64 | 65 | 66 | class ContextBlock2d(nn.Module): 67 | 68 | def __init__(self, inplanes=9, planes=32, pool='att', fusions=['channel_add'], ratio=4): 69 | super(ContextBlock2d, self).__init__() 70 | assert pool in ['avg', 'att'] 71 | assert all([f in ['channel_add', 'channel_mul'] for f in fusions]) 72 | assert len(fusions) > 0, 'at least one fusion should be used' 73 | self.inplanes = inplanes 74 | self.planes = planes 75 | self.pool = pool 76 | self.fusions = fusions 77 | if 'att' in pool: 78 | self.conv_mask = nn.Conv2d(inplanes, 1, kernel_size=1) # context Modeling 79 | self.softmax = nn.Softmax(dim=2) 80 | else: 81 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 82 | if 'channel_add' in fusions: 83 | self.channel_add_conv = nn.Sequential( 84 | nn.Conv2d(self.inplanes, self.planes // ratio, kernel_size=1), 85 | nn.LayerNorm([self.planes // ratio, 1, 1]), 86 | nn.PReLU(), 87 | nn.Conv2d(self.planes // ratio, self.inplanes, kernel_size=1) 88 | ) 89 | else: 90 | self.channel_add_conv = None 91 | if 'channel_mul' in fusions: 92 | self.channel_mul_conv = nn.Sequential( 93 | nn.Conv2d(self.inplanes, self.planes // ratio, kernel_size=1), 94 | nn.LayerNorm([self.planes // ratio, 1, 1]), 95 | nn.PReLU(), 96 | nn.Conv2d(self.planes // ratio, self.inplanes, kernel_size=1) 97 | ) 98 | else: 99 | self.channel_mul_conv = None 100 | 101 | def spatial_pool(self, x): 102 | batch, channel, height, width = x.size() 103 | if self.pool == 'att': 104 | input_x = x 105 | # [N, C, H * W] 106 | input_x = input_x.view(batch, channel, height * width) 107 | # [N, 1, C, H * W] 108 | input_x = input_x.unsqueeze(1) 109 | # [N, 1, H, W] 110 | context_mask = self.conv_mask(x) 111 | # [N, 1, H * W] 112 | context_mask = context_mask.view(batch, 1, height * width) 113 | # [N, 1, H * W] 114 | context_mask = self.softmax(context_mask) 115 | # [N, 1, H * W, 1] 116 | context_mask = context_mask.unsqueeze(3) 117 | # [N, 1, C, 1] 118 | context = torch.matmul(input_x, context_mask) 119 | # [N, C, 1, 1] 120 | context = context.view(batch, channel, 1, 1) 121 | else: 122 | # [N, C, 1, 1] 123 | context = self.avg_pool(x) 124 | 125 | return context 126 | 127 | def forward(self, x): 128 | # [N, C, 1, 1] 129 | context = self.spatial_pool(x) 130 | 131 | if self.channel_mul_conv is not None: 132 | # [N, C, 1, 1] 133 | channel_mul_term = torch.sigmoid(self.channel_mul_conv(context)) 134 | out = x * channel_mul_term 135 | else: 136 | out = x 137 | if self.channel_add_conv is not None: 138 | # [N, C, 1, 1] 139 | channel_add_term = self.channel_add_conv(context) 140 | out = out + channel_add_term 141 | 142 | return out 143 | 144 | 145 | class GCWTResDown(nn.Module): 146 | def __init__(self, in_channels, att_block, norm_layer=nn.BatchNorm2d): 147 | super().__init__() 148 | self.dwt = DWT() 149 | if norm_layer: 150 | self.stem = nn.Sequential(nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=1), 151 | norm_layer(in_channels), 152 | nn.PReLU(), 153 | nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1), 154 | norm_layer(in_channels), 155 | nn.PReLU()) 156 | else: 157 | self.stem = nn.Sequential(nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=1), 158 | nn.PReLU(), 159 | nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1), 160 | nn.PReLU()) 161 | self.conv1x1 = nn.Conv2d(in_channels, in_channels, kernel_size=1, padding=0) 162 | self.conv_down = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, stride=2) 163 | #self.att = att_block(in_channels * 2, in_channels * 2) 164 | 165 | def forward(self, x): 166 | stem = self.stem(x) 167 | xLL, dwt = self.dwt(x) 168 | res = self.conv1x1(xLL) 169 | out = torch.cat([stem, res], dim=1) 170 | #out = self.att(out) 171 | return out, dwt 172 | 173 | 174 | 175 | class GCIWTResUp(nn.Module): 176 | 177 | def __init__(self, in_channels, att_block, norm_layer=None): 178 | super().__init__() 179 | if norm_layer: 180 | self.stem = nn.Sequential( 181 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), 182 | nn.Conv2d(in_channels, in_channels // 2, kernel_size=3, padding=1), 183 | norm_layer(in_channels // 2), 184 | nn.PReLU(), 185 | nn.Conv2d(in_channels // 2, in_channels // 2, kernel_size=3, padding=1), 186 | norm_layer(in_channels // 2), 187 | nn.PReLU(), 188 | ) 189 | else: 190 | self.stem = nn.Sequential( 191 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), 192 | nn.Conv2d(in_channels, in_channels // 2, kernel_size=3, padding=1), 193 | nn.PReLU(), 194 | nn.Conv2d(in_channels // 2, in_channels // 4, kernel_size=3, padding=1), 195 | nn.PReLU(), 196 | ) 197 | 198 | self.pre_conv = nn.Conv2d(in_channels * 2, in_channels * 2, kernel_size=1, padding=0) 199 | self.prelu = nn.PReLU() 200 | self.conv1x1 = nn.Conv2d(in_channels // 2, in_channels // 4, kernel_size=1, padding=0) 201 | #self.att = att_block(in_channels // 2, in_channels // 8) 202 | self.iwt = IWT() 203 | 204 | def forward(self, x, x_dwt): 205 | stem = self.stem(x) 206 | x_dwt = self.prelu(self.pre_conv(x_dwt)) 207 | x_iwt = self.iwt(x_dwt) 208 | x_iwt = self.conv1x1(x_iwt) 209 | out = torch.cat([stem, x_iwt], dim=1) 210 | # out = stem + x_iwt 211 | #out = self.att(out) 212 | return out 213 | 214 | 215 | class GCIWTResUp_1(nn.Module): 216 | 217 | def __init__(self, in_channels, dwt_channels, att_block, norm_layer=None): 218 | super().__init__() 219 | if norm_layer: 220 | self.stem = nn.Sequential( 221 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), 222 | nn.Conv2d(in_channels, in_channels // 2, kernel_size=3, padding=1), 223 | norm_layer(in_channels // 2), 224 | nn.PReLU(), 225 | nn.Conv2d(in_channels // 2, in_channels // 2, kernel_size=3, padding=1), 226 | norm_layer(in_channels // 2), 227 | nn.PReLU(), 228 | ) 229 | else: 230 | self.stem = nn.Sequential( 231 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), 232 | nn.Conv2d(in_channels, in_channels // 2, kernel_size=3, padding=1), 233 | nn.PReLU(), 234 | nn.Conv2d(in_channels // 2, in_channels // 2, kernel_size=3, padding=1), 235 | nn.PReLU(), 236 | ) 237 | 238 | self.pre_conv = nn.Conv2d(dwt_channels, dwt_channels, kernel_size=1, padding=0) 239 | self.prelu = nn.PReLU() 240 | self.conv1x1 = nn.Conv2d(3, 3, kernel_size=1, padding=0) 241 | self.att = att_block(35, 16) 242 | self.iwt = IWT() 243 | 244 | def forward(self, x, x_dwt): 245 | stem = self.stem(x) 246 | x_dwt = self.prelu(self.pre_conv(x_dwt)) 247 | x_iwt = self.iwt(x_dwt) 248 | x_iwt = self.conv1x1(x_iwt) 249 | out = torch.cat((stem, x_iwt), dim=1) 250 | out = self.att(out) 251 | return out 252 | 253 | 254 | class shortcutblock(nn.Module): 255 | def __init__(self, in_channels): 256 | super().__init__() 257 | self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1) 258 | self.conv2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1) 259 | self.se = SE_net(in_channels, in_channels) 260 | self.relu = nn.ReLU() 261 | 262 | def forward(self, x): 263 | return self.se(self.relu(self.conv2(self.relu(self.conv1(x))))) 264 | 265 | 266 | class PSPModule(nn.Module): 267 | def __init__(self, features, out_features=1024, sizes=(1, 2, 3, 6)): 268 | super().__init__() 269 | self.stages = [] 270 | self.stages = nn.ModuleList([self._make_stage(features, size) for size in sizes]) 271 | self.bottleneck = nn.Conv2d(features * (len(sizes) + 1), out_features, kernel_size=1) 272 | self.relu = nn.ReLU() 273 | 274 | def _make_stage(self, features, size): 275 | prior = nn.AdaptiveAvgPool2d(output_size=(size, size)) 276 | conv = nn.Conv2d(features, features, kernel_size=1, bias=False) 277 | return nn.Sequential(prior, conv) 278 | 279 | def forward(self, feats): 280 | h, w = feats.size(2), feats.size(3) 281 | priors = [F.upsample(input=stage(feats), size=(h, w), mode='bilinear') for stage in self.stages] + [feats] 282 | bottle = self.bottleneck(torch.cat(priors, 1)) 283 | return self.relu(bottle) 284 | 285 | 286 | if __name__ == '__main__': 287 | x = torch.randn(1, 64, 448, 448) 288 | net = GCRDB(64, ContextBlock2d) 289 | net2 = GCWTResDown(64, ContextBlock2d) 290 | net3 = GCIWTResUp(128, ContextBlock2d) 291 | y = net(x) 292 | y = net2(y) 293 | y = net3(y) 294 | print(y.shape) 295 | -------------------------------------------------------------------------------- /models/modules_4channel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from models.utils import DWT, IWT 5 | 6 | 7 | class GCRDB(nn.Module): 8 | def __init__(self, in_channels, att_block, num_dense_layer=6, growth_rate=16): 9 | super(GCRDB, self).__init__() 10 | _in_channels = in_channels 11 | modules = [] 12 | for i in range(num_dense_layer): 13 | modules.append(MakeDense(_in_channels, growth_rate)) 14 | _in_channels += growth_rate 15 | self.residual_dense_layers = nn.Sequential(*modules) 16 | self.conv_1x1 = nn.Conv2d(_in_channels, in_channels, kernel_size=1, padding=0) 17 | self.final_att = att_block(inplanes=in_channels, planes=in_channels) 18 | 19 | def forward(self, x): 20 | out_rdb = self.residual_dense_layers(x) 21 | out_rdb = self.conv_1x1(out_rdb) 22 | out_rdb = self.final_att(out_rdb) 23 | out = out_rdb + x 24 | return out 25 | 26 | 27 | class MakeDense(nn.Module): 28 | def __init__(self, in_channels, growth_rate, kernel_size=3): 29 | super(MakeDense, self).__init__() 30 | self.conv = nn.Conv2d(in_channels, growth_rate, kernel_size=kernel_size, padding=(kernel_size - 1) // 2) 31 | self.norm_layer = nn.BatchNorm2d(growth_rate) 32 | 33 | def forward(self, x): 34 | out = F.relu(self.conv(x)) 35 | out = self.norm_layer(out) 36 | out = torch.cat((x, out), 1) 37 | return out 38 | 39 | 40 | class SE_net(nn.Module): 41 | def __init__(self, in_channels, out_channels, reduction=4, attention=True): 42 | super().__init__() 43 | self.attention = attention 44 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 45 | self.conv_in = nn.Conv2d(in_channels, in_channels, kernel_size=1, padding=0) 46 | self.conv_mid = nn.Conv2d(in_channels, in_channels // reduction, kernel_size=1, padding=0) 47 | self.conv_out = nn.Conv2d(in_channels // reduction, out_channels, kernel_size=1, padding=0) 48 | 49 | self.x_red = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0) 50 | 51 | def forward(self, x): 52 | 53 | if self.attention is True: 54 | y = self.avg_pool(x) 55 | y = F.relu(self.conv_in(y)) 56 | y = F.relu(self.conv_mid(y)) 57 | y = torch.sigmoid(self.conv_out(y)) 58 | x = self.x_red(x) 59 | return x * y 60 | else: 61 | return x 62 | 63 | 64 | class ContextBlock2d(nn.Module): 65 | 66 | def __init__(self, inplanes=9, planes=32, pool='att', fusions=['channel_add'], ratio=4): 67 | super(ContextBlock2d, self).__init__() 68 | assert pool in ['avg', 'att'] 69 | assert all([f in ['channel_add', 'channel_mul'] for f in fusions]) 70 | assert len(fusions) > 0, 'at least one fusion should be used' 71 | self.inplanes = inplanes 72 | self.planes = planes 73 | self.pool = pool 74 | self.fusions = fusions 75 | if 'att' in pool: 76 | self.conv_mask = nn.Conv2d(inplanes, 1, kernel_size=1) # context Modeling 77 | self.softmax = nn.Softmax(dim=2) 78 | else: 79 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 80 | if 'channel_add' in fusions: 81 | self.channel_add_conv = nn.Sequential( 82 | nn.Conv2d(self.inplanes, self.planes // ratio, kernel_size=1), 83 | nn.LayerNorm([self.planes // ratio, 1, 1]), 84 | nn.PReLU(), 85 | nn.Conv2d(self.planes // ratio, self.inplanes, kernel_size=1) 86 | ) 87 | else: 88 | self.channel_add_conv = None 89 | if 'channel_mul' in fusions: 90 | self.channel_mul_conv = nn.Sequential( 91 | nn.Conv2d(self.inplanes, self.planes // ratio, kernel_size=1), 92 | nn.LayerNorm([self.planes // ratio, 1, 1]), 93 | nn.PReLU(), 94 | nn.Conv2d(self.planes // ratio, self.inplanes, kernel_size=1) 95 | ) 96 | else: 97 | self.channel_mul_conv = None 98 | 99 | def spatial_pool(self, x): 100 | batch, channel, height, width = x.size() 101 | if self.pool == 'att': 102 | input_x = x 103 | # [N, C, H * W] 104 | input_x = input_x.view(batch, channel, height * width) 105 | # [N, 1, C, H * W] 106 | input_x = input_x.unsqueeze(1) 107 | # [N, 1, H, W] 108 | context_mask = self.conv_mask(x) 109 | # [N, 1, H * W] 110 | context_mask = context_mask.view(batch, 1, height * width) 111 | # [N, 1, H * W] 112 | context_mask = self.softmax(context_mask) 113 | # [N, 1, H * W, 1] 114 | context_mask = context_mask.unsqueeze(3) 115 | # [N, 1, C, 1] 116 | context = torch.matmul(input_x, context_mask) 117 | # [N, C, 1, 1] 118 | context = context.view(batch, channel, 1, 1) 119 | else: 120 | # [N, C, 1, 1] 121 | context = self.avg_pool(x) 122 | 123 | return context 124 | 125 | def forward(self, x): 126 | # [N, C, 1, 1] 127 | context = self.spatial_pool(x) 128 | 129 | if self.channel_mul_conv is not None: 130 | # [N, C, 1, 1] 131 | channel_mul_term = torch.sigmoid(self.channel_mul_conv(context)) 132 | out = x * channel_mul_term 133 | else: 134 | out = x 135 | if self.channel_add_conv is not None: 136 | # [N, C, 1, 1] 137 | channel_add_term = self.channel_add_conv(context) 138 | out = out + channel_add_term 139 | 140 | return out 141 | 142 | 143 | class GCWTResDown(nn.Module): 144 | def __init__(self, in_channels, att_block, norm_layer=nn.BatchNorm2d): 145 | super().__init__() 146 | self.dwt = DWT() 147 | if norm_layer: 148 | self.stem = nn.Sequential(nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=1), 149 | norm_layer(in_channels), 150 | nn.PReLU(), 151 | nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1), 152 | norm_layer(in_channels), 153 | nn.PReLU()) 154 | else: 155 | self.stem = nn.Sequential(nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=1), 156 | nn.PReLU(), 157 | nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1), 158 | nn.PReLU()) 159 | self.conv1x1 = nn.Conv2d(in_channels, in_channels, kernel_size=1, padding=0) 160 | self.conv_down = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, stride=2) 161 | #self.att = att_block(in_channels * 2, in_channels * 2) 162 | 163 | def forward(self, x): 164 | stem = self.stem(x) 165 | xLL, dwt = self.dwt(x) 166 | res = self.conv1x1(xLL) 167 | out = torch.cat([stem, res], dim=1) 168 | #out = self.att(out) 169 | return out, dwt 170 | 171 | 172 | class GCIWTResUp(nn.Module): 173 | 174 | def __init__(self, in_channels, att_block, norm_layer=None): 175 | super().__init__() 176 | if norm_layer: 177 | self.stem = nn.Sequential( 178 | nn.PixelShuffle(2), 179 | nn.Conv2d(in_channels // 4, in_channels // 4, kernel_size=3, padding=1), 180 | norm_layer(in_channels // 4), 181 | nn.PReLU(), 182 | nn.Conv2d(in_channels // 4, in_channels // 4, kernel_size=3, padding=1), 183 | norm_layer(in_channels // 4), 184 | nn.PReLU(), 185 | ) 186 | else: 187 | self.stem = nn.Sequential( 188 | nn.PixelShuffle(2), 189 | nn.Conv2d(in_channels // 4, in_channels // 4, kernel_size=3, padding=1), 190 | nn.PReLU(), 191 | nn.Conv2d(in_channels // 4, in_channels // 4, kernel_size=3, padding=1), 192 | nn.PReLU(), 193 | ) 194 | self.pre_conv_stem = nn.Conv2d(in_channels // 2, in_channels, kernel_size=1, padding=0) 195 | self.pre_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1, padding=0) 196 | # self.prelu = nn.PReLU() 197 | self.post_conv = nn.Conv2d(in_channels // 4, in_channels // 4, kernel_size=1, padding=0) 198 | self.iwt = IWT() 199 | self.last_conv = nn.Conv2d(in_channels // 2, in_channels // 4, kernel_size=1, padding=0) 200 | # self.se = SE_net(in_channels // 2, in_channels // 4) 201 | 202 | def forward(self, x, x_dwt): 203 | x = self.pre_conv_stem(x) 204 | stem = self.stem(x) 205 | x_dwt = self.pre_conv(x_dwt) 206 | x_iwt = self.iwt(x_dwt) 207 | x_iwt = self.post_conv(x_iwt) 208 | out = torch.cat((stem, x_iwt), dim=1) 209 | out = self.last_conv(out) 210 | return out 211 | 212 | 213 | class shortcutblock(nn.Module): 214 | def __init__(self, in_channels, out_channels): 215 | super().__init__() 216 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) 217 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1) 218 | self.se = SE_net(out_channels, out_channels) 219 | self.relu = nn.ReLU() 220 | 221 | def forward(self, x): 222 | return self.se(self.relu(self.conv2(self.relu(self.conv1(x))))) 223 | 224 | 225 | class PSPModule(nn.Module): 226 | def __init__(self, features, out_features=1024, sizes=(1, 2, 3, 6)): 227 | super().__init__() 228 | self.stages = [] 229 | self.stages = nn.ModuleList([self._make_stage(features, size) for size in sizes]) 230 | self.bottleneck = nn.Conv2d(features * (len(sizes) + 1), out_features, kernel_size=1) 231 | self.relu = nn.ReLU() 232 | 233 | def _make_stage(self, features, size): 234 | prior = nn.AdaptiveAvgPool2d(output_size=(size, size)) 235 | conv = nn.Conv2d(features, features, kernel_size=1, bias=False) 236 | return nn.Sequential(prior, conv) 237 | 238 | def forward(self, feats): 239 | h, w = feats.size(2), feats.size(3) 240 | priors = [F.upsample(input=stage(feats), size=(h, w), mode='bilinear') for stage in self.stages] + [feats] 241 | bottle = self.bottleneck(torch.cat(priors, 1)) 242 | return self.relu(bottle) 243 | 244 | 245 | class last_upsample(nn.Module): 246 | def __init__(self): 247 | super().__init__() 248 | self.up = nn.PixelShuffle(2) 249 | self.pre_enhance = nn.Conv2d(16, 16, kernel_size=3, padding=1) 250 | self.enhance = PSPModule(16, 16) 251 | self.post_conv = nn.Conv2d(32, 32, kernel_size=3, padding=1) 252 | self.se = SE_net(32, 32) 253 | self.final = nn.Conv2d(32, 3, kernel_size=3, padding=1) 254 | 255 | def forward(self, x): 256 | x = self.up(x) 257 | x = self.pre_enhance(x) 258 | enhanced = self.enhance(x) 259 | x = torch.cat((enhanced, x), dim=1) 260 | out = self.se(self.post_conv(x)) 261 | out = self.final(out) 262 | return F.sigmoid(out) 263 | 264 | 265 | if __name__ == '__main__': 266 | x = torch.randn(1, 256, 112, 112) 267 | net2 = GCWTResDown(256, ContextBlock2d) 268 | net3 = GCIWTResUp(1024, ContextBlock2d) 269 | y = net2(x) 270 | print(y[0].shape, y[1].shape) 271 | y = net3(y[0], y[1]) 272 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def dwt_init(x): 6 | 7 | x01 = x[:, :, 0::2, :] / 2 8 | x02 = x[:, :, 1::2, :] / 2 9 | x1 = x01[:, :, :, 0::2] 10 | x2 = x02[:, :, :, 0::2] 11 | x3 = x01[:, :, :, 1::2] 12 | x4 = x02[:, :, :, 1::2] 13 | x_LL = x1 + x2 + x3 + x4 14 | x_HL = -x1 - x2 + x3 + x4 15 | x_LH = -x1 + x2 - x3 + x4 16 | x_HH = x1 - x2 - x3 + x4 17 | 18 | return x_LL, torch.cat((x_LL, x_HL, x_LH, x_HH), 1) 19 | 20 | def iwt_init(x): 21 | r = 2 22 | in_batch, in_channel, in_height, in_width = x.size() 23 | out_batch, out_channel, out_height, out_width = in_batch, int( 24 | in_channel / (r**2)), r * in_height, r * in_width 25 | x1 = x[:, 0:out_channel, :, :] / 2 26 | x2 = x[:, out_channel:out_channel * 2, :, :] / 2 27 | x3 = x[:, out_channel * 2:out_channel * 3, :, :] / 2 28 | x4 = x[:, out_channel * 3:out_channel * 4, :, :] / 2 29 | h = torch.zeros([out_batch, out_channel, out_height, out_width]).float().to(x.device) 30 | 31 | h[:, :, 0::2, 0::2] = x1 - x2 - x3 + x4 32 | h[:, :, 1::2, 0::2] = x1 - x2 + x3 - x4 33 | h[:, :, 0::2, 1::2] = x1 + x2 - x3 - x4 34 | h[:, :, 1::2, 1::2] = x1 + x2 + x3 + x4 35 | 36 | return h 37 | 38 | class DWT(nn.Module): 39 | def __init__(self): 40 | super(DWT, self).__init__() 41 | self.requires_grad = False 42 | 43 | def forward(self, x): 44 | return dwt_init(x) 45 | 46 | class IWT(nn.Module): 47 | def __init__(self): 48 | super(IWT, self).__init__() 49 | self.requires_grad = False 50 | 51 | def forward(self, x): 52 | return iwt_init(x) -------------------------------------------------------------------------------- /proc_img.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import os 4 | 5 | def add_boarder(img, width=3): 6 | img = cv2.copyMakeBorder(img, width, width, width, width, cv2.BORDER_CONSTANT, value=(0,255,0)) 7 | return img 8 | 9 | def crop(img, point, dist, scale=2, boarder_width=3): 10 | # 11 | cropped = img[point[0]:point[0]+dist, point[1]:point[1]+dist, :] 12 | rescaled = rescale(cropped, scale) 13 | cropped_boarder = add_boarder(cropped, boarder_width) 14 | cropped_rescale_boarder = add_boarder(rescaled, boarder_width) 15 | print('####', cropped_rescale_boarder.shape) 16 | return cropped_boarder, cropped_rescale_boarder 17 | 18 | def rescale(img ,scale): 19 | return cv2.resize(img, (int(img.shape[1]*scale), int(img.shape[0]*scale))) 20 | 21 | def add_back(img, 22 | cropped_boarder, 23 | cropped_rescale_boarder, 24 | point, 25 | crop_size, 26 | boarder_width, 27 | scale 28 | ): 29 | '''Added cropped area''' 30 | img[point[0]-boarder_width:point[0]+crop_size+boarder_width, point[1]-boarder_width:point[1]+crop_size+boarder_width, :] = cropped_boarder 31 | '''Added rescale area to corner''' 32 | img[img.shape[0]-int(crop_size*scale)-boarder_width*2:img.shape[0],img.shape[1]-int(crop_size*scale)-boarder_width*2:img.shape[1],:] = cropped_rescale_boarder 33 | return img 34 | 35 | def ddn_real_2(): 36 | boarder_width = 3 37 | crop_size = 200 38 | scale = 2 39 | s_img = cv2.imread('./experiment_images_on_latex/experiments/ddn_real_png/2/2_rainy.jpg') 40 | img_shape = s_img.shape 41 | print(img_shape) 42 | folder = './experiment_images_on_latex/experiments/ddn_real_png/2/' 43 | save_folder = './experiment_images_on_latex/experiments/ddn_real_png/2/refine' 44 | if not os.path.exists(save_folder): 45 | os.makedirs(save_folder) 46 | filenames = os.listdir(folder) 47 | for filename in filenames: 48 | if filename.endswith('.jpg') or filename.endswith('.png'): 49 | img_name = os.path.join(folder, filename) 50 | print(img_name) 51 | img = cv2.imread(img_name) 52 | img = cv2.resize(img, (img_shape[1], img_shape[0])) 53 | point = (img.shape[0]//11, img.shape[1]//3) #h, w 54 | cropped_boarder, cropped_rescale_boarder = crop(img, point, crop_size, scale=scale) 55 | result = add_back(img, cropped_boarder, cropped_rescale_boarder, point, crop_size, boarder_width, scale) 56 | save_path = os.path.join(save_folder, filename) 57 | cv2.imwrite(save_path, result) 58 | 59 | def ddn_real_19(): 60 | boarder_width = 3 61 | crop_size = 100 62 | scale = 1.5 63 | s_img = cv2.imread('./experiment_images_on_latex/experiments/ddn_real_png/19/19_rainy.jpg') 64 | img_shape = s_img.shape 65 | print(img_shape) 66 | folder = './experiment_images_on_latex/experiments/ddn_real_png/19/' 67 | save_folder = './experiment_images_on_latex/experiments/ddn_real_png/19/refine' 68 | if not os.path.exists(save_folder): 69 | os.makedirs(save_folder) 70 | filenames = os.listdir(folder) 71 | for filename in filenames: 72 | if filename.endswith('.jpg') or filename.endswith('.png'): 73 | img_name = os.path.join(folder, filename) 74 | print(img_name) 75 | img = cv2.imread(img_name) 76 | img = cv2.resize(img, (img_shape[1], img_shape[0])) 77 | point = (10, img.shape[1]//4*3) #h, w 78 | cropped_boarder, cropped_rescale_boarder = crop(img, point, crop_size, scale=scale) 79 | result = add_back(img, cropped_boarder, cropped_rescale_boarder, point, crop_size, boarder_width, scale) 80 | save_path = os.path.join(save_folder, filename) 81 | cv2.imwrite(save_path, result) 82 | 83 | def our_real_2(): 84 | boarder_width = 3 85 | crop_size = 200 86 | scale = 2 87 | s_img = cv2.imread('./experiment_images_on_latex/experiments/ours_real_png/2/2_rainy.jpg') 88 | img_shape = s_img.shape 89 | print(img_shape) 90 | folder = './experiment_images_on_latex/experiments/ours_real_png/2/' 91 | save_folder = './experiment_images_on_latex/experiments/ours_real_png/2/refine' 92 | if not os.path.exists(save_folder): 93 | os.makedirs(save_folder) 94 | filenames = os.listdir(folder) 95 | for filename in filenames: 96 | if filename.endswith('.jpg') or filename.endswith('.png'): 97 | img_name = os.path.join(folder, filename) 98 | print(img_name) 99 | img = cv2.imread(img_name) 100 | img = cv2.resize(img, (img_shape[1], img_shape[0])) 101 | point = (4, 4) #h, w 102 | cropped_boarder, cropped_rescale_boarder = crop(img, point, crop_size, scale=scale) 103 | result = add_back(img, cropped_boarder, cropped_rescale_boarder, point, crop_size, boarder_width, scale) 104 | save_path = os.path.join(save_folder, filename) 105 | cv2.imwrite(save_path, result) 106 | 107 | def our_real_1218(): 108 | boarder_width = 3 109 | crop_size = 200 110 | scale = 2 111 | s_img = cv2.imread('./experiment_images_on_latex/experiments/ours_real_png/rain_01218/rain_01218_rainy.jpg') 112 | img_shape = s_img.shape 113 | print(img_shape) 114 | folder = './experiment_images_on_latex/experiments/ours_real_png/rain_01218/' 115 | save_folder = './experiment_images_on_latex/experiments/ours_real_png/rain_01218/refine' 116 | if not os.path.exists(save_folder): 117 | os.makedirs(save_folder) 118 | filenames = os.listdir(folder) 119 | for filename in filenames: 120 | if filename.endswith('.jpg') or filename.endswith('.png'): 121 | img_name = os.path.join(folder, filename) 122 | print(img_name) 123 | img = cv2.imread(img_name) 124 | img = cv2.resize(img, (img_shape[1], img_shape[0])) 125 | point = (img.shape[0]//4, img.shape[1]//4) #h, w 126 | cropped_boarder, cropped_rescale_boarder = crop(img, point, crop_size, scale=scale) 127 | result = add_back(img, cropped_boarder, cropped_rescale_boarder, point, crop_size, boarder_width, scale) 128 | save_path = os.path.join(save_folder, filename) 129 | cv2.imwrite(save_path, result) 130 | 131 | def ddn_901(): 132 | boarder_width = 3 133 | crop_size = 100 134 | scale = 2 135 | s_img = cv2.imread('./experiment_images_on_latex/experiments/ddn_test/901/901.jpg') 136 | img_shape = s_img.shape 137 | print(img_shape) 138 | folder = './experiment_images_on_latex/experiments/ddn_test/901/' 139 | save_folder = './experiment_images_on_latex/experiments/ddn_test/901/refine' 140 | if not os.path.exists(save_folder): 141 | os.makedirs(save_folder) 142 | filenames = os.listdir(folder) 143 | for filename in filenames: 144 | if filename.endswith('.jpg') or filename.endswith('.png'): 145 | img_name = os.path.join(folder, filename) 146 | print(img_name) 147 | img = cv2.imread(img_name) 148 | img = cv2.resize(img, (img_shape[1], img_shape[0])) 149 | point = (img.shape[0]//2, img.shape[1]//4) #h, w 150 | cropped_boarder, cropped_rescale_boarder = crop(img, point, crop_size, scale=scale) 151 | result = add_back(img, cropped_boarder, cropped_rescale_boarder, point, crop_size, boarder_width, scale) 152 | save_path = os.path.join(save_folder, filename) 153 | cv2.imwrite(save_path, result) 154 | 155 | def ddn_905(): 156 | boarder_width = 3 157 | crop_size = 100 158 | scale = 2 159 | s_img = cv2.imread('./experiment_images_on_latex/experiments/ddn_test/905/905.jpg') 160 | img_shape = s_img.shape 161 | print(img_shape) 162 | folder = './experiment_images_on_latex/experiments/ddn_test/905/' 163 | save_folder = './experiment_images_on_latex/experiments/ddn_test/905/refine' 164 | if not os.path.exists(save_folder): 165 | os.makedirs(save_folder) 166 | filenames = os.listdir(folder) 167 | for filename in filenames: 168 | if filename.endswith('.jpg') or filename.endswith('.png'): 169 | img_name = os.path.join(folder, filename) 170 | print(img_name) 171 | img = cv2.imread(img_name) 172 | img = cv2.resize(img, (img_shape[1], img_shape[0])) 173 | point = (img.shape[0]//3, 3) #h, w 174 | cropped_boarder, cropped_rescale_boarder = crop(img, point, crop_size, scale=scale) 175 | result = add_back(img, cropped_boarder, cropped_rescale_boarder, point, crop_size, boarder_width, scale) 176 | save_path = os.path.join(save_folder, filename) 177 | cv2.imwrite(save_path, result) 178 | 179 | 180 | def ours_563(): 181 | boarder_width = 3 182 | crop_size = 100 183 | scale = 2 184 | s_img = cv2.imread('./experiment_images_on_latex/experiments/ours_test/0563/0563.jpg') 185 | img_shape = s_img.shape 186 | print(img_shape) 187 | folder = './experiment_images_on_latex/experiments/ours_test/0563/' 188 | save_folder = './experiment_images_on_latex/experiments/ours_test/0563/refine' 189 | if not os.path.exists(save_folder): 190 | os.makedirs(save_folder) 191 | filenames = os.listdir(folder) 192 | for filename in filenames: 193 | if filename.endswith('.jpg') or filename.endswith('.png'): 194 | img_name = os.path.join(folder, filename) 195 | print(img_name) 196 | img = cv2.imread(img_name) 197 | img = cv2.resize(img, (img_shape[1], img_shape[0])) 198 | point = (200, 100) #h, w 199 | cropped_boarder, cropped_rescale_boarder = crop(img, point, crop_size, scale=scale) 200 | result = add_back(img, cropped_boarder, cropped_rescale_boarder, point, crop_size, boarder_width, scale) 201 | save_path = os.path.join(save_folder, filename) 202 | cv2.imwrite(save_path, result) 203 | 204 | 205 | def ours_2800(): 206 | boarder_width = 3 207 | crop_size = 100 208 | scale = 2 209 | s_img = cv2.imread('./experiment_images_on_latex/experiments/ours_test/2800/2800.jpg') 210 | img_shape = s_img.shape 211 | print(img_shape) 212 | folder = './experiment_images_on_latex/experiments/ours_test/2800/' 213 | save_folder = './experiment_images_on_latex/experiments/ours_test/2800/refine' 214 | if not os.path.exists(save_folder): 215 | os.makedirs(save_folder) 216 | filenames = os.listdir(folder) 217 | for filename in filenames: 218 | if filename.endswith('.jpg') or filename.endswith('.png'): 219 | img_name = os.path.join(folder, filename) 220 | print(img_name) 221 | img = cv2.imread(img_name) 222 | img = cv2.resize(img, (img_shape[1], img_shape[0])) 223 | point = (350, 370) #h, w 224 | cropped_boarder, cropped_rescale_boarder = crop(img, point, crop_size, scale=scale) 225 | result = add_back(img, cropped_boarder, cropped_rescale_boarder, point, crop_size, boarder_width, scale) 226 | save_path = os.path.join(save_folder, filename) 227 | cv2.imwrite(save_path, result) 228 | 229 | if __name__ == '__main__': 230 | #ddn_real_2() 231 | #ddn_real_19() 232 | #our_real_2() 233 | #our_real_1218() 234 | #ddn_901() 235 | #ddn_905() 236 | #ours_563() 237 | ours_2800() 238 | -------------------------------------------------------------------------------- /pytorch_ssim/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from math import exp 4 | import numpy as np 5 | 6 | 7 | def gaussian(window_size, sigma): 8 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 9 | return gauss/gauss.sum() 10 | 11 | 12 | def create_window(window_size, channel=1): 13 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 14 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 15 | window = _2D_window.expand(channel, 1, window_size, window_size).contiguous() 16 | return window 17 | 18 | 19 | def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None): 20 | # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh). 21 | if val_range is None: 22 | if torch.max(img1) > 128: 23 | max_val = 255 24 | else: 25 | max_val = 1 26 | 27 | if torch.min(img1) < -0.5: 28 | min_val = -1 29 | else: 30 | min_val = 0 31 | L = max_val - min_val 32 | else: 33 | L = val_range 34 | 35 | padd = 0 36 | (_, channel, height, width) = img1.size() 37 | if window is None: 38 | real_size = min(window_size, height, width) 39 | window = create_window(real_size, channel=channel).to(img1.device) 40 | 41 | mu1 = F.conv2d(img1, window, padding=padd, groups=channel) 42 | mu2 = F.conv2d(img2, window, padding=padd, groups=channel) 43 | 44 | mu1_sq = mu1.pow(2) 45 | mu2_sq = mu2.pow(2) 46 | mu1_mu2 = mu1 * mu2 47 | 48 | sigma1_sq = F.conv2d(img1 * img1, window, padding=padd, groups=channel) - mu1_sq 49 | sigma2_sq = F.conv2d(img2 * img2, window, padding=padd, groups=channel) - mu2_sq 50 | sigma12 = F.conv2d(img1 * img2, window, padding=padd, groups=channel) - mu1_mu2 51 | 52 | C1 = (0.01 * L) ** 2 53 | C2 = (0.03 * L) ** 2 54 | 55 | v1 = 2.0 * sigma12 + C2 56 | v2 = sigma1_sq + sigma2_sq + C2 57 | cs = torch.mean(v1 / v2) # contrast sensitivity 58 | 59 | ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2) 60 | 61 | if size_average: 62 | ret = ssim_map.mean() 63 | else: 64 | ret = ssim_map.mean(1).mean(1).mean(1) 65 | 66 | if full: 67 | return ret, cs 68 | return ret 69 | 70 | 71 | def msssim(img1, img2, window_size=11, size_average=True, val_range=None, normalize=False): 72 | device = img1.device 73 | weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device) 74 | levels = weights.size()[0] 75 | mssim = [] 76 | mcs = [] 77 | for _ in range(levels): 78 | sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range) 79 | mssim.append(sim) 80 | mcs.append(cs) 81 | 82 | img1 = F.avg_pool2d(img1, (2, 2)) 83 | img2 = F.avg_pool2d(img2, (2, 2)) 84 | 85 | mssim = torch.stack(mssim) 86 | mcs = torch.stack(mcs) 87 | 88 | # Normalize (to avoid NaNs during training unstable models, not compliant with original definition) 89 | if normalize: 90 | mssim = (mssim + 1) / 2 91 | mcs = (mcs + 1) / 2 92 | 93 | pow1 = mcs ** weights 94 | pow2 = mssim ** weights 95 | # From Matlab implementation https://ece.uwaterloo.ca/~z70wang/research/iwssim/ 96 | output = torch.prod(pow1[:-1] * pow2[-1]) 97 | return output 98 | 99 | 100 | # Classes to re-use window 101 | class SSIM(torch.nn.Module): 102 | def __init__(self, window_size=11, size_average=True, val_range=None): 103 | super(SSIM, self).__init__() 104 | self.window_size = window_size 105 | self.size_average = size_average 106 | self.val_range = val_range 107 | 108 | # Assume 1 channel for SSIM 109 | self.channel = 1 110 | self.window = create_window(window_size) 111 | 112 | def forward(self, img1, img2): 113 | (_, channel, _, _) = img1.size() 114 | 115 | if channel == self.channel and self.window.dtype == img1.dtype: 116 | window = self.window 117 | else: 118 | window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype) 119 | self.window = window 120 | self.channel = channel 121 | 122 | return ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average) 123 | 124 | class MSSSIM(torch.nn.Module): 125 | def __init__(self, window_size=11, size_average=True, channel=3): 126 | super(MSSSIM, self).__init__() 127 | self.window_size = window_size 128 | self.size_average = size_average 129 | self.channel = channel 130 | 131 | def forward(self, img1, img2): 132 | # TODO: store window between calls if possible 133 | return msssim(img1, img2, window_size=self.window_size, size_average=self.size_average) 134 | -------------------------------------------------------------------------------- /pytorch_ssim/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Charlie0215/AWNet-Attentive-Wavelet-Network-for-Image-ISP/a1088bcef3a98fa6ba9e8febb2a40dd9b5bf3e20/pytorch_ssim/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /train_3channel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.utils.data import DataLoader 4 | import time 5 | import numpy as np 6 | from models.model_3channel import AWNet 7 | from loss import ms_Loss 8 | from dataloader_3channel import LoadData 9 | from config import trainConfig 10 | from utils import validation, print_log, to_psnr, adjust_learning_rate_step 11 | 12 | np.random.seed(0) 13 | torch.manual_seed(0) 14 | 15 | # Dataset size 16 | TRAIN_SIZE = 46839 17 | TEST_SIZE = 1204 18 | torch.backends.cudnn.benchmark = False 19 | torch.backends.cudnn.deterministic = True 20 | 21 | 22 | def train(): 23 | device_ids = [0] 24 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 25 | print("CUDA visible devices: " + str(torch.cuda.device_count())) 26 | print("CUDA Device Name: " + str(torch.cuda.get_device_name(device))) 27 | 28 | # Initialize loss and model 29 | loss = ms_Loss().to(device) 30 | net = AWNet(3, 3, block=[3, 3, 3, 4, 4]).to(device) 31 | net = nn.DataParallel(net, device_ids=device_ids) 32 | new_lr = trainConfig.learning_rate[0] 33 | 34 | # Reload 35 | if trainConfig.pretrain == True: 36 | net.load_state_dict( 37 | torch.load( 38 | '{}/best_3channel.pkl'.format(trainConfig.save_best), 39 | map_location=device)["model_state"]) 40 | print('weight loaded.') 41 | else: 42 | print('no weight loaded.') 43 | pytorch_total_params = sum( 44 | p.numel() for p in net.parameters() if p.requires_grad) 45 | print("Total_params: {}".format(pytorch_total_params)) 46 | 47 | # optimizer and scheduler 48 | optimizer = torch.optim.Adam( 49 | net.parameters(), lr=new_lr, betas=(0.9, 0.999)) 50 | 51 | # Dataloaders 52 | train_dataset = LoadData( 53 | trainConfig.data_dir, TRAIN_SIZE, dslr_scale=224, test=False) 54 | train_loader = DataLoader( 55 | dataset=train_dataset, 56 | batch_size=trainConfig.batch_size, 57 | shuffle=True, 58 | num_workers=32, 59 | pin_memory=True, 60 | drop_last=True) 61 | 62 | test_dataset = LoadData( 63 | trainConfig.data_dir, TEST_SIZE, dslr_scale=224, test=True) 64 | test_loader = DataLoader( 65 | dataset=test_dataset, 66 | batch_size=8, 67 | shuffle=False, 68 | num_workers=18, 69 | pin_memory=True, 70 | drop_last=False) 71 | 72 | print('Train loader length: {}'.format(len(train_loader))) 73 | 74 | pre_psnr, pre_ssim = validation(net, test_loader, device, save_tag=True) 75 | print('previous PSNR: {:.4f}, previous ssim: {:.4f}'.format( 76 | pre_psnr, pre_ssim)) 77 | iteration = 0 78 | for epoch in range(trainConfig.epoch): 79 | psnr_list = [] 80 | start_time = time.time() 81 | if epoch > 0: 82 | new_lr = adjust_learning_rate_step( 83 | optimizer, epoch, trainConfig.epoch, trainConfig.learning_rate) 84 | for batch_id, data in enumerate(train_loader): 85 | x, target, _ = data 86 | x = x.to(device) 87 | target = target.to(device) 88 | pred, _ = net(x) 89 | 90 | optimizer.zero_grad() 91 | 92 | total_loss, losses = loss(pred, target) 93 | total_loss.backward() 94 | optimizer.step() 95 | 96 | iteration += 1 97 | if trainConfig.print_loss: 98 | print("epoch:{}/{} | Loss: {:.4f} ".format( 99 | epoch, trainConfig.epoch, total_loss.item())) 100 | if not (batch_id % 1000): 101 | print('Epoch:{0}, Iteration:{1}'.format(epoch, batch_id)) 102 | 103 | psnr_list.extend(to_psnr(pred[0], target)) 104 | 105 | train_psnr = sum(psnr_list) / len(psnr_list) 106 | state = { 107 | "model_state": net.state_dict(), 108 | "lr": new_lr, 109 | } 110 | print('saved checkpoint') 111 | torch.save(state, '{}/three_channel_epoch_{}.pkl'.format( 112 | trainConfig.checkpoints, epoch)) 113 | 114 | one_epoch_time = time.time() - start_time 115 | print('time: {}, train psnr: {}'.format(one_epoch_time, train_psnr)) 116 | val_psnr, val_ssim = validation( 117 | net, test_loader, device, save_tag=True) 118 | print_log(epoch + 1, trainConfig.epoch, one_epoch_time, train_psnr, 119 | val_psnr, val_ssim, 'multi_loss') 120 | 121 | if val_psnr >= pre_psnr: 122 | state = { 123 | "model_state": net.state_dict(), 124 | "lr": new_lr, 125 | } 126 | 127 | print('saved best weight') 128 | torch.save(state, '{}/best_3channel.pkl'.format( 129 | trainConfig.save_best)) 130 | pre_psnr = val_psnr 131 | 132 | 133 | if __name__ == '__main__': 134 | train() 135 | -------------------------------------------------------------------------------- /train_4channel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.utils.data import DataLoader 4 | import time 5 | import numpy as np 6 | from models.model_4channel import AWNet 7 | from loss import ms_Loss 8 | from dataloader_4channel import LoadData 9 | from config import trainConfig 10 | from utils import validation, print_log, to_psnr, adjust_learning_rate_step 11 | 12 | np.random.seed(0) 13 | torch.manual_seed(0) 14 | 15 | # Dataset size 16 | TRAIN_SIZE = 46839 17 | TEST_SIZE = 1204 18 | torch.backends.cudnn.benchmark = False 19 | torch.backends.cudnn.deterministic = True 20 | 21 | 22 | def train(): 23 | device_ids = [0] 24 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 25 | print("CUDA visible devices: " + str(torch.cuda.device_count())) 26 | print("CUDA Device Name: " + str(torch.cuda.get_device_name(device))) 27 | 28 | # Initialize loss and model 29 | loss = ms_Loss().to(device) 30 | net = AWNet(4, 3, block=[3, 3, 3, 4, 4]).to(device) 31 | net = nn.DataParallel(net, device_ids=device_ids) 32 | new_lr = trainConfig.learning_rate[0] 33 | 34 | # Reload 35 | if trainConfig.pretrain == True: 36 | net.load_state_dict( 37 | torch.load( 38 | '{}/best_4channel.pkl'.format(trainConfig.save_best), 39 | map_location=device)["model_state"]) 40 | print('weight loaded.') 41 | else: 42 | print('no weight loaded.') 43 | pytorch_total_params = sum( 44 | p.numel() for p in net.parameters() if p.requires_grad) 45 | print("Total_params: {}".format(pytorch_total_params)) 46 | 47 | # optimizer and scheduler 48 | optimizer = torch.optim.Adam( 49 | net.parameters(), lr=new_lr, betas=(0.9, 0.999)) 50 | 51 | # Dataloaders 52 | train_dataset = LoadData( 53 | trainConfig.data_dir, TRAIN_SIZE, dslr_scale=1, test=False) 54 | train_loader = DataLoader( 55 | dataset=train_dataset, 56 | batch_size=trainConfig.batch_size, 57 | shuffle=True, 58 | num_workers=32, 59 | pin_memory=True, 60 | drop_last=True) 61 | 62 | test_dataset = LoadData( 63 | trainConfig.data_dir, TEST_SIZE, dslr_scale=1, test=True) 64 | test_loader = DataLoader( 65 | dataset=test_dataset, 66 | batch_size=8, 67 | shuffle=False, 68 | num_workers=18, 69 | pin_memory=True, 70 | drop_last=False) 71 | 72 | print('Train loader length: {}'.format(len(train_loader))) 73 | 74 | pre_psnr, pre_ssim = validation(net, test_loader, device, save_tag=True) 75 | print('previous PSNR: {:.4f}, previous ssim: {:.4f}'.format( 76 | pre_psnr, pre_ssim)) 77 | iteration = 0 78 | for epoch in range(trainConfig.epoch): 79 | psnr_list = [] 80 | start_time = time.time() 81 | if epoch > 0: 82 | new_lr = adjust_learning_rate_step( 83 | optimizer, epoch, trainConfig.epoch, trainConfig.learning_rate) 84 | for batch_id, data in enumerate(train_loader): 85 | x, target, _ = data 86 | x = x.to(device) 87 | target = target.to(device) 88 | pred, _ = net(x) 89 | 90 | optimizer.zero_grad() 91 | 92 | total_loss, losses = loss(pred, target) 93 | total_loss.backward() 94 | optimizer.step() 95 | 96 | iteration += 1 97 | if trainConfig.print_loss: 98 | print("epoch:{}/{} | Loss: {:.4f} ".format( 99 | epoch, trainConfig.epoch, total_loss.item())) 100 | if not (batch_id % 1000): 101 | print('Epoch:{0}, Iteration:{1}'.format(epoch, batch_id)) 102 | 103 | psnr_list.extend(to_psnr(pred[0], target)) 104 | 105 | train_psnr = sum(psnr_list) / len(psnr_list) 106 | state = { 107 | "model_state": net.state_dict(), 108 | "lr": new_lr, 109 | } 110 | print('saved checkpoint') 111 | torch.save(state, '{}/four_channel_epoch_{}.pkl'.format( 112 | trainConfig.checkpoints, epoch)) 113 | 114 | one_epoch_time = time.time() - start_time 115 | print('time: {}, train psnr: {}'.format(one_epoch_time, train_psnr)) 116 | val_psnr, val_ssim = validation( 117 | net, test_loader, device, save_tag=True) 118 | print_log(epoch + 1, trainConfig.epoch, one_epoch_time, train_psnr, 119 | val_psnr, val_ssim, 'multi_loss') 120 | 121 | if val_psnr >= pre_psnr: 122 | state = { 123 | "model_state": net.state_dict(), 124 | "lr": new_lr, 125 | } 126 | 127 | print('saved best weight') 128 | torch.save(state, '{}/best_4channel.pkl'.format( 129 | trainConfig.save_best)) 130 | pre_psnr = val_psnr 131 | 132 | 133 | if __name__ == '__main__': 134 | train() 135 | -------------------------------------------------------------------------------- /training_log/multi_loss_log.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Charlie0215/AWNet-Attentive-Wavelet-Network-for-Image-ISP/a1088bcef3a98fa6ba9e8febb2a40dd9b5bf3e20/training_log/multi_loss_log.txt -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | from datetime import datetime 3 | from skimage import measure 4 | import os 5 | import cv2 6 | import numpy as np 7 | from PIL import Image 8 | from math import log10 9 | import torch 10 | import torch.nn.functional as F 11 | import torchvision.utils as utils 12 | import torch.nn as nn 13 | from torchvision.transforms import Compose, RandomCrop, ToTensor, ToPILImage, CenterCrop, Resize 14 | import torchvision.utils as vutils 15 | 16 | 17 | def display_transform(): 18 | return Compose([ToPILImage(), ToTensor()]) 19 | 20 | 21 | def get_colors(): 22 | ''' 23 | Dictionary of color map 24 | ''' 25 | return np.asarray([0, 128, 255]) 26 | 27 | 28 | class time_calculator(): 29 | def __init__(self): 30 | self.accumulate = 0 31 | 32 | def calculator(self, start_time, end_time, length=0): 33 | if length == 0: 34 | total_time = int(end_time - start_time) + self.accumulate 35 | self.accumulate = total_time 36 | else: 37 | total_time = int(end_time - start_time) * length 38 | total_time_min = total_time // 60 39 | total_time_sec = total_time % 60 40 | total_time_hr = total_time // 2400 41 | return total_time_hr, total_time_min, total_time_sec 42 | 43 | 44 | def writer_add_image(dir, writer, image, iter): 45 | ''' 46 | tensorboard image writer 47 | ''' 48 | x = vutils.make_grid(image, nrow=4) 49 | writer.add_image(dir, x, iter) 50 | 51 | 52 | def save_image(target, preds, img_name, root): 53 | ''' 54 | : img: image to be saved 55 | : img_name: image name 56 | ''' 57 | target = torch.split(target, 1, dim=0) 58 | preds = torch.split(preds, 1, dim=0) 59 | batch_num = len(preds) 60 | 61 | for ind in range(batch_num): 62 | vutils.save_image( 63 | target[ind], 64 | root + '{}'.format(img_name[ind].split('.png')[0] + '_target.png')) 65 | vutils.save_image( 66 | preds[ind], 67 | root + '{}'.format(img_name[ind].split('.png')[0] + '_pred.png')) 68 | 69 | 70 | def save_validation_image(preds, img_name, save_folder): 71 | ''' 72 | : img: image to be saved 73 | : img_name: image name 74 | ''' 75 | preds = torch.split(preds, 1, dim=0) 76 | batch_num = len(preds) 77 | 78 | for ind in range(batch_num): 79 | print('saving {}'.format(img_name[ind])) 80 | vutils.save_image( 81 | preds[ind], 82 | save_folder + '{}'.format(img_name[ind].split('.png')[0] + '.png')) 83 | 84 | 85 | def fun_ensemble(img): 86 | imgs = [ 87 | img, # 0 88 | img.rotate(90), # 1 89 | img.rotate(180), # 2 90 | img.rotate(270), # 3 91 | img.transpose(Image.FLIP_TOP_BOTTOM), # 4 92 | img.rotate(90).transpose(Image.FLIP_TOP_BOTTOM), # 5 93 | img.rotate(180).transpose(Image.FLIP_TOP_BOTTOM), # 6 94 | img.rotate(270).transpose(Image.FLIP_TOP_BOTTOM) 95 | ] # 7 96 | return imgs 97 | 98 | 99 | def fun_ensemble_numpy(img): 100 | imgs = [ 101 | img, # 0 102 | np.rot90(img, k=1, axes=(0, 1)), 103 | np.rot90(img, k=2, axes=(0, 1)), 104 | np.rot90(img, k=3, axes=(0, 1)), 105 | np.flipud(img), # 0 106 | np.flipud(np.rot90(img, k=1, axes=(0, 1))), 107 | np.flipud(np.rot90(img, k=2, axes=(0, 1))), 108 | np.flipud(np.rot90(img, k=3, axes=(0, 1))), 109 | ] 110 | return imgs 111 | 112 | 113 | def fun_ensemble_back(imgs): 114 | imgs = [ 115 | imgs[0], imgs[1].transpose(2, 3).flip(3), imgs[2].flip(2).flip(3), 116 | imgs[3].transpose(2, 3).flip(2), imgs[4].flip(2), imgs[5].transpose( 117 | 2, 3), imgs[6].flip(3), imgs[7].transpose(2, 3).flip(2).flip(3) 118 | ] 119 | 120 | img = sum(imgs) / len(imgs) 121 | 122 | return img 123 | 124 | 125 | def validation(net, 126 | val_data_loader, 127 | device, 128 | texture_net=None, 129 | save_tag=False, 130 | mode='student', 131 | is_validation=False, 132 | is_ensemble=False): 133 | psnr_list = [] 134 | ssim_list = [] 135 | save_folder = os.path.join( 136 | './results', 137 | 'result_' + datetime.now().strftime("%Y%m%d_%H%M%S") + '/') 138 | net.eval() 139 | if not os.path.exists(save_folder): 140 | os.makedirs(save_folder) 141 | 142 | for batch_id, val_data in enumerate(val_data_loader): 143 | with torch.no_grad(): 144 | x, target, image_name = val_data 145 | target = target.to(device, non_blocking=True) 146 | if mode == 'orig_size': 147 | if isinstance(x, list): 148 | x = [i.to(device) for i in x] 149 | y = [net(i)[1] for i in x] 150 | y = fun_ensemble_back(y) 151 | else: 152 | x = x.to(device, non_blocking=True) 153 | _, y, _ = net(x) 154 | psnr_list.extend(to_psnr(y, target)) 155 | ssim_list.extend(to_ssim_skimage(y, target)) 156 | elif mode == 'student' or mode == 'teacher': 157 | if texture_net: 158 | x = x.to(device, non_blocking=True) 159 | x = texture_net(x) 160 | y, _ = net(x) 161 | psnr_list.extend(to_psnr(y[0], target)) 162 | ssim_list.extend(to_ssim_skimage(y[0], target)) 163 | else: 164 | if isinstance(x, list): 165 | x = [i.to(device) for i in x] 166 | y = [net(i)[0][0] for i in x] 167 | y = fun_ensemble_back(y) 168 | psnr_list.extend(to_psnr(y, target)) 169 | ssim_list.extend(to_ssim_skimage(y, target)) 170 | else: 171 | x = x.to(device, non_blocking=True) 172 | y, _ = net(x) 173 | psnr_list.extend(to_psnr(y[0], target)) 174 | ssim_list.extend(to_ssim_skimage(y[0], target)) 175 | 176 | elif mode == 'texture': 177 | x = x.to(device, non_blocking=True) 178 | y = net(x) 179 | psnr_list.extend(to_psnr(y, target)) 180 | ssim_list.extend(to_ssim_skimage(y, target)) 181 | 182 | # Save image 183 | if save_tag: 184 | if mode == 'orig_size': 185 | save_image(target, y, image_name, save_folder) 186 | elif (mode == 'student' 187 | or mode == 'teacher') and is_validation == False: 188 | save_image(target, y[0], image_name, save_folder) 189 | elif (mode == 'student' 190 | or mode == 'teacher') and is_validation == True: 191 | if is_ensemble: 192 | save_validation_image(y, image_name, './validation') 193 | else: 194 | save_validation_image(y[0], image_name, './validation') 195 | elif mode == 'texture': 196 | save_image(target, y, image_name, save_folder) 197 | 198 | avr_psnr = sum(psnr_list) / len(psnr_list) 199 | avr_ssim = sum(ssim_list) / len(ssim_list) 200 | 201 | return avr_psnr, avr_ssim 202 | 203 | 204 | def to_psnr(dehaze, gt): 205 | mse = F.mse_loss(dehaze, gt, reduction='none') 206 | mse_split = torch.split(mse, 1, dim=0) 207 | mse_list = [ 208 | torch.mean(torch.squeeze(mse_split[ind])).item() 209 | for ind in range(len(mse_split)) 210 | ] 211 | 212 | # ToTensor scales input images to [0.0, 1.0] 213 | intensity_max = 1.0 214 | psnr_list = [10.0 * log10(intensity_max / mse) for mse in mse_list] 215 | return psnr_list 216 | 217 | 218 | def to_ssim_skimage(dehaze, gt): 219 | dehaze_list = torch.split(dehaze, 1, dim=0) 220 | gt_list = torch.split(gt, 1, dim=0) 221 | 222 | dehaze_list_np = [ 223 | dehaze_list[ind].permute(0, 2, 3, 1).data.cpu().numpy().squeeze() 224 | for ind in range(len(dehaze_list)) 225 | ] 226 | gt_list_np = [ 227 | gt_list[ind].permute(0, 2, 3, 1).data.cpu().numpy().squeeze() 228 | for ind in range(len(dehaze_list)) 229 | ] 230 | ssim_list = [ 231 | measure.compare_ssim( 232 | dehaze_list_np[ind], 233 | gt_list_np[ind], 234 | data_range=1, 235 | multichannel=True, 236 | sigma=1.5, 237 | gaussian_weights=True, 238 | use_sample_covariance=False) for ind in range(len(dehaze_list)) 239 | ] 240 | 241 | return ssim_list 242 | 243 | 244 | def print_log(epoch, num_epochs, one_epoch_time, train_psnr, val_psnr, 245 | val_ssim, category): 246 | print( 247 | '({0:.0f}s) Epoch [{1}/{2}], Train_PSNR:{3:.2f}, Val_PSNR:{4:.2f}, Val_SSIM:{5:.4f}' 248 | .format(one_epoch_time, epoch, num_epochs, train_psnr, val_psnr, 249 | val_ssim)) 250 | # write training log 251 | with open('./training_log/{}_log.txt'.format(category), 'a') as f: 252 | print( 253 | 'Date: {0}s, Time_Cost: {1:.0f}s, Epoch: [{2}/{3}], Train_PSNR: {4:.2f}, Val_Image_PSNR: {5:.2f}, Val_Image_SSIM: {6:.4f}' 254 | .format( 255 | time.strftime("%Y-%m-%d %H:%M:%S", 256 | time.localtime()), one_epoch_time, epoch, 257 | num_epochs, train_psnr, val_psnr, val_ssim), 258 | file=f) 259 | 260 | 261 | def adjust_learning_rate(optimizer, scheduler, epoch, learning_rate, writer): 262 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 263 | 264 | if epoch > 0: # and epoch < 16: 265 | if epoch % 2 == 0: 266 | learning_rate = scheduler.get_lr()[0] 267 | for param_group in optimizer.param_groups: 268 | param_group['lr'] = learning_rate 269 | print('Learning rate sets to {}.'.format(param_group['lr'])) 270 | scheduler.step() 271 | writer.add_scalars('lr/train_lr_group', { 272 | 'lr': learning_rate, 273 | }, epoch) 274 | return learning_rate 275 | 276 | 277 | def poly_learning_decay(optimizer, iter, total_epoch, loader_length, writer): 278 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 279 | max_iteration = total_epoch * loader_length 280 | learning_rate = optimizer.param_groups[0]['lr'] 281 | learning_rate = learning_rate * (1 - iter / max_iteration) 282 | for param_group in optimizer.param_groups: 283 | param_group['lr'] = learning_rate 284 | 285 | writer.add_scalars('lr/train_lr_group', { 286 | 'lr': learning_rate, 287 | }, iter) 288 | return learning_rate 289 | 290 | 291 | def set_requires_grad(nets, requires_grad=False): 292 | if not isinstance(nets, list): 293 | nets = [nets] 294 | for net in nets: 295 | if net is not None: 296 | for param in net.parameters(): 297 | param.requires_grad = requires_grad 298 | 299 | 300 | def adjust_learning_rate_step(optimizer, epoch, num_epochs, learning_rate): 301 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 302 | step = num_epochs // len(learning_rate) 303 | 304 | for param_group in optimizer.param_groups: 305 | param_group['lr'] = learning_rate[epoch // step] 306 | print('Learning rate sets to {}.'.format(param_group['lr'])) 307 | return learning_rate[epoch // step] 308 | -------------------------------------------------------------------------------- /validation_3channel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from torch.utils.data import Dataset 5 | from torchvision import transforms 6 | from torch.utils.data import DataLoader 7 | 8 | from models.model_3channel import AWNet 9 | from config import trainConfig 10 | import numpy as np 11 | import imageio 12 | import PIL.Image as Image 13 | import time 14 | import os 15 | from utils import validation, fun_ensemble_back, save_validation_image, fun_ensemble, fun_ensemble_numpy 16 | 17 | 18 | ENSEMBLE = False 19 | 20 | class wrapped_3_channel(nn.Module): 21 | def __init__(self): 22 | super().__init__() 23 | self.module = AWNet(3,3, block=[3,3,3,4,4]) 24 | def forward(self, x): 25 | return self.module(x) 26 | 27 | class LoadData_real(Dataset): 28 | 29 | def __init__(self, dataset_dir, is_ensemble=False): 30 | self.is_ensemble = is_ensemble 31 | 32 | self.raw_dir = os.path.join(dataset_dir, 'AIM2020_ISP_fullres_test_raw_pseudo_demosaicing') 33 | 34 | self.dataset_size = 42 35 | 36 | self.toTensor = transforms.Compose([transforms.ToTensor()]) 37 | 38 | def __len__(self): 39 | return self.dataset_size 40 | 41 | def __getitem__(self, idx): 42 | idx = idx + 1 43 | raw_image = Image.open(os.path.join(self.raw_dir, str(idx) + ".png")) 44 | 45 | if self.is_ensemble: 46 | raw_image = fun_ensemble(raw_image) 47 | raw_image = [self.toTensor(x) for x in raw_image] 48 | 49 | else: 50 | raw_image = self.toTensor(raw_image) 51 | 52 | return raw_image, str(idx) 53 | 54 | def extract_bayer_channels(raw): 55 | # Reshape the input bayer image 56 | 57 | ch_B = raw[1::2, 1::2] 58 | ch_Gb = raw[0::2, 1::2] 59 | ch_R = raw[0::2, 0::2] 60 | ch_Gr = raw[1::2, 0::2] 61 | 62 | RAW_combined = np.dstack((ch_B, ch_Gb, ch_R, ch_Gr)) 63 | RAW_norm = RAW_combined.astype(np.float32) / (4 * 255) 64 | 65 | return RAW_norm 66 | 67 | def test(): 68 | device_ids = [0] 69 | print('using device: {}'.format(device_ids)) 70 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 71 | net = wrapped_3_channel() 72 | # Reload 73 | 74 | net.load_state_dict(torch.load('{}/weight_3channel_best.pkl'.format(trainConfig.save_best), map_location='cpu')["model_state"]) 75 | print('weight loaded.') 76 | 77 | test_dataset = LoadData_real(trainConfig.data_dir, is_ensemble=ENSEMBLE) 78 | test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False, num_workers=0, 79 | pin_memory=False, drop_last=False) 80 | 81 | net.eval() 82 | save_folder = './result_fullres_3channel/' 83 | if not os.path.exists(save_folder): 84 | os.makedirs(save_folder) 85 | 86 | for batch_id, val_data in enumerate(test_loader): 87 | 88 | with torch.no_grad(): 89 | x, image_name = val_data 90 | if isinstance(x, list): 91 | y = [net(i)[0][0] for i in x] 92 | y = fun_ensemble_back(y) 93 | else: 94 | y, _ = net(x) 95 | if ENSEMBLE: 96 | save_validation_image(y, image_name, save_folder) 97 | else: 98 | save_validation_image(y[0], image_name, save_folder) 99 | 100 | if __name__ == '__main__': 101 | test() 102 | -------------------------------------------------------------------------------- /validation_4channel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from models.model_4channel import AWNet 4 | 5 | from torch.utils.data import Dataset 6 | from torchvision import transforms 7 | from torch.utils.data import DataLoader 8 | 9 | from config import trainConfig 10 | import numpy as np 11 | import imageio 12 | import PIL.Image as Image 13 | import time 14 | import os 15 | from utils import validation, fun_ensemble_back, save_validation_image, fun_ensemble, fun_ensemble_numpy 16 | 17 | ENSEMBLE = False 18 | 19 | class wrapped_4_channel(nn.Module): 20 | def __init__(self): 21 | super().__init__() 22 | self.module = AWNet(4, 3, block=[3, 3, 3, 4, 4]) 23 | 24 | def forward(self, x): 25 | return self.module(x) 26 | 27 | class LoadData_real(Dataset): 28 | 29 | def __init__(self, dataset_dir, is_ensemble=False): 30 | self.is_ensemble = is_ensemble 31 | 32 | self.raw_dir = os.path.join(dataset_dir, 'AIM2020_ISP_fullres_test_raw') 33 | 34 | self.dataset_size = 42 35 | 36 | self.toTensor = transforms.Compose([transforms.ToTensor()]) 37 | 38 | def __len__(self): 39 | return self.dataset_size 40 | 41 | def __getitem__(self, idx): 42 | idx = idx + 1 43 | raw_image = np.asarray(imageio.imread(os.path.join(self.raw_dir, str(idx) + '.png'))) 44 | raw_image = extract_bayer_channels(raw_image) 45 | 46 | if self.is_ensemble: 47 | 48 | raw_image = fun_ensemble_numpy(raw_image) 49 | raw_image = [torch.from_numpy(x.transpose((2, 0, 1)).copy()) for x in raw_image] 50 | 51 | else: 52 | raw_image = torch.from_numpy(raw_image.transpose((2, 0, 1)).copy()) 53 | 54 | return raw_image, str(idx) 55 | 56 | def extract_bayer_channels(raw): 57 | # Reshape the input bayer image 58 | 59 | ch_B = raw[1::2, 1::2] 60 | ch_Gb = raw[0::2, 1::2] 61 | ch_R = raw[0::2, 0::2] 62 | ch_Gr = raw[1::2, 0::2] 63 | 64 | RAW_combined = np.dstack((ch_B, ch_Gb, ch_R, ch_Gr)) 65 | RAW_norm = RAW_combined.astype(np.float32) / (4 * 255) 66 | 67 | return RAW_norm 68 | 69 | def test(): 70 | net1 = wrapped_4_channel() 71 | 72 | net1.load_state_dict( 73 | torch.load('{}/weight_4channel_best.pkl'.format(trainConfig.save_best), map_location="cpu")["model_state"]) 74 | print('weight loaded.') 75 | 76 | test_dataset = LoadData_real(trainConfig.data_dir, is_ensemble=ENSEMBLE) 77 | test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False, num_workers=0, 78 | pin_memory=False, drop_last=False) 79 | 80 | net1.eval() 81 | save_folder = './result_fullres_4channel/' 82 | if not os.path.exists(save_folder): 83 | os.makedirs(save_folder) 84 | 85 | for batch_id, val_data in enumerate(test_loader): 86 | 87 | with torch.no_grad(): 88 | raw_image, image_name = val_data 89 | if isinstance(raw_image, list): 90 | print('ensemble') 91 | y1 = [net1(i)[0][0] for i in raw_image] 92 | y1 = fun_ensemble_back(y1) 93 | print(y1.shape) 94 | 95 | else: 96 | y1, _ = net1(raw_image) 97 | y = y1[0] 98 | if ENSEMBLE: 99 | save_validation_image(y, image_name, save_folder) 100 | else: 101 | save_validation_image(y, image_name, save_folder) 102 | 103 | 104 | if __name__ == '__main__': 105 | test() 106 | -------------------------------------------------------------------------------- /validation_final.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from models.model_4channel import AWNet as gen1 4 | from models.model_3channel import AWNet as gen2 5 | from torch.utils.data import Dataset 6 | from torchvision import transforms 7 | from torch.utils.data import DataLoader 8 | from config import trainConfig 9 | import numpy as np 10 | import imageio 11 | import PIL.Image as Image 12 | import time 13 | import os 14 | from utils import fun_ensemble_back, save_validation_image, fun_ensemble, fun_ensemble_numpy 15 | 16 | ENSEMBLE = True 17 | 18 | 19 | def extract_bayer_channels(raw): 20 | # Reshape the input bayer image 21 | 22 | ch_B = raw[1::2, 1::2] 23 | ch_Gb = raw[0::2, 1::2] 24 | ch_R = raw[0::2, 0::2] 25 | ch_Gr = raw[1::2, 0::2] 26 | 27 | RAW_combined = np.dstack((ch_B, ch_Gb, ch_R, ch_Gr)) 28 | RAW_norm = RAW_combined.astype(np.float32) / (4 * 255) 29 | 30 | return RAW_norm 31 | 32 | 33 | class LoadData_real(Dataset): 34 | def __init__(self, dataset_dir, is_ensemble=False): 35 | self.is_ensemble = is_ensemble 36 | self.raw_dir1 = os.path.join(dataset_dir, 'AIM2020_ISP_test_raw') 37 | self.raw_dir2 = os.path.join(dataset_dir, 'AIM2020_ISP_test_pseudo_demosaicing') 38 | self.dataset_size = 1342 39 | 40 | self.toTensor = transforms.Compose([transforms.ToTensor()]) 41 | 42 | def __len__(self): 43 | return self.dataset_size 44 | 45 | def __getitem__(self, idx): 46 | raw_image1 = np.asarray( 47 | imageio.imread(os.path.join(self.raw_dir1, 48 | str(idx) + '.png'))) 49 | raw_image1 = extract_bayer_channels(raw_image1) 50 | 51 | raw_image2 = Image.open(os.path.join(self.raw_dir2, str(idx) + ".png")) 52 | 53 | if self.is_ensemble: 54 | 55 | raw_image1 = fun_ensemble_numpy(raw_image1) 56 | raw_image1 = [ 57 | torch.from_numpy(x.transpose((2, 0, 1)).copy()) 58 | for x in raw_image1 59 | ] 60 | 61 | raw_image2 = fun_ensemble(raw_image2) 62 | raw_image2 = [self.toTensor(x) for x in raw_image2] 63 | 64 | else: 65 | raw_image1 = torch.from_numpy( 66 | raw_image1.transpose((2, 0, 1)).copy()) 67 | raw_image2 = self.toTensor(raw_image2) 68 | 69 | return raw_image1, raw_image2, str(idx) 70 | 71 | 72 | def test(): 73 | device_ids = [0] 74 | print('using device: {}'.format(device_ids)) 75 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 76 | 77 | net1 = gen1(4, 3, block=[3, 3, 3, 4, 4]).to(device) 78 | net2 = gen2(3, 3, block=[3, 3, 3, 4, 4]).to(device) 79 | net1 = nn.DataParallel(net1, device_ids=device_ids) 80 | net2 = nn.DataParallel(net2, device_ids=device_ids) 81 | 82 | # Reload 83 | net1.load_state_dict( 84 | torch.load( 85 | '{}/weight_4channel_best.pkl'.format(trainConfig.save_best), 86 | map_location='cuda:0')["model_state"]) 87 | net2.load_state_dict( 88 | torch.load( 89 | '{}/weight_3channel_best.pkl'.format(trainConfig.save_best), 90 | map_location='cuda:0')["model_state"]) 91 | print('weight loaded.') 92 | 93 | test_dataset = LoadData_real(trainConfig.data_dir, is_ensemble=ENSEMBLE) 94 | test_loader = DataLoader( 95 | dataset=test_dataset, 96 | batch_size=4, 97 | shuffle=False, 98 | num_workers=12, 99 | pin_memory=True, 100 | drop_last=False) 101 | 102 | save_folder = './final_result/' 103 | if not os.path.exists(save_folder): 104 | os.makedirs(save_folder) 105 | 106 | for batch_id, val_data in enumerate(test_loader): 107 | 108 | with torch.no_grad(): 109 | raw_image1, raw_image2, image_name = val_data 110 | if isinstance(raw_image1, list): 111 | print('ensemble') 112 | raw_image1 = [i.to(device) for i in raw_image1] 113 | y1 = [net1(i)[0][0] for i in raw_image1] 114 | y1 = fun_ensemble_back(y1) 115 | 116 | raw_image2 = [i.to(device) for i in raw_image2] 117 | y2 = [net2(i)[0][0] for i in raw_image2] 118 | y2 = fun_ensemble_back(y2) 119 | y = (y1 + y2) / 2 120 | else: 121 | raw_image1 = raw_image1.to(device, non_blocking=True) 122 | raw_image2 = raw_image2.to(device, non_blocking=True) 123 | y1, _ = net1(raw_image1) 124 | y2, _ = net2(raw_image2) 125 | y = torch.zeros_like(y1[0]) 126 | y = (y1[0] + y2[0]) / 2 127 | if ENSEMBLE: 128 | save_validation_image(y, image_name, save_folder) 129 | else: 130 | save_validation_image(y, image_name, save_folder) 131 | 132 | 133 | if __name__ == '__main__': 134 | test() 135 | -------------------------------------------------------------------------------- /validation_final_fullres.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from models.model_4channel import AWNet as gen1 4 | from models.model_3channel import AWNet as gen2 5 | from torch.utils.data import Dataset 6 | from torchvision import transforms 7 | from torch.utils.data import DataLoader 8 | from config import trainConfig 9 | import numpy as np 10 | import imageio 11 | import PIL.Image as Image 12 | import time 13 | import os 14 | from utils import fun_ensemble_back, save_validation_image, fun_ensemble, to_psnr, to_ssim_skimage, fun_ensemble_numpy 15 | 16 | ENSEMBLE = False 17 | 18 | 19 | class wrapped_3_channel(nn.Module): 20 | def __init__(self): 21 | super().__init__() 22 | self.module = gen2(3, 3, block=[3, 3, 3, 4, 4]) 23 | 24 | def forward(self, x): 25 | return self.module(x) 26 | 27 | 28 | class wrapped_4_channel(nn.Module): 29 | def __init__(self): 30 | super().__init__() 31 | self.module = gen1(4, 3, block=[3, 3, 3, 4, 4]) 32 | 33 | def forward(self, x): 34 | return self.module(x) 35 | 36 | 37 | def extract_bayer_channels(raw): 38 | # Reshape the input bayer image 39 | 40 | ch_B = raw[1::2, 1::2] 41 | ch_Gb = raw[0::2, 1::2] 42 | ch_R = raw[0::2, 0::2] 43 | ch_Gr = raw[1::2, 0::2] 44 | 45 | RAW_combined = np.dstack((ch_B, ch_Gb, ch_R, ch_Gr)) 46 | RAW_norm = RAW_combined.astype(np.float32) / (4 * 255) 47 | 48 | return RAW_norm 49 | 50 | 51 | class LoadData_real(Dataset): 52 | def __init__(self, dataset_dir, is_ensemble=False): 53 | self.is_ensemble = is_ensemble 54 | 55 | 56 | self.raw_dir1 = os.path.join(dataset_dir, 57 | 'AIM2020_ISP_fullres_test_raw') 58 | self.raw_dir2 = os.path.join(dataset_dir, 59 | 'AIM2020_ISP_fullres_test_raw_pseudo_demosaicing') 60 | self.dataset_size = 42 61 | 62 | self.toTensor = transforms.Compose([transforms.ToTensor()]) 63 | 64 | def __len__(self): 65 | return self.dataset_size 66 | 67 | def __getitem__(self, idx): 68 | idx = idx + 1 69 | raw_image1 = np.asarray( 70 | imageio.imread(os.path.join(self.raw_dir1, 71 | str(idx) + '.png'))) 72 | raw_image1 = extract_bayer_channels(raw_image1) 73 | 74 | raw_image2 = Image.open(os.path.join(self.raw_dir2, str(idx) + ".png")) 75 | 76 | if self.is_ensemble: 77 | 78 | raw_image1 = fun_ensemble_numpy(raw_image1) 79 | raw_image1 = [ 80 | torch.from_numpy(x.transpose((2, 0, 1)).copy()) 81 | for x in raw_image1 82 | ] 83 | 84 | raw_image2 = fun_ensemble(raw_image2) 85 | raw_image2 = [self.toTensor(x) for x in raw_image2] 86 | 87 | else: 88 | raw_image1 = torch.from_numpy( 89 | raw_image1.transpose((2, 0, 1)).copy()) 90 | raw_image2 = self.toTensor(raw_image2) 91 | 92 | return raw_image1, raw_image2, str(idx) 93 | 94 | 95 | def test(): 96 | net1 = wrapped_4_channel() 97 | net2 = wrapped_3_channel() 98 | 99 | # Reload 100 | net1.load_state_dict( 101 | torch.load( 102 | '{}/weight_4channel_best.pkl'.format(trainConfig.save_best), 103 | map_location="cpu")["model_state"]) 104 | net2.load_state_dict( 105 | torch.load( 106 | '{}/weight_3channel_best.pkl'.format(trainConfig.save_best), 107 | map_location="cpu")["model_state"]) 108 | print('weight loaded.') 109 | 110 | test_dataset = LoadData_real(trainConfig.data_dir, is_ensemble=ENSEMBLE) 111 | test_loader = DataLoader( 112 | dataset=test_dataset, 113 | batch_size=1, 114 | shuffle=False, 115 | num_workers=0, 116 | pin_memory=False, 117 | drop_last=False) 118 | 119 | net1.eval() 120 | net2.eval() 121 | save_folder1 = './final_result_fullres/' 122 | if not os.path.exists(save_folder1): 123 | os.makedirs(save_folder1) 124 | 125 | for batch_id, val_data in enumerate(test_loader): 126 | 127 | with torch.no_grad(): 128 | raw_image1, raw_image2, image_name = val_data 129 | if isinstance(raw_image1, list): 130 | print('ensemble') 131 | y1 = [net1(i)[0][0] for i in raw_image1] 132 | y1 = fun_ensemble_back(y1) 133 | 134 | y2 = [net2(i)[0][0] for i in raw_image2] 135 | y2 = fun_ensemble_back(y2) 136 | y = (y1 + y2) / 2 137 | else: 138 | y1, _ = net1(raw_image1) 139 | y2, _ = net2(raw_image2) 140 | y = torch.zeros_like(y1[0]) 141 | y = (y1[0] + y2[0]) / 2 142 | if ENSEMBLE: 143 | save_validation_image(y, image_name, save_folder1) 144 | else: 145 | save_validation_image(y, image_name, save_folder1) 146 | 147 | 148 | if __name__ == '__main__': 149 | test() 150 | --------------------------------------------------------------------------------