├── Figures ├── README.md ├── 3RD.PNG ├── HMA.PNG ├── Figure_1.png ├── Figure_2.png ├── Figure_3.png ├── Figure_4.png ├── Figure_5.png ├── Figure_6.png ├── Final_Results.PNG └── Burst_Results_List.PNG ├── Single ├── README.md ├── utils.py ├── test.py ├── train.py ├── ssim.py ├── dataset.py └── models.py ├── NTIRE2020_Demoireing_Challenge_Factsheet__C3Net_.pdf ├── Burst ├── README.md ├── utils.py ├── train.py ├── test.py ├── ssim.py ├── models.py └── dataset.py └── README.md /Figures/README.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /Figures/3RD.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bmycheez/C3Net/HEAD/Figures/3RD.PNG -------------------------------------------------------------------------------- /Figures/HMA.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bmycheez/C3Net/HEAD/Figures/HMA.PNG -------------------------------------------------------------------------------- /Figures/Figure_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bmycheez/C3Net/HEAD/Figures/Figure_1.png -------------------------------------------------------------------------------- /Figures/Figure_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bmycheez/C3Net/HEAD/Figures/Figure_2.png -------------------------------------------------------------------------------- /Figures/Figure_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bmycheez/C3Net/HEAD/Figures/Figure_3.png -------------------------------------------------------------------------------- /Figures/Figure_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bmycheez/C3Net/HEAD/Figures/Figure_4.png -------------------------------------------------------------------------------- /Figures/Figure_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bmycheez/C3Net/HEAD/Figures/Figure_5.png -------------------------------------------------------------------------------- /Figures/Figure_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bmycheez/C3Net/HEAD/Figures/Figure_6.png -------------------------------------------------------------------------------- /Figures/Final_Results.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bmycheez/C3Net/HEAD/Figures/Final_Results.PNG -------------------------------------------------------------------------------- /Figures/Burst_Results_List.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bmycheez/C3Net/HEAD/Figures/Burst_Results_List.PNG -------------------------------------------------------------------------------- /Single/README.md: -------------------------------------------------------------------------------- 1 | # Track 1: Single Image, C3Net 2 | [Reference](https://competitions.codalab.org/competitions/22223#learn_the_details) 3 | -------------------------------------------------------------------------------- /NTIRE2020_Demoireing_Challenge_Factsheet__C3Net_.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bmycheez/C3Net/HEAD/NTIRE2020_Demoireing_Challenge_Factsheet__C3Net_.pdf -------------------------------------------------------------------------------- /Burst/README.md: -------------------------------------------------------------------------------- 1 | # Track 2: Burst, C3Net-Burst 2 | For Track 2: Burst, we gave some variations from Track 1: Single Image (C3Net). 3 | 4 | 1. pre-processed input images for padding by chroma key 5 | 6 | 2. Controlled the number of channels 7 | 8 | 3. used global maxpooling 9 | -------------------------------------------------------------------------------- /Burst/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | from skimage.measure.simple_metrics import compare_psnr 6 | 7 | def weights_init_kaiming(m): 8 | classname = m.__class__.__name__ 9 | if classname.find('Conv') != -1: 10 | nn.init.kaiming_normal(m.weight.data, a=0, mode='fan_in') 11 | elif classname.find('Linear') != -1: 12 | nn.init.kaiming_normal(m.weight.data, a=0, mode='fan_in') 13 | elif classname.find('BatchNorm') != -1: 14 | # nn.init.uniform(m.weight.data, 1.0, 0.02) 15 | m.weight.data.normal_(mean=0, std=math.sqrt(2./9./64.)).clamp_(-0.025,0.025) 16 | nn.init.constant(m.bias.data, 0.0) 17 | 18 | def batch_PSNR(img, imclean, data_range): 19 | Img = img.data.cpu().numpy().astype(np.float32) 20 | Iclean = imclean.data.cpu().numpy().astype(np.float32) 21 | PSNR = 0 22 | for i in range(Img.shape[0]): 23 | PSNR += compare_psnr(Iclean[i,:,:,:], Img[i,:,:,:], data_range=data_range) 24 | return (PSNR/Img.shape[0]) 25 | 26 | def data_augmentation(image, mode): 27 | out = np.transpose(image, (1,2,0)) 28 | if mode == 0: 29 | # original 30 | out = out 31 | elif mode == 1: 32 | # flip up and down 33 | out = np.flipud(out) 34 | elif mode == 2: 35 | # rotate counterwise 90 degree 36 | out = np.rot90(out) 37 | elif mode == 3: 38 | # rotate 90 degree and flip up and down 39 | out = np.rot90(out) 40 | out = np.flipud(out) 41 | elif mode == 4: 42 | # rotate 180 degree 43 | out = np.rot90(out, k=2) 44 | elif mode == 5: 45 | # rotate 180 degree and flip 46 | out = np.rot90(out, k=2) 47 | out = np.flipud(out) 48 | elif mode == 6: 49 | # rotate 270 degree 50 | out = np.rot90(out, k=3) 51 | elif mode == 7: 52 | # rotate 270 degree and flip 53 | out = np.rot90(out, k=3) 54 | out = np.flipud(out) 55 | return np.transpose(out, (2,0,1)) 56 | -------------------------------------------------------------------------------- /Single/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | from skimage.measure.simple_metrics import compare_psnr 6 | 7 | def weights_init_kaiming(m): 8 | classname = m.__class__.__name__ 9 | if classname.find('Conv') != -1: 10 | nn.init.kaiming_normal(m.weight.data, a=0, mode='fan_in') 11 | elif classname.find('Linear') != -1: 12 | nn.init.kaiming_normal(m.weight.data, a=0, mode='fan_in') 13 | elif classname.find('BatchNorm') != -1: 14 | # nn.init.uniform(m.weight.data, 1.0, 0.02) 15 | m.weight.data.normal_(mean=0, std=math.sqrt(2./9./64.)).clamp_(-0.025,0.025) 16 | nn.init.constant(m.bias.data, 0.0) 17 | 18 | def batch_PSNR(img, imclean, data_range): 19 | Img = img.data.cpu().numpy().astype(np.float32) 20 | Iclean = imclean.data.cpu().numpy().astype(np.float32) 21 | PSNR = 0 22 | for i in range(Img.shape[0]): 23 | PSNR += compare_psnr(Iclean[i,:,:,:], Img[i,:,:,:], data_range=data_range) 24 | return (PSNR/Img.shape[0]) 25 | 26 | def data_augmentation(image, mode): 27 | out = np.transpose(image, (1,2,0)) 28 | if mode == 0: 29 | # original 30 | out = out 31 | elif mode == 1: 32 | # flip up and down 33 | out = np.flipud(out) 34 | elif mode == 2: 35 | # rotate counterwise 90 degree 36 | out = np.rot90(out) 37 | elif mode == 3: 38 | # rotate 90 degree and flip up and down 39 | out = np.rot90(out) 40 | out = np.flipud(out) 41 | elif mode == 4: 42 | # rotate 180 degree 43 | out = np.rot90(out, k=2) 44 | elif mode == 5: 45 | # rotate 180 degree and flip 46 | out = np.rot90(out, k=2) 47 | out = np.flipud(out) 48 | elif mode == 6: 49 | # rotate 270 degree 50 | out = np.rot90(out, k=3) 51 | elif mode == 7: 52 | # rotate 270 degree and flip 53 | out = np.rot90(out, k=3) 54 | out = np.flipud(out) 55 | return np.transpose(out, (2,0,1)) 56 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # C3Net 2 | This is a PyTorch implementation of the [New Trends in Image Restoration and Enhancement workshop and challenges on image and video restoration and enhancement (NTIRE 2020 with CVPR 2020)](https://data.vision.ee.ethz.ch/cvl/ntire20/) paper, [C3Net: Demoireing Network Attentive in Channel, Color and Concatenation](http://openaccess.thecvf.com/content_CVPRW_2020/html/w31/Kim_C3Net_Demoireing_Network_Attentive_in_Channel_Color_and_Concatenation_CVPRW_2020_paper.html). 3 | 4 | If you find our project useful in your research, please consider citing: 5 | ~~~ 6 | @InProceedings{Kim_2020_CVPR_Workshops, 7 | author = {Kim, Sangmin and Nam, Hyungjoon and Kim, Jisu and Jeong, Jechang}, 8 | title = {C3Net: Demoireing Network Attentive in Channel, Color and Concatenation}, 9 | booktitle = {The IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) Workshops}, 10 | month = {June}, 11 | year = {2020} 12 | } 13 | ~~~ 14 | 15 | # Dependencies 16 | Python 3.6.9 17 | PyTorch 1.4.0 18 | 19 | # Data 20 | [Reference](https://competitions.codalab.org/competitions/22223#participate-get_data) 21 | 22 | You have to sign in Codalab and apply to **NTIRE 2020 Demoireing Challenge** before getting the data. 23 | 24 | # Proposed algorithm 25 | ![C3Net (Track 1: Single Image)](Figures/Figure_1.png) 26 | ![AVC_Block](Figures/Figure_2.png) 27 | ![AttBlock](Figures/Figure_3.png) 28 | ![ResBlock](Figures/Figure_4.png) 29 | ![C3Net-Burst (Track 2: Burst)](Figures/Figure_5.png) 30 | ![AVC_Block-Burst](Figures/Figure_6.png) 31 | 32 | # Training 33 | Use the following command to use our training codes 34 | ~~~ 35 | python train.py 36 | ~~~ 37 | There are other options you can choose. 38 | Please refer to train.py. 39 | 40 | # Test 41 | Use the following command to use our test codes 42 | ~~~ 43 | python test.py 44 | ~~~ 45 | There are other options you can choose. 46 | Please refer to test.py. 47 | 48 | # Performance (PSNR/SSIM) 49 | To use heavier model, we also used numpy to read input data, not hdf5. 50 | [Hyung-Joon](https://github.com/Hyung-Joon) and [jisukim](https://github.com/jisus189) helped it. 51 | **Our best records can be derived in [the code](https://github.com/Hyung-Joon/Demoire-Burst-single-master)** by changing h5 into numpy and reducing GPU memory. 52 | 53 | |Validation Server |PSNR |SSIM |Rank | 54 | |:-----------------------------------------------------------------------------------|:-------|:-------|:-------| 55 | |[Track 1: Single Image](https://competitions.codalab.org/competitions/22223#results)|41.30 |0.99 |9th | 56 | |[Track 2: Burst](https://competitions.codalab.org/competitions/22224#results) |40.55 |0.99 |5th | 57 | 58 | ![Burst_Results_List](Figures/Burst_Results_List.PNG) 59 | 60 | [Testing Server Reference](https://arxiv.org/pdf/2005.03155.pdf) 61 | |Testing Server |PSNR |SSIM |Rank | 62 | |:--------------------|:-------|:-------|:------| 63 | |Track 1: Single Image|41.11 |0.99 |4th | 64 | |Track 2: Burst |40.33 |0.99 |5th | 65 | 66 | ![Final_Results](Figures/Final_Results.PNG) 67 | 68 | ![Honorable_Mention_Award](Figures/HMA.PNG) 69 | 70 | # Contact 71 | If you have any question about **Demoireing** model and the CVPR2020 challenge paper, feel free to ask me to . 72 | If you have any question about **Deblurring** model, visit [here](https://github.com/Hyung-Joon/Deblur-mobile-RCAN-Master) and feel free to ask Hyung-Joon to <013107nam@gmail.com>. 73 | If you have any question about using **more heavier C3Net**, visit [here](https://github.com/Hyung-Joon/Demoire-Burst-single-master) and feel free to ask jisukim to . 74 | 75 | # Acknowledgement 76 | Thanks for [SaoYan](https://github.com/SaoYan/DnCNN-PyTorch) who gave the implementaion of DnCNN. 77 | Thanks for [yun_yang](https://github.com/jt827859032/DRRN-pytorch) who gave the implementation of DRRN. 78 | Thanks for [BumjunPark](https://github.com/BumjunPark/DHDN) who gave the implementation of DHDN. 79 | 80 | Hint of color loss from [Jorge Pessoa](https://github.com/jorge-pessoa/pytorch-colors). 81 | Hint of concatenation and residual learning from [RDN (informal implementation)](https://github.com/lingtengqiu/RDN-pytorch). 82 | Hint of U-net block from [DIDN (formal implementation)](https://github.com/SonghyunYu/DIDN). 83 | 84 | C3Net started from [RUN](https://github.com/bmycheez/RUN). 85 | 86 | # More Details 87 | Also, we won 3rd Place in [**NTIRE 2020 Challenge on Image and Video Deblurring**](https://arxiv.org/pdf/2005.01244.pdf) thanks to [Hyung-Joon](https://github.com/Hyung-Joon) and [jisukim](https://github.com/jisus189). 88 | The code is available at [here](https://github.com/Hyung-Joon/Deblur-mobile-RCAN-Master). 89 | 90 | ![3rd_Place](Figures/3RD.PNG) 91 | -------------------------------------------------------------------------------- /Single/test.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | import argparse 4 | import glob 5 | import time 6 | from torch.autograd import Variable 7 | from models_v2 import Net 8 | from utils import * 9 | 10 | 11 | parser = argparse.ArgumentParser(description="DnCNN_Test") 12 | parser.add_argument("--num", type=int, default=3, help="Number of total layers") 13 | parser.add_argument("--logdir", type=str, default=".", help='path of log files') 14 | parser.add_argument("--gpu", type=str, default='0', help='test on Set12 or Set68') 15 | parser.add_argument("--inputdir", type=str, default='DemoireingTestInputSingle', help='noise level used on test set') 16 | opt = parser.parse_args() 17 | 18 | 19 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 20 | os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu 21 | 22 | 23 | def normalize(data): 24 | return data/255. 25 | 26 | 27 | def self_ensemble(out, mode, forward): 28 | if mode == 0: 29 | # original 30 | out = out 31 | elif mode == 1: 32 | # flip up and down 33 | out = np.flipud(out) 34 | elif mode == 2: 35 | # rotate counterwise 90 degree 36 | if forward == 1: 37 | out = np.rot90(out) 38 | else: 39 | out = np.rot90(out, k=3) 40 | elif mode == 3: 41 | # rotate 90 degree and flip up and down 42 | if forward == 1: 43 | out = np.rot90(out) 44 | out = np.flipud(out) 45 | else: 46 | out = np.flipud(out) 47 | out = np.rot90(out, k=3) 48 | elif mode == 4: 49 | # rotate 180 degree 50 | out = np.rot90(out, k=2) 51 | elif mode == 5: 52 | # rotate 180 degree and flip 53 | if forward == 1: 54 | out = np.rot90(out, k=2) 55 | out = np.flipud(out) 56 | else: 57 | out = np.flipud(out) 58 | out = np.rot90(out, k=2) 59 | elif mode == 6: 60 | if forward == 1: 61 | out = np.rot90(out, k=3) 62 | else: 63 | out = np.rot90(out) 64 | elif mode == 7: 65 | # rotate 270 degree and flip 66 | if forward == 1: 67 | out = np.rot90(out, k=3) 68 | out = np.flipud(out) 69 | else: 70 | out = np.flipud(out) 71 | out = np.rot90(out) 72 | return out 73 | 74 | 75 | def self_ensemble_v2(out, mode, forward): 76 | if mode == 0: 77 | # original 78 | out = out 79 | elif mode == 1: 80 | # flip up and down 81 | out = np.flipud(out) 82 | elif mode == 2: 83 | out = np.fliplr(out) 84 | elif mode == 3: 85 | out = np.flipud(out) 86 | out = np.fliplr(out) 87 | return out 88 | 89 | 90 | def main(): 91 | # Build model 92 | print('Loading model ...\n') 93 | model = Net().cuda() 94 | # device_ids = [0] 95 | # model = nn.DataParallel(net, device_ids=device_ids).cuda() 96 | a = torch.load(glob.glob(os.path.join(opt.logdir, '*.pth'))[0]) 97 | print(glob.glob(os.path.join(opt.logdir, '*.pth'))[0]) 98 | ok = input("Right model? ") 99 | if ok == 'n': 100 | return 101 | model.load_state_dict(a) 102 | DHDN_flag = 4 103 | ensemble_flag = 4 104 | model.eval() 105 | # load data info 106 | print('Loading data info ...\n') 107 | files_source = glob.glob(os.path.join('D:/', opt.inputdir, '*_%d.png' 108 | % opt.num)) 109 | files_source.sort() 110 | # process data 111 | psnr_test = 0 112 | c = 0 113 | for f in files_source: 114 | # image 115 | start = time.time() 116 | final = np.zeros(cv2.imread(f).shape) 117 | for mode in range(ensemble_flag): 118 | Img = cv2.imread(f) 119 | hh, ww, cc = Img.shape 120 | Img = self_ensemble_v2(Img, mode, 1) 121 | Img = np.swapaxes(Img, 0, 2) 122 | Img = np.swapaxes(Img, 1, 2) 123 | Img = np.float32(normalize(Img)) 124 | a = Img.shape[1] 125 | b = Img.shape[2] 126 | if a % DHDN_flag != 0 or b % DHDN_flag != 0: 127 | h = DHDN_flag - (a % DHDN_flag) 128 | w = DHDN_flag - (b % DHDN_flag) 129 | Img = np.pad(Img, [(0, 0), (h//2, h-h//2), (w//2, w-w//2)], mode='edge') 130 | Img = np.expand_dims(Img, 0) 131 | ISource = torch.Tensor(Img) 132 | INoisy = Variable(ISource.cuda()) 133 | with torch.no_grad(): # this can save much memory 134 | Out = torch.clamp(model(INoisy), 0., 1.) 135 | if a % DHDN_flag != 0 or b % DHDN_flag != 0: 136 | h = DHDN_flag - (a % DHDN_flag) 137 | w = DHDN_flag - (b % DHDN_flag) 138 | Out = Out[:, :, h//2:Img.shape[0]-(h-h//2+1), w//2:Img.shape[1]-(w-w//2+1)] 139 | name = str(c) 140 | if str(c) != 6: 141 | for i in range(6 - len(str(c))): 142 | name = '0' + name 143 | out = Out.squeeze(0).permute(1, 2, 0) * 255 144 | out = out.cpu().numpy() 145 | out = self_ensemble_v2(out, mode, 0) 146 | final += out 147 | cv2.imwrite(name + "_gt.png", final/ensemble_flag) 148 | mytime = time.time() - start 149 | psnr_test += mytime 150 | print("%s" % f) 151 | c += 1 152 | psnr_test /= len(files_source) 153 | print("\nRuntime on test data %.2f" % psnr_test) 154 | 155 | 156 | if __name__ == "__main__": 157 | main() 158 | -------------------------------------------------------------------------------- /Single/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch.optim as optim 4 | from torch.autograd import Variable 5 | from torch.utils.data import DataLoader 6 | from torchsummary import * 7 | from dataset import prepare_data, Dataset 8 | from utils import * 9 | from datetime import datetime 10 | from ssim import * 11 | from models import * 12 | 13 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 14 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 15 | 16 | parser = argparse.ArgumentParser(description="DnCNN") 17 | parser.add_argument("--preprocess", type=bool, default=False, help='run prepare_data or not') 18 | parser.add_argument("--batchSize", type=int, default=1, help="Training batch size") 19 | parser.add_argument("--patch", type=int, default=128, help="Number of total layers") 20 | parser.add_argument("--epochs", type=int, default=300, help="Number of training epochs") 21 | parser.add_argument("--start_epochs", type=int, default=27, help="Number of training epochs") 22 | parser.add_argument("--start_iters", type=int, default=5998, help="Number of training epochs") 23 | parser.add_argument("--resume", type=str, default="net_38.0006.pth", help="Number of training epochs") 24 | parser.add_argument("--step", type=int, default=30, help="When to decay learning rate; should be less than epochs") 25 | parser.add_argument("--lr", type=float, default=1e-4, help="Initial learning rate") 26 | parser.add_argument("--decay", type=int, default=10, help="Initial learning rate") 27 | parser.add_argument("--outf", type=str, default="./checkpoint", help='path of log files') 28 | parser.add_argument("--mode", type=str, default="S", help='with known noise level (S) or blind training (B)') 29 | opt = parser.parse_args() 30 | 31 | 32 | def main(): 33 | # Load dataset 34 | print('Loading dataset ...\n') 35 | dataset_train = Dataset(train=True) 36 | dataset_val = Dataset(train=False) 37 | loader_train = DataLoader(dataset=dataset_train, num_workers=4, batch_size=opt.batchSize, shuffle=True) 38 | loader_val = DataLoader(dataset=dataset_val, num_workers=4, batch_size=1, shuffle=False) 39 | # print(opt.batchSize) 40 | print("# of training samples: %d\n" % int(len(dataset_train))) 41 | # Build model 42 | # net = DnCNN(channels=1, num_of_layers=opt.num_of_layers) 43 | model = Net().cuda() 44 | # s = MSSSIM() 45 | criterion = nn.L1Loss().cuda() 46 | # vgg = Vgg16(requires_grad=False).cuda() 47 | # vgg = VGG('54').cuda() 48 | # Move to GPU 49 | # model = nn.DataParallel(net, device_ids=device_ids).cuda() 50 | # ''' 51 | if opt.resume: 52 | model.load_state_dict(torch.load(opt.resume)) 53 | # ''' 54 | summary(model, (3, 128, 128)) 55 | # Optimizer 56 | optimizer = optim.Adam(model.parameters(), lr=opt.lr) 57 | for epoch in range(opt.start_epochs, opt.epochs): 58 | current_lr = opt.lr * ((1 / opt.decay) ** ((epoch - opt.start_epochs) // opt.step)) 59 | # set learning rate 60 | for param_group in optimizer.param_groups: 61 | param_group["lr"] = current_lr 62 | print('learning rate %f' % current_lr) 63 | # train 64 | for i, (imgn_train, img_train) in enumerate(loader_train, 0): 65 | if i < opt.start_iters and epoch == opt.start_epochs: 66 | continue 67 | # training step 68 | model.train() 69 | model.zero_grad() 70 | optimizer.zero_grad() 71 | img_train, imgn_train = Variable(img_train.cuda()), Variable(imgn_train.cuda()) 72 | out_train = model(imgn_train) 73 | # feat_x = vgg(imgn_train) 74 | # feat_y = vgg(out_train) 75 | # perceptual_loss = criterion(feat_y.relu2_2, feat_x.relu2_2) 76 | # perceptual_loss = vgg(out_train, img_train) 77 | loss = color_loss(out_train, img_train) + criterion(out_train, img_train) 78 | # + 1e-4 * ((1 - s(out_train, img_train)) / 2.) 79 | loss /= 2 80 | loss.backward() 81 | optimizer.step() 82 | # ''' 83 | # if you are using older version of PyTorch, you may need to change loss.item() to loss.data[0] 84 | if i % int(len(loader_train)//5) == 0: 85 | # the end of each epoch 86 | model.eval() 87 | # validate 88 | psnr_val = 0 89 | for _, (imgn_val, img_val) in enumerate(loader_val, 0): 90 | with torch.no_grad(): 91 | img_val, imgn_val = Variable(img_val.cuda()), Variable(imgn_val.cuda()) 92 | out_val = torch.clamp(model(imgn_val), 0., 1.) 93 | psnr_val += batch_PSNR(out_val, img_val, 1.) 94 | psnr_val /= len(dataset_val) 95 | now = datetime.now() 96 | print("[epoch %d][%d/%d] loss: %.6f PSNR_val: %.4f" % 97 | (epoch+1, i+1, len(loader_train), loss.item(), psnr_val), end='') 98 | print(' ', now.year, now.month, now.day, now.hour, now.minute, now.second) 99 | if psnr_val > 38: 100 | torch.save(model.state_dict(), os.path.join(opt.outf, 'net_' + str(round(psnr_val, 4)) + '.pth')) 101 | # ''' 102 | torch.save(model.state_dict(), os.path.join(opt.outf, 'net_' + str(round(psnr_val, 4)) + '.pth')) 103 | 104 | 105 | if __name__ == "__main__": 106 | if opt.preprocess: 107 | if opt.mode == 'S': 108 | prepare_data(data_path='data', patch_size=opt.patch, stride=opt.patch, aug_times=1) 109 | if opt.mode == 'B': 110 | prepare_data(data_path='data', patch_size=50, stride=10, aug_times=2) 111 | main() 112 | -------------------------------------------------------------------------------- /Burst/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch.optim as optim 3 | from torch.autograd import Variable 4 | from torch.utils.data import DataLoader 5 | from models import * 6 | from dataset import * 7 | from utils import * 8 | from datetime import datetime 9 | from ssim import * 10 | 11 | 12 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 13 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 14 | 15 | parser = argparse.ArgumentParser(description="DnCNN") 16 | parser.add_argument("--preprocess", type=bool, default=False, help='run prepare_data or not') 17 | parser.add_argument("--batchSize", type=int, default=1, help="Training batch size") 18 | parser.add_argument("--patch", type=int, default=128, help="Number of total layers") 19 | parser.add_argument("--epochs", type=int, default=300, help="Number of training epochs") 20 | parser.add_argument("--start_epochs", type=int, default=60, help="Number of training epochs") 21 | parser.add_argument("--start_iters", type=int, default=0, help="Number of training epochs") 22 | parser.add_argument("--resume", type=str, default="/home/user/depthMap/ksm/CVPR/demoire/logs/48_40.50146.pth", 23 | help="Number of training epochs") 24 | parser.add_argument("--step", type=int, default=30, help="When to decay learning rate; should be less than epochs") 25 | parser.add_argument("--lr", type=float, default=1e-4, help="Initial learning rate") 26 | parser.add_argument("--decay", type=int, default=10, help="Initial learning rate") 27 | parser.add_argument("--outf", type=str, default="/home/user/depthMap/ksm/CVPR/demoire/checkpoint", 28 | help='path of log files') 29 | parser.add_argument("--mode", type=str, default="S", help='with known noise level (S) or blind training (B)') 30 | opt = parser.parse_args() 31 | 32 | 33 | def main(): 34 | # Load dataset 35 | print('Loading dataset ...\n') 36 | dataset_train = DatasetBurst(train=True) 37 | dataset_val = DatasetBurst(train=False) 38 | loader_train = DataLoader(dataset=dataset_train, num_workers=4, batch_size=opt.batchSize, shuffle=False) 39 | loader_val = DataLoader(dataset=dataset_val, num_workers=4, batch_size=1, shuffle=False) 40 | # print(opt.batchSize) 41 | print("# of training samples: %d\n" % int(len(dataset_train))) 42 | # Build model 43 | # net = DnCNN(channels=1, num_of_layers=opt.num_of_layers) 44 | model = Net().cuda() 45 | # s = MSSSIM() 46 | criterion = nn.L1Loss().cuda() 47 | burst = BurstLoss().cuda() 48 | # vgg = Vgg16(requires_grad=False).cuda() 49 | # vgg = VGG('54').cuda() 50 | # Move to GPU 51 | # model = nn.DataParallel(net, device_ids=device_ids).cuda() 52 | # ''' 53 | if opt.resume: 54 | model.load_state_dict(torch.load(opt.resume)) 55 | # test.main(model) 56 | # return 57 | # ''' 58 | # summary(model, (3, 128, 128)) 59 | # Optimizer 60 | optimizer = optim.Adam(model.parameters(), lr=opt.lr) 61 | psnr_max = 0 62 | loss_min = 1 63 | for epoch in range(opt.start_epochs, opt.epochs): 64 | # current_lr = opt.lr * ((1 / opt.decay) ** ((epoch - opt.start_epochs) // opt.step)) 65 | current_lr = opt.lr * ((1 / opt.decay) ** (epoch // opt.step)) 66 | # set learning rate 67 | for param_group in optimizer.param_groups: 68 | param_group["lr"] = current_lr 69 | print('learning rate %f' % current_lr) 70 | # train 71 | for i, (imgn_train, img_train) in enumerate(loader_train, 0): 72 | if i < opt.start_iters: 73 | continue 74 | # training step 75 | model.train() 76 | model.zero_grad() 77 | optimizer.zero_grad() 78 | img_train, imgn_train = Variable(img_train.cuda()), Variable(imgn_train.cuda()) 79 | out_train = model(imgn_train) 80 | # feat_x = vgg(imgn_train) 81 | # feat_y = vgg(out_train) 82 | # perceptual_loss = criterion(feat_y.relu2_2, feat_x.relu2_2) 83 | # perceptual_loss = vgg(out_train, img_train) 84 | loss_color = color_loss(out_train, img_train) 85 | loss_content = criterion(out_train, img_train) 86 | loss_burst = burst(out_train, img_train) 87 | m = [5, 5, 0] 88 | loss = torch.div(m[0] * loss_color.cuda() + m[1] * loss_content.cuda() + m[2] * loss_burst.cuda(), 10) 89 | loss.backward() 90 | optimizer.step() 91 | # ''' 92 | # if you are using older version of PyTorch, you may need to change loss.item() to loss.data[0] 93 | if i % int(len(loader_train)//5) == 0: 94 | # the end of each epoch 95 | model.eval() 96 | # validate 97 | psnr_val = 0 98 | for _, (imgn_val, img_val) in enumerate(loader_val, 0): 99 | with torch.no_grad(): 100 | img_val, imgn_val = Variable(img_val.cuda()), Variable(imgn_val.cuda()) 101 | out_val = torch.clamp(model(imgn_val), 0., 1.) 102 | psnr_val += batch_PSNR(out_val, img_val, 1.) 103 | psnr_val /= len(dataset_val) 104 | now = datetime.now() 105 | print("[epoch %d][%d/%d] loss: %.6f PSNR_val: %.4f" % 106 | (epoch+1, i+1, len(loader_train), loss.item(), psnr_val), end=' ') 107 | print(now.year, now.month, now.day, now.hour, now.minute, now.second) 108 | if psnr_val > psnr_max or loss < loss_min: 109 | psnr_max = psnr_val 110 | loss_min = loss 111 | torch.save(model.state_dict(), os.path.join(opt.outf, 'net_' + str(round(psnr_val, 4)) + '.pth')) 112 | # ''' 113 | torch.save(model.state_dict(), os.path.join(opt.outf, 'net_' + str(round(psnr_val, 4)) + '.pth')) 114 | 115 | 116 | if __name__ == "__main__": 117 | if opt.preprocess: 118 | if opt.mode == 'S': 119 | prepare_data(data_path='data', patch_size=opt.patch, stride=opt.patch, aug_times=1) 120 | if opt.mode == 'B': 121 | prepare_data(data_path='data', patch_size=50, stride=10, aug_times=2) 122 | main() 123 | -------------------------------------------------------------------------------- /Burst/test.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | import argparse 4 | import glob 5 | import time 6 | from torch.autograd import Variable 7 | from models import * 8 | from utils import * 9 | 10 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 11 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 12 | 13 | parser = argparse.ArgumentParser(description="DnCNN_Test") 14 | parser.add_argument("--logdir", type=str, default="/home/user/depthMap/ksm/CVPR/demoire/logs", help='path of log files') 15 | opt = parser.parse_args() 16 | 17 | 18 | def normalize(data): 19 | return data/255. 20 | 21 | 22 | def self_ensemble(out, mode, forward): 23 | if mode == 0: 24 | # original 25 | out = out 26 | elif mode == 1: 27 | # flip up and down 28 | out = np.flipud(out) 29 | elif mode == 2: 30 | # rotate counterwise 90 degree 31 | if forward == 1: 32 | out = np.rot90(out) 33 | else: 34 | out = np.rot90(out, k=3) 35 | elif mode == 3: 36 | # rotate 90 degree and flip up and down 37 | if forward == 1: 38 | out = np.rot90(out) 39 | out = np.flipud(out) 40 | else: 41 | out = np.flipud(out) 42 | out = np.rot90(out, k=3) 43 | elif mode == 4: 44 | # rotate 180 degree 45 | out = np.rot90(out, k=2) 46 | elif mode == 5: 47 | # rotate 180 degree and flip 48 | if forward == 1: 49 | out = np.rot90(out, k=2) 50 | out = np.flipud(out) 51 | else: 52 | out = np.flipud(out) 53 | out = np.rot90(out, k=2) 54 | elif mode == 6: 55 | if forward == 1: 56 | out = np.rot90(out, k=3) 57 | else: 58 | out = np.rot90(out) 59 | elif mode == 7: 60 | # rotate 270 degree and flip 61 | if forward == 1: 62 | out = np.rot90(out, k=3) 63 | out = np.flipud(out) 64 | else: 65 | out = np.flipud(out) 66 | out = np.rot90(out) 67 | return out 68 | 69 | 70 | def main(model=0): 71 | # Build model 72 | 73 | print('Loading model ...\n') 74 | model = Net().cuda() 75 | # device_ids = [0] 76 | # model = nn.DataParallel(net, device_ids=device_ids).cuda() 77 | a = torch.load(glob.glob(os.path.join(opt.logdir, '*.pth'))[0]) 78 | print(glob.glob(os.path.join(opt.logdir, '*.pth'))[0]) 79 | ok = input("Right model? ") 80 | if ok == 'n': 81 | return 82 | model.load_state_dict(a) 83 | 84 | # DHDN_flag = 4 85 | frame = 7 86 | ensemble_flag = 1 87 | model.eval() 88 | # load data info 89 | print('Loading data info ...\n') 90 | files_source = glob.glob(os.path.join('/home/user/depthMap/ksm/CVPR/demoire', 'ValidationInput', '*.png')) 91 | files_source.sort() 92 | # process data 93 | psnr_test = 0 94 | c = 0 95 | for f in range(len(files_source)//frame): 96 | # image 97 | start = time.time() 98 | ISource = [] 99 | # final = np.zeros(cv2.imread(f).shape) 100 | origin = cv2.imread(files_source[f * frame + 3]) 101 | for mode in range(ensemble_flag): 102 | for im in range(frame): 103 | data = cv2.imread(files_source[f * frame + im]) 104 | if im != 3: 105 | _, bin2 = cv2.threshold(data, 50, 255, cv2.THRESH_BINARY) 106 | _, bin3 = cv2.threshold(data, 50, 255, cv2.THRESH_BINARY_INV) 107 | final2 = cv2.bitwise_and(data, bin2, mask=None) 108 | final3 = cv2.bitwise_and(origin, bin3, mask=None) 109 | data = cv2.bitwise_or(final3, final2, mask=None) 110 | data = np.float32(normalize(data)) 111 | data = np.transpose(data, (2, 0, 1)) 112 | data = torch.Tensor(data).unsqueeze(0) 113 | ISource.append(data) 114 | """ 115 | data = cv2.imread(files_source[f * frame + 3]) 116 | data = np.float32(normalize(data)) 117 | data = np.transpose(data, (2, 0, 1)) 118 | data = torch.Tensor(data).unsqueeze(0) 119 | ISource.append(data) 120 | """ 121 | ISource = torch.cat(ISource, 0) 122 | """ 123 | hh, ww, cc = Img.shape 124 | for ch in range(cc): 125 | pl = Img[:, :, ch] 126 | Img[:, :, ch] = self_ensemble(pl, mode, 1) 127 | Img = np.swapaxes(Img, 0, 2) 128 | Img = np.swapaxes(Img, 1, 2) 129 | Img = np.float32(normalize(Img)) 130 | a = Img.shape[1] 131 | b = Img.shape[2] 132 | if a % DHDN_flag != 0 or b % DHDN_flag != 0: 133 | h = DHDN_flag - (a % DHDN_flag) 134 | w = DHDN_flag - (b % DHDN_flag) 135 | Img = np.pad(Img, [(0, 0), (h//2, h-h//2), (w//2, w-w//2)], mode='edge') 136 | Img = np.expand_dims(Img, 0) 137 | ISource = torch.Tensor(Img) 138 | """ 139 | INoisy = Variable(ISource.unsqueeze(0).cuda()) 140 | # print(INoisy.size()) 141 | with torch.no_grad(): # this can save much memory 142 | Out = torch.clamp(model(INoisy), 0., 1.) 143 | """ 144 | if a % DHDN_flag != 0 or b % DHDN_flag != 0: 145 | h = DHDN_flag - (a % DHDN_flag) 146 | w = DHDN_flag - (b % DHDN_flag) 147 | Out = Out[:, :, h//2:Img.shape[0]-(h-h//2+1), w//2:Img.shape[1]-(w-w//2+1)] 148 | """ 149 | c = f 150 | name = str(c) 151 | if str(c) != 6: 152 | for i in range(6 - len(str(c))): 153 | name = '0' + name 154 | out = Out.squeeze(0).permute(1, 2, 0) * 255 155 | out = out.cpu().numpy() 156 | """ 157 | for ch in range(cc): 158 | out[:, :, ch] = self_ensemble(out[:, :, ch], mode, 0) 159 | final += out 160 | """ 161 | cv2.imwrite("/home/user/depthMap/ksm/CVPR/demoire/" + name + "_gt.png", out/ensemble_flag) 162 | mytime = time.time() - start 163 | psnr_test += mytime 164 | print("%s" % f) 165 | c += 1 166 | psnr_test /= (len(files_source)//frame) 167 | print("\nRuntime on test data %.2f" % psnr_test) 168 | 169 | 170 | if __name__ == "__main__": 171 | main() 172 | -------------------------------------------------------------------------------- /Single/ssim.py: -------------------------------------------------------------------------------- 1 | from math import exp 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torchvision.models as models 6 | from torch.autograd import Variable 7 | from kornia.color import rgb_to_yuv 8 | 9 | 10 | def gaussian(window_size, sigma): 11 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 12 | return gauss/gauss.sum() 13 | 14 | 15 | def create_window(window_size, channel=1): 16 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 17 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 18 | window = _2D_window.expand(channel, 1, window_size, window_size).contiguous() 19 | return window 20 | 21 | 22 | def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None): 23 | # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh). 24 | if val_range is None: 25 | if torch.max(img1) > 128: 26 | max_val = 255 27 | else: 28 | max_val = 1 29 | 30 | if torch.min(img1) < -0.5: 31 | min_val = -1 32 | else: 33 | min_val = 0 34 | L = max_val - min_val 35 | else: 36 | L = val_range 37 | 38 | padd = 0 39 | (_, channel, height, width) = img1.size() 40 | if window is None: 41 | real_size = min(window_size, height, width) 42 | window = create_window(real_size, channel=channel).to(img1.device) 43 | 44 | mu1 = F.conv2d(img1, window, padding=padd, groups=channel) 45 | mu2 = F.conv2d(img2, window, padding=padd, groups=channel) 46 | 47 | mu1_sq = mu1.pow(2) 48 | mu2_sq = mu2.pow(2) 49 | mu1_mu2 = mu1 * mu2 50 | 51 | sigma1_sq = F.conv2d(img1 * img1, window, padding=padd, groups=channel) - mu1_sq 52 | sigma2_sq = F.conv2d(img2 * img2, window, padding=padd, groups=channel) - mu2_sq 53 | sigma12 = F.conv2d(img1 * img2, window, padding=padd, groups=channel) - mu1_mu2 54 | 55 | C1 = (0.01 * L) ** 2 56 | C2 = (0.03 * L) ** 2 57 | 58 | v1 = 2.0 * sigma12 + C2 59 | v2 = sigma1_sq + sigma2_sq + C2 60 | cs = torch.mean(v1 / v2) # contrast sensitivity 61 | 62 | ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2) 63 | 64 | if size_average: 65 | ret = ssim_map.mean() 66 | else: 67 | ret = ssim_map.mean(1).mean(1).mean(1) 68 | 69 | if full: 70 | return ret, cs 71 | return ret 72 | 73 | 74 | def msssim(img1, img2, window_size=11, size_average=True, val_range=None, normalize=False): 75 | device = img1.device 76 | weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device) 77 | levels = weights.size()[0] 78 | mssim = [] 79 | mcs = [] 80 | for _ in range(levels): 81 | sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range) 82 | mssim.append(sim) 83 | mcs.append(cs) 84 | 85 | img1 = F.avg_pool2d(img1, (2, 2)) 86 | img2 = F.avg_pool2d(img2, (2, 2)) 87 | 88 | mssim = torch.stack(mssim) 89 | mcs = torch.stack(mcs) 90 | 91 | # Normalize (to avoid NaNs during training unstable models, not compliant with original definition) 92 | if normalize: 93 | mssim = (mssim + 1) / 2 94 | mcs = (mcs + 1) / 2 95 | 96 | pow1 = mcs ** weights 97 | pow2 = mssim ** weights 98 | # From Matlab implementation https://ece.uwaterloo.ca/~z70wang/research/iwssim/ 99 | output = torch.prod(pow1[:-1] * pow2[-1]) 100 | return output 101 | 102 | 103 | # Classes to re-use window 104 | class SSIM(torch.nn.Module): 105 | def __init__(self, window_size=11, size_average=True, val_range=None): 106 | super(SSIM, self).__init__() 107 | self.window_size = window_size 108 | self.size_average = size_average 109 | self.val_range = val_range 110 | 111 | # Assume 1 channel for SSIM 112 | self.channel = 1 113 | self.window = create_window(window_size) 114 | 115 | def forward(self, img1, img2): 116 | (_, channel, _, _) = img1.size() 117 | 118 | if channel == self.channel and self.window.dtype == img1.dtype: 119 | window = self.window 120 | else: 121 | window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype) 122 | self.window = window 123 | self.channel = channel 124 | 125 | return ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average) 126 | 127 | 128 | class MSSSIM(torch.nn.Module): 129 | def __init__(self, window_size=11, size_average=True, channel=3): 130 | super(MSSSIM, self).__init__() 131 | self.window_size = window_size 132 | self.size_average = size_average 133 | self.channel = channel 134 | 135 | def forward(self, img1, img2): 136 | # TODO: store window between calls if possible 137 | return msssim(img1, img2, window_size=self.window_size, size_average=self.size_average) 138 | 139 | 140 | class MeanShift(nn.Conv2d): 141 | def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1): 142 | super(MeanShift, self).__init__(3, 3, kernel_size=1) 143 | std = torch.Tensor(rgb_std) 144 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) 145 | self.weight.data.div_(std.view(3, 1, 1, 1)) 146 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) 147 | self.bias.data.div_(std) 148 | self.requires_grad = False 149 | 150 | 151 | class VGG(torch.nn.Module): 152 | def __init__(self, conv_index, rgb_range=1): 153 | super(VGG, self).__init__() 154 | vgg_features = models.vgg19(pretrained=True).features 155 | modules = [m for m in vgg_features] 156 | if conv_index == '22': 157 | self.vgg = nn.Sequential(*modules[:8]) 158 | elif conv_index == '54': 159 | self.vgg = nn.Sequential(*modules[:35]) 160 | 161 | vgg_mean = (0.485, 0.456, 0.406) 162 | vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range) 163 | self.sub_mean = MeanShift(rgb_range, vgg_mean, vgg_std) 164 | self.vgg.requires_grad = False 165 | 166 | def forward(self, sr, hr): 167 | def _forward(x): 168 | x = self.sub_mean(x) 169 | x = self.vgg(x) 170 | return x 171 | 172 | vgg_sr = _forward(sr) 173 | with torch.no_grad(): 174 | vgg_hr = _forward(hr.detach()) 175 | 176 | loss = F.l1_loss(vgg_sr, vgg_hr) 177 | 178 | return loss 179 | 180 | 181 | def color_loss(out, target): 182 | out_yuv = rgb_to_yuv(out) 183 | # out_y = out_yuv[:, 0, :, :] 184 | out_u = out_yuv[:, 1, :, :] 185 | out_v = out_yuv[:, 2, :, :] 186 | target_yuv = rgb_to_yuv(target) 187 | # target_y = target_yuv[:, 0, :, :] 188 | target_u = target_yuv[:, 1, :, :] 189 | target_v = target_yuv[:, 2, :, :] 190 | 191 | return torch.div( 192 | torch.mean((out_u - target_u).pow(1)).abs() + torch.mean((out_v - target_v).pow(1)).abs(), 2) 193 | -------------------------------------------------------------------------------- /Single/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import numpy as np 4 | import random 5 | import h5py 6 | import torch 7 | import cv2 8 | import glob 9 | import torch.utils.data as udata 10 | from utils import data_augmentation 11 | 12 | 13 | def normalize(data): 14 | return data/255. 15 | 16 | 17 | def Im2Patch(img, win, stride=1): 18 | k = 0 19 | endc = img.shape[0] 20 | endh = img.shape[1] 21 | endw = img.shape[2] 22 | patch = img[:, 0:endh-win+0+1:stride, 0:endw-win+0+1:stride] 23 | TotalPatNum = patch.shape[1] * patch.shape[2] 24 | Y = np.zeros([endc, win*win, TotalPatNum], np.float32) 25 | for i in range(win): 26 | for j in range(win): 27 | patch = img[:, i:endh-win+i+1:stride, j:endw-win+j+1:stride] 28 | Y[:, k, :] = np.array(patch[:]).reshape(endc, TotalPatNum) 29 | k = k + 1 30 | return Y.reshape([endc, win, win, TotalPatNum]) 31 | 32 | 33 | def prepare_data(data_path, patch_size, stride, aug_times=1): 34 | # ''' 35 | # train 36 | print('process training data') 37 | scales = [1] 38 | files = glob.glob(os.path.join('D:', 'NH-HAZE_train', 'HAZY', '*.png')) 39 | # mix = list(range(len(files))) 40 | # random.shuffle(mix) 41 | # mix_train = mix[:int(len(files)*0.96)] 42 | # mix_val = mix[int(len(files)*0.96):] 43 | files.sort() 44 | h5f = h5py.File('D:/train_input.h5', 'w') 45 | train_num = 0 46 | for i in range(len(files)): 47 | Img = cv2.imread(files[i]) 48 | h, w, c = Img.shape 49 | for k in range(len(scales)): 50 | # Img = cv2.resize(img, (int(h*scales[k]), int(w*scales[k])), interpolation=cv2.INTER_CUBIC) 51 | # Img = np.expand_dims(Img[:, :, :].copy(), 0) 52 | Img = np.swapaxes(Img, 0, 2) 53 | Img = np.swapaxes(Img, 1, 2) 54 | Img = np.float32(normalize(Img)) 55 | # print(Img.shape) 56 | patches = Im2Patch(Img, patch_size, stride) 57 | # print(i) 58 | print("file: %s scale %.1f # samples: %d" % (files[i], scales[k], aug_times*patches.shape[3])) 59 | for n in range(patches.shape[3]): 60 | data = patches[:, :, :, n].copy() 61 | # print(data.shape) 62 | h5f.create_dataset(str(train_num), data=data) 63 | train_num += 1 64 | for m in range(aug_times-1): 65 | data_aug = data_augmentation(data, np.random.randint(1, 8)) 66 | h5f.create_dataset(str(train_num)+"_aug_%d" % (m+1), data=data_aug) 67 | train_num += 1 68 | h5f.close() 69 | print('process training gt') 70 | scales = [1] 71 | files = glob.glob(os.path.join('D:', 'NH-HAZE_train', 'GT', '*.png')) 72 | files.sort() 73 | h5f = h5py.File('D:/train_gt.h5', 'w') 74 | train_num = 0 75 | for i in range(len(files)): 76 | Img = cv2.imread(files[i]) 77 | h, w, c = Img.shape 78 | for k in range(len(scales)): 79 | # Img = cv2.resize(img, (int(h*scales[k]), int(w*scales[k])), interpolation=cv2.INTER_CUBIC) 80 | # Img = np.expand_dims(Img[:, :, :].copy(), 0) 81 | Img = np.swapaxes(Img, 0, 2) 82 | Img = np.swapaxes(Img, 1, 2) 83 | Img = np.float32(normalize(Img)) 84 | patches = Im2Patch(Img, patch_size, stride) 85 | # print(i) 86 | print("file: %s scale %.1f # samples: %d" % (files[i], scales[k], aug_times*patches.shape[3])) 87 | for n in range(patches.shape[3]): 88 | data = patches[:, :, :, n].copy() 89 | # print(data.shape) 90 | h5f.create_dataset(str(train_num), data=data) 91 | train_num += 1 92 | for m in range(aug_times-1): 93 | data_aug = data_augmentation(data, np.random.randint(1, 8)) 94 | h5f.create_dataset(str(train_num)+"_aug_%d" % (m+1), data=data_aug) 95 | train_num += 1 96 | h5f.close() 97 | # val 98 | print('\nprocess validation data') 99 | # files.clear() 100 | files = glob.glob(os.path.join('D:', 'NH-HAZE_validation', 'HAZY', '*.png')) 101 | files.sort() 102 | h5f = h5py.File('D:/val_input.h5', 'w') 103 | val_num = 0 104 | for i in range(len(files)): 105 | print("file: %s" % files[i]) 106 | img = cv2.imread(files[i]) 107 | # img = np.expand_dims(img[:, :, :], 0) 108 | img = np.swapaxes(img, 0, 2) 109 | img = np.swapaxes(img, 1, 2) 110 | img = np.float32(normalize(img)) 111 | # print(i) 112 | # print(img.shape) 113 | h5f.create_dataset(str(val_num), data=img) 114 | val_num += 1 115 | h5f.close() 116 | # ''' 117 | print('\nprocess validation gt') 118 | # files.clear() 119 | files = glob.glob(os.path.join('D:', 'NH-HAZE_validation', 'GT', '*.png')) 120 | files.sort() 121 | h5f = h5py.File('D:/val_gt.h5', 'w') 122 | val_num = 0 123 | for i in range(len(files)): 124 | print("file: %s" % files[i]) 125 | img = cv2.imread(files[i]) 126 | # img = np.expand_dims(img[:, :, :], 0) 127 | img = np.swapaxes(img, 0, 2) 128 | img = np.swapaxes(img, 1, 2) 129 | img = np.float32(normalize(img)) 130 | # print(i) 131 | # print(img.shape) 132 | h5f.create_dataset(str(val_num), data=img) 133 | val_num += 1 134 | h5f.close() 135 | # print('training set, # samples %d\n' % train_num) 136 | print('val set, # samples %d\n' % val_num) 137 | # ''' 138 | 139 | 140 | class Dataset(udata.Dataset): 141 | def __init__(self, train=True): 142 | super(Dataset, self).__init__() 143 | self.train = train 144 | if self.train: 145 | h5f = h5py.File('/home/user/depthMap/ksm/AIM/demoire/data/train_input.h5', 'r') 146 | h5f_gt = h5py.File('/home/user/depthMap/ksm/AIM/demoire/data/train_gt.h5', 'r') 147 | self.keys = list(h5f.keys()) 148 | self.keys_gt = list(h5f_gt.keys()) 149 | h5f.close() 150 | h5f_gt.close() 151 | else: 152 | h5f = h5py.File('/home/user/depthMap/ksm/AIM/demoire/data/val_input.h5', 'r') 153 | h5f_gt = h5py.File('/home/user/depthMap/ksm/AIM/demoire/data/val_gt.h5', 'r') 154 | self.keys = list(h5f.keys()) 155 | self.keys_gt = list(h5f_gt.keys()) 156 | h5f.close() 157 | h5f_gt.close() 158 | 159 | def __len__(self): 160 | return len(self.keys) 161 | 162 | def __getitem__(self, index): 163 | if self.train: 164 | h5f = h5py.File('/home/user/depthMap/ksm/AIM/demoire/data/train_input.h5', 'r') 165 | h5f_gt = h5py.File('/home/user/depthMap/ksm/AIM/demoire/data/train_gt.h5', 'r') 166 | key = self.keys[index] 167 | key_gt = self.keys_gt[index] 168 | data = np.array(h5f[key]) 169 | gt = np.array(h5f_gt[key_gt]) 170 | h5f.close() 171 | h5f_gt.close() 172 | return torch.Tensor(data), torch.Tensor(gt) 173 | else: 174 | h5f = h5py.File('/home/user/depthMap/ksm/AIM/demoire/data/val_input.h5', 'r') 175 | h5f_gt = h5py.File('/home/user/depthMap/ksm/AIM/demoire/data/val_gt.h5', 'r') 176 | key = self.keys[index] 177 | key_gt = self.keys_gt[index] 178 | data = np.array(h5f[key]) 179 | gt = np.array(h5f_gt[key_gt]) 180 | h5f.close() 181 | h5f_gt.close() 182 | return torch.Tensor(data), torch.Tensor(gt) 183 | -------------------------------------------------------------------------------- /Single/models.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from math import sqrt 3 | import torch 4 | 5 | 6 | class DnCNN(nn.Module): 7 | def __init__(self, channels, num_of_layers=17): 8 | super(DnCNN, self).__init__() 9 | kernel_size = 3 10 | padding = 1 11 | features = 64 12 | layers = [] 13 | layers.append(nn.Conv2d(in_channels=channels, out_channels=features, kernel_size=kernel_size, padding=padding, bias=False)) 14 | layers.append(nn.ReLU(inplace=True)) 15 | for _ in range(num_of_layers-2): 16 | layers.append(nn.Conv2d(in_channels=features, out_channels=features, kernel_size=kernel_size, padding=padding, bias=False)) 17 | layers.append(nn.BatchNorm2d(features)) 18 | layers.append(nn.ReLU(inplace=True)) 19 | layers.append(nn.Conv2d(in_channels=features, out_channels=channels, kernel_size=kernel_size, padding=padding, bias=False)) 20 | self.dncnn = nn.Sequential(*layers) 21 | # weights initialization 22 | for m in self.modules(): 23 | if isinstance(m, nn.Conv2d): 24 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 25 | m.weight.data.normal_(0, sqrt(2. / n)) 26 | 27 | def forward(self, x): 28 | out = self.dncnn(x) 29 | return out 30 | 31 | 32 | class CA(nn.Module): 33 | def __init__(self, channel, reduction=16): 34 | super(CA, self).__init__() 35 | # global average pooling: feature --> point 36 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 37 | # feature channel downscale and upscale --> channel weight 38 | self.conv_du = nn.Sequential( 39 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), 40 | nn.ReLU(inplace=True), 41 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), 42 | nn.Sigmoid() 43 | ) 44 | 45 | def forward(self, x): 46 | y = self.avg_pool(x) 47 | y = self.conv_du(y) 48 | return x * y 49 | 50 | 51 | class RB(nn.Module): 52 | def __init__(self, features): 53 | super(RB, self).__init__() 54 | layers = [] 55 | kernel_size = 3 56 | for _ in range(1): 57 | layers.append(nn.Conv2d(in_channels=features, out_channels=features 58 | , kernel_size=kernel_size, padding=kernel_size//2, bias=True)) 59 | layers.append(nn.PReLU()) 60 | layers.append(nn.Conv2d(in_channels=features, out_channels=features 61 | , kernel_size=kernel_size, padding=kernel_size//2, bias=True)) 62 | self.res = nn.Sequential(*layers) 63 | self.ca = CA(features) 64 | 65 | def forward(self, x): 66 | out = self.res(x) 67 | out = self.ca(out) 68 | out += x 69 | return out 70 | 71 | 72 | class _down(nn.Module): 73 | def __init__(self, channel_in): 74 | super(_down, self).__init__() 75 | 76 | self.conv = nn.Conv2d(in_channels=channel_in, out_channels=2 * channel_in, kernel_size=3, stride=2, padding=1) 77 | self.relu = nn.PReLU() 78 | 79 | def forward(self, x): 80 | 81 | out = self.relu(self.conv(x)) 82 | 83 | return out 84 | 85 | 86 | class _up(nn.Module): 87 | def __init__(self, channel_in): 88 | super(_up, self).__init__() 89 | 90 | self.conv = nn.PixelShuffle(2) 91 | self.relu = nn.PReLU() 92 | 93 | def forward(self, x): 94 | 95 | out = self.relu(self.conv(x)) 96 | 97 | return out 98 | 99 | 100 | class AB(nn.Module): 101 | def __init__(self, features): 102 | super(AB, self).__init__() 103 | 104 | num = 2 105 | self.DCR_block1 = self.make_layer(RB, features, num) 106 | self.down1 = self.make_layer(_down, features, 1) 107 | self.DCR_block2 = self.make_layer(RB, features * 2, num) 108 | self.down2 = self.make_layer(_down, features * 2, 1) 109 | self.DCR_block3 = self.make_layer(RB, features * 4, num) 110 | self.up2 = self.make_layer(_up, features * 8, 1) 111 | self.DCR_block22 = self.make_layer(RB, features * 4, num) 112 | self.up1 = self.make_layer(_up, features * 4, 1) 113 | self.DCR_block11 = self.make_layer(RB, features * 2, num) 114 | self.conv_f = nn.Conv2d(in_channels=features * 2, out_channels=features, kernel_size=1, stride=1, padding=0) 115 | self.relu2 = nn.PReLU() 116 | 117 | def make_layer(self, block, channel_in, num): 118 | layers = [] 119 | for _ in range(num): 120 | layers.append(block(channel_in)) 121 | return nn.Sequential(*layers) 122 | 123 | def forward(self, x): 124 | 125 | conc1 = self.DCR_block1(x) 126 | out = self.down1(conc1) 127 | 128 | conc2 = self.DCR_block2(out) 129 | conc3 = self.down2(conc2) 130 | 131 | out = self.DCR_block3(conc3) 132 | out = torch.cat([conc3, out], 1) 133 | 134 | out = self.up2(out) 135 | out = torch.cat([conc2, out], 1) 136 | out = self.DCR_block22(out) 137 | 138 | out = self.up1(out) 139 | out = torch.cat([conc1, out], 1) 140 | out = self.DCR_block11(out) 141 | 142 | out = self.relu2(self.conv_f(out)) 143 | out += x 144 | 145 | return out 146 | 147 | 148 | class GAB(nn.Module): 149 | def __init__(self, features): 150 | super(GAB, self).__init__() 151 | 152 | kernel_size = 3 153 | self.res1 = self.make_layer(RB, features, 2) 154 | self.R = 2 155 | self.A = 1 156 | self.RB = nn.ModuleList() 157 | for _ in range(self.R): 158 | self.RB.append(RB(features)) 159 | self.AB = nn.ModuleList() 160 | for _ in range(self.A): 161 | self.AB.append(AB(features)) 162 | self.GFF_R = nn.Sequential( 163 | nn.Conv2d(self.R * features, features, kernel_size=1, padding=0, stride=1), 164 | nn.Conv2d(features, features, kernel_size, padding=kernel_size//2, stride=1), 165 | ) 166 | self.GFF_A = nn.Sequential( 167 | nn.Conv2d(self.A * features, features, kernel_size=1, padding=0, stride=1), 168 | nn.Conv2d(features, features, kernel_size, padding=kernel_size//2, stride=1), 169 | ) 170 | self.softmax = nn.Sigmoid() 171 | self.res2 = self.make_layer(RB, features * 2, 2) 172 | self.conv_f = nn.Conv2d(in_channels=features * 2, out_channels=features, kernel_size=1, stride=1, padding=0) 173 | self.relu2 = nn.PReLU() 174 | 175 | def make_layer(self, block, channel_in, num): 176 | layers = [] 177 | for _ in range(num): 178 | layers.append(block(channel_in)) 179 | return nn.Sequential(*layers) 180 | 181 | def forward(self, x): 182 | out = self.res1(x) 183 | 184 | RB_outs = [] 185 | for i in range(self.R): 186 | outR = self.RB[i](out) 187 | RB_outs.append(outR) 188 | outR = torch.cat(RB_outs, 1) 189 | outR = self.GFF_R(outR) 190 | outR += out 191 | AB_outs = [] 192 | for i in range(self.A): 193 | outA = self.AB[i](out) 194 | AB_outs.append(outA) 195 | outA = torch.cat(AB_outs, 1) 196 | outA = self.GFF_A(outA) 197 | outA += out 198 | # outR *= self.softmax(outA) 199 | out = torch.cat([outR, outA], 1) 200 | out = self.relu2(self.conv_f(self.res2(out))) 201 | out *= 0.2 202 | out += x 203 | 204 | return out 205 | 206 | 207 | class Net(nn.Module): 208 | def __init__(self, features=64): 209 | super(Net, self).__init__() 210 | 211 | kernel_size = 3 212 | self.conv_i = nn.Conv2d(in_channels=3, out_channels=features, kernel_size=1, stride=1, padding=0) 213 | self.relu1 = nn.PReLU() 214 | self.GA = 22 215 | self.GAB = nn.ModuleList() 216 | for _ in range(self.GA): 217 | self.GAB.append(GAB(features)) 218 | self.GFF_GA = nn.Sequential( 219 | nn.Conv2d(self.GA * features, features, kernel_size=1, padding=0, stride=1), 220 | nn.Conv2d(features, features, kernel_size, padding=kernel_size//2, stride=1), 221 | ) 222 | self.conv_f = nn.Conv2d(in_channels=features, out_channels=3, kernel_size=1, stride=1, padding=0) 223 | self.relu2 = nn.PReLU() 224 | 225 | def make_layer(self, block, channel_in, num): 226 | layers = [] 227 | for _ in range(num): 228 | layers.append(block(channel_in)) 229 | return nn.Sequential(*layers) 230 | 231 | def forward(self, x): 232 | out = self.relu1(self.conv_i(x)) 233 | 234 | GAB_outs = [] 235 | for i in range(self.GA): 236 | out = self.GAB[i](out) 237 | GAB_outs.append(out) 238 | out = torch.cat(GAB_outs, 1) 239 | out = self.GFF_GA(out) 240 | 241 | out = self.relu2(self.conv_f(out)) 242 | out += x 243 | 244 | return out 245 | -------------------------------------------------------------------------------- /Burst/ssim.py: -------------------------------------------------------------------------------- 1 | from math import exp 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torchvision.models as models 6 | from kornia.color import rgb_to_yuv 7 | from torch.nn.modules.loss import _Loss 8 | import numpy as np 9 | 10 | 11 | def gaussian(window_size, sigma): 12 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 13 | return gauss/gauss.sum() 14 | 15 | 16 | def create_window(window_size, channel=1): 17 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 18 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 19 | window = _2D_window.expand(channel, 1, window_size, window_size).contiguous() 20 | return window 21 | 22 | 23 | def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None): 24 | # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh). 25 | if val_range is None: 26 | if torch.max(img1) > 128: 27 | max_val = 255 28 | else: 29 | max_val = 1 30 | 31 | if torch.min(img1) < -0.5: 32 | min_val = -1 33 | else: 34 | min_val = 0 35 | L = max_val - min_val 36 | else: 37 | L = val_range 38 | 39 | padd = 0 40 | (_, channel, height, width) = img1.size() 41 | if window is None: 42 | real_size = min(window_size, height, width) 43 | window = create_window(real_size, channel=channel).to(img1.device) 44 | 45 | mu1 = F.conv2d(img1, window, padding=padd, groups=channel) 46 | mu2 = F.conv2d(img2, window, padding=padd, groups=channel) 47 | 48 | mu1_sq = mu1.pow(2) 49 | mu2_sq = mu2.pow(2) 50 | mu1_mu2 = mu1 * mu2 51 | 52 | sigma1_sq = F.conv2d(img1 * img1, window, padding=padd, groups=channel) - mu1_sq 53 | sigma2_sq = F.conv2d(img2 * img2, window, padding=padd, groups=channel) - mu2_sq 54 | sigma12 = F.conv2d(img1 * img2, window, padding=padd, groups=channel) - mu1_mu2 55 | 56 | C1 = (0.01 * L) ** 2 57 | C2 = (0.03 * L) ** 2 58 | 59 | v1 = 2.0 * sigma12 + C2 60 | v2 = sigma1_sq + sigma2_sq + C2 61 | cs = torch.mean(v1 / v2) # contrast sensitivity 62 | 63 | ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2) 64 | 65 | if size_average: 66 | ret = ssim_map.mean() 67 | else: 68 | ret = ssim_map.mean(1).mean(1).mean(1) 69 | 70 | if full: 71 | return ret, cs 72 | return ret 73 | 74 | 75 | def msssim(img1, img2, window_size=11, size_average=True, val_range=None, normalize=False): 76 | device = img1.device 77 | weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device) 78 | levels = weights.size()[0] 79 | mssim = [] 80 | mcs = [] 81 | for _ in range(levels): 82 | sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range) 83 | mssim.append(sim) 84 | mcs.append(cs) 85 | 86 | img1 = F.avg_pool2d(img1, (2, 2)) 87 | img2 = F.avg_pool2d(img2, (2, 2)) 88 | 89 | mssim = torch.stack(mssim) 90 | mcs = torch.stack(mcs) 91 | 92 | # Normalize (to avoid NaNs during training unstable models, not compliant with original definition) 93 | if normalize: 94 | mssim = (mssim + 1) / 2 95 | mcs = (mcs + 1) / 2 96 | 97 | pow1 = mcs ** weights 98 | pow2 = mssim ** weights 99 | # From Matlab implementation https://ece.uwaterloo.ca/~z70wang/research/iwssim/ 100 | output = torch.prod(pow1[:-1] * pow2[-1]) 101 | return output 102 | 103 | 104 | # Classes to re-use window 105 | class SSIM(torch.nn.Module): 106 | def __init__(self, window_size=11, size_average=True, val_range=None): 107 | super(SSIM, self).__init__() 108 | self.window_size = window_size 109 | self.size_average = size_average 110 | self.val_range = val_range 111 | 112 | # Assume 1 channel for SSIM 113 | self.channel = 1 114 | self.window = create_window(window_size) 115 | 116 | def forward(self, img1, img2): 117 | (_, channel, _, _) = img1.size() 118 | 119 | if channel == self.channel and self.window.dtype == img1.dtype: 120 | window = self.window 121 | else: 122 | window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype) 123 | self.window = window 124 | self.channel = channel 125 | 126 | return ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average) 127 | 128 | 129 | class MSSSIM(torch.nn.Module): 130 | def __init__(self, window_size=11, size_average=True, channel=3): 131 | super(MSSSIM, self).__init__() 132 | self.window_size = window_size 133 | self.size_average = size_average 134 | self.channel = channel 135 | 136 | def forward(self, img1, img2): 137 | # TODO: store window between calls if possible 138 | return msssim(img1, img2, window_size=self.window_size, size_average=self.size_average) 139 | 140 | 141 | class MeanShift(nn.Conv2d): 142 | def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1): 143 | super(MeanShift, self).__init__(3, 3, kernel_size=1) 144 | std = torch.Tensor(rgb_std) 145 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) 146 | self.weight.data.div_(std.view(3, 1, 1, 1)) 147 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) 148 | self.bias.data.div_(std) 149 | self.requires_grad = False 150 | 151 | 152 | class VGG(torch.nn.Module): 153 | def __init__(self, conv_index, rgb_range=1): 154 | super(VGG, self).__init__() 155 | vgg_features = models.vgg19(pretrained=True).features 156 | modules = [m for m in vgg_features] 157 | if conv_index == '22': 158 | self.vgg = nn.Sequential(*modules[:8]) 159 | elif conv_index == '54': 160 | self.vgg = nn.Sequential(*modules[:35]) 161 | 162 | vgg_mean = (0.485, 0.456, 0.406) 163 | vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range) 164 | self.sub_mean = MeanShift(rgb_range, vgg_mean, vgg_std) 165 | self.vgg.requires_grad = False 166 | 167 | def forward(self, sr, hr): 168 | def _forward(x): 169 | x = self.sub_mean(x) 170 | x = self.vgg(x) 171 | return x 172 | 173 | vgg_sr = _forward(sr) 174 | with torch.no_grad(): 175 | vgg_hr = _forward(hr.detach()) 176 | 177 | loss = F.l1_loss(vgg_sr, vgg_hr) 178 | 179 | return loss 180 | 181 | 182 | def color_loss(out, target): 183 | out_yuv = rgb_to_yuv(out) 184 | out_u = out_yuv[:, 1, :, :] 185 | out_v = out_yuv[:, 2, :, :] 186 | target_yuv = rgb_to_yuv(target) 187 | target_u = target_yuv[:, 1, :, :] 188 | target_v = target_yuv[:, 2, :, :] 189 | 190 | return torch.div(torch.mean((out_u - target_u).pow(1)).abs() + torch.mean((out_v - target_v).pow(1)).abs(), 2) 191 | 192 | 193 | class BurstLoss(_Loss): 194 | 195 | def __init__(self, size_average=None, reduce=None, reduction='mean'): 196 | super(BurstLoss, self).__init__(size_average, reduce, reduction) 197 | 198 | self.reduction = reduction 199 | use_cuda = torch.cuda.is_available() 200 | device = torch.device("cuda:0" if use_cuda else "cpu") 201 | 202 | prewitt_filter = 1 / 6 * np.array([[1, 0, -1], 203 | [1, 0, -1], 204 | [1, 0, -1]]) 205 | 206 | self.prewitt_filter_horizontal = torch.nn.Conv2d(in_channels=1, out_channels=1, 207 | kernel_size=prewitt_filter.shape, 208 | padding=prewitt_filter.shape[0] // 2).to(device) 209 | 210 | self.prewitt_filter_horizontal.weight.data.copy_(torch.from_numpy(prewitt_filter).to(device)) 211 | self.prewitt_filter_horizontal.bias.data.copy_(torch.from_numpy(np.array([0.0])).to(device)) 212 | 213 | self.prewitt_filter_vertical = torch.nn.Conv2d(in_channels=1, out_channels=1, 214 | kernel_size=prewitt_filter.shape, 215 | padding=prewitt_filter.shape[0] // 2).to(device) 216 | 217 | self.prewitt_filter_vertical.weight.data.copy_(torch.from_numpy(prewitt_filter.T).to(device)) 218 | self.prewitt_filter_vertical.bias.data.copy_(torch.from_numpy(np.array([0.0])).to(device)) 219 | 220 | def get_gradients(self, img): 221 | img_r = img[:, 0:1, :, :] 222 | img_g = img[:, 1:2, :, :] 223 | img_b = img[:, 2:3, :, :] 224 | 225 | grad_x_r = self.prewitt_filter_horizontal(img_r) 226 | grad_y_r = self.prewitt_filter_vertical(img_r) 227 | grad_x_g = self.prewitt_filter_horizontal(img_g) 228 | grad_y_g = self.prewitt_filter_vertical(img_g) 229 | grad_x_b = self.prewitt_filter_horizontal(img_b) 230 | grad_y_b = self.prewitt_filter_vertical(img_b) 231 | 232 | grad_x = torch.stack([grad_x_r[:, 0, :, :], grad_x_g[:, 0, :, :], grad_x_b[:, 0, :, :]], dim=1) 233 | grad_y = torch.stack([grad_y_r[:, 0, :, :], grad_y_g[:, 0, :, :], grad_y_b[:, 0, :, :]], dim=1) 234 | 235 | grad = torch.stack([grad_x, grad_y], dim=1) 236 | 237 | return grad 238 | 239 | def forward(self, input, target): 240 | input_grad = self.get_gradients(input) 241 | target_grad = self.get_gradients(target) 242 | 243 | return F.l1_loss(input_grad, target_grad, reduction=self.reduction) 244 | -------------------------------------------------------------------------------- /Burst/models.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from math import sqrt 3 | import torch 4 | 5 | 6 | class DnCNN(nn.Module): 7 | def __init__(self, channels, num_of_layers=17): 8 | super(DnCNN, self).__init__() 9 | kernel_size = 3 10 | padding = 1 11 | features = 64 12 | layers = [] 13 | layers.append(nn.Conv2d(in_channels=channels, out_channels=features, kernel_size=kernel_size, padding=padding, bias=False)) 14 | layers.append(nn.ReLU(inplace=True)) 15 | for _ in range(num_of_layers-2): 16 | layers.append(nn.Conv2d(in_channels=features, out_channels=features, kernel_size=kernel_size, padding=padding, bias=False)) 17 | layers.append(nn.BatchNorm2d(features)) 18 | layers.append(nn.ReLU(inplace=True)) 19 | layers.append(nn.Conv2d(in_channels=features, out_channels=channels, kernel_size=kernel_size, padding=padding, bias=False)) 20 | self.dncnn = nn.Sequential(*layers) 21 | # weights initialization 22 | for m in self.modules(): 23 | if isinstance(m, nn.Conv2d): 24 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 25 | m.weight.data.normal_(0, sqrt(2. / n)) 26 | 27 | def forward(self, x): 28 | out = self.dncnn(x) 29 | return out 30 | 31 | 32 | class GlobalMaxPool(torch.nn.Module): 33 | def __init__(self): 34 | super(GlobalMaxPool, self).__init__() 35 | 36 | def forward(self, input): 37 | output = torch.max(input, dim=1)[0] 38 | 39 | return torch.unsqueeze(output, 1) 40 | 41 | 42 | class CA(nn.Module): 43 | def __init__(self, channel, reduction=16): 44 | super(CA, self).__init__() 45 | # global average pooling: feature --> point 46 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 47 | # feature channel downscale and upscale --> channel weight 48 | self.conv_du = nn.Sequential( 49 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), 50 | nn.ReLU(inplace=True), 51 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), 52 | nn.Sigmoid() 53 | ) 54 | 55 | def forward(self, x): 56 | y = self.avg_pool(x) 57 | y = self.conv_du(y) 58 | return x * y 59 | 60 | 61 | class RB(nn.Module): 62 | def __init__(self, features): 63 | super(RB, self).__init__() 64 | layers = [] 65 | kernel_size = 3 66 | for _ in range(1): 67 | layers.append(nn.Conv2d(in_channels=features, out_channels=features 68 | , kernel_size=kernel_size, padding=kernel_size//2, bias=True)) 69 | layers.append(nn.PReLU()) 70 | layers.append(nn.Conv2d(in_channels=features, out_channels=features 71 | , kernel_size=kernel_size, padding=kernel_size//2, bias=True)) 72 | self.res = nn.Sequential(*layers) 73 | self.ca = CA(features) 74 | 75 | def forward(self, x): 76 | out = self.res(x) 77 | out = self.ca(out) 78 | out += x 79 | return out 80 | 81 | 82 | class _down(nn.Module): 83 | def __init__(self, channel_in): 84 | super(_down, self).__init__() 85 | 86 | self.conv = nn.Conv2d(in_channels=channel_in, out_channels=2 * channel_in, kernel_size=3, stride=2, padding=1) 87 | self.relu = nn.PReLU() 88 | 89 | def forward(self, x): 90 | 91 | out = self.relu(self.conv(x)) 92 | 93 | return out 94 | 95 | 96 | class _up(nn.Module): 97 | def __init__(self, channel_in): 98 | super(_up, self).__init__() 99 | 100 | self.conv = nn.PixelShuffle(2) 101 | self.relu = nn.PReLU() 102 | 103 | def forward(self, x): 104 | 105 | out = self.relu(self.conv(x)) 106 | 107 | return out 108 | 109 | 110 | class AB(nn.Module): 111 | def __init__(self, features): 112 | super(AB, self).__init__() 113 | 114 | num = 2 115 | self.DCR_block1 = self.make_layer(RB, features, num) 116 | self.down1 = self.make_layer(_down, features, 1) 117 | self.DCR_block2 = self.make_layer(RB, features * 2, num) 118 | self.down2 = self.make_layer(_down, features * 2, 1) 119 | self.DCR_block3 = self.make_layer(RB, features * 4, num) 120 | self.up2 = self.make_layer(_up, features * 8, 1) 121 | self.DCR_block22 = self.make_layer(RB, features * 4, num) 122 | self.up1 = self.make_layer(_up, features * 4, 1) 123 | self.DCR_block11 = self.make_layer(RB, features * 2, num) 124 | self.conv_f = nn.Conv2d(in_channels=features * 2, out_channels=features, kernel_size=1, stride=1, padding=0) 125 | self.relu2 = nn.PReLU() 126 | 127 | def make_layer(self, block, channel_in, num): 128 | layers = [] 129 | for _ in range(num): 130 | layers.append(block(channel_in)) 131 | return nn.Sequential(*layers) 132 | 133 | def forward(self, x): 134 | 135 | conc1 = self.DCR_block1(x) 136 | out = self.down1(conc1) 137 | 138 | conc2 = self.DCR_block2(out) 139 | conc3 = self.down2(conc2) 140 | 141 | out = self.DCR_block3(conc3) 142 | out = torch.cat([conc3, out], 1) 143 | 144 | out = self.up2(out) 145 | out = torch.cat([conc2, out], 1) 146 | out = self.DCR_block22(out) 147 | 148 | out = self.up1(out) 149 | out = torch.cat([conc1, out], 1) 150 | out = self.DCR_block11(out) 151 | 152 | out = self.relu2(self.conv_f(out)) 153 | out += x 154 | 155 | return out 156 | 157 | 158 | class GAB(nn.Module): 159 | def __init__(self, features): 160 | super(GAB, self).__init__() 161 | 162 | kernel_size = 3 163 | self.res1 = self.make_layer(RB, features, 2) 164 | self.R = 2 165 | self.A = 1 166 | self.RB = nn.ModuleList() 167 | for _ in range(self.R): 168 | self.RB.append(RB(features)) 169 | self.AB = nn.ModuleList() 170 | for _ in range(self.A): 171 | self.AB.append(AB(features)) 172 | self.GFF_R = nn.Sequential( 173 | nn.Conv2d(self.R * features, features, kernel_size=1, padding=0, stride=1), 174 | nn.Conv2d(features, features, kernel_size, padding=kernel_size//2, stride=1), 175 | ) 176 | self.GFF_A = nn.Sequential( 177 | nn.Conv2d(self.A * features, features, kernel_size=1, padding=0, stride=1), 178 | nn.Conv2d(features, features, kernel_size, padding=kernel_size//2, stride=1), 179 | ) 180 | self.softmax = nn.Sigmoid() 181 | self.res2 = self.make_layer(RB, features * 2, 2) 182 | self.conv_f = nn.Conv2d(in_channels=features * 2, out_channels=features, kernel_size=1, stride=1, padding=0) 183 | self.relu2 = nn.PReLU() 184 | 185 | def make_layer(self, block, channel_in, num): 186 | layers = [] 187 | for _ in range(num): 188 | layers.append(block(channel_in)) 189 | return nn.Sequential(*layers) 190 | 191 | def forward(self, x): 192 | out = self.res1(x) 193 | 194 | RB_outs = [] 195 | for i in range(self.R): 196 | outR = self.RB[i](out) 197 | RB_outs.append(outR) 198 | outR = torch.cat(RB_outs, 1) 199 | outR = self.GFF_R(outR) 200 | 201 | AB_outs = [] 202 | for i in range(self.A): 203 | outA = self.AB[i](out) 204 | AB_outs.append(outA) 205 | outA = torch.cat(AB_outs, 1) 206 | outA = self.GFF_A(outA) 207 | 208 | # outR *= self.softmax(outA) 209 | out = torch.cat([outR, outA], 1) 210 | out = self.relu2(self.conv_f(self.res2(out))) 211 | 212 | out += x 213 | 214 | return out 215 | 216 | 217 | class Net(nn.Module): 218 | def __init__(self, features=48): 219 | super(Net, self).__init__() 220 | 221 | kernel_size = 3 222 | self.conv_i = nn.Conv2d(in_channels=3, out_channels=features, kernel_size=1, stride=1, padding=0) 223 | self.relu1 = nn.PReLU() 224 | self.GA = 5 225 | self.maxpool = GlobalMaxPool() 226 | self.GAB = nn.ModuleList() 227 | for _ in range(self.GA): 228 | self.GAB.append(GAB(features)) 229 | self.conv_m = nn.Conv2d(in_channels=features * 2, out_channels=features, kernel_size=1, stride=1, padding=0) 230 | self.relum = nn.PReLU() 231 | self.GFF_GA = nn.Sequential( 232 | nn.Conv2d(self.GA * 1 * features * 7, features, kernel_size=1, padding=0, stride=1), 233 | nn.Conv2d(features, features, kernel_size, padding=kernel_size//2, stride=1), 234 | ) 235 | self.conv_f = nn.Conv2d(in_channels=features, out_channels=3, kernel_size=1, stride=1, padding=0) 236 | self.relu2 = nn.PReLU() 237 | 238 | def make_layer(self, block, channel_in, num): 239 | layers = [] 240 | for _ in range(num): 241 | layers.append(block(channel_in)) 242 | return nn.Sequential(*layers) 243 | 244 | def forward(self, x): 245 | b, im, c, h, w = x.size() 246 | out = self.relu1(self.conv_i(x.view((b*im, c, h, w)))) 247 | residual = out[3, :, :, :] 248 | GAB_outs = [] 249 | for i in range(self.GA): 250 | out = self.GAB[i](out) 251 | out_max = self.maxpool(out.view((b, im, -1, h, w))) 252 | out_max = out_max.repeat(1, im, 1, 1, 1).view(b*im, -1, h, w) 253 | out = self.relum(self.conv_m(torch.cat([out, out_max], 1))) 254 | GAB_outs.append(out) 255 | out = torch.cat(GAB_outs, 1) 256 | out = self.GFF_GA(out.view((b, -1, h, w))) 257 | out += residual 258 | out = self.relu2(self.conv_f(out)) 259 | out += x[:, 3, :, :, :] 260 | 261 | return out 262 | -------------------------------------------------------------------------------- /Burst/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import numpy as np 4 | from glob import glob 5 | import h5py 6 | import torch 7 | import cv2 8 | import glob 9 | import torch.utils.data as udata 10 | from utils import data_augmentation 11 | from random import shuffle 12 | 13 | 14 | def normalize(data): 15 | return data/255. 16 | 17 | 18 | def Im2Patch(img, win, stride=1): 19 | k = 0 20 | endc = img.shape[0] 21 | endh = img.shape[1] 22 | endw = img.shape[2] 23 | patch = img[:, 0:endh-win+0+1:stride, 0:endw-win+0+1:stride] 24 | TotalPatNum = patch.shape[1] * patch.shape[2] 25 | Y = np.zeros([endc, win*win, TotalPatNum], np.float32) 26 | for i in range(win): 27 | for j in range(win): 28 | patch = img[:, i:endh-win+i+1:stride, j:endw-win+j+1:stride] 29 | Y[:, k, :] = np.array(patch[:]).reshape(endc, TotalPatNum) 30 | k = k + 1 31 | return Y.reshape([endc, win, win, TotalPatNum]) 32 | 33 | 34 | def prepare_data(data_path, patch_size, stride, aug_times=1): 35 | # ''' 36 | # train 37 | print('process training data') 38 | scales = [1] 39 | files = glob.glob(os.path.join('D:', 'NH-HAZE_train', 'HAZY', '*.png')) 40 | # mix = list(range(len(files))) 41 | # random.shuffle(mix) 42 | # mix_train = mix[:int(len(files)*0.96)] 43 | # mix_val = mix[int(len(files)*0.96):] 44 | files.sort() 45 | h5f = h5py.File('D:/train_input.h5', 'w') 46 | train_num = 0 47 | for i in range(len(files)): 48 | Img = cv2.imread(files[i]) 49 | h, w, c = Img.shape 50 | for k in range(len(scales)): 51 | # Img = cv2.resize(img, (int(h*scales[k]), int(w*scales[k])), interpolation=cv2.INTER_CUBIC) 52 | # Img = np.expand_dims(Img[:, :, :].copy(), 0) 53 | Img = np.swapaxes(Img, 0, 2) 54 | Img = np.swapaxes(Img, 1, 2) 55 | Img = np.float32(normalize(Img)) 56 | # print(Img.shape) 57 | patches = Im2Patch(Img, patch_size, stride) 58 | # print(i) 59 | print("file: %s scale %.1f # samples: %d" % (files[i], scales[k], aug_times*patches.shape[3])) 60 | for n in range(patches.shape[3]): 61 | data = patches[:, :, :, n].copy() 62 | # print(data.shape) 63 | h5f.create_dataset(str(train_num), data=data) 64 | train_num += 1 65 | for m in range(aug_times-1): 66 | data_aug = data_augmentation(data, np.random.randint(1, 8)) 67 | h5f.create_dataset(str(train_num)+"_aug_%d" % (m+1), data=data_aug) 68 | train_num += 1 69 | h5f.close() 70 | print('process training gt') 71 | scales = [1] 72 | files = glob.glob(os.path.join('D:', 'NH-HAZE_train', 'GT', '*.png')) 73 | files.sort() 74 | h5f = h5py.File('D:/train_gt.h5', 'w') 75 | train_num = 0 76 | for i in range(len(files)): 77 | Img = cv2.imread(files[i]) 78 | h, w, c = Img.shape 79 | for k in range(len(scales)): 80 | # Img = cv2.resize(img, (int(h*scales[k]), int(w*scales[k])), interpolation=cv2.INTER_CUBIC) 81 | # Img = np.expand_dims(Img[:, :, :].copy(), 0) 82 | Img = np.swapaxes(Img, 0, 2) 83 | Img = np.swapaxes(Img, 1, 2) 84 | Img = np.float32(normalize(Img)) 85 | patches = Im2Patch(Img, patch_size, stride) 86 | # print(i) 87 | print("file: %s scale %.1f # samples: %d" % (files[i], scales[k], aug_times*patches.shape[3])) 88 | for n in range(patches.shape[3]): 89 | data = patches[:, :, :, n].copy() 90 | # print(data.shape) 91 | h5f.create_dataset(str(train_num), data=data) 92 | train_num += 1 93 | for m in range(aug_times-1): 94 | data_aug = data_augmentation(data, np.random.randint(1, 8)) 95 | h5f.create_dataset(str(train_num)+"_aug_%d" % (m+1), data=data_aug) 96 | train_num += 1 97 | h5f.close() 98 | # val 99 | print('\nprocess validation data') 100 | # files.clear() 101 | files = glob.glob(os.path.join('D:', 'NH-HAZE_validation', 'HAZY', '*.png')) 102 | files.sort() 103 | h5f = h5py.File('D:/val_input.h5', 'w') 104 | val_num = 0 105 | for i in range(len(files)): 106 | print("file: %s" % files[i]) 107 | img = cv2.imread(files[i]) 108 | # img = np.expand_dims(img[:, :, :], 0) 109 | img = np.swapaxes(img, 0, 2) 110 | img = np.swapaxes(img, 1, 2) 111 | img = np.float32(normalize(img)) 112 | # print(i) 113 | # print(img.shape) 114 | h5f.create_dataset(str(val_num), data=img) 115 | val_num += 1 116 | h5f.close() 117 | # ''' 118 | print('\nprocess validation gt') 119 | # files.clear() 120 | files = glob.glob(os.path.join('D:', 'NH-HAZE_validation', 'GT', '*.png')) 121 | files.sort() 122 | h5f = h5py.File('D:/val_gt.h5', 'w') 123 | val_num = 0 124 | for i in range(len(files)): 125 | print("file: %s" % files[i]) 126 | img = cv2.imread(files[i]) 127 | # img = np.expand_dims(img[:, :, :], 0) 128 | img = np.swapaxes(img, 0, 2) 129 | img = np.swapaxes(img, 1, 2) 130 | img = np.float32(normalize(img)) 131 | # print(i) 132 | # print(img.shape) 133 | h5f.create_dataset(str(val_num), data=img) 134 | val_num += 1 135 | h5f.close() 136 | # print('training set, # samples %d\n' % train_num) 137 | print('val set, # samples %d\n' % val_num) 138 | # ''' 139 | 140 | 141 | class Dataset(udata.Dataset): 142 | def __init__(self, train=True): 143 | super(Dataset, self).__init__() 144 | self.train = train 145 | if self.train: 146 | h5f = [] 147 | for im in range(7): 148 | h5 = h5py.File('/home/user/depthMap/ksm/CVPR/demoire/data/train_input' + str(im) + '.h5', 'r') 149 | h5f.append(h5) 150 | h5f_gt = h5py.File('/home/user/depthMap/ksm/CVPR/demoire/data/train_gt.h5', 'r') 151 | else: 152 | h5f = [] 153 | for im in range(7): 154 | h5 = h5py.File('/home/user/depthMap/ksm/CVPR/demoire/data/val_input' + str(im) + '.h5', 'r') 155 | h5f.append(h5) 156 | h5f_gt = h5py.File('/home/user/depthMap/ksm/CVPR/demoire/data/val_gt.h5', 'r') 157 | self.keys = [] 158 | for im in range(7): 159 | h5 = h5f[im] 160 | self.keys.append(list(h5.keys())) 161 | h5.close() 162 | self.keys_gt = list(h5f_gt.keys()) 163 | h5f_gt.close() 164 | 165 | def __len__(self): 166 | return len(self.keys_gt) 167 | 168 | def __getitem__(self, index): 169 | if self.train: 170 | h5f = [] 171 | for im in range(7): 172 | h5 = h5py.File('/home/user/depthMap/ksm/CVPR/demoire/data/train_input' + str(im) + '.h5', 'r') 173 | h5f.append(h5) 174 | h5f_gt = h5py.File('/home/user/depthMap/ksm/CVPR/demoire/data/train_gt.h5', 'r') 175 | else: 176 | h5f = [] 177 | for im in range(7): 178 | h5 = h5py.File('/home/user/depthMap/ksm/CVPR/demoire/data/val_input' + str(im) + '.h5', 'r') 179 | h5f.append(h5) 180 | h5f_gt = h5py.File('/home/user/depthMap/ksm/CVPR/demoire/data/val_gt.h5', 'r') 181 | data = [] 182 | for im in range(7): 183 | k = self.keys[im][index] 184 | h5 = h5f[im] 185 | kk = h5[k] 186 | data.append(torch.Tensor(np.array(kk)).unsqueeze(0)) 187 | h5.close() 188 | key_gt = self.keys_gt[index] 189 | gt = np.array(h5f_gt[key_gt]) 190 | h5f_gt.close() 191 | return torch.cat(data, 0), torch.Tensor(gt) 192 | 193 | 194 | class DatasetBurst(udata.Dataset): 195 | def __init__(self, train=True): 196 | super(DatasetBurst, self).__init__() 197 | self.train = train 198 | if self.train: 199 | self.input_list = glob.glob("/home/user/depthMap/ksm/CVPR/demoire/data/train/input/*.png") 200 | self.gt_list = glob.glob("/home/user/depthMap/ksm/CVPR/demoire/data/train/gt/*.png") 201 | else: 202 | self.input_list = glob.glob("/home/user/depthMap/ksm/CVPR/demoire/data/val/input/*.png") 203 | self.gt_list = glob.glob("/home/user/depthMap/ksm/CVPR/demoire/data/val/gt/*.png") 204 | self.frame = int(len(self.input_list)/len(self.gt_list)) 205 | self.crop = 128 206 | self.th = 50 207 | 208 | def __len__(self): 209 | return len(self.gt_list) 210 | 211 | def __getitem__(self, index): 212 | order = list(range(len(self.gt_list))) 213 | shuffle(order) 214 | index = order[index] 215 | data_list = [] 216 | self.input_list.sort(key=str.lower) 217 | self.gt_list.sort(key=str.lower) 218 | origin = cv2.imread(self.input_list[index * self.frame + 3]) 219 | for im in range(self.frame): 220 | data = cv2.imread(self.input_list[index * self.frame + im]) 221 | if im != 3: 222 | _, bin2 = cv2.threshold(data, self.th, 255, cv2.THRESH_BINARY) 223 | _, bin3 = cv2.threshold(data, self.th, 255, cv2.THRESH_BINARY_INV) 224 | final2 = cv2.bitwise_and(data, bin2, mask=None) 225 | final3 = cv2.bitwise_and(origin, bin3, mask=None) 226 | data = cv2.bitwise_or(final3, final2, mask=None) 227 | data = np.float32(normalize(data)) 228 | data = np.transpose(data, (2, 0, 1)) 229 | data = torch.Tensor(data).unsqueeze(0) 230 | data_list.append(data) 231 | gt = cv2.imread(self.gt_list[index]) 232 | gt = np.float32(normalize(gt)) 233 | gt = np.transpose(gt, (2, 0, 1)) 234 | return torch.cat(data_list, 0), torch.Tensor(gt) 235 | --------------------------------------------------------------------------------