├── Figures
├── README.md
├── 3RD.PNG
├── HMA.PNG
├── Figure_1.png
├── Figure_2.png
├── Figure_3.png
├── Figure_4.png
├── Figure_5.png
├── Figure_6.png
├── Final_Results.PNG
└── Burst_Results_List.PNG
├── Single
├── README.md
├── utils.py
├── test.py
├── train.py
├── ssim.py
├── dataset.py
└── models.py
├── NTIRE2020_Demoireing_Challenge_Factsheet__C3Net_.pdf
├── Burst
├── README.md
├── utils.py
├── train.py
├── test.py
├── ssim.py
├── models.py
└── dataset.py
└── README.md
/Figures/README.md:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/Figures/3RD.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bmycheez/C3Net/HEAD/Figures/3RD.PNG
--------------------------------------------------------------------------------
/Figures/HMA.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bmycheez/C3Net/HEAD/Figures/HMA.PNG
--------------------------------------------------------------------------------
/Figures/Figure_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bmycheez/C3Net/HEAD/Figures/Figure_1.png
--------------------------------------------------------------------------------
/Figures/Figure_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bmycheez/C3Net/HEAD/Figures/Figure_2.png
--------------------------------------------------------------------------------
/Figures/Figure_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bmycheez/C3Net/HEAD/Figures/Figure_3.png
--------------------------------------------------------------------------------
/Figures/Figure_4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bmycheez/C3Net/HEAD/Figures/Figure_4.png
--------------------------------------------------------------------------------
/Figures/Figure_5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bmycheez/C3Net/HEAD/Figures/Figure_5.png
--------------------------------------------------------------------------------
/Figures/Figure_6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bmycheez/C3Net/HEAD/Figures/Figure_6.png
--------------------------------------------------------------------------------
/Figures/Final_Results.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bmycheez/C3Net/HEAD/Figures/Final_Results.PNG
--------------------------------------------------------------------------------
/Figures/Burst_Results_List.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bmycheez/C3Net/HEAD/Figures/Burst_Results_List.PNG
--------------------------------------------------------------------------------
/Single/README.md:
--------------------------------------------------------------------------------
1 | # Track 1: Single Image, C3Net
2 | [Reference](https://competitions.codalab.org/competitions/22223#learn_the_details)
3 |
--------------------------------------------------------------------------------
/NTIRE2020_Demoireing_Challenge_Factsheet__C3Net_.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bmycheez/C3Net/HEAD/NTIRE2020_Demoireing_Challenge_Factsheet__C3Net_.pdf
--------------------------------------------------------------------------------
/Burst/README.md:
--------------------------------------------------------------------------------
1 | # Track 2: Burst, C3Net-Burst
2 | For Track 2: Burst, we gave some variations from Track 1: Single Image (C3Net).
3 |
4 | 1. pre-processed input images for padding by chroma key
5 |
6 | 2. Controlled the number of channels
7 |
8 | 3. used global maxpooling
9 |
--------------------------------------------------------------------------------
/Burst/utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import numpy as np
5 | from skimage.measure.simple_metrics import compare_psnr
6 |
7 | def weights_init_kaiming(m):
8 | classname = m.__class__.__name__
9 | if classname.find('Conv') != -1:
10 | nn.init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
11 | elif classname.find('Linear') != -1:
12 | nn.init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
13 | elif classname.find('BatchNorm') != -1:
14 | # nn.init.uniform(m.weight.data, 1.0, 0.02)
15 | m.weight.data.normal_(mean=0, std=math.sqrt(2./9./64.)).clamp_(-0.025,0.025)
16 | nn.init.constant(m.bias.data, 0.0)
17 |
18 | def batch_PSNR(img, imclean, data_range):
19 | Img = img.data.cpu().numpy().astype(np.float32)
20 | Iclean = imclean.data.cpu().numpy().astype(np.float32)
21 | PSNR = 0
22 | for i in range(Img.shape[0]):
23 | PSNR += compare_psnr(Iclean[i,:,:,:], Img[i,:,:,:], data_range=data_range)
24 | return (PSNR/Img.shape[0])
25 |
26 | def data_augmentation(image, mode):
27 | out = np.transpose(image, (1,2,0))
28 | if mode == 0:
29 | # original
30 | out = out
31 | elif mode == 1:
32 | # flip up and down
33 | out = np.flipud(out)
34 | elif mode == 2:
35 | # rotate counterwise 90 degree
36 | out = np.rot90(out)
37 | elif mode == 3:
38 | # rotate 90 degree and flip up and down
39 | out = np.rot90(out)
40 | out = np.flipud(out)
41 | elif mode == 4:
42 | # rotate 180 degree
43 | out = np.rot90(out, k=2)
44 | elif mode == 5:
45 | # rotate 180 degree and flip
46 | out = np.rot90(out, k=2)
47 | out = np.flipud(out)
48 | elif mode == 6:
49 | # rotate 270 degree
50 | out = np.rot90(out, k=3)
51 | elif mode == 7:
52 | # rotate 270 degree and flip
53 | out = np.rot90(out, k=3)
54 | out = np.flipud(out)
55 | return np.transpose(out, (2,0,1))
56 |
--------------------------------------------------------------------------------
/Single/utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import numpy as np
5 | from skimage.measure.simple_metrics import compare_psnr
6 |
7 | def weights_init_kaiming(m):
8 | classname = m.__class__.__name__
9 | if classname.find('Conv') != -1:
10 | nn.init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
11 | elif classname.find('Linear') != -1:
12 | nn.init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
13 | elif classname.find('BatchNorm') != -1:
14 | # nn.init.uniform(m.weight.data, 1.0, 0.02)
15 | m.weight.data.normal_(mean=0, std=math.sqrt(2./9./64.)).clamp_(-0.025,0.025)
16 | nn.init.constant(m.bias.data, 0.0)
17 |
18 | def batch_PSNR(img, imclean, data_range):
19 | Img = img.data.cpu().numpy().astype(np.float32)
20 | Iclean = imclean.data.cpu().numpy().astype(np.float32)
21 | PSNR = 0
22 | for i in range(Img.shape[0]):
23 | PSNR += compare_psnr(Iclean[i,:,:,:], Img[i,:,:,:], data_range=data_range)
24 | return (PSNR/Img.shape[0])
25 |
26 | def data_augmentation(image, mode):
27 | out = np.transpose(image, (1,2,0))
28 | if mode == 0:
29 | # original
30 | out = out
31 | elif mode == 1:
32 | # flip up and down
33 | out = np.flipud(out)
34 | elif mode == 2:
35 | # rotate counterwise 90 degree
36 | out = np.rot90(out)
37 | elif mode == 3:
38 | # rotate 90 degree and flip up and down
39 | out = np.rot90(out)
40 | out = np.flipud(out)
41 | elif mode == 4:
42 | # rotate 180 degree
43 | out = np.rot90(out, k=2)
44 | elif mode == 5:
45 | # rotate 180 degree and flip
46 | out = np.rot90(out, k=2)
47 | out = np.flipud(out)
48 | elif mode == 6:
49 | # rotate 270 degree
50 | out = np.rot90(out, k=3)
51 | elif mode == 7:
52 | # rotate 270 degree and flip
53 | out = np.rot90(out, k=3)
54 | out = np.flipud(out)
55 | return np.transpose(out, (2,0,1))
56 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # C3Net
2 | This is a PyTorch implementation of the [New Trends in Image Restoration and Enhancement workshop and challenges on image and video restoration and enhancement (NTIRE 2020 with CVPR 2020)](https://data.vision.ee.ethz.ch/cvl/ntire20/) paper, [C3Net: Demoireing Network Attentive in Channel, Color and Concatenation](http://openaccess.thecvf.com/content_CVPRW_2020/html/w31/Kim_C3Net_Demoireing_Network_Attentive_in_Channel_Color_and_Concatenation_CVPRW_2020_paper.html).
3 |
4 | If you find our project useful in your research, please consider citing:
5 | ~~~
6 | @InProceedings{Kim_2020_CVPR_Workshops,
7 | author = {Kim, Sangmin and Nam, Hyungjoon and Kim, Jisu and Jeong, Jechang},
8 | title = {C3Net: Demoireing Network Attentive in Channel, Color and Concatenation},
9 | booktitle = {The IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) Workshops},
10 | month = {June},
11 | year = {2020}
12 | }
13 | ~~~
14 |
15 | # Dependencies
16 | Python 3.6.9
17 | PyTorch 1.4.0
18 |
19 | # Data
20 | [Reference](https://competitions.codalab.org/competitions/22223#participate-get_data)
21 |
22 | You have to sign in Codalab and apply to **NTIRE 2020 Demoireing Challenge** before getting the data.
23 |
24 | # Proposed algorithm
25 | 
26 | 
27 | 
28 | 
29 | 
30 | 
31 |
32 | # Training
33 | Use the following command to use our training codes
34 | ~~~
35 | python train.py
36 | ~~~
37 | There are other options you can choose.
38 | Please refer to train.py.
39 |
40 | # Test
41 | Use the following command to use our test codes
42 | ~~~
43 | python test.py
44 | ~~~
45 | There are other options you can choose.
46 | Please refer to test.py.
47 |
48 | # Performance (PSNR/SSIM)
49 | To use heavier model, we also used numpy to read input data, not hdf5.
50 | [Hyung-Joon](https://github.com/Hyung-Joon) and [jisukim](https://github.com/jisus189) helped it.
51 | **Our best records can be derived in [the code](https://github.com/Hyung-Joon/Demoire-Burst-single-master)** by changing h5 into numpy and reducing GPU memory.
52 |
53 | |Validation Server |PSNR |SSIM |Rank |
54 | |:-----------------------------------------------------------------------------------|:-------|:-------|:-------|
55 | |[Track 1: Single Image](https://competitions.codalab.org/competitions/22223#results)|41.30 |0.99 |9th |
56 | |[Track 2: Burst](https://competitions.codalab.org/competitions/22224#results) |40.55 |0.99 |5th |
57 |
58 | 
59 |
60 | [Testing Server Reference](https://arxiv.org/pdf/2005.03155.pdf)
61 | |Testing Server |PSNR |SSIM |Rank |
62 | |:--------------------|:-------|:-------|:------|
63 | |Track 1: Single Image|41.11 |0.99 |4th |
64 | |Track 2: Burst |40.33 |0.99 |5th |
65 |
66 | 
67 |
68 | 
69 |
70 | # Contact
71 | If you have any question about **Demoireing** model and the CVPR2020 challenge paper, feel free to ask me to .
72 | If you have any question about **Deblurring** model, visit [here](https://github.com/Hyung-Joon/Deblur-mobile-RCAN-Master) and feel free to ask Hyung-Joon to <013107nam@gmail.com>.
73 | If you have any question about using **more heavier C3Net**, visit [here](https://github.com/Hyung-Joon/Demoire-Burst-single-master) and feel free to ask jisukim to .
74 |
75 | # Acknowledgement
76 | Thanks for [SaoYan](https://github.com/SaoYan/DnCNN-PyTorch) who gave the implementaion of DnCNN.
77 | Thanks for [yun_yang](https://github.com/jt827859032/DRRN-pytorch) who gave the implementation of DRRN.
78 | Thanks for [BumjunPark](https://github.com/BumjunPark/DHDN) who gave the implementation of DHDN.
79 |
80 | Hint of color loss from [Jorge Pessoa](https://github.com/jorge-pessoa/pytorch-colors).
81 | Hint of concatenation and residual learning from [RDN (informal implementation)](https://github.com/lingtengqiu/RDN-pytorch).
82 | Hint of U-net block from [DIDN (formal implementation)](https://github.com/SonghyunYu/DIDN).
83 |
84 | C3Net started from [RUN](https://github.com/bmycheez/RUN).
85 |
86 | # More Details
87 | Also, we won 3rd Place in [**NTIRE 2020 Challenge on Image and Video Deblurring**](https://arxiv.org/pdf/2005.01244.pdf) thanks to [Hyung-Joon](https://github.com/Hyung-Joon) and [jisukim](https://github.com/jisus189).
88 | The code is available at [here](https://github.com/Hyung-Joon/Deblur-mobile-RCAN-Master).
89 |
90 | 
91 |
--------------------------------------------------------------------------------
/Single/test.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import os
3 | import argparse
4 | import glob
5 | import time
6 | from torch.autograd import Variable
7 | from models_v2 import Net
8 | from utils import *
9 |
10 |
11 | parser = argparse.ArgumentParser(description="DnCNN_Test")
12 | parser.add_argument("--num", type=int, default=3, help="Number of total layers")
13 | parser.add_argument("--logdir", type=str, default=".", help='path of log files')
14 | parser.add_argument("--gpu", type=str, default='0', help='test on Set12 or Set68')
15 | parser.add_argument("--inputdir", type=str, default='DemoireingTestInputSingle', help='noise level used on test set')
16 | opt = parser.parse_args()
17 |
18 |
19 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
20 | os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu
21 |
22 |
23 | def normalize(data):
24 | return data/255.
25 |
26 |
27 | def self_ensemble(out, mode, forward):
28 | if mode == 0:
29 | # original
30 | out = out
31 | elif mode == 1:
32 | # flip up and down
33 | out = np.flipud(out)
34 | elif mode == 2:
35 | # rotate counterwise 90 degree
36 | if forward == 1:
37 | out = np.rot90(out)
38 | else:
39 | out = np.rot90(out, k=3)
40 | elif mode == 3:
41 | # rotate 90 degree and flip up and down
42 | if forward == 1:
43 | out = np.rot90(out)
44 | out = np.flipud(out)
45 | else:
46 | out = np.flipud(out)
47 | out = np.rot90(out, k=3)
48 | elif mode == 4:
49 | # rotate 180 degree
50 | out = np.rot90(out, k=2)
51 | elif mode == 5:
52 | # rotate 180 degree and flip
53 | if forward == 1:
54 | out = np.rot90(out, k=2)
55 | out = np.flipud(out)
56 | else:
57 | out = np.flipud(out)
58 | out = np.rot90(out, k=2)
59 | elif mode == 6:
60 | if forward == 1:
61 | out = np.rot90(out, k=3)
62 | else:
63 | out = np.rot90(out)
64 | elif mode == 7:
65 | # rotate 270 degree and flip
66 | if forward == 1:
67 | out = np.rot90(out, k=3)
68 | out = np.flipud(out)
69 | else:
70 | out = np.flipud(out)
71 | out = np.rot90(out)
72 | return out
73 |
74 |
75 | def self_ensemble_v2(out, mode, forward):
76 | if mode == 0:
77 | # original
78 | out = out
79 | elif mode == 1:
80 | # flip up and down
81 | out = np.flipud(out)
82 | elif mode == 2:
83 | out = np.fliplr(out)
84 | elif mode == 3:
85 | out = np.flipud(out)
86 | out = np.fliplr(out)
87 | return out
88 |
89 |
90 | def main():
91 | # Build model
92 | print('Loading model ...\n')
93 | model = Net().cuda()
94 | # device_ids = [0]
95 | # model = nn.DataParallel(net, device_ids=device_ids).cuda()
96 | a = torch.load(glob.glob(os.path.join(opt.logdir, '*.pth'))[0])
97 | print(glob.glob(os.path.join(opt.logdir, '*.pth'))[0])
98 | ok = input("Right model? ")
99 | if ok == 'n':
100 | return
101 | model.load_state_dict(a)
102 | DHDN_flag = 4
103 | ensemble_flag = 4
104 | model.eval()
105 | # load data info
106 | print('Loading data info ...\n')
107 | files_source = glob.glob(os.path.join('D:/', opt.inputdir, '*_%d.png'
108 | % opt.num))
109 | files_source.sort()
110 | # process data
111 | psnr_test = 0
112 | c = 0
113 | for f in files_source:
114 | # image
115 | start = time.time()
116 | final = np.zeros(cv2.imread(f).shape)
117 | for mode in range(ensemble_flag):
118 | Img = cv2.imread(f)
119 | hh, ww, cc = Img.shape
120 | Img = self_ensemble_v2(Img, mode, 1)
121 | Img = np.swapaxes(Img, 0, 2)
122 | Img = np.swapaxes(Img, 1, 2)
123 | Img = np.float32(normalize(Img))
124 | a = Img.shape[1]
125 | b = Img.shape[2]
126 | if a % DHDN_flag != 0 or b % DHDN_flag != 0:
127 | h = DHDN_flag - (a % DHDN_flag)
128 | w = DHDN_flag - (b % DHDN_flag)
129 | Img = np.pad(Img, [(0, 0), (h//2, h-h//2), (w//2, w-w//2)], mode='edge')
130 | Img = np.expand_dims(Img, 0)
131 | ISource = torch.Tensor(Img)
132 | INoisy = Variable(ISource.cuda())
133 | with torch.no_grad(): # this can save much memory
134 | Out = torch.clamp(model(INoisy), 0., 1.)
135 | if a % DHDN_flag != 0 or b % DHDN_flag != 0:
136 | h = DHDN_flag - (a % DHDN_flag)
137 | w = DHDN_flag - (b % DHDN_flag)
138 | Out = Out[:, :, h//2:Img.shape[0]-(h-h//2+1), w//2:Img.shape[1]-(w-w//2+1)]
139 | name = str(c)
140 | if str(c) != 6:
141 | for i in range(6 - len(str(c))):
142 | name = '0' + name
143 | out = Out.squeeze(0).permute(1, 2, 0) * 255
144 | out = out.cpu().numpy()
145 | out = self_ensemble_v2(out, mode, 0)
146 | final += out
147 | cv2.imwrite(name + "_gt.png", final/ensemble_flag)
148 | mytime = time.time() - start
149 | psnr_test += mytime
150 | print("%s" % f)
151 | c += 1
152 | psnr_test /= len(files_source)
153 | print("\nRuntime on test data %.2f" % psnr_test)
154 |
155 |
156 | if __name__ == "__main__":
157 | main()
158 |
--------------------------------------------------------------------------------
/Single/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import torch.optim as optim
4 | from torch.autograd import Variable
5 | from torch.utils.data import DataLoader
6 | from torchsummary import *
7 | from dataset import prepare_data, Dataset
8 | from utils import *
9 | from datetime import datetime
10 | from ssim import *
11 | from models import *
12 |
13 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
14 | os.environ["CUDA_VISIBLE_DEVICES"] = "1"
15 |
16 | parser = argparse.ArgumentParser(description="DnCNN")
17 | parser.add_argument("--preprocess", type=bool, default=False, help='run prepare_data or not')
18 | parser.add_argument("--batchSize", type=int, default=1, help="Training batch size")
19 | parser.add_argument("--patch", type=int, default=128, help="Number of total layers")
20 | parser.add_argument("--epochs", type=int, default=300, help="Number of training epochs")
21 | parser.add_argument("--start_epochs", type=int, default=27, help="Number of training epochs")
22 | parser.add_argument("--start_iters", type=int, default=5998, help="Number of training epochs")
23 | parser.add_argument("--resume", type=str, default="net_38.0006.pth", help="Number of training epochs")
24 | parser.add_argument("--step", type=int, default=30, help="When to decay learning rate; should be less than epochs")
25 | parser.add_argument("--lr", type=float, default=1e-4, help="Initial learning rate")
26 | parser.add_argument("--decay", type=int, default=10, help="Initial learning rate")
27 | parser.add_argument("--outf", type=str, default="./checkpoint", help='path of log files')
28 | parser.add_argument("--mode", type=str, default="S", help='with known noise level (S) or blind training (B)')
29 | opt = parser.parse_args()
30 |
31 |
32 | def main():
33 | # Load dataset
34 | print('Loading dataset ...\n')
35 | dataset_train = Dataset(train=True)
36 | dataset_val = Dataset(train=False)
37 | loader_train = DataLoader(dataset=dataset_train, num_workers=4, batch_size=opt.batchSize, shuffle=True)
38 | loader_val = DataLoader(dataset=dataset_val, num_workers=4, batch_size=1, shuffle=False)
39 | # print(opt.batchSize)
40 | print("# of training samples: %d\n" % int(len(dataset_train)))
41 | # Build model
42 | # net = DnCNN(channels=1, num_of_layers=opt.num_of_layers)
43 | model = Net().cuda()
44 | # s = MSSSIM()
45 | criterion = nn.L1Loss().cuda()
46 | # vgg = Vgg16(requires_grad=False).cuda()
47 | # vgg = VGG('54').cuda()
48 | # Move to GPU
49 | # model = nn.DataParallel(net, device_ids=device_ids).cuda()
50 | # '''
51 | if opt.resume:
52 | model.load_state_dict(torch.load(opt.resume))
53 | # '''
54 | summary(model, (3, 128, 128))
55 | # Optimizer
56 | optimizer = optim.Adam(model.parameters(), lr=opt.lr)
57 | for epoch in range(opt.start_epochs, opt.epochs):
58 | current_lr = opt.lr * ((1 / opt.decay) ** ((epoch - opt.start_epochs) // opt.step))
59 | # set learning rate
60 | for param_group in optimizer.param_groups:
61 | param_group["lr"] = current_lr
62 | print('learning rate %f' % current_lr)
63 | # train
64 | for i, (imgn_train, img_train) in enumerate(loader_train, 0):
65 | if i < opt.start_iters and epoch == opt.start_epochs:
66 | continue
67 | # training step
68 | model.train()
69 | model.zero_grad()
70 | optimizer.zero_grad()
71 | img_train, imgn_train = Variable(img_train.cuda()), Variable(imgn_train.cuda())
72 | out_train = model(imgn_train)
73 | # feat_x = vgg(imgn_train)
74 | # feat_y = vgg(out_train)
75 | # perceptual_loss = criterion(feat_y.relu2_2, feat_x.relu2_2)
76 | # perceptual_loss = vgg(out_train, img_train)
77 | loss = color_loss(out_train, img_train) + criterion(out_train, img_train)
78 | # + 1e-4 * ((1 - s(out_train, img_train)) / 2.)
79 | loss /= 2
80 | loss.backward()
81 | optimizer.step()
82 | # '''
83 | # if you are using older version of PyTorch, you may need to change loss.item() to loss.data[0]
84 | if i % int(len(loader_train)//5) == 0:
85 | # the end of each epoch
86 | model.eval()
87 | # validate
88 | psnr_val = 0
89 | for _, (imgn_val, img_val) in enumerate(loader_val, 0):
90 | with torch.no_grad():
91 | img_val, imgn_val = Variable(img_val.cuda()), Variable(imgn_val.cuda())
92 | out_val = torch.clamp(model(imgn_val), 0., 1.)
93 | psnr_val += batch_PSNR(out_val, img_val, 1.)
94 | psnr_val /= len(dataset_val)
95 | now = datetime.now()
96 | print("[epoch %d][%d/%d] loss: %.6f PSNR_val: %.4f" %
97 | (epoch+1, i+1, len(loader_train), loss.item(), psnr_val), end='')
98 | print(' ', now.year, now.month, now.day, now.hour, now.minute, now.second)
99 | if psnr_val > 38:
100 | torch.save(model.state_dict(), os.path.join(opt.outf, 'net_' + str(round(psnr_val, 4)) + '.pth'))
101 | # '''
102 | torch.save(model.state_dict(), os.path.join(opt.outf, 'net_' + str(round(psnr_val, 4)) + '.pth'))
103 |
104 |
105 | if __name__ == "__main__":
106 | if opt.preprocess:
107 | if opt.mode == 'S':
108 | prepare_data(data_path='data', patch_size=opt.patch, stride=opt.patch, aug_times=1)
109 | if opt.mode == 'B':
110 | prepare_data(data_path='data', patch_size=50, stride=10, aug_times=2)
111 | main()
112 |
--------------------------------------------------------------------------------
/Burst/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch.optim as optim
3 | from torch.autograd import Variable
4 | from torch.utils.data import DataLoader
5 | from models import *
6 | from dataset import *
7 | from utils import *
8 | from datetime import datetime
9 | from ssim import *
10 |
11 |
12 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
13 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
14 |
15 | parser = argparse.ArgumentParser(description="DnCNN")
16 | parser.add_argument("--preprocess", type=bool, default=False, help='run prepare_data or not')
17 | parser.add_argument("--batchSize", type=int, default=1, help="Training batch size")
18 | parser.add_argument("--patch", type=int, default=128, help="Number of total layers")
19 | parser.add_argument("--epochs", type=int, default=300, help="Number of training epochs")
20 | parser.add_argument("--start_epochs", type=int, default=60, help="Number of training epochs")
21 | parser.add_argument("--start_iters", type=int, default=0, help="Number of training epochs")
22 | parser.add_argument("--resume", type=str, default="/home/user/depthMap/ksm/CVPR/demoire/logs/48_40.50146.pth",
23 | help="Number of training epochs")
24 | parser.add_argument("--step", type=int, default=30, help="When to decay learning rate; should be less than epochs")
25 | parser.add_argument("--lr", type=float, default=1e-4, help="Initial learning rate")
26 | parser.add_argument("--decay", type=int, default=10, help="Initial learning rate")
27 | parser.add_argument("--outf", type=str, default="/home/user/depthMap/ksm/CVPR/demoire/checkpoint",
28 | help='path of log files')
29 | parser.add_argument("--mode", type=str, default="S", help='with known noise level (S) or blind training (B)')
30 | opt = parser.parse_args()
31 |
32 |
33 | def main():
34 | # Load dataset
35 | print('Loading dataset ...\n')
36 | dataset_train = DatasetBurst(train=True)
37 | dataset_val = DatasetBurst(train=False)
38 | loader_train = DataLoader(dataset=dataset_train, num_workers=4, batch_size=opt.batchSize, shuffle=False)
39 | loader_val = DataLoader(dataset=dataset_val, num_workers=4, batch_size=1, shuffle=False)
40 | # print(opt.batchSize)
41 | print("# of training samples: %d\n" % int(len(dataset_train)))
42 | # Build model
43 | # net = DnCNN(channels=1, num_of_layers=opt.num_of_layers)
44 | model = Net().cuda()
45 | # s = MSSSIM()
46 | criterion = nn.L1Loss().cuda()
47 | burst = BurstLoss().cuda()
48 | # vgg = Vgg16(requires_grad=False).cuda()
49 | # vgg = VGG('54').cuda()
50 | # Move to GPU
51 | # model = nn.DataParallel(net, device_ids=device_ids).cuda()
52 | # '''
53 | if opt.resume:
54 | model.load_state_dict(torch.load(opt.resume))
55 | # test.main(model)
56 | # return
57 | # '''
58 | # summary(model, (3, 128, 128))
59 | # Optimizer
60 | optimizer = optim.Adam(model.parameters(), lr=opt.lr)
61 | psnr_max = 0
62 | loss_min = 1
63 | for epoch in range(opt.start_epochs, opt.epochs):
64 | # current_lr = opt.lr * ((1 / opt.decay) ** ((epoch - opt.start_epochs) // opt.step))
65 | current_lr = opt.lr * ((1 / opt.decay) ** (epoch // opt.step))
66 | # set learning rate
67 | for param_group in optimizer.param_groups:
68 | param_group["lr"] = current_lr
69 | print('learning rate %f' % current_lr)
70 | # train
71 | for i, (imgn_train, img_train) in enumerate(loader_train, 0):
72 | if i < opt.start_iters:
73 | continue
74 | # training step
75 | model.train()
76 | model.zero_grad()
77 | optimizer.zero_grad()
78 | img_train, imgn_train = Variable(img_train.cuda()), Variable(imgn_train.cuda())
79 | out_train = model(imgn_train)
80 | # feat_x = vgg(imgn_train)
81 | # feat_y = vgg(out_train)
82 | # perceptual_loss = criterion(feat_y.relu2_2, feat_x.relu2_2)
83 | # perceptual_loss = vgg(out_train, img_train)
84 | loss_color = color_loss(out_train, img_train)
85 | loss_content = criterion(out_train, img_train)
86 | loss_burst = burst(out_train, img_train)
87 | m = [5, 5, 0]
88 | loss = torch.div(m[0] * loss_color.cuda() + m[1] * loss_content.cuda() + m[2] * loss_burst.cuda(), 10)
89 | loss.backward()
90 | optimizer.step()
91 | # '''
92 | # if you are using older version of PyTorch, you may need to change loss.item() to loss.data[0]
93 | if i % int(len(loader_train)//5) == 0:
94 | # the end of each epoch
95 | model.eval()
96 | # validate
97 | psnr_val = 0
98 | for _, (imgn_val, img_val) in enumerate(loader_val, 0):
99 | with torch.no_grad():
100 | img_val, imgn_val = Variable(img_val.cuda()), Variable(imgn_val.cuda())
101 | out_val = torch.clamp(model(imgn_val), 0., 1.)
102 | psnr_val += batch_PSNR(out_val, img_val, 1.)
103 | psnr_val /= len(dataset_val)
104 | now = datetime.now()
105 | print("[epoch %d][%d/%d] loss: %.6f PSNR_val: %.4f" %
106 | (epoch+1, i+1, len(loader_train), loss.item(), psnr_val), end=' ')
107 | print(now.year, now.month, now.day, now.hour, now.minute, now.second)
108 | if psnr_val > psnr_max or loss < loss_min:
109 | psnr_max = psnr_val
110 | loss_min = loss
111 | torch.save(model.state_dict(), os.path.join(opt.outf, 'net_' + str(round(psnr_val, 4)) + '.pth'))
112 | # '''
113 | torch.save(model.state_dict(), os.path.join(opt.outf, 'net_' + str(round(psnr_val, 4)) + '.pth'))
114 |
115 |
116 | if __name__ == "__main__":
117 | if opt.preprocess:
118 | if opt.mode == 'S':
119 | prepare_data(data_path='data', patch_size=opt.patch, stride=opt.patch, aug_times=1)
120 | if opt.mode == 'B':
121 | prepare_data(data_path='data', patch_size=50, stride=10, aug_times=2)
122 | main()
123 |
--------------------------------------------------------------------------------
/Burst/test.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import os
3 | import argparse
4 | import glob
5 | import time
6 | from torch.autograd import Variable
7 | from models import *
8 | from utils import *
9 |
10 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
11 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
12 |
13 | parser = argparse.ArgumentParser(description="DnCNN_Test")
14 | parser.add_argument("--logdir", type=str, default="/home/user/depthMap/ksm/CVPR/demoire/logs", help='path of log files')
15 | opt = parser.parse_args()
16 |
17 |
18 | def normalize(data):
19 | return data/255.
20 |
21 |
22 | def self_ensemble(out, mode, forward):
23 | if mode == 0:
24 | # original
25 | out = out
26 | elif mode == 1:
27 | # flip up and down
28 | out = np.flipud(out)
29 | elif mode == 2:
30 | # rotate counterwise 90 degree
31 | if forward == 1:
32 | out = np.rot90(out)
33 | else:
34 | out = np.rot90(out, k=3)
35 | elif mode == 3:
36 | # rotate 90 degree and flip up and down
37 | if forward == 1:
38 | out = np.rot90(out)
39 | out = np.flipud(out)
40 | else:
41 | out = np.flipud(out)
42 | out = np.rot90(out, k=3)
43 | elif mode == 4:
44 | # rotate 180 degree
45 | out = np.rot90(out, k=2)
46 | elif mode == 5:
47 | # rotate 180 degree and flip
48 | if forward == 1:
49 | out = np.rot90(out, k=2)
50 | out = np.flipud(out)
51 | else:
52 | out = np.flipud(out)
53 | out = np.rot90(out, k=2)
54 | elif mode == 6:
55 | if forward == 1:
56 | out = np.rot90(out, k=3)
57 | else:
58 | out = np.rot90(out)
59 | elif mode == 7:
60 | # rotate 270 degree and flip
61 | if forward == 1:
62 | out = np.rot90(out, k=3)
63 | out = np.flipud(out)
64 | else:
65 | out = np.flipud(out)
66 | out = np.rot90(out)
67 | return out
68 |
69 |
70 | def main(model=0):
71 | # Build model
72 |
73 | print('Loading model ...\n')
74 | model = Net().cuda()
75 | # device_ids = [0]
76 | # model = nn.DataParallel(net, device_ids=device_ids).cuda()
77 | a = torch.load(glob.glob(os.path.join(opt.logdir, '*.pth'))[0])
78 | print(glob.glob(os.path.join(opt.logdir, '*.pth'))[0])
79 | ok = input("Right model? ")
80 | if ok == 'n':
81 | return
82 | model.load_state_dict(a)
83 |
84 | # DHDN_flag = 4
85 | frame = 7
86 | ensemble_flag = 1
87 | model.eval()
88 | # load data info
89 | print('Loading data info ...\n')
90 | files_source = glob.glob(os.path.join('/home/user/depthMap/ksm/CVPR/demoire', 'ValidationInput', '*.png'))
91 | files_source.sort()
92 | # process data
93 | psnr_test = 0
94 | c = 0
95 | for f in range(len(files_source)//frame):
96 | # image
97 | start = time.time()
98 | ISource = []
99 | # final = np.zeros(cv2.imread(f).shape)
100 | origin = cv2.imread(files_source[f * frame + 3])
101 | for mode in range(ensemble_flag):
102 | for im in range(frame):
103 | data = cv2.imread(files_source[f * frame + im])
104 | if im != 3:
105 | _, bin2 = cv2.threshold(data, 50, 255, cv2.THRESH_BINARY)
106 | _, bin3 = cv2.threshold(data, 50, 255, cv2.THRESH_BINARY_INV)
107 | final2 = cv2.bitwise_and(data, bin2, mask=None)
108 | final3 = cv2.bitwise_and(origin, bin3, mask=None)
109 | data = cv2.bitwise_or(final3, final2, mask=None)
110 | data = np.float32(normalize(data))
111 | data = np.transpose(data, (2, 0, 1))
112 | data = torch.Tensor(data).unsqueeze(0)
113 | ISource.append(data)
114 | """
115 | data = cv2.imread(files_source[f * frame + 3])
116 | data = np.float32(normalize(data))
117 | data = np.transpose(data, (2, 0, 1))
118 | data = torch.Tensor(data).unsqueeze(0)
119 | ISource.append(data)
120 | """
121 | ISource = torch.cat(ISource, 0)
122 | """
123 | hh, ww, cc = Img.shape
124 | for ch in range(cc):
125 | pl = Img[:, :, ch]
126 | Img[:, :, ch] = self_ensemble(pl, mode, 1)
127 | Img = np.swapaxes(Img, 0, 2)
128 | Img = np.swapaxes(Img, 1, 2)
129 | Img = np.float32(normalize(Img))
130 | a = Img.shape[1]
131 | b = Img.shape[2]
132 | if a % DHDN_flag != 0 or b % DHDN_flag != 0:
133 | h = DHDN_flag - (a % DHDN_flag)
134 | w = DHDN_flag - (b % DHDN_flag)
135 | Img = np.pad(Img, [(0, 0), (h//2, h-h//2), (w//2, w-w//2)], mode='edge')
136 | Img = np.expand_dims(Img, 0)
137 | ISource = torch.Tensor(Img)
138 | """
139 | INoisy = Variable(ISource.unsqueeze(0).cuda())
140 | # print(INoisy.size())
141 | with torch.no_grad(): # this can save much memory
142 | Out = torch.clamp(model(INoisy), 0., 1.)
143 | """
144 | if a % DHDN_flag != 0 or b % DHDN_flag != 0:
145 | h = DHDN_flag - (a % DHDN_flag)
146 | w = DHDN_flag - (b % DHDN_flag)
147 | Out = Out[:, :, h//2:Img.shape[0]-(h-h//2+1), w//2:Img.shape[1]-(w-w//2+1)]
148 | """
149 | c = f
150 | name = str(c)
151 | if str(c) != 6:
152 | for i in range(6 - len(str(c))):
153 | name = '0' + name
154 | out = Out.squeeze(0).permute(1, 2, 0) * 255
155 | out = out.cpu().numpy()
156 | """
157 | for ch in range(cc):
158 | out[:, :, ch] = self_ensemble(out[:, :, ch], mode, 0)
159 | final += out
160 | """
161 | cv2.imwrite("/home/user/depthMap/ksm/CVPR/demoire/" + name + "_gt.png", out/ensemble_flag)
162 | mytime = time.time() - start
163 | psnr_test += mytime
164 | print("%s" % f)
165 | c += 1
166 | psnr_test /= (len(files_source)//frame)
167 | print("\nRuntime on test data %.2f" % psnr_test)
168 |
169 |
170 | if __name__ == "__main__":
171 | main()
172 |
--------------------------------------------------------------------------------
/Single/ssim.py:
--------------------------------------------------------------------------------
1 | from math import exp
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torchvision.models as models
6 | from torch.autograd import Variable
7 | from kornia.color import rgb_to_yuv
8 |
9 |
10 | def gaussian(window_size, sigma):
11 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
12 | return gauss/gauss.sum()
13 |
14 |
15 | def create_window(window_size, channel=1):
16 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
17 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
18 | window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
19 | return window
20 |
21 |
22 | def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None):
23 | # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
24 | if val_range is None:
25 | if torch.max(img1) > 128:
26 | max_val = 255
27 | else:
28 | max_val = 1
29 |
30 | if torch.min(img1) < -0.5:
31 | min_val = -1
32 | else:
33 | min_val = 0
34 | L = max_val - min_val
35 | else:
36 | L = val_range
37 |
38 | padd = 0
39 | (_, channel, height, width) = img1.size()
40 | if window is None:
41 | real_size = min(window_size, height, width)
42 | window = create_window(real_size, channel=channel).to(img1.device)
43 |
44 | mu1 = F.conv2d(img1, window, padding=padd, groups=channel)
45 | mu2 = F.conv2d(img2, window, padding=padd, groups=channel)
46 |
47 | mu1_sq = mu1.pow(2)
48 | mu2_sq = mu2.pow(2)
49 | mu1_mu2 = mu1 * mu2
50 |
51 | sigma1_sq = F.conv2d(img1 * img1, window, padding=padd, groups=channel) - mu1_sq
52 | sigma2_sq = F.conv2d(img2 * img2, window, padding=padd, groups=channel) - mu2_sq
53 | sigma12 = F.conv2d(img1 * img2, window, padding=padd, groups=channel) - mu1_mu2
54 |
55 | C1 = (0.01 * L) ** 2
56 | C2 = (0.03 * L) ** 2
57 |
58 | v1 = 2.0 * sigma12 + C2
59 | v2 = sigma1_sq + sigma2_sq + C2
60 | cs = torch.mean(v1 / v2) # contrast sensitivity
61 |
62 | ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)
63 |
64 | if size_average:
65 | ret = ssim_map.mean()
66 | else:
67 | ret = ssim_map.mean(1).mean(1).mean(1)
68 |
69 | if full:
70 | return ret, cs
71 | return ret
72 |
73 |
74 | def msssim(img1, img2, window_size=11, size_average=True, val_range=None, normalize=False):
75 | device = img1.device
76 | weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device)
77 | levels = weights.size()[0]
78 | mssim = []
79 | mcs = []
80 | for _ in range(levels):
81 | sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range)
82 | mssim.append(sim)
83 | mcs.append(cs)
84 |
85 | img1 = F.avg_pool2d(img1, (2, 2))
86 | img2 = F.avg_pool2d(img2, (2, 2))
87 |
88 | mssim = torch.stack(mssim)
89 | mcs = torch.stack(mcs)
90 |
91 | # Normalize (to avoid NaNs during training unstable models, not compliant with original definition)
92 | if normalize:
93 | mssim = (mssim + 1) / 2
94 | mcs = (mcs + 1) / 2
95 |
96 | pow1 = mcs ** weights
97 | pow2 = mssim ** weights
98 | # From Matlab implementation https://ece.uwaterloo.ca/~z70wang/research/iwssim/
99 | output = torch.prod(pow1[:-1] * pow2[-1])
100 | return output
101 |
102 |
103 | # Classes to re-use window
104 | class SSIM(torch.nn.Module):
105 | def __init__(self, window_size=11, size_average=True, val_range=None):
106 | super(SSIM, self).__init__()
107 | self.window_size = window_size
108 | self.size_average = size_average
109 | self.val_range = val_range
110 |
111 | # Assume 1 channel for SSIM
112 | self.channel = 1
113 | self.window = create_window(window_size)
114 |
115 | def forward(self, img1, img2):
116 | (_, channel, _, _) = img1.size()
117 |
118 | if channel == self.channel and self.window.dtype == img1.dtype:
119 | window = self.window
120 | else:
121 | window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype)
122 | self.window = window
123 | self.channel = channel
124 |
125 | return ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average)
126 |
127 |
128 | class MSSSIM(torch.nn.Module):
129 | def __init__(self, window_size=11, size_average=True, channel=3):
130 | super(MSSSIM, self).__init__()
131 | self.window_size = window_size
132 | self.size_average = size_average
133 | self.channel = channel
134 |
135 | def forward(self, img1, img2):
136 | # TODO: store window between calls if possible
137 | return msssim(img1, img2, window_size=self.window_size, size_average=self.size_average)
138 |
139 |
140 | class MeanShift(nn.Conv2d):
141 | def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1):
142 | super(MeanShift, self).__init__(3, 3, kernel_size=1)
143 | std = torch.Tensor(rgb_std)
144 | self.weight.data = torch.eye(3).view(3, 3, 1, 1)
145 | self.weight.data.div_(std.view(3, 1, 1, 1))
146 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean)
147 | self.bias.data.div_(std)
148 | self.requires_grad = False
149 |
150 |
151 | class VGG(torch.nn.Module):
152 | def __init__(self, conv_index, rgb_range=1):
153 | super(VGG, self).__init__()
154 | vgg_features = models.vgg19(pretrained=True).features
155 | modules = [m for m in vgg_features]
156 | if conv_index == '22':
157 | self.vgg = nn.Sequential(*modules[:8])
158 | elif conv_index == '54':
159 | self.vgg = nn.Sequential(*modules[:35])
160 |
161 | vgg_mean = (0.485, 0.456, 0.406)
162 | vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range)
163 | self.sub_mean = MeanShift(rgb_range, vgg_mean, vgg_std)
164 | self.vgg.requires_grad = False
165 |
166 | def forward(self, sr, hr):
167 | def _forward(x):
168 | x = self.sub_mean(x)
169 | x = self.vgg(x)
170 | return x
171 |
172 | vgg_sr = _forward(sr)
173 | with torch.no_grad():
174 | vgg_hr = _forward(hr.detach())
175 |
176 | loss = F.l1_loss(vgg_sr, vgg_hr)
177 |
178 | return loss
179 |
180 |
181 | def color_loss(out, target):
182 | out_yuv = rgb_to_yuv(out)
183 | # out_y = out_yuv[:, 0, :, :]
184 | out_u = out_yuv[:, 1, :, :]
185 | out_v = out_yuv[:, 2, :, :]
186 | target_yuv = rgb_to_yuv(target)
187 | # target_y = target_yuv[:, 0, :, :]
188 | target_u = target_yuv[:, 1, :, :]
189 | target_v = target_yuv[:, 2, :, :]
190 |
191 | return torch.div(
192 | torch.mean((out_u - target_u).pow(1)).abs() + torch.mean((out_v - target_v).pow(1)).abs(), 2)
193 |
--------------------------------------------------------------------------------
/Single/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path
3 | import numpy as np
4 | import random
5 | import h5py
6 | import torch
7 | import cv2
8 | import glob
9 | import torch.utils.data as udata
10 | from utils import data_augmentation
11 |
12 |
13 | def normalize(data):
14 | return data/255.
15 |
16 |
17 | def Im2Patch(img, win, stride=1):
18 | k = 0
19 | endc = img.shape[0]
20 | endh = img.shape[1]
21 | endw = img.shape[2]
22 | patch = img[:, 0:endh-win+0+1:stride, 0:endw-win+0+1:stride]
23 | TotalPatNum = patch.shape[1] * patch.shape[2]
24 | Y = np.zeros([endc, win*win, TotalPatNum], np.float32)
25 | for i in range(win):
26 | for j in range(win):
27 | patch = img[:, i:endh-win+i+1:stride, j:endw-win+j+1:stride]
28 | Y[:, k, :] = np.array(patch[:]).reshape(endc, TotalPatNum)
29 | k = k + 1
30 | return Y.reshape([endc, win, win, TotalPatNum])
31 |
32 |
33 | def prepare_data(data_path, patch_size, stride, aug_times=1):
34 | # '''
35 | # train
36 | print('process training data')
37 | scales = [1]
38 | files = glob.glob(os.path.join('D:', 'NH-HAZE_train', 'HAZY', '*.png'))
39 | # mix = list(range(len(files)))
40 | # random.shuffle(mix)
41 | # mix_train = mix[:int(len(files)*0.96)]
42 | # mix_val = mix[int(len(files)*0.96):]
43 | files.sort()
44 | h5f = h5py.File('D:/train_input.h5', 'w')
45 | train_num = 0
46 | for i in range(len(files)):
47 | Img = cv2.imread(files[i])
48 | h, w, c = Img.shape
49 | for k in range(len(scales)):
50 | # Img = cv2.resize(img, (int(h*scales[k]), int(w*scales[k])), interpolation=cv2.INTER_CUBIC)
51 | # Img = np.expand_dims(Img[:, :, :].copy(), 0)
52 | Img = np.swapaxes(Img, 0, 2)
53 | Img = np.swapaxes(Img, 1, 2)
54 | Img = np.float32(normalize(Img))
55 | # print(Img.shape)
56 | patches = Im2Patch(Img, patch_size, stride)
57 | # print(i)
58 | print("file: %s scale %.1f # samples: %d" % (files[i], scales[k], aug_times*patches.shape[3]))
59 | for n in range(patches.shape[3]):
60 | data = patches[:, :, :, n].copy()
61 | # print(data.shape)
62 | h5f.create_dataset(str(train_num), data=data)
63 | train_num += 1
64 | for m in range(aug_times-1):
65 | data_aug = data_augmentation(data, np.random.randint(1, 8))
66 | h5f.create_dataset(str(train_num)+"_aug_%d" % (m+1), data=data_aug)
67 | train_num += 1
68 | h5f.close()
69 | print('process training gt')
70 | scales = [1]
71 | files = glob.glob(os.path.join('D:', 'NH-HAZE_train', 'GT', '*.png'))
72 | files.sort()
73 | h5f = h5py.File('D:/train_gt.h5', 'w')
74 | train_num = 0
75 | for i in range(len(files)):
76 | Img = cv2.imread(files[i])
77 | h, w, c = Img.shape
78 | for k in range(len(scales)):
79 | # Img = cv2.resize(img, (int(h*scales[k]), int(w*scales[k])), interpolation=cv2.INTER_CUBIC)
80 | # Img = np.expand_dims(Img[:, :, :].copy(), 0)
81 | Img = np.swapaxes(Img, 0, 2)
82 | Img = np.swapaxes(Img, 1, 2)
83 | Img = np.float32(normalize(Img))
84 | patches = Im2Patch(Img, patch_size, stride)
85 | # print(i)
86 | print("file: %s scale %.1f # samples: %d" % (files[i], scales[k], aug_times*patches.shape[3]))
87 | for n in range(patches.shape[3]):
88 | data = patches[:, :, :, n].copy()
89 | # print(data.shape)
90 | h5f.create_dataset(str(train_num), data=data)
91 | train_num += 1
92 | for m in range(aug_times-1):
93 | data_aug = data_augmentation(data, np.random.randint(1, 8))
94 | h5f.create_dataset(str(train_num)+"_aug_%d" % (m+1), data=data_aug)
95 | train_num += 1
96 | h5f.close()
97 | # val
98 | print('\nprocess validation data')
99 | # files.clear()
100 | files = glob.glob(os.path.join('D:', 'NH-HAZE_validation', 'HAZY', '*.png'))
101 | files.sort()
102 | h5f = h5py.File('D:/val_input.h5', 'w')
103 | val_num = 0
104 | for i in range(len(files)):
105 | print("file: %s" % files[i])
106 | img = cv2.imread(files[i])
107 | # img = np.expand_dims(img[:, :, :], 0)
108 | img = np.swapaxes(img, 0, 2)
109 | img = np.swapaxes(img, 1, 2)
110 | img = np.float32(normalize(img))
111 | # print(i)
112 | # print(img.shape)
113 | h5f.create_dataset(str(val_num), data=img)
114 | val_num += 1
115 | h5f.close()
116 | # '''
117 | print('\nprocess validation gt')
118 | # files.clear()
119 | files = glob.glob(os.path.join('D:', 'NH-HAZE_validation', 'GT', '*.png'))
120 | files.sort()
121 | h5f = h5py.File('D:/val_gt.h5', 'w')
122 | val_num = 0
123 | for i in range(len(files)):
124 | print("file: %s" % files[i])
125 | img = cv2.imread(files[i])
126 | # img = np.expand_dims(img[:, :, :], 0)
127 | img = np.swapaxes(img, 0, 2)
128 | img = np.swapaxes(img, 1, 2)
129 | img = np.float32(normalize(img))
130 | # print(i)
131 | # print(img.shape)
132 | h5f.create_dataset(str(val_num), data=img)
133 | val_num += 1
134 | h5f.close()
135 | # print('training set, # samples %d\n' % train_num)
136 | print('val set, # samples %d\n' % val_num)
137 | # '''
138 |
139 |
140 | class Dataset(udata.Dataset):
141 | def __init__(self, train=True):
142 | super(Dataset, self).__init__()
143 | self.train = train
144 | if self.train:
145 | h5f = h5py.File('/home/user/depthMap/ksm/AIM/demoire/data/train_input.h5', 'r')
146 | h5f_gt = h5py.File('/home/user/depthMap/ksm/AIM/demoire/data/train_gt.h5', 'r')
147 | self.keys = list(h5f.keys())
148 | self.keys_gt = list(h5f_gt.keys())
149 | h5f.close()
150 | h5f_gt.close()
151 | else:
152 | h5f = h5py.File('/home/user/depthMap/ksm/AIM/demoire/data/val_input.h5', 'r')
153 | h5f_gt = h5py.File('/home/user/depthMap/ksm/AIM/demoire/data/val_gt.h5', 'r')
154 | self.keys = list(h5f.keys())
155 | self.keys_gt = list(h5f_gt.keys())
156 | h5f.close()
157 | h5f_gt.close()
158 |
159 | def __len__(self):
160 | return len(self.keys)
161 |
162 | def __getitem__(self, index):
163 | if self.train:
164 | h5f = h5py.File('/home/user/depthMap/ksm/AIM/demoire/data/train_input.h5', 'r')
165 | h5f_gt = h5py.File('/home/user/depthMap/ksm/AIM/demoire/data/train_gt.h5', 'r')
166 | key = self.keys[index]
167 | key_gt = self.keys_gt[index]
168 | data = np.array(h5f[key])
169 | gt = np.array(h5f_gt[key_gt])
170 | h5f.close()
171 | h5f_gt.close()
172 | return torch.Tensor(data), torch.Tensor(gt)
173 | else:
174 | h5f = h5py.File('/home/user/depthMap/ksm/AIM/demoire/data/val_input.h5', 'r')
175 | h5f_gt = h5py.File('/home/user/depthMap/ksm/AIM/demoire/data/val_gt.h5', 'r')
176 | key = self.keys[index]
177 | key_gt = self.keys_gt[index]
178 | data = np.array(h5f[key])
179 | gt = np.array(h5f_gt[key_gt])
180 | h5f.close()
181 | h5f_gt.close()
182 | return torch.Tensor(data), torch.Tensor(gt)
183 |
--------------------------------------------------------------------------------
/Single/models.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from math import sqrt
3 | import torch
4 |
5 |
6 | class DnCNN(nn.Module):
7 | def __init__(self, channels, num_of_layers=17):
8 | super(DnCNN, self).__init__()
9 | kernel_size = 3
10 | padding = 1
11 | features = 64
12 | layers = []
13 | layers.append(nn.Conv2d(in_channels=channels, out_channels=features, kernel_size=kernel_size, padding=padding, bias=False))
14 | layers.append(nn.ReLU(inplace=True))
15 | for _ in range(num_of_layers-2):
16 | layers.append(nn.Conv2d(in_channels=features, out_channels=features, kernel_size=kernel_size, padding=padding, bias=False))
17 | layers.append(nn.BatchNorm2d(features))
18 | layers.append(nn.ReLU(inplace=True))
19 | layers.append(nn.Conv2d(in_channels=features, out_channels=channels, kernel_size=kernel_size, padding=padding, bias=False))
20 | self.dncnn = nn.Sequential(*layers)
21 | # weights initialization
22 | for m in self.modules():
23 | if isinstance(m, nn.Conv2d):
24 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
25 | m.weight.data.normal_(0, sqrt(2. / n))
26 |
27 | def forward(self, x):
28 | out = self.dncnn(x)
29 | return out
30 |
31 |
32 | class CA(nn.Module):
33 | def __init__(self, channel, reduction=16):
34 | super(CA, self).__init__()
35 | # global average pooling: feature --> point
36 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
37 | # feature channel downscale and upscale --> channel weight
38 | self.conv_du = nn.Sequential(
39 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True),
40 | nn.ReLU(inplace=True),
41 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True),
42 | nn.Sigmoid()
43 | )
44 |
45 | def forward(self, x):
46 | y = self.avg_pool(x)
47 | y = self.conv_du(y)
48 | return x * y
49 |
50 |
51 | class RB(nn.Module):
52 | def __init__(self, features):
53 | super(RB, self).__init__()
54 | layers = []
55 | kernel_size = 3
56 | for _ in range(1):
57 | layers.append(nn.Conv2d(in_channels=features, out_channels=features
58 | , kernel_size=kernel_size, padding=kernel_size//2, bias=True))
59 | layers.append(nn.PReLU())
60 | layers.append(nn.Conv2d(in_channels=features, out_channels=features
61 | , kernel_size=kernel_size, padding=kernel_size//2, bias=True))
62 | self.res = nn.Sequential(*layers)
63 | self.ca = CA(features)
64 |
65 | def forward(self, x):
66 | out = self.res(x)
67 | out = self.ca(out)
68 | out += x
69 | return out
70 |
71 |
72 | class _down(nn.Module):
73 | def __init__(self, channel_in):
74 | super(_down, self).__init__()
75 |
76 | self.conv = nn.Conv2d(in_channels=channel_in, out_channels=2 * channel_in, kernel_size=3, stride=2, padding=1)
77 | self.relu = nn.PReLU()
78 |
79 | def forward(self, x):
80 |
81 | out = self.relu(self.conv(x))
82 |
83 | return out
84 |
85 |
86 | class _up(nn.Module):
87 | def __init__(self, channel_in):
88 | super(_up, self).__init__()
89 |
90 | self.conv = nn.PixelShuffle(2)
91 | self.relu = nn.PReLU()
92 |
93 | def forward(self, x):
94 |
95 | out = self.relu(self.conv(x))
96 |
97 | return out
98 |
99 |
100 | class AB(nn.Module):
101 | def __init__(self, features):
102 | super(AB, self).__init__()
103 |
104 | num = 2
105 | self.DCR_block1 = self.make_layer(RB, features, num)
106 | self.down1 = self.make_layer(_down, features, 1)
107 | self.DCR_block2 = self.make_layer(RB, features * 2, num)
108 | self.down2 = self.make_layer(_down, features * 2, 1)
109 | self.DCR_block3 = self.make_layer(RB, features * 4, num)
110 | self.up2 = self.make_layer(_up, features * 8, 1)
111 | self.DCR_block22 = self.make_layer(RB, features * 4, num)
112 | self.up1 = self.make_layer(_up, features * 4, 1)
113 | self.DCR_block11 = self.make_layer(RB, features * 2, num)
114 | self.conv_f = nn.Conv2d(in_channels=features * 2, out_channels=features, kernel_size=1, stride=1, padding=0)
115 | self.relu2 = nn.PReLU()
116 |
117 | def make_layer(self, block, channel_in, num):
118 | layers = []
119 | for _ in range(num):
120 | layers.append(block(channel_in))
121 | return nn.Sequential(*layers)
122 |
123 | def forward(self, x):
124 |
125 | conc1 = self.DCR_block1(x)
126 | out = self.down1(conc1)
127 |
128 | conc2 = self.DCR_block2(out)
129 | conc3 = self.down2(conc2)
130 |
131 | out = self.DCR_block3(conc3)
132 | out = torch.cat([conc3, out], 1)
133 |
134 | out = self.up2(out)
135 | out = torch.cat([conc2, out], 1)
136 | out = self.DCR_block22(out)
137 |
138 | out = self.up1(out)
139 | out = torch.cat([conc1, out], 1)
140 | out = self.DCR_block11(out)
141 |
142 | out = self.relu2(self.conv_f(out))
143 | out += x
144 |
145 | return out
146 |
147 |
148 | class GAB(nn.Module):
149 | def __init__(self, features):
150 | super(GAB, self).__init__()
151 |
152 | kernel_size = 3
153 | self.res1 = self.make_layer(RB, features, 2)
154 | self.R = 2
155 | self.A = 1
156 | self.RB = nn.ModuleList()
157 | for _ in range(self.R):
158 | self.RB.append(RB(features))
159 | self.AB = nn.ModuleList()
160 | for _ in range(self.A):
161 | self.AB.append(AB(features))
162 | self.GFF_R = nn.Sequential(
163 | nn.Conv2d(self.R * features, features, kernel_size=1, padding=0, stride=1),
164 | nn.Conv2d(features, features, kernel_size, padding=kernel_size//2, stride=1),
165 | )
166 | self.GFF_A = nn.Sequential(
167 | nn.Conv2d(self.A * features, features, kernel_size=1, padding=0, stride=1),
168 | nn.Conv2d(features, features, kernel_size, padding=kernel_size//2, stride=1),
169 | )
170 | self.softmax = nn.Sigmoid()
171 | self.res2 = self.make_layer(RB, features * 2, 2)
172 | self.conv_f = nn.Conv2d(in_channels=features * 2, out_channels=features, kernel_size=1, stride=1, padding=0)
173 | self.relu2 = nn.PReLU()
174 |
175 | def make_layer(self, block, channel_in, num):
176 | layers = []
177 | for _ in range(num):
178 | layers.append(block(channel_in))
179 | return nn.Sequential(*layers)
180 |
181 | def forward(self, x):
182 | out = self.res1(x)
183 |
184 | RB_outs = []
185 | for i in range(self.R):
186 | outR = self.RB[i](out)
187 | RB_outs.append(outR)
188 | outR = torch.cat(RB_outs, 1)
189 | outR = self.GFF_R(outR)
190 | outR += out
191 | AB_outs = []
192 | for i in range(self.A):
193 | outA = self.AB[i](out)
194 | AB_outs.append(outA)
195 | outA = torch.cat(AB_outs, 1)
196 | outA = self.GFF_A(outA)
197 | outA += out
198 | # outR *= self.softmax(outA)
199 | out = torch.cat([outR, outA], 1)
200 | out = self.relu2(self.conv_f(self.res2(out)))
201 | out *= 0.2
202 | out += x
203 |
204 | return out
205 |
206 |
207 | class Net(nn.Module):
208 | def __init__(self, features=64):
209 | super(Net, self).__init__()
210 |
211 | kernel_size = 3
212 | self.conv_i = nn.Conv2d(in_channels=3, out_channels=features, kernel_size=1, stride=1, padding=0)
213 | self.relu1 = nn.PReLU()
214 | self.GA = 22
215 | self.GAB = nn.ModuleList()
216 | for _ in range(self.GA):
217 | self.GAB.append(GAB(features))
218 | self.GFF_GA = nn.Sequential(
219 | nn.Conv2d(self.GA * features, features, kernel_size=1, padding=0, stride=1),
220 | nn.Conv2d(features, features, kernel_size, padding=kernel_size//2, stride=1),
221 | )
222 | self.conv_f = nn.Conv2d(in_channels=features, out_channels=3, kernel_size=1, stride=1, padding=0)
223 | self.relu2 = nn.PReLU()
224 |
225 | def make_layer(self, block, channel_in, num):
226 | layers = []
227 | for _ in range(num):
228 | layers.append(block(channel_in))
229 | return nn.Sequential(*layers)
230 |
231 | def forward(self, x):
232 | out = self.relu1(self.conv_i(x))
233 |
234 | GAB_outs = []
235 | for i in range(self.GA):
236 | out = self.GAB[i](out)
237 | GAB_outs.append(out)
238 | out = torch.cat(GAB_outs, 1)
239 | out = self.GFF_GA(out)
240 |
241 | out = self.relu2(self.conv_f(out))
242 | out += x
243 |
244 | return out
245 |
--------------------------------------------------------------------------------
/Burst/ssim.py:
--------------------------------------------------------------------------------
1 | from math import exp
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torchvision.models as models
6 | from kornia.color import rgb_to_yuv
7 | from torch.nn.modules.loss import _Loss
8 | import numpy as np
9 |
10 |
11 | def gaussian(window_size, sigma):
12 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
13 | return gauss/gauss.sum()
14 |
15 |
16 | def create_window(window_size, channel=1):
17 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
18 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
19 | window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
20 | return window
21 |
22 |
23 | def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None):
24 | # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
25 | if val_range is None:
26 | if torch.max(img1) > 128:
27 | max_val = 255
28 | else:
29 | max_val = 1
30 |
31 | if torch.min(img1) < -0.5:
32 | min_val = -1
33 | else:
34 | min_val = 0
35 | L = max_val - min_val
36 | else:
37 | L = val_range
38 |
39 | padd = 0
40 | (_, channel, height, width) = img1.size()
41 | if window is None:
42 | real_size = min(window_size, height, width)
43 | window = create_window(real_size, channel=channel).to(img1.device)
44 |
45 | mu1 = F.conv2d(img1, window, padding=padd, groups=channel)
46 | mu2 = F.conv2d(img2, window, padding=padd, groups=channel)
47 |
48 | mu1_sq = mu1.pow(2)
49 | mu2_sq = mu2.pow(2)
50 | mu1_mu2 = mu1 * mu2
51 |
52 | sigma1_sq = F.conv2d(img1 * img1, window, padding=padd, groups=channel) - mu1_sq
53 | sigma2_sq = F.conv2d(img2 * img2, window, padding=padd, groups=channel) - mu2_sq
54 | sigma12 = F.conv2d(img1 * img2, window, padding=padd, groups=channel) - mu1_mu2
55 |
56 | C1 = (0.01 * L) ** 2
57 | C2 = (0.03 * L) ** 2
58 |
59 | v1 = 2.0 * sigma12 + C2
60 | v2 = sigma1_sq + sigma2_sq + C2
61 | cs = torch.mean(v1 / v2) # contrast sensitivity
62 |
63 | ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)
64 |
65 | if size_average:
66 | ret = ssim_map.mean()
67 | else:
68 | ret = ssim_map.mean(1).mean(1).mean(1)
69 |
70 | if full:
71 | return ret, cs
72 | return ret
73 |
74 |
75 | def msssim(img1, img2, window_size=11, size_average=True, val_range=None, normalize=False):
76 | device = img1.device
77 | weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device)
78 | levels = weights.size()[0]
79 | mssim = []
80 | mcs = []
81 | for _ in range(levels):
82 | sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range)
83 | mssim.append(sim)
84 | mcs.append(cs)
85 |
86 | img1 = F.avg_pool2d(img1, (2, 2))
87 | img2 = F.avg_pool2d(img2, (2, 2))
88 |
89 | mssim = torch.stack(mssim)
90 | mcs = torch.stack(mcs)
91 |
92 | # Normalize (to avoid NaNs during training unstable models, not compliant with original definition)
93 | if normalize:
94 | mssim = (mssim + 1) / 2
95 | mcs = (mcs + 1) / 2
96 |
97 | pow1 = mcs ** weights
98 | pow2 = mssim ** weights
99 | # From Matlab implementation https://ece.uwaterloo.ca/~z70wang/research/iwssim/
100 | output = torch.prod(pow1[:-1] * pow2[-1])
101 | return output
102 |
103 |
104 | # Classes to re-use window
105 | class SSIM(torch.nn.Module):
106 | def __init__(self, window_size=11, size_average=True, val_range=None):
107 | super(SSIM, self).__init__()
108 | self.window_size = window_size
109 | self.size_average = size_average
110 | self.val_range = val_range
111 |
112 | # Assume 1 channel for SSIM
113 | self.channel = 1
114 | self.window = create_window(window_size)
115 |
116 | def forward(self, img1, img2):
117 | (_, channel, _, _) = img1.size()
118 |
119 | if channel == self.channel and self.window.dtype == img1.dtype:
120 | window = self.window
121 | else:
122 | window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype)
123 | self.window = window
124 | self.channel = channel
125 |
126 | return ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average)
127 |
128 |
129 | class MSSSIM(torch.nn.Module):
130 | def __init__(self, window_size=11, size_average=True, channel=3):
131 | super(MSSSIM, self).__init__()
132 | self.window_size = window_size
133 | self.size_average = size_average
134 | self.channel = channel
135 |
136 | def forward(self, img1, img2):
137 | # TODO: store window between calls if possible
138 | return msssim(img1, img2, window_size=self.window_size, size_average=self.size_average)
139 |
140 |
141 | class MeanShift(nn.Conv2d):
142 | def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1):
143 | super(MeanShift, self).__init__(3, 3, kernel_size=1)
144 | std = torch.Tensor(rgb_std)
145 | self.weight.data = torch.eye(3).view(3, 3, 1, 1)
146 | self.weight.data.div_(std.view(3, 1, 1, 1))
147 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean)
148 | self.bias.data.div_(std)
149 | self.requires_grad = False
150 |
151 |
152 | class VGG(torch.nn.Module):
153 | def __init__(self, conv_index, rgb_range=1):
154 | super(VGG, self).__init__()
155 | vgg_features = models.vgg19(pretrained=True).features
156 | modules = [m for m in vgg_features]
157 | if conv_index == '22':
158 | self.vgg = nn.Sequential(*modules[:8])
159 | elif conv_index == '54':
160 | self.vgg = nn.Sequential(*modules[:35])
161 |
162 | vgg_mean = (0.485, 0.456, 0.406)
163 | vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range)
164 | self.sub_mean = MeanShift(rgb_range, vgg_mean, vgg_std)
165 | self.vgg.requires_grad = False
166 |
167 | def forward(self, sr, hr):
168 | def _forward(x):
169 | x = self.sub_mean(x)
170 | x = self.vgg(x)
171 | return x
172 |
173 | vgg_sr = _forward(sr)
174 | with torch.no_grad():
175 | vgg_hr = _forward(hr.detach())
176 |
177 | loss = F.l1_loss(vgg_sr, vgg_hr)
178 |
179 | return loss
180 |
181 |
182 | def color_loss(out, target):
183 | out_yuv = rgb_to_yuv(out)
184 | out_u = out_yuv[:, 1, :, :]
185 | out_v = out_yuv[:, 2, :, :]
186 | target_yuv = rgb_to_yuv(target)
187 | target_u = target_yuv[:, 1, :, :]
188 | target_v = target_yuv[:, 2, :, :]
189 |
190 | return torch.div(torch.mean((out_u - target_u).pow(1)).abs() + torch.mean((out_v - target_v).pow(1)).abs(), 2)
191 |
192 |
193 | class BurstLoss(_Loss):
194 |
195 | def __init__(self, size_average=None, reduce=None, reduction='mean'):
196 | super(BurstLoss, self).__init__(size_average, reduce, reduction)
197 |
198 | self.reduction = reduction
199 | use_cuda = torch.cuda.is_available()
200 | device = torch.device("cuda:0" if use_cuda else "cpu")
201 |
202 | prewitt_filter = 1 / 6 * np.array([[1, 0, -1],
203 | [1, 0, -1],
204 | [1, 0, -1]])
205 |
206 | self.prewitt_filter_horizontal = torch.nn.Conv2d(in_channels=1, out_channels=1,
207 | kernel_size=prewitt_filter.shape,
208 | padding=prewitt_filter.shape[0] // 2).to(device)
209 |
210 | self.prewitt_filter_horizontal.weight.data.copy_(torch.from_numpy(prewitt_filter).to(device))
211 | self.prewitt_filter_horizontal.bias.data.copy_(torch.from_numpy(np.array([0.0])).to(device))
212 |
213 | self.prewitt_filter_vertical = torch.nn.Conv2d(in_channels=1, out_channels=1,
214 | kernel_size=prewitt_filter.shape,
215 | padding=prewitt_filter.shape[0] // 2).to(device)
216 |
217 | self.prewitt_filter_vertical.weight.data.copy_(torch.from_numpy(prewitt_filter.T).to(device))
218 | self.prewitt_filter_vertical.bias.data.copy_(torch.from_numpy(np.array([0.0])).to(device))
219 |
220 | def get_gradients(self, img):
221 | img_r = img[:, 0:1, :, :]
222 | img_g = img[:, 1:2, :, :]
223 | img_b = img[:, 2:3, :, :]
224 |
225 | grad_x_r = self.prewitt_filter_horizontal(img_r)
226 | grad_y_r = self.prewitt_filter_vertical(img_r)
227 | grad_x_g = self.prewitt_filter_horizontal(img_g)
228 | grad_y_g = self.prewitt_filter_vertical(img_g)
229 | grad_x_b = self.prewitt_filter_horizontal(img_b)
230 | grad_y_b = self.prewitt_filter_vertical(img_b)
231 |
232 | grad_x = torch.stack([grad_x_r[:, 0, :, :], grad_x_g[:, 0, :, :], grad_x_b[:, 0, :, :]], dim=1)
233 | grad_y = torch.stack([grad_y_r[:, 0, :, :], grad_y_g[:, 0, :, :], grad_y_b[:, 0, :, :]], dim=1)
234 |
235 | grad = torch.stack([grad_x, grad_y], dim=1)
236 |
237 | return grad
238 |
239 | def forward(self, input, target):
240 | input_grad = self.get_gradients(input)
241 | target_grad = self.get_gradients(target)
242 |
243 | return F.l1_loss(input_grad, target_grad, reduction=self.reduction)
244 |
--------------------------------------------------------------------------------
/Burst/models.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from math import sqrt
3 | import torch
4 |
5 |
6 | class DnCNN(nn.Module):
7 | def __init__(self, channels, num_of_layers=17):
8 | super(DnCNN, self).__init__()
9 | kernel_size = 3
10 | padding = 1
11 | features = 64
12 | layers = []
13 | layers.append(nn.Conv2d(in_channels=channels, out_channels=features, kernel_size=kernel_size, padding=padding, bias=False))
14 | layers.append(nn.ReLU(inplace=True))
15 | for _ in range(num_of_layers-2):
16 | layers.append(nn.Conv2d(in_channels=features, out_channels=features, kernel_size=kernel_size, padding=padding, bias=False))
17 | layers.append(nn.BatchNorm2d(features))
18 | layers.append(nn.ReLU(inplace=True))
19 | layers.append(nn.Conv2d(in_channels=features, out_channels=channels, kernel_size=kernel_size, padding=padding, bias=False))
20 | self.dncnn = nn.Sequential(*layers)
21 | # weights initialization
22 | for m in self.modules():
23 | if isinstance(m, nn.Conv2d):
24 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
25 | m.weight.data.normal_(0, sqrt(2. / n))
26 |
27 | def forward(self, x):
28 | out = self.dncnn(x)
29 | return out
30 |
31 |
32 | class GlobalMaxPool(torch.nn.Module):
33 | def __init__(self):
34 | super(GlobalMaxPool, self).__init__()
35 |
36 | def forward(self, input):
37 | output = torch.max(input, dim=1)[0]
38 |
39 | return torch.unsqueeze(output, 1)
40 |
41 |
42 | class CA(nn.Module):
43 | def __init__(self, channel, reduction=16):
44 | super(CA, self).__init__()
45 | # global average pooling: feature --> point
46 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
47 | # feature channel downscale and upscale --> channel weight
48 | self.conv_du = nn.Sequential(
49 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True),
50 | nn.ReLU(inplace=True),
51 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True),
52 | nn.Sigmoid()
53 | )
54 |
55 | def forward(self, x):
56 | y = self.avg_pool(x)
57 | y = self.conv_du(y)
58 | return x * y
59 |
60 |
61 | class RB(nn.Module):
62 | def __init__(self, features):
63 | super(RB, self).__init__()
64 | layers = []
65 | kernel_size = 3
66 | for _ in range(1):
67 | layers.append(nn.Conv2d(in_channels=features, out_channels=features
68 | , kernel_size=kernel_size, padding=kernel_size//2, bias=True))
69 | layers.append(nn.PReLU())
70 | layers.append(nn.Conv2d(in_channels=features, out_channels=features
71 | , kernel_size=kernel_size, padding=kernel_size//2, bias=True))
72 | self.res = nn.Sequential(*layers)
73 | self.ca = CA(features)
74 |
75 | def forward(self, x):
76 | out = self.res(x)
77 | out = self.ca(out)
78 | out += x
79 | return out
80 |
81 |
82 | class _down(nn.Module):
83 | def __init__(self, channel_in):
84 | super(_down, self).__init__()
85 |
86 | self.conv = nn.Conv2d(in_channels=channel_in, out_channels=2 * channel_in, kernel_size=3, stride=2, padding=1)
87 | self.relu = nn.PReLU()
88 |
89 | def forward(self, x):
90 |
91 | out = self.relu(self.conv(x))
92 |
93 | return out
94 |
95 |
96 | class _up(nn.Module):
97 | def __init__(self, channel_in):
98 | super(_up, self).__init__()
99 |
100 | self.conv = nn.PixelShuffle(2)
101 | self.relu = nn.PReLU()
102 |
103 | def forward(self, x):
104 |
105 | out = self.relu(self.conv(x))
106 |
107 | return out
108 |
109 |
110 | class AB(nn.Module):
111 | def __init__(self, features):
112 | super(AB, self).__init__()
113 |
114 | num = 2
115 | self.DCR_block1 = self.make_layer(RB, features, num)
116 | self.down1 = self.make_layer(_down, features, 1)
117 | self.DCR_block2 = self.make_layer(RB, features * 2, num)
118 | self.down2 = self.make_layer(_down, features * 2, 1)
119 | self.DCR_block3 = self.make_layer(RB, features * 4, num)
120 | self.up2 = self.make_layer(_up, features * 8, 1)
121 | self.DCR_block22 = self.make_layer(RB, features * 4, num)
122 | self.up1 = self.make_layer(_up, features * 4, 1)
123 | self.DCR_block11 = self.make_layer(RB, features * 2, num)
124 | self.conv_f = nn.Conv2d(in_channels=features * 2, out_channels=features, kernel_size=1, stride=1, padding=0)
125 | self.relu2 = nn.PReLU()
126 |
127 | def make_layer(self, block, channel_in, num):
128 | layers = []
129 | for _ in range(num):
130 | layers.append(block(channel_in))
131 | return nn.Sequential(*layers)
132 |
133 | def forward(self, x):
134 |
135 | conc1 = self.DCR_block1(x)
136 | out = self.down1(conc1)
137 |
138 | conc2 = self.DCR_block2(out)
139 | conc3 = self.down2(conc2)
140 |
141 | out = self.DCR_block3(conc3)
142 | out = torch.cat([conc3, out], 1)
143 |
144 | out = self.up2(out)
145 | out = torch.cat([conc2, out], 1)
146 | out = self.DCR_block22(out)
147 |
148 | out = self.up1(out)
149 | out = torch.cat([conc1, out], 1)
150 | out = self.DCR_block11(out)
151 |
152 | out = self.relu2(self.conv_f(out))
153 | out += x
154 |
155 | return out
156 |
157 |
158 | class GAB(nn.Module):
159 | def __init__(self, features):
160 | super(GAB, self).__init__()
161 |
162 | kernel_size = 3
163 | self.res1 = self.make_layer(RB, features, 2)
164 | self.R = 2
165 | self.A = 1
166 | self.RB = nn.ModuleList()
167 | for _ in range(self.R):
168 | self.RB.append(RB(features))
169 | self.AB = nn.ModuleList()
170 | for _ in range(self.A):
171 | self.AB.append(AB(features))
172 | self.GFF_R = nn.Sequential(
173 | nn.Conv2d(self.R * features, features, kernel_size=1, padding=0, stride=1),
174 | nn.Conv2d(features, features, kernel_size, padding=kernel_size//2, stride=1),
175 | )
176 | self.GFF_A = nn.Sequential(
177 | nn.Conv2d(self.A * features, features, kernel_size=1, padding=0, stride=1),
178 | nn.Conv2d(features, features, kernel_size, padding=kernel_size//2, stride=1),
179 | )
180 | self.softmax = nn.Sigmoid()
181 | self.res2 = self.make_layer(RB, features * 2, 2)
182 | self.conv_f = nn.Conv2d(in_channels=features * 2, out_channels=features, kernel_size=1, stride=1, padding=0)
183 | self.relu2 = nn.PReLU()
184 |
185 | def make_layer(self, block, channel_in, num):
186 | layers = []
187 | for _ in range(num):
188 | layers.append(block(channel_in))
189 | return nn.Sequential(*layers)
190 |
191 | def forward(self, x):
192 | out = self.res1(x)
193 |
194 | RB_outs = []
195 | for i in range(self.R):
196 | outR = self.RB[i](out)
197 | RB_outs.append(outR)
198 | outR = torch.cat(RB_outs, 1)
199 | outR = self.GFF_R(outR)
200 |
201 | AB_outs = []
202 | for i in range(self.A):
203 | outA = self.AB[i](out)
204 | AB_outs.append(outA)
205 | outA = torch.cat(AB_outs, 1)
206 | outA = self.GFF_A(outA)
207 |
208 | # outR *= self.softmax(outA)
209 | out = torch.cat([outR, outA], 1)
210 | out = self.relu2(self.conv_f(self.res2(out)))
211 |
212 | out += x
213 |
214 | return out
215 |
216 |
217 | class Net(nn.Module):
218 | def __init__(self, features=48):
219 | super(Net, self).__init__()
220 |
221 | kernel_size = 3
222 | self.conv_i = nn.Conv2d(in_channels=3, out_channels=features, kernel_size=1, stride=1, padding=0)
223 | self.relu1 = nn.PReLU()
224 | self.GA = 5
225 | self.maxpool = GlobalMaxPool()
226 | self.GAB = nn.ModuleList()
227 | for _ in range(self.GA):
228 | self.GAB.append(GAB(features))
229 | self.conv_m = nn.Conv2d(in_channels=features * 2, out_channels=features, kernel_size=1, stride=1, padding=0)
230 | self.relum = nn.PReLU()
231 | self.GFF_GA = nn.Sequential(
232 | nn.Conv2d(self.GA * 1 * features * 7, features, kernel_size=1, padding=0, stride=1),
233 | nn.Conv2d(features, features, kernel_size, padding=kernel_size//2, stride=1),
234 | )
235 | self.conv_f = nn.Conv2d(in_channels=features, out_channels=3, kernel_size=1, stride=1, padding=0)
236 | self.relu2 = nn.PReLU()
237 |
238 | def make_layer(self, block, channel_in, num):
239 | layers = []
240 | for _ in range(num):
241 | layers.append(block(channel_in))
242 | return nn.Sequential(*layers)
243 |
244 | def forward(self, x):
245 | b, im, c, h, w = x.size()
246 | out = self.relu1(self.conv_i(x.view((b*im, c, h, w))))
247 | residual = out[3, :, :, :]
248 | GAB_outs = []
249 | for i in range(self.GA):
250 | out = self.GAB[i](out)
251 | out_max = self.maxpool(out.view((b, im, -1, h, w)))
252 | out_max = out_max.repeat(1, im, 1, 1, 1).view(b*im, -1, h, w)
253 | out = self.relum(self.conv_m(torch.cat([out, out_max], 1)))
254 | GAB_outs.append(out)
255 | out = torch.cat(GAB_outs, 1)
256 | out = self.GFF_GA(out.view((b, -1, h, w)))
257 | out += residual
258 | out = self.relu2(self.conv_f(out))
259 | out += x[:, 3, :, :, :]
260 |
261 | return out
262 |
--------------------------------------------------------------------------------
/Burst/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path
3 | import numpy as np
4 | from glob import glob
5 | import h5py
6 | import torch
7 | import cv2
8 | import glob
9 | import torch.utils.data as udata
10 | from utils import data_augmentation
11 | from random import shuffle
12 |
13 |
14 | def normalize(data):
15 | return data/255.
16 |
17 |
18 | def Im2Patch(img, win, stride=1):
19 | k = 0
20 | endc = img.shape[0]
21 | endh = img.shape[1]
22 | endw = img.shape[2]
23 | patch = img[:, 0:endh-win+0+1:stride, 0:endw-win+0+1:stride]
24 | TotalPatNum = patch.shape[1] * patch.shape[2]
25 | Y = np.zeros([endc, win*win, TotalPatNum], np.float32)
26 | for i in range(win):
27 | for j in range(win):
28 | patch = img[:, i:endh-win+i+1:stride, j:endw-win+j+1:stride]
29 | Y[:, k, :] = np.array(patch[:]).reshape(endc, TotalPatNum)
30 | k = k + 1
31 | return Y.reshape([endc, win, win, TotalPatNum])
32 |
33 |
34 | def prepare_data(data_path, patch_size, stride, aug_times=1):
35 | # '''
36 | # train
37 | print('process training data')
38 | scales = [1]
39 | files = glob.glob(os.path.join('D:', 'NH-HAZE_train', 'HAZY', '*.png'))
40 | # mix = list(range(len(files)))
41 | # random.shuffle(mix)
42 | # mix_train = mix[:int(len(files)*0.96)]
43 | # mix_val = mix[int(len(files)*0.96):]
44 | files.sort()
45 | h5f = h5py.File('D:/train_input.h5', 'w')
46 | train_num = 0
47 | for i in range(len(files)):
48 | Img = cv2.imread(files[i])
49 | h, w, c = Img.shape
50 | for k in range(len(scales)):
51 | # Img = cv2.resize(img, (int(h*scales[k]), int(w*scales[k])), interpolation=cv2.INTER_CUBIC)
52 | # Img = np.expand_dims(Img[:, :, :].copy(), 0)
53 | Img = np.swapaxes(Img, 0, 2)
54 | Img = np.swapaxes(Img, 1, 2)
55 | Img = np.float32(normalize(Img))
56 | # print(Img.shape)
57 | patches = Im2Patch(Img, patch_size, stride)
58 | # print(i)
59 | print("file: %s scale %.1f # samples: %d" % (files[i], scales[k], aug_times*patches.shape[3]))
60 | for n in range(patches.shape[3]):
61 | data = patches[:, :, :, n].copy()
62 | # print(data.shape)
63 | h5f.create_dataset(str(train_num), data=data)
64 | train_num += 1
65 | for m in range(aug_times-1):
66 | data_aug = data_augmentation(data, np.random.randint(1, 8))
67 | h5f.create_dataset(str(train_num)+"_aug_%d" % (m+1), data=data_aug)
68 | train_num += 1
69 | h5f.close()
70 | print('process training gt')
71 | scales = [1]
72 | files = glob.glob(os.path.join('D:', 'NH-HAZE_train', 'GT', '*.png'))
73 | files.sort()
74 | h5f = h5py.File('D:/train_gt.h5', 'w')
75 | train_num = 0
76 | for i in range(len(files)):
77 | Img = cv2.imread(files[i])
78 | h, w, c = Img.shape
79 | for k in range(len(scales)):
80 | # Img = cv2.resize(img, (int(h*scales[k]), int(w*scales[k])), interpolation=cv2.INTER_CUBIC)
81 | # Img = np.expand_dims(Img[:, :, :].copy(), 0)
82 | Img = np.swapaxes(Img, 0, 2)
83 | Img = np.swapaxes(Img, 1, 2)
84 | Img = np.float32(normalize(Img))
85 | patches = Im2Patch(Img, patch_size, stride)
86 | # print(i)
87 | print("file: %s scale %.1f # samples: %d" % (files[i], scales[k], aug_times*patches.shape[3]))
88 | for n in range(patches.shape[3]):
89 | data = patches[:, :, :, n].copy()
90 | # print(data.shape)
91 | h5f.create_dataset(str(train_num), data=data)
92 | train_num += 1
93 | for m in range(aug_times-1):
94 | data_aug = data_augmentation(data, np.random.randint(1, 8))
95 | h5f.create_dataset(str(train_num)+"_aug_%d" % (m+1), data=data_aug)
96 | train_num += 1
97 | h5f.close()
98 | # val
99 | print('\nprocess validation data')
100 | # files.clear()
101 | files = glob.glob(os.path.join('D:', 'NH-HAZE_validation', 'HAZY', '*.png'))
102 | files.sort()
103 | h5f = h5py.File('D:/val_input.h5', 'w')
104 | val_num = 0
105 | for i in range(len(files)):
106 | print("file: %s" % files[i])
107 | img = cv2.imread(files[i])
108 | # img = np.expand_dims(img[:, :, :], 0)
109 | img = np.swapaxes(img, 0, 2)
110 | img = np.swapaxes(img, 1, 2)
111 | img = np.float32(normalize(img))
112 | # print(i)
113 | # print(img.shape)
114 | h5f.create_dataset(str(val_num), data=img)
115 | val_num += 1
116 | h5f.close()
117 | # '''
118 | print('\nprocess validation gt')
119 | # files.clear()
120 | files = glob.glob(os.path.join('D:', 'NH-HAZE_validation', 'GT', '*.png'))
121 | files.sort()
122 | h5f = h5py.File('D:/val_gt.h5', 'w')
123 | val_num = 0
124 | for i in range(len(files)):
125 | print("file: %s" % files[i])
126 | img = cv2.imread(files[i])
127 | # img = np.expand_dims(img[:, :, :], 0)
128 | img = np.swapaxes(img, 0, 2)
129 | img = np.swapaxes(img, 1, 2)
130 | img = np.float32(normalize(img))
131 | # print(i)
132 | # print(img.shape)
133 | h5f.create_dataset(str(val_num), data=img)
134 | val_num += 1
135 | h5f.close()
136 | # print('training set, # samples %d\n' % train_num)
137 | print('val set, # samples %d\n' % val_num)
138 | # '''
139 |
140 |
141 | class Dataset(udata.Dataset):
142 | def __init__(self, train=True):
143 | super(Dataset, self).__init__()
144 | self.train = train
145 | if self.train:
146 | h5f = []
147 | for im in range(7):
148 | h5 = h5py.File('/home/user/depthMap/ksm/CVPR/demoire/data/train_input' + str(im) + '.h5', 'r')
149 | h5f.append(h5)
150 | h5f_gt = h5py.File('/home/user/depthMap/ksm/CVPR/demoire/data/train_gt.h5', 'r')
151 | else:
152 | h5f = []
153 | for im in range(7):
154 | h5 = h5py.File('/home/user/depthMap/ksm/CVPR/demoire/data/val_input' + str(im) + '.h5', 'r')
155 | h5f.append(h5)
156 | h5f_gt = h5py.File('/home/user/depthMap/ksm/CVPR/demoire/data/val_gt.h5', 'r')
157 | self.keys = []
158 | for im in range(7):
159 | h5 = h5f[im]
160 | self.keys.append(list(h5.keys()))
161 | h5.close()
162 | self.keys_gt = list(h5f_gt.keys())
163 | h5f_gt.close()
164 |
165 | def __len__(self):
166 | return len(self.keys_gt)
167 |
168 | def __getitem__(self, index):
169 | if self.train:
170 | h5f = []
171 | for im in range(7):
172 | h5 = h5py.File('/home/user/depthMap/ksm/CVPR/demoire/data/train_input' + str(im) + '.h5', 'r')
173 | h5f.append(h5)
174 | h5f_gt = h5py.File('/home/user/depthMap/ksm/CVPR/demoire/data/train_gt.h5', 'r')
175 | else:
176 | h5f = []
177 | for im in range(7):
178 | h5 = h5py.File('/home/user/depthMap/ksm/CVPR/demoire/data/val_input' + str(im) + '.h5', 'r')
179 | h5f.append(h5)
180 | h5f_gt = h5py.File('/home/user/depthMap/ksm/CVPR/demoire/data/val_gt.h5', 'r')
181 | data = []
182 | for im in range(7):
183 | k = self.keys[im][index]
184 | h5 = h5f[im]
185 | kk = h5[k]
186 | data.append(torch.Tensor(np.array(kk)).unsqueeze(0))
187 | h5.close()
188 | key_gt = self.keys_gt[index]
189 | gt = np.array(h5f_gt[key_gt])
190 | h5f_gt.close()
191 | return torch.cat(data, 0), torch.Tensor(gt)
192 |
193 |
194 | class DatasetBurst(udata.Dataset):
195 | def __init__(self, train=True):
196 | super(DatasetBurst, self).__init__()
197 | self.train = train
198 | if self.train:
199 | self.input_list = glob.glob("/home/user/depthMap/ksm/CVPR/demoire/data/train/input/*.png")
200 | self.gt_list = glob.glob("/home/user/depthMap/ksm/CVPR/demoire/data/train/gt/*.png")
201 | else:
202 | self.input_list = glob.glob("/home/user/depthMap/ksm/CVPR/demoire/data/val/input/*.png")
203 | self.gt_list = glob.glob("/home/user/depthMap/ksm/CVPR/demoire/data/val/gt/*.png")
204 | self.frame = int(len(self.input_list)/len(self.gt_list))
205 | self.crop = 128
206 | self.th = 50
207 |
208 | def __len__(self):
209 | return len(self.gt_list)
210 |
211 | def __getitem__(self, index):
212 | order = list(range(len(self.gt_list)))
213 | shuffle(order)
214 | index = order[index]
215 | data_list = []
216 | self.input_list.sort(key=str.lower)
217 | self.gt_list.sort(key=str.lower)
218 | origin = cv2.imread(self.input_list[index * self.frame + 3])
219 | for im in range(self.frame):
220 | data = cv2.imread(self.input_list[index * self.frame + im])
221 | if im != 3:
222 | _, bin2 = cv2.threshold(data, self.th, 255, cv2.THRESH_BINARY)
223 | _, bin3 = cv2.threshold(data, self.th, 255, cv2.THRESH_BINARY_INV)
224 | final2 = cv2.bitwise_and(data, bin2, mask=None)
225 | final3 = cv2.bitwise_and(origin, bin3, mask=None)
226 | data = cv2.bitwise_or(final3, final2, mask=None)
227 | data = np.float32(normalize(data))
228 | data = np.transpose(data, (2, 0, 1))
229 | data = torch.Tensor(data).unsqueeze(0)
230 | data_list.append(data)
231 | gt = cv2.imread(self.gt_list[index])
232 | gt = np.float32(normalize(gt))
233 | gt = np.transpose(gt, (2, 0, 1))
234 | return torch.cat(data_list, 0), torch.Tensor(gt)
235 |
--------------------------------------------------------------------------------