├── .DS_Store
├── .gitattributes
├── .gitignore
├── .idea
├── 20181218Deep-Residual-Network-for-Joint-Demosaicing-and-Super-Resolution.iml
├── encodings.xml
├── misc.xml
├── modules.xml
└── workspace.xml
├── 256x256实验结果对比.xlsx
├── Code_V1.0
├── .DS_Store
├── 2_Level_ResNet_Train.py
├── Bayer_to_HR.py
├── Create_fast_test_img.py
├── DataSet.py
├── HD_test.py
├── NewResNet.py
├── ResNet_Reference.py
├── Test.py
├── Test_class.py
├── Text_Create.py
├── ToBayer.py
├── __pycache__
│ └── .DS_Store
├── fast_train.py
├── main.py
└── temp.py
├── DataSet.py
├── Final_test
├── .DS_Store
└── Result.txt
├── Model.py
├── README.md
├── Saved_Models
├── .DS_Store
└── 20190226Traned_Model
│ └── .DS_Store
├── TEST_DATA
├── .DS_Store
├── 0data.TIF
├── 0label.TIF
├── 1data.TIF
├── 1label.TIF
├── 2data.TIF
├── 2label.TIF
├── 3data.TIF
├── 3label.TIF
├── 4data.TIF
├── 4label.TIF
├── 5data.TIF
├── 5label.TIF
├── 6data.TIF
├── 6label.TIF
├── 7data.TIF
├── 7label.TIF
├── 8data.TIF
├── 8label.TIF
├── 9data.TIF
├── 9label.TIF
└── TEST_DATA.txt
├── Test_class.py
├── Train.py
├── Train_control_pannel.py
├── __pycache__
├── .DS_Store
├── DataSet.cpython-37.pyc
├── Model.cpython-37.pyc
├── Test_class.cpython-37.pyc
└── Train.cpython-37.pyc
├── img_to_imgblk.py
└── 测试图片名单.xlsx
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Linwei-Chen/Deep-Residual-Network-for-JointDemosaicing-and-Super-Resolution/a224f0ea673d70c26ee17aec9f27e1a7c31cbe8e/.DS_Store
--------------------------------------------------------------------------------
/.gitattributes:
--------------------------------------------------------------------------------
1 | *.pkl filter=lfs diff=lfs merge=lfs -text
2 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pkl
2 | <<<<<<< HEAD
3 | .idea
4 | __*/
5 | =======
6 | __*
7 | .D*
8 | >>>>>>> fa844c1f1b12cb3fa0e7faf7e88bf8e93828049f
9 |
--------------------------------------------------------------------------------
/.idea/20181218Deep-Residual-Network-for-Joint-Demosaicing-and-Super-Resolution.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
--------------------------------------------------------------------------------
/.idea/encodings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
82 |
83 |
84 |
85 | K
86 | eta
87 | i
88 | N
89 | Pe
90 | Me
91 |
92 | ,
93 | block
94 | col_real
95 | row_real
96 | blk_size
97 | clock()
98 | process
99 | Process_time
100 | process_time
101 | A
102 | to_PIL_image
103 | SUB_EPOCH
104 | batch_counter
105 | optimize
106 | self.
107 | Net
108 | save
109 | BEST_MODEL_SAVE_PATH
110 | TEST_DATA
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
231 |
232 |
233 |
234 |
235 |
236 |
237 |
238 |
239 |
240 |
241 |
242 |
243 |
244 |
245 |
246 |
247 |
248 |
249 |
250 |
251 |
252 |
253 |
254 |
255 |
256 |
257 |
258 |
259 |
260 |
261 |
262 |
263 |
264 |
265 |
266 |
267 |
268 |
269 |
270 |
271 |
272 |
273 |
274 |
275 |
276 |
277 |
278 |
279 |
280 |
281 |
282 |
283 |
284 |
285 |
286 |
287 |
288 |
289 |
290 |
291 |
292 |
293 |
294 |
295 |
296 |
297 |
298 |
299 |
300 |
301 |
302 |
303 |
304 |
305 |
306 |
307 |
308 |
309 |
310 |
311 |
312 |
313 |
314 |
315 |
316 |
317 |
318 |
319 |
320 |
321 |
322 |
323 |
324 |
325 |
326 |
327 |
328 |
329 |
330 |
331 |
332 |
333 |
334 |
335 |
336 |
337 |
338 |
339 |
340 |
341 |
342 |
343 |
344 | $USER_HOME$/.subversion
345 |
346 |
347 |
348 |
349 |
350 | 1550211700126
351 |
352 |
353 | 1550211700126
354 |
355 |
356 |
357 |
358 |
359 |
360 |
361 |
362 |
363 |
364 |
365 |
366 |
367 |
368 |
369 |
370 |
371 |
372 |
373 |
374 |
375 |
376 |
377 |
378 |
379 |
380 |
381 |
382 |
383 |
384 |
385 |
386 |
387 |
388 |
389 |
390 |
391 |
392 |
393 |
394 |
395 |
396 |
397 |
398 |
399 |
400 |
401 |
402 |
403 |
404 |
405 |
406 |
407 |
408 |
409 |
410 |
411 |
412 |
413 |
414 |
415 |
416 |
417 |
418 |
419 |
420 |
421 |
422 |
423 |
424 |
425 |
426 |
427 |
428 |
429 |
430 |
431 |
432 |
433 |
434 |
435 |
436 |
437 |
438 |
439 |
440 |
441 |
442 |
443 |
444 |
445 |
446 |
447 |
448 |
449 |
450 |
451 |
452 |
453 |
454 |
455 |
456 |
457 |
458 |
459 |
460 |
461 |
462 |
463 |
464 |
465 |
466 |
467 |
468 |
469 |
470 |
471 |
472 |
473 |
474 |
475 |
476 |
477 |
478 |
479 |
480 |
481 |
482 |
483 |
484 |
485 |
486 |
487 |
488 |
489 |
490 |
491 |
492 |
493 |
494 |
495 |
496 |
497 |
498 |
499 |
500 |
501 |
502 |
503 |
504 |
505 |
506 |
507 |
508 |
509 |
510 |
511 |
512 |
513 |
514 |
515 |
516 |
517 |
518 |
519 |
520 |
521 |
522 |
523 |
524 |
525 |
526 |
527 |
528 |
529 |
530 |
531 |
532 |
533 |
534 |
535 |
536 |
537 |
538 |
539 |
540 |
541 |
542 |
543 |
544 |
545 |
546 |
547 |
548 |
549 |
550 |
551 |
552 |
553 |
554 |
555 |
556 |
557 |
558 |
559 |
560 |
561 |
562 |
563 |
564 |
565 |
566 |
567 |
568 |
569 |
570 |
571 |
572 |
573 |
574 |
575 |
576 |
577 |
578 |
579 |
580 |
581 |
582 |
583 |
584 |
585 |
586 |
587 |
588 |
589 |
590 |
--------------------------------------------------------------------------------
/256x256实验结果对比.xlsx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Linwei-Chen/Deep-Residual-Network-for-JointDemosaicing-and-Super-Resolution/a224f0ea673d70c26ee17aec9f27e1a7c31cbe8e/256x256实验结果对比.xlsx
--------------------------------------------------------------------------------
/Code_V1.0/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Linwei-Chen/Deep-Residual-Network-for-JointDemosaicing-and-Super-Resolution/a224f0ea673d70c26ee17aec9f27e1a7c31cbe8e/Code_V1.0/.DS_Store
--------------------------------------------------------------------------------
/Code_V1.0/2_Level_ResNet_Train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.utils.data as data
3 | import torch.nn as nn
4 | import torchvision
5 | import torchvision.transforms as transforms
6 | import numpy as np
7 | from multiprocessing import Process
8 |
9 | from NewResNet import Net
10 | from DataSet import CustomDataset
11 | from Test_class import Run_test
12 |
13 | # *** 超参数***
14 | BATC_SIZE = 1
15 |
16 | # 保存和恢复模型
17 | # https://www.cnblogs.com/nkh222/p/7656623.html
18 | # https://blog.csdn.net/quincuntial/article/details/78045036
19 | #
20 | # 保存
21 | # torch.save(the_model.state_dict(), PATH)
22 | # 恢复
23 | # the_model = TheModelClass(*args, **kwargs)
24 | # the_model.load_state_dict(torch.load(PATH))
25 |
26 | # # 只保存网络的参数, 官方推荐的方式
27 | # torch.save(net.state_dict(), 'net_params.pkl')
28 | ## 加载网络参数
29 | # net.load_state_dict(torch.load('net_params.pkl'))
30 |
31 | print("Loading the saving Model...")
32 | MyNet = Net(2)
33 | try:
34 | MyNet.load_state_dict(torch.load('./Model_ResNet=2.pkl'))
35 | except:
36 | print("Loading Fail.")
37 | pass
38 |
39 | # 训练集与测试集的路径
40 | train_data_path = "/Users/chenlinwei/Desktop/计算机学习资料/TrainData/RAISE_1K/Train_Data.txt"
41 | test_data_path = "/Users/chenlinwei/Desktop/计算机学习资料/TrainData/RAISE_1K/Test_Data.txt"
42 | all_data_path = "/Users/chenlinwei/Desktop/计算机学习资料/TrainData/RAISE_1K/Data_Read.txt"
43 |
44 | print("Loading the Training data...")
45 | MyData = CustomDataset(train_data_path, random_augment=10, block_size=32)
46 |
47 | # CLASS torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False,
48 | # sampler=None, batch_sampler=None, num_workers=0, collate_fn=,
49 | # pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)
50 |
51 | train_data = data.DataLoader(dataset=MyData,
52 | batch_size=10,
53 | shuffle=True)
54 |
55 | # CLASS torch.optim.Adam(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
56 | Optimizer = torch.optim.Adam(MyNet.parameters(), lr=0.00000001, betas=(0.9, 0.999), eps=1e-08)
57 | # CLASS torch.nn.MSELoss(size_average=None, reduce=None, reduction='mean')
58 | Loss_Func = nn.MSELoss()
59 |
60 | EPOCH = 1000
61 | to_PIL_image = transforms.ToPILImage()
62 |
63 |
64 | def adjust_learning_rate(optimizer):
65 | for param_group in optimizer.param_groups:
66 | param_group['lr'] = param_group['lr'] / 2
67 |
68 |
69 | print("Start training...")
70 | for epoch in range(EPOCH):
71 | for step, (data, label) in enumerate(train_data):
72 | # print(type(data), type(label))
73 | # print(data.shape, label.shape)
74 | # print(data, label)
75 | # for i in range(data.size()[0]):
76 | # img1 = to_PIL_image(data)
77 | # img2 = to_PIL_image(label)
78 | # img1.show()
79 | # img2.show()
80 | out = MyNet(data)
81 | # print(type(out), out.shape)
82 | loss = Loss_Func(out, label)
83 | Optimizer.zero_grad()
84 | loss.backward()
85 | Optimizer.step()
86 | print(loss)
87 | print(epoch, step)
88 | if step != 0 and 0 == step % 10:
89 | print("Saving the model...")
90 | torch.save(MyNet.state_dict(), './Model_ResNet=2.pkl')
91 | elif step != 0 and 0 == step % 99:
92 | torch.set_default_tensor_type('torch.DoubleTensor')
93 | # def Run_test(net, model_path, test_data_path, save_to, as_name, _block_size=32):
94 | multi_Process = Process(target=Run_test(net=Net(2), model_path='./Model_ResNet=2.pkl',
95 | test_data_path="/Users/linweichen/Desktop/计算机学习资料/TrainData/RAISE_1K/Fast_Test_Data.txt",
96 | save_to="./test_result/2层ResNet模型/", as_name="test_result_PSNR.txt",
97 | _block_size=128))
98 | multi_Process.start()
99 |
--------------------------------------------------------------------------------
/Code_V1.0/Bayer_to_HR.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from PIL import Image
3 | import numpy as np
4 | from Test_class import LR_to_HR
5 | from NewResNet import Net
6 | import random
7 |
8 | torch.set_default_tensor_type('torch.DoubleTensor')
9 | img_path = "/Users/chenlinwei/Desktop/2.jpg"
10 | MODEL_PATH = './Final_Model.pkl'
11 | # def LR_to_HR(net, model_path, img_path, save_to, _block_size=32):
12 |
13 | if 1:
14 | LR_to_HR(net=Net(24), model_path=MODEL_PATH,
15 | img_path=img_path,
16 | save_to="./" + str(random.random()), _block_size=64 )
17 |
--------------------------------------------------------------------------------
/Code_V1.0/Create_fast_test_img.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import numpy as np
3 |
4 |
5 | # https://www.jianshu.com/p/85eba2a51142
6 | # PIL.Image.NEAREST:最低质量
7 | # PIL.Image.BILINEAR:双线性
8 | # PIL.Image.BICUBIC:三次样条插值
9 | # PIL.Image.ANTIALIAS:最高质量
10 |
11 | def DownSize_256x256(img):
12 | for i in range(6):
13 | w, h = img.size
14 | # print(int(w / 1.25), int(h / 1.25))
15 | img = img.resize((int(w / 1.25), int(h / 1.25)), Image.ANTIALIAS)
16 |
17 | min_size = min(w/2, h/2)
18 | img = img.crop((0, 0, min_size, min_size))
19 | img = img.resize((256, 256), Image.ANTIALIAS)
20 | return img
21 |
22 |
23 | def DownSize_2(img):
24 | pass
25 |
26 |
27 | def To_Bayer(img):
28 | w, h = img.size
29 | img = img.resize((int(w / 2), int(h / 2)), Image.ANTIALIAS)
30 | w, h = img.size
31 | # r,g,b=img.split()
32 | data = np.array(img)
33 | """
34 | R G R G
35 | G B G B
36 | R G R G
37 | G B G B
38 | """
39 | bayer_mono = np.zeros((h, w))
40 | for r in range(h):
41 | for c in range(w):
42 | if (0 == r % 2):
43 | if (1 == c % 2):
44 | data[r, c, 0] = 0
45 | data[r, c, 2] = 0
46 |
47 | bayer_mono[r, c] = data[r, c, 1]
48 | else:
49 | data[r, c, 1] = 0
50 | data[r, c, 2] = 0
51 |
52 | bayer_mono[r, c] = data[r, c, 0]
53 | else:
54 | if (0 == c % 2):
55 | data[r, c, 0] = 0
56 | data[r, c, 2] = 0
57 |
58 | bayer_mono[r, c] = data[r, c, 1]
59 | else:
60 | data[r, c, 0] = 0
61 | data[r, c, 1] = 0
62 |
63 | bayer_mono[r, c] = data[r, c, 2]
64 |
65 | # 三通道Bayer图像
66 | bayer = Image.fromarray(data)
67 | # bayer.show()
68 |
69 | return bayer
70 |
71 |
72 | resize_dir = "/Users/linweichen/Desktop/计算机学习资料/TrainData/RAISE_1K/Resize/"
73 | save_path = "/Users/linweichen/Desktop/计算机学习资料/TrainData/RAISE_1K/Fast_test/"
74 |
75 | for i in range(1, 1001):
76 | img = Image.open(resize_dir + "/real" + str(i) + ".TIF", "r")
77 | # img.show()
78 | img = DownSize_256x256(img)
79 | # img.show()
80 | img.save(save_path + str(i) + "real" + ".TIF", "TIFF")
81 |
82 | bayer=To_Bayer(img)
83 | bayer.save(save_path + str(i) + "bayer" + ".TIF", "TIFF")
84 |
--------------------------------------------------------------------------------
/Code_V1.0/DataSet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torchvision
4 | import torchvision.transforms as transforms
5 | import torch.utils.data as data
6 | from PIL import Image
7 | import random
8 | import numpy as np
9 |
10 |
11 | # Reference link:
12 | # 如何构建数据集
13 | # https://oidiotlin.com/create-custom-dataset-in-pytorch/
14 | # https://www.pytorchtutorial.com/pytorch-custom-dataset-examples/
15 |
16 | # transforms 函数的使用
17 | # https://www.jianshu.com/p/13e31d619c15
18 | # ToTensor:convert a PIL image to tensor (H*W*C) in range [0,255] to a torch.Tensor(C*H*W) in the range [0.0,1.0]
19 |
20 | # torch.set_default_tensor_type('torch.DoubleTensor')
21 | class CustomDataset(data.Dataset):
22 | # file_path TXT文件路径
23 | # random_augment=1 随机裁剪数据增强
24 | # block_size=64 裁剪大小
25 | def __init__(self, file_path, block_size=64):
26 | with open(file_path, 'r') as file:
27 | self.imgs = list(map(lambda line: line.strip().split(' '), file))
28 | self.Block_size = block_size
29 | print("DataSet Size is: ", self.__len__())
30 | # print(len(self.imgs))
31 | # for i in self.imgs:
32 | # print(len(i))
33 |
34 | def __getitem__(self, index):
35 | # 注意!!! 读入的Bayer图像最左上为:
36 | # R G
37 | # G B
38 | # Reference API
39 | # class torchvision.transforms.RandomCrop(size, padding=0, pad_if_needed=False)
40 | # class torchvision.transforms.Compose([transforms_list,])->生成一个函数
41 | data_path, label_path = self.imgs[index]
42 | # print(index, data_path, label_path)
43 |
44 | data = Image.open(data_path).convert('L')
45 | label = Image.open(label_path).convert('RGB')
46 |
47 | '''
48 | # 生成截取图块
49 | w, h = data.size
50 | block_size = self.Block_size
51 |
52 | # 图块左上角坐标,为了保证Bayer阵列,必须取偶数
53 | random_w_pos = random.randint(0, int((w - block_size) / 2)) * 2
54 | random_h_pos = random.randint(0, int((h - block_size) / 2)) * 2
55 |
56 | # print("w=", w, "h=", h, "Crop at: ", random_w_pos, random_h_pos)
57 | # Image.crop(box=None)
58 | # Returns a rectangular region from this image. The box is a 4-tuple defining the left, upper, right, and lower pixel coordinate.
59 | # box – The crop rectangle, as a (left, upper, right, lower)-tuple.
60 |
61 | # (random_w_pos:random_w_pos + 32, random_h_pos: random_h_pos + 32)
62 |
63 | # 转换成Tensor之前 一定要convert一下
64 | # !!! https://discuss.pytorch.org/t/runtimeerror-invalid-argument-0/17919/5
65 | data = data.crop((random_w_pos, random_h_pos,
66 | random_w_pos + block_size, random_h_pos + block_size)).convert('L')
67 |
68 | # label[2 * random_w_pos:2 * random_w_pos + 64, 2 * random_h_pos:2 * random_h_pos + 64]
69 | label = label.crop((2 * random_w_pos, 2 * random_h_pos,
70 | 2 * random_w_pos + 2 * block_size, 2 * random_h_pos + 2 * block_size)).convert('RGB')
71 | '''
72 | trans = transforms.Compose([transforms.ToTensor()])
73 |
74 | data_img = trans(data)
75 | label_img = trans(label)
76 |
77 | return data_img, label_img
78 |
79 | def __len__(self):
80 | return len(self.imgs)
81 |
--------------------------------------------------------------------------------
/Code_V1.0/HD_test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from Test_class import Run_test
3 | from NewResNet import Net
4 |
5 | test_data = "/Users/chenlinwei/Desktop/计算机学习资料/TrainData/RAISE_1K/Train_Data.txt"
6 |
7 | # 要加上下面这行,不然会出错
8 | torch.set_default_tensor_type('torch.DoubleTensor')
9 |
10 | # def Run_test(net, model_path, test_data_path,save_to, as_name):
11 | if 1:
12 | Run_test(net=Net(2), model_path='./Model_ResNet=2.pkl',
13 | test_data_path=test_data,
14 | save_to="./test_result_HD/2-level/", as_name="test_result_PSNR.txt", _block_size=256,
15 | _test_list=range(0, 10))
16 |
17 | else:
18 | Run_test(net=Net(24), model_path='./Model.pkl',
19 | test_data_path=test_data,
20 | save_to="./test_result_HD/", as_name="test_result_PSNR.txt", _block_size=256, _test_list=range(1, 10))
21 |
--------------------------------------------------------------------------------
/Code_V1.0/NewResNet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torchvision
4 | import torchvision.transforms as transforms
5 | import numpy as np
6 |
7 |
8 | # ResNet
9 | # https://blog.csdn.net/sunqiande88/article/details/80100891
10 | class ResidualBlock(nn.Module):
11 | def __init__(self):
12 | super(ResidualBlock, self).__init__()
13 | self.left = nn.Sequential(
14 | nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, bias=True),
15 | nn.PReLU(),
16 | nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, bias=True),
17 | )
18 | self.shortcut = nn.Sequential()
19 | self.active_f = nn.PReLU()
20 |
21 | def forward(self, x):
22 | out = self.left(x)
23 | out += self.shortcut(x)
24 | out = self.active_f(out)
25 | return out
26 |
27 |
28 | class Net(nn.Module):
29 |
30 | def __init__(self, resnet_level=2):
31 | super(Net, self).__init__()
32 |
33 | # ***Stage1***
34 | # class torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True)
35 | self.stage1_1_conv4x4 = nn.Conv2d(in_channels=1, out_channels=256,
36 | kernel_size=4, stride=2, padding=1, bias=True)
37 | # Reference:
38 | # CLASS torch.nn.PixelShuffle(upscale_factor)
39 | # Examples:
40 | #
41 | # >>> pixel_shuffle = nn.PixelShuffle(3)
42 | # >>> input = torch.randn(1, 9, 4, 4)
43 | # >>> output = pixel_shuffle(input)
44 | # >>> print(output.size())
45 | # torch.Size([1, 1, 12, 12])
46 |
47 | self.stage1_2_SP_conv = nn.PixelShuffle(2)
48 | self.stage1_2_conv4x4 = nn.Conv2d(in_channels=64, out_channels=256,
49 | kernel_size=3, stride=1, padding=1, bias=True)
50 |
51 | # CLASS torch.nn.PReLU(num_parameters=1, init=0.25)
52 | self.stage1_2_PReLU = nn.PReLU()
53 |
54 | # ***Stage2***
55 | self.stage2_ResNetBlock = []
56 | for i in range(resnet_level):
57 | self.stage2_ResNetBlock.append(ResidualBlock())
58 | self.stage2_ResNetBlock = nn.Sequential(*self.stage2_ResNetBlock)
59 |
60 | # ***Stage3***
61 | self.stage3_1_SP_conv = nn.PixelShuffle(2)
62 | self.stage3_2_conv3x3 = nn.Conv2d(in_channels=64, out_channels=256,
63 | kernel_size=3, stride=1, padding=1, bias=True)
64 | self.stage3_2_PReLU = nn.PReLU()
65 | self.stage3_3_conv3x3 = nn.Conv2d(in_channels=256, out_channels=3,
66 | kernel_size=3, stride=1, padding=1, bias=True)
67 |
68 | def forward(self, x):
69 | out = self.stage1_1_conv4x4(x)
70 | out = self.stage1_2_SP_conv(out)
71 | out = self.stage1_2_conv4x4(out)
72 | out = self.stage1_2_PReLU(out)
73 |
74 | out = self.stage2_ResNetBlock(out)
75 |
76 | out = self.stage3_1_SP_conv(out)
77 | out = self.stage3_2_conv3x3(out)
78 | out = self.stage3_2_PReLU(out)
79 | out = self.stage3_3_conv3x3(out)
80 |
81 | return out
82 |
--------------------------------------------------------------------------------
/Code_V1.0/ResNet_Reference.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | # ---------------------------------------------------------------------------- #
4 | # An implementation of https://arxiv.org/pdf/1512.03385.pdf #
5 | # See section 4.2 for the model architecture on CIFAR-10 #
6 | # Some part of the code was referenced from below #
7 | # https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py #
8 | # ---------------------------------------------------------------------------- #
9 |
10 | import torch
11 | import torchvision
12 | import numpy as np
13 |
14 | import torch
15 | import torch.nn as nn
16 | import torchvision
17 | import torchvision.transforms as transforms
18 |
19 |
20 | # Device configuration
21 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
22 |
23 | # Hyper-parameters
24 | num_epochs = 80
25 | learning_rate = 0.001
26 |
27 | # Image preprocessing modules
28 | transform = transforms.Compose([
29 | transforms.Pad(4),
30 | transforms.RandomHorizontalFlip(),
31 | transforms.RandomCrop(32),
32 | transforms.ToTensor()])
33 |
34 | # CIFAR-10 dataset
35 | train_dataset = torchvision.datasets.CIFAR10(root='./data/',
36 | train=True,
37 | transform=transform,
38 | download=True)
39 |
40 | test_dataset = torchvision.datasets.CIFAR10(root='./data/',
41 | train=False,
42 | transform=transforms.ToTensor())
43 |
44 | # Data loader
45 | train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
46 | batch_size=100,
47 | shuffle=True)
48 |
49 | test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
50 | batch_size=100,
51 | shuffle=False)
52 |
53 | # 3x3 convolution
54 | def conv3x3(in_channels, out_channels, stride=1):
55 | return nn.Conv2d(in_channels, out_channels, kernel_size=3,
56 | stride=stride, padding=1, bias=False)
57 |
58 | # Residual block
59 | class ResidualBlock(nn.Module):
60 | def __init__(self, in_channels, out_channels, stride=1, downsample=None):
61 | super(ResidualBlock, self).__init__()
62 | self.conv1 = conv3x3(in_channels, out_channels, stride)
63 | self.bn1 = nn.BatchNorm2d(out_channels)
64 | self.relu = nn.ReLU(inplace=True)
65 | self.conv2 = conv3x3(out_channels, out_channels)
66 | self.bn2 = nn.BatchNorm2d(out_channels)
67 | self.downsample = downsample
68 |
69 | def forward(self, x):
70 | residual = x
71 | out = self.conv1(x)
72 | out = self.bn1(out)
73 | out = self.relu(out)
74 | out = self.conv2(out)
75 | out = self.bn2(out)
76 | if self.downsample:
77 | residual = self.downsample(x)
78 | out += residual
79 | out = self.relu(out)
80 | return out
81 |
82 | # ResNet
83 | class ResNet(nn.Module):
84 | def __init__(self, block, layers, num_classes=10):
85 | super(ResNet, self).__init__()
86 | self.in_channels = 16
87 | self.conv = conv3x3(3, 16)
88 | self.bn = nn.BatchNorm2d(16)
89 | self.relu = nn.ReLU(inplace=True)
90 | self.layer1 = self.make_layer(block, 16, layers[0])
91 | self.layer2 = self.make_layer(block, 32, layers[1], 2)
92 | self.layer3 = self.make_layer(block, 64, layers[2], 2)
93 | self.avg_pool = nn.AvgPool2d(8)
94 | self.fc = nn.Linear(64, num_classes)
95 |
96 | def make_layer(self, block, out_channels, blocks, stride=1):
97 | downsample = None
98 | if (stride != 1) or (self.in_channels != out_channels):
99 | downsample = nn.Sequential(
100 | conv3x3(self.in_channels, out_channels, stride=stride),
101 | nn.BatchNorm2d(out_channels))
102 | layers = []
103 | layers.append(block(self.in_channels, out_channels, stride, downsample))
104 | self.in_channels = out_channels
105 | for i in range(1, blocks):
106 | layers.append(block(out_channels, out_channels))
107 | return nn.Sequential(*layers)
108 |
109 | def forward(self, x):
110 | out = self.conv(x)
111 | out = self.bn(out)
112 | out = self.relu(out)
113 | out = self.layer1(out)
114 | out = self.layer2(out)
115 | out = self.layer3(out)
116 | out = self.avg_pool(out)
117 | out = out.view(out.size(0), -1)
118 | out = self.fc(out)
119 | return out
120 |
121 | model = ResNet(ResidualBlock, [2, 2, 2]).to(device)
122 |
123 |
124 | # Loss and optimizer
125 | criterion = nn.CrossEntropyLoss()
126 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
127 |
128 | # For updating learning rate
129 | def update_lr(optimizer, lr):
130 | for param_group in optimizer.param_groups:
131 | param_group['lr'] = lr
132 |
133 | # Train the model
134 | total_step = len(train_loader)
135 | curr_lr = learning_rate
136 | for epoch in range(num_epochs):
137 | for i, (images, labels) in enumerate(train_loader):
138 | images = images.to(device)
139 | labels = labels.to(device)
140 |
141 | # Forward pass
142 | outputs = model(images)
143 | loss = criterion(outputs, labels)
144 |
145 | # Backward and optimize
146 | optimizer.zero_grad()
147 | loss.backward()
148 | optimizer.step()
149 |
150 | if (i+1) % 100 == 0:
151 | print ("Epoch [{}/{}], Step [{}/{}] Loss: {:.4f}"
152 | .format(epoch+1, num_epochs, i+1, total_step, loss.item()))
153 |
154 | # Decay learning rate
155 | if (epoch+1) % 20 == 0:
156 | curr_lr /= 3
157 | update_lr(optimizer, curr_lr)
158 |
159 | # Test the model
160 | model.eval()
161 | with torch.no_grad():
162 | correct = 0
163 | total = 0
164 | for images, labels in test_loader:
165 | images = images.to(device)
166 | labels = labels.to(device)
167 | outputs = model(images)
168 | _, predicted = torch.max(outputs.data, 1)
169 | total += labels.size(0)
170 | correct += (predicted == labels).sum().item()
171 |
172 | print('Accuracy of the model on the test images: {} %'.format(100 * correct / total))
173 |
174 | # Save the model checkpoint
175 | torch.save(model.state_dict(), 'resnet.ckpt')
176 |
--------------------------------------------------------------------------------
/Code_V1.0/Test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from Test_class import Run_test
3 | from NewResNet import Net
4 |
5 | test_data = "./TEST_DATA/TEST_DATA.txt"
6 | cross_test_data = "./8K_CROSS_DATA/8K_CROSS_DATA_10PIC.txt"
7 | test_data_128x128 = "/Users/chenlinwei/Desktop/计算机学习资料/TrainData/RAISE_1K/Fast_Test_Data.txt"
8 | MODEL_PATH = './Final_Model.pkl'
9 | HD_train_data = '/Users/chenlinwei/Desktop/计算机学习资料/TrainData/RAISE_1K/Train_Data.txt'
10 | final_test = "./8K_TEST_DATA/8K_TEST_DATA.txt"
11 |
12 | # 要加上下面这行,不然会出错
13 | torch.set_default_tensor_type('torch.DoubleTensor')
14 |
15 | # def Run_test(net, model_path, test_data_path,save_to, as_name):
16 | if 1:
17 | Run_test(net=Net(24), model_path=MODEL_PATH,
18 | test_data_path=test_data_128x128,
19 | save_to="./test_result/", as_name="Final_Model", _block_size=64)
20 |
21 | if 0:
22 | Run_test(net=Net(24), model_path=MODEL_PATH,
23 | test_data_path=test_data,
24 | save_to="./test_result/", as_name="Final_Model_HD", _block_size=64)
25 | if 0:
26 | Run_test(net=Net(24), model_path=MODEL_PATH,
27 | test_data_path=cross_test_data,
28 | save_to="./test_result/", as_name="cross_test", _block_size=64,
29 | _test_list=range(0, 10))
30 | if 0:
31 | Run_test(net=Net(24), model_path=MODEL_PATH,
32 | test_data_path=final_test,
33 | save_to="./Final_test/", as_name="final_test", _block_size=64,
34 | _test_list=range(0, 50))
35 |
36 | if 0:
37 | Run_test(net=Net(24), model_path=MODEL_PATH,
38 | test_data_path=HD_train_data,
39 | save_to="./test_result_HD/", as_name="HD_train_data", _block_size=64,
40 | _test_list=range(0, 10))
41 |
--------------------------------------------------------------------------------
/Code_V1.0/Test_class.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 | from PIL import Image
4 | import numpy as np
5 | import time
6 | import math
7 | from skimage.measure import compare_ssim, compare_psnr
8 | import torchvision.transforms as transforms
9 | import torch.utils.data as data
10 |
11 | PADDING = 2
12 | FORWORD_BLOCK_NUM = 1
13 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
14 | print("cuda:", torch.cuda.is_available(), "GPUs", torch.cuda.device_count())
15 |
16 |
17 | class TestDataset(data.Dataset):
18 | def __init__(self, img_blks):
19 | self.img_blks = torch.from_numpy(img_blks / 255.).double()
20 |
21 | def __len__(self):
22 | return self.img_blks.shape[0]
23 |
24 | def __getitem__(self, index):
25 | return self.img_blks[index]
26 |
27 |
28 | class Test:
29 | def __init__(self, file_path, Net, block_size=64):
30 | self.Net = Net.to(device)
31 | self.block_size = block_size
32 | if file_path != '':
33 | with open(file_path, 'r') as file:
34 | self.imgs = list(map(lambda line: line.strip().split(' '), file))
35 |
36 | def padding_and_img_to_blks(self, raw_img):
37 | # *** step1: 图像预处理
38 | # s为padding大小
39 | s = PADDING
40 | block_size = self.block_size
41 | # print("block_size:", block_size)
42 | # print("raw_img.shape: ",raw_img.shape)
43 | w, h = raw_img.size
44 | print(w, h)
45 | raw_data_arr = np.array(raw_img)
46 | print("raw_data_arr shape", raw_data_arr.shape)
47 |
48 | # 将图片分割成块,向上取整
49 | row_block = math.ceil(h / block_size)
50 | col_block = math.ceil(w / block_size)
51 | block_total = row_block * col_block
52 |
53 | # 根据取整情况,对原图进行镜像Padding
54 | raw_data_arr = np.pad(raw_data_arr, ((s, row_block * block_size - h + s),
55 | (s, col_block * block_size - w + s)), mode='reflect')
56 |
57 | # *** step2: 将原图切割成块保存于blks
58 | blks = np.zeros((block_total, 1, block_size + 2 * s, block_size + 2 * s))
59 | blk_counter = 0
60 | for i in range(row_block):
61 | for j in range(col_block):
62 | r = i * block_size
63 | c = j * block_size
64 | blks[blk_counter, :, :, :] = raw_data_arr[r: r + block_size + 2 * s, c:c + block_size + 2 * s]
65 | blk_counter = blk_counter + 1
66 | return blks, h, w
67 |
68 | def run_forward(self, blks, h, w):
69 | block_size = self.block_size
70 | row_block = math.ceil(h / block_size)
71 | col_block = math.ceil(w / block_size)
72 | s = PADDING
73 | blk_num = blks.shape[0]
74 |
75 | res = np.zeros((blk_num, 3, (block_size + 2 * s) * 2, (block_size + 2 * s) * 2))
76 | res = torch.from_numpy(res) # .to(device)
77 |
78 | '''
79 | img_dataset = TestDataset(blks)
80 | img_data = data.DataLoader(dataset=img_dataset,
81 | batch_size=FORWORD_BLOCK_NUM,
82 | shuffle=False)
83 |
84 |
85 | l = range(math.ceil(blk_num / FORWORD_BLOCK_NUM))
86 | start = time.perf_counter()
87 | for i, img_blks in enumerate(img_data):
88 | blk_No = min((l[i] + 1) * FORWORD_BLOCK_NUM, blk_num)
89 | img_blks = img_blks.to(device)
90 | out = self.Net(img_blks).detach()
91 | used_time = time.perf_counter() - start
92 | print(used_time)
93 | res[i * FORWORD_BLOCK_NUM:min((l[i] + 1) * FORWORD_BLOCK_NUM, blk_num), :, :] = out
94 | used_time = time.perf_counter() - start
95 | print("Processing:", blk_No, '/', blk_num)
96 | print("Now:", used_time)
97 | print("Remain:", used_time / blk_No * blk_num - used_time)
98 | '''
99 | start = time.perf_counter()
100 | # blks = torch.from_numpy(blks / 255.).double().to(device)
101 | with torch.no_grad():
102 | for i in range(math.ceil(blk_num / FORWORD_BLOCK_NUM)):
103 | blk_No = min((i + 1) * FORWORD_BLOCK_NUM, blk_num)
104 | temp = blks[i * FORWORD_BLOCK_NUM:blk_No, :, :, :]
105 | temp = torch.from_numpy(temp / 255.).to(device)
106 | temp = self.Net(temp).detach()
107 | res[i * FORWORD_BLOCK_NUM:min((i + 1) * FORWORD_BLOCK_NUM, blk_num), :, :] = temp.cpu()
108 |
109 | # 输出时间
110 | used_time = time.perf_counter() - start
111 | print("Processing:", blk_No, '/', blk_num)
112 | print("Now:", used_time)
113 | print("Remain:", used_time / blk_No * blk_num - used_time)
114 |
115 | # 计算完成,将图块拼回图像
116 | print("BLOCKS:", blk_num)
117 | start = time.perf_counter()
118 | res_img = np.zeros((3, 2 * h, 2 * w))
119 | res = res.cpu().numpy() * 255.0
120 | # res = res.clip(0, 255)
121 | print('Copy use:', time.perf_counter() - start)
122 | # res_img = torch.from_numpy(res_img).double().to(device)
123 | for i in range(row_block):
124 | for j in range(col_block):
125 | r = i * block_size
126 | c = j * block_size
127 | block_h = min(h, r + block_size) - r
128 | block_w = min(w, c + block_size) - c
129 | res_img[:, 2 * r: 2 * (r + block_h), 2 * c:2 * (c + block_w)] = \
130 | res[i * col_block + j, :, 2 * s:2 * (block_h + s), 2 * s:2 * (block_w + s)]
131 | print('Re-range use:', time.perf_counter() - start)
132 |
133 | # res_img = Image.fromarray((H,W,3))
134 | # 将矩阵形状从(3,H,W)变成(H,W,3)
135 | start = time.perf_counter()
136 | # res_img = res_img.to('cpu').numpy() * 255.
137 | print('Copy use:', time.perf_counter() - start)
138 |
139 | r = Image.fromarray(res_img[0]).convert('L')
140 | g = Image.fromarray(res_img[1]).convert('L')
141 | b = Image.fromarray(res_img[2]).convert('L')
142 | # print(res_img)
143 | print("Output shape:", res_img.shape)
144 | # size = shape[1], shape[0]
145 | # res_img = Image.fromarray(res * 255.,mode='RGB')
146 | # PIL.Image.merge(mode, bands)
147 | res_img = Image.merge('RGB', (r, g, b))
148 |
149 | return res_img.convert('RGB')
150 |
151 | def run(self, raw_img):
152 | blks, h, w = self.padding_and_img_to_blks(raw_img)
153 | return self.run_forward(blks=blks, h=h, w=w)
154 |
155 | def test(self, save_path="./test_result/", filename="test_result_PSNR", test_list=range(0, 10)):
156 | # random_pick = random.sample(range(0, 100), 100)
157 | random_pick = range(0, 10)
158 | PSNR_SUM = 0
159 | SSIM_SUM = 0
160 | for i in test_list:
161 | data_path, label_path = self.imgs[i]
162 | data = Image.open(data_path).convert('L')
163 | label = Image.open(label_path).convert('RGB')
164 | # data.show()
165 | # label.show()
166 | Net_img = self.run(data)
167 | # Net_img.show()
168 |
169 | label_np, Net_img_np = self.to_same_size_ndarray(label, Net_img)
170 | PSNR = compare_psnr(im_true=label_np, im_test=Net_img_np)
171 | # http://www.voidcn.com/article/p-auyocqzg-bac.html
172 | SSIM = compare_ssim(X=label_np, Y=Net_img_np, win_size=11, multichannel=True)
173 | print("PSNR:", PSNR, "SSIM:", SSIM)
174 |
175 | PSNR_SUM += PSNR
176 | SSIM_SUM += SSIM
177 |
178 | PSNR = str(PSNR)
179 | SSIM = str(SSIM)
180 |
181 | str_write = ''
182 | str_write += ("No." + str(i + 1) + ": " + PSNR + "\n")
183 | str_write += ("No." + str(i + 1) + ": " + SSIM + "\n")
184 | self.save_res(str_write, save_path, filename)
185 |
186 | # random_ID = str(random.randint(0, 100000000))
187 | Net_img.save(save_path + filename + str(i + 1) + '_PSNR=' + PSNR[:7] + ".TIF", "TIFF")
188 |
189 | label.save(save_path + filename + str(i + 1) + "_Real" + ".TIF", "TIFF")
190 |
191 | avg_str = ""
192 | avg_str += "PSNR_AVG=" + str(PSNR_SUM / len(test_list))[:7] + "\n"
193 | avg_str += "SSIM_AVG=" + str(SSIM_SUM / len(test_list))[:7] + "\n"
194 | print("AVG: ", avg_str)
195 | self.save_res(avg_str, save_path, filename)
196 |
197 | def save_res(self, contents, save_path, filename):
198 | filename = save_path + filename + '.txt'
199 | fh = open(filename, 'a')
200 | fh.write(contents)
201 | fh.close()
202 |
203 | def PSNR(self, A, B):
204 | A = np.array(A)
205 | B = np.array(B)
206 | h = min(A.shape[0], B.shape[0])
207 | w = min(A.shape[1], B.shape[1])
208 | A = A[:h, :w]
209 | B = B[:h, :w]
210 | mse = ((A.astype(np.float) - B.astype(np.float)) ** 2).mean()
211 | if mse == 0: return 10e4
212 | print("MSE: ", mse)
213 | return 10 * np.log10((255.0 ** 2) / mse)
214 |
215 | def to_same_size_ndarray(self, A: Image, B: Image) -> (np.ndarray, np.ndarray):
216 | A = np.array(A)
217 | B = np.array(B)
218 | h = min(A.shape[0], B.shape[0])
219 | w = min(A.shape[1], B.shape[1])
220 | A = A[:h, :w]
221 | B = B[:h, :w]
222 | # A = np.swapaxes(A, 1, 2)
223 | # A = np.swapaxes(A, 0, 1)
224 | # B = np.swapaxes(B, 1, 2)
225 | # B = np.swapaxes(B, 0, 1)
226 | return A, B
227 |
228 |
229 | def Run_test(net, model_path, test_data_path, save_to, as_name, _block_size=32, _test_list=range(0, 10)):
230 | torch.set_default_tensor_type('torch.DoubleTensor')
231 | print("Creating the model...")
232 | MyNet = net.double().to(device)
233 | print("Loading the model data...")
234 | MyNet.load_state_dict(torch.load(model_path, map_location=device))
235 | print("Init...")
236 | Mytest = Test(test_data_path, MyNet, block_size=_block_size)
237 | Mytest.test(save_path=save_to, filename=as_name, test_list=_test_list)
238 |
239 |
240 | def To_Bayer(img):
241 | w, h = img.size
242 | # img=img.resize((int(w/2),int(h/2)), Image.ANTIALIAS)
243 | # w,h=img.size
244 | # r,g,b=img.split()
245 | data = np.array(img)
246 | """
247 | R G R G
248 | G B G B
249 | R G R G
250 | G B G B
251 | """
252 | bayer_mono = np.zeros((h, w))
253 | for r in range(h):
254 | for c in range(w):
255 | if (0 == r % 2):
256 | if (1 == c % 2):
257 | data[r, c, 0] = 0
258 | data[r, c, 2] = 0
259 |
260 | bayer_mono[r, c] = data[r, c, 1]
261 | else:
262 | data[r, c, 1] = 0
263 | data[r, c, 2] = 0
264 |
265 | bayer_mono[r, c] = data[r, c, 0]
266 | else:
267 | if (0 == c % 2):
268 | data[r, c, 0] = 0
269 | data[r, c, 2] = 0
270 |
271 | bayer_mono[r, c] = data[r, c, 1]
272 | else:
273 | data[r, c, 0] = 0
274 | data[r, c, 1] = 0
275 |
276 | bayer_mono[r, c] = data[r, c, 2]
277 |
278 | # 三通道Bayer图像
279 | bayer = Image.fromarray(data)
280 |
281 | return bayer
282 |
283 | # Bayer_mono=Image.fromarray(bayer_mono)
284 | # Bayer_mono.convert('L')
285 | # Bayer_mono.convert('RGB')
286 | # Bayer_mono.show()
287 | # return Bayer_mono
288 |
289 |
290 | def LR_to_HR(net, model_path, img_path, save_to, _block_size=32):
291 | torch.set_default_tensor_type('torch.DoubleTensor')
292 | print("Creating the model...")
293 | MyNet = net.double()
294 | print("Loading the model data...")
295 | MyNet.load_state_dict(torch.load(model_path))
296 | print("Init...")
297 | Mytest = Test(Net=MyNet, file_path="", block_size=_block_size)
298 | print("Read image...")
299 | img = Image.open(img_path).convert("RGB")
300 | img = To_Bayer(img)
301 | img = Mytest.run(raw_img=img.convert('L'))
302 | img.save(save_to + ".TIF", "TIFF")
303 |
--------------------------------------------------------------------------------
/Code_V1.0/Text_Create.py:
--------------------------------------------------------------------------------
1 | def save_txt(filename,contents):
2 | fh = open(filename, 'w')#, encoding='utf-8')
3 | fh.write(contents)
4 | fh.close()
5 |
6 |
7 | train_data_path="/Users/linweichen/Desktop/计算机学习资料/TrainData/RAISE_1K/Train_Data.txt"
8 | test_data_path="/Users/linweichen/Desktop/计算机学习资料/TrainData/RAISE_1K/Test_Data.txt"
9 |
10 | data_dir = "/Users/linweichen/Desktop/计算机学习资料/TrainData/RAISE_1K/Bayer/Bayer"
11 | label_dir = "/Users/linweichen/Desktop/计算机学习资料/TrainData/RAISE_1K/Resize/real"
12 |
13 | contents = ""
14 | for i in range(1, 1001):
15 | contents += data_dir + str(i) + ".TIF" + " " + label_dir + str(i) + ".TIF\n"
16 |
17 | print(contents)
18 |
19 | save_txt(train_data_path,contents)
20 |
21 | '''
22 |
23 | contents=""
24 |
25 | for i in range(901, 1001):
26 | contents += data_dir + str(i) + ".TIF" + " " + label_dir + str(i) + ".TIF\n"
27 |
28 | print(contents)
29 |
30 |
31 |
32 |
33 | save_txt(test_data_path,contents)
34 |
35 |
36 | data_dir = "/Users/linweichen/Desktop/计算机学习资料/TrainData/RAISE_1K/Fast_test/"
37 |
38 | contents = ""
39 | for i in range(1, 1001):
40 | contents += data_dir + str(i) + "bayer" + ".TIF" + " " + data_dir + str(i) + "real" + ".TIF\n"
41 |
42 | print(contents)
43 |
44 | fast_test_data_path = "/Users/linweichen/Desktop/计算机学习资料/TrainData/RAISE_1K/Fast_Test_Data.txt"
45 |
46 | save_txt(fast_test_data_path, contents)
47 | '''
48 |
--------------------------------------------------------------------------------
/Code_V1.0/ToBayer.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import numpy as np
3 |
4 | def DownSize(img):
5 | for i in range(3):
6 | w,h=img.size
7 | print(int(w/1.25),int(h/1.25))
8 | img=img.resize((int(w/1.25),int(h/1.25)), Image.ANTIALIAS)
9 | return img
10 |
11 | def To_Bayer(img):
12 | w,h=img.size
13 | img=img.resize((int(w/2),int(h/2)), Image.ANTIALIAS)
14 | w,h=img.size
15 | # r,g,b=img.split()
16 | data=np.array(img)
17 | """
18 | R G R G
19 | G B G B
20 | R G R G
21 | G B G B
22 | """
23 | bayer_mono=np.zeros((h,w))
24 | for r in range(h):
25 | for c in range(w):
26 | if(0==r%2):
27 | if(1==c%2):
28 | data[r,c,0]=0
29 | data[r,c,2]=0
30 |
31 | bayer_mono[r,c]=data[r,c,1]
32 | else:
33 | data[r,c,1]=0
34 | data[r,c,2]=0
35 |
36 | bayer_mono[r,c]=data[r,c,0]
37 | else:
38 | if(0==c%2):
39 | data[r,c,0]=0
40 | data[r,c,2]=0
41 |
42 | bayer_mono[r,c]=data[r,c,1]
43 | else:
44 | data[r,c,0]=0
45 | data[r,c,1]=0
46 |
47 | bayer_mono[r,c]=data[r,c,2]
48 |
49 | # 三通道Bayer图像
50 | bayer=Image.fromarray(data)
51 | #bayer.show()
52 |
53 | return bayer
54 |
55 |
56 | dir="/Users/linweichen/Desktop/计算机学习资料/TrainData/RAISE_1K/Rename"
57 |
58 | for i in range(1001,1001):
59 | img=Image.open(dir+"/raw"+str(i)+".TIF","r")
60 | # img.show()
61 | img=DownSize(img)
62 | # img.show()
63 | img.save("/Users/linweichen/Desktop/计算机学习资料/TrainData/RAISE_1K/Resize/real"+str(i)+".TIF","TIFF")
64 | # r,g,b=img.split()
65 |
66 | #print (type(h))
67 | # emtpy=Image.new("L",(w,h))
68 | # print(type(r))
69 | # r_merged = Image.merge('RGB',(r,emtpy,emtpy))
70 | # g_merged = Image.merge('RGB',(emtpy,g,emtpy))
71 | # b_merged = Image.merge('RGB',(emtpy,emtpy,b))
72 | # r_merged.show()
73 | # b_merged.show()
74 | # g_merged.show()
75 |
76 | resize_dir="/Users/linweichen/Desktop/计算机学习资料/TrainData/RAISE_1K/Resize/real"
77 | for i in range(1,1001):
78 | img=Image.open(resize_dir+str(i)+".TIF","r")
79 | # img.show()
80 | img=To_Bayer(img)
81 | img.save("/Users/linweichen/Desktop/计算机学习资料/TrainData/RAISE_1K/Bayer/Bayer"+str(i)+".TIF","TIFF")
82 |
83 |
84 | # Python PIL image split to RGB
85 | '''
86 | from PIL import Image
87 |
88 | img = Image.open('ra.jpg')
89 | data = img.getdata()
90 |
91 | # Suppress specific bands (e.g. (255, 120, 65) -> (0, 120, 0) for g)
92 | r = [(d[0], 0, 0) for d in data]
93 | g = [(0, d[1], 0) for d in data]
94 | b = [(0, 0, d[2]) for d in data]
95 |
96 | img.putdata(r)
97 | img.save('r.png')
98 | img.putdata(g)
99 | img.save('g.png')
100 | img.putdata(b)
101 | img.save('b.png')
102 | '''
103 |
--------------------------------------------------------------------------------
/Code_V1.0/__pycache__/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Linwei-Chen/Deep-Residual-Network-for-JointDemosaicing-and-Super-Resolution/a224f0ea673d70c26ee17aec9f27e1a7c31cbe8e/Code_V1.0/__pycache__/.DS_Store
--------------------------------------------------------------------------------
/Code_V1.0/fast_train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.utils.data as data
3 | import torch.nn as nn
4 | import torchvision
5 | import torchvision.transforms as transforms
6 | import numpy as np
7 | import time
8 | from PIL import Image
9 | from DataSet import CustomDataset
10 | from NewResNet import Net
11 | from multiprocessing import Process
12 | from Test_class import Run_test
13 |
14 | # *** 超参数*** `
15 | Parameter_path = './Final_train_LR.txt'
16 | MODEL_PATH = './Final_Model.pkl'
17 | EPOCH = 1
18 | HALF_LR_STEP = 40000
19 | LR = 0.0001
20 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
21 | # 训练集与测试集的路径
22 | train_data_path = "../8K_TRAIN_DATA/8K_TRAIN_DATA.txt"
23 | test_data_path = "../8K_CROSS_DATA/8K_CROSS_DATA.txt"
24 | BATCH_BLOCK_SIZE = 64
25 | BATCH_SIZE = 12
26 | DATA_SHUFFLE = True
27 |
28 | # 检查GPU是否可用
29 | print("cuda:", torch.cuda.is_available(), "GPUs", torch.cuda.device_count())
30 |
31 | # 保存和恢复模型
32 | # https://www.cnblogs.com/nkh222/p/7656623.html
33 | # https://blog.csdn.net/quincuntial/article/details/78045036
34 | #
35 | # 保存
36 | # torch.save(the_model.state_dict(), PATH)
37 | # 恢复
38 | # the_model = TheModelClass(*args, **kwargs)
39 | # the_model.load_state_dict(torch.load(PATH))
40 |
41 | # # 只保存网络的参数, 官方推荐的方式
42 | # torch.save(net.state_dict(), 'net_params.pkl')
43 | ## 加载网络参数
44 | # net.load_state_dict(torch.load('net_params.pkl'))
45 |
46 | print("Loading the LR...")
47 | try:
48 | P = open(Parameter_path)
49 | P = list(P)
50 | LR = float(P[0])
51 | except:
52 | print("Loading LR fail...")
53 |
54 | print("Loading the saving Model...")
55 | MyNet = Net(24).to(device)
56 |
57 | try:
58 | MyNet.load_state_dict(torch.load(MODEL_PATH, map_location=device))
59 | except:
60 | print("Loading Fail.")
61 | pass
62 | print("Loading the Training data...")
63 |
64 | MyData = CustomDataset(file_path=train_data_path,
65 | block_size=BATCH_BLOCK_SIZE)
66 |
67 | # CLASS torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False,
68 | # sampler=None, batch_sampler=None, num_workers=0, collate_fn=,
69 | # pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)
70 |
71 | train_data = data.DataLoader(dataset=MyData,
72 | batch_size=BATCH_SIZE,
73 | shuffle=DATA_SHUFFLE,
74 | num_workers= 4,
75 | pin_memory=True)
76 |
77 | # CLASS torch.optim.Adam(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
78 | Optimizer = torch.optim.Adam(MyNet.parameters(), lr=LR, betas=(0.9, 0.999), eps=1e-08, amsgrad=True)
79 | # CLASS torch.nn.MSELoss(size_average=None, reduce=None, reduction='mean')
80 | Loss_Func = nn.MSELoss()
81 |
82 | counter = 0
83 |
84 | print("Start training...")
85 | for epoch in range(EPOCH):
86 | for step, (data, label) in enumerate(train_data):
87 | counter = counter + 1
88 | if counter != 0 and counter % HALF_LR_STEP == 0:
89 | LR = LR / 2
90 | Optimizer = torch.optim.Adam(MyNet.parameters(), lr=LR, betas=(0.9, 0.999), eps=1e-08)
91 | with open(Parameter_path, 'w') as f:
92 | f.write(str(LR))
93 | print('LR:', LR)
94 |
95 | data, label = data.to(device), label.to(device)
96 | start = time.perf_counter()
97 | out = MyNet(data)
98 | # print(type(out), out.shape)
99 | loss = Loss_Func(out, label)
100 | Optimizer.zero_grad()
101 | loss.backward()
102 | Optimizer.step()
103 | print(loss)
104 | print(epoch, step)
105 | print("Time:", time.perf_counter() - start)
106 | if counter != 0 and 0 == counter % 100:
107 | print("Saving the model...")
108 | torch.save(MyNet.state_dict(), MODEL_PATH)
109 |
--------------------------------------------------------------------------------
/Code_V1.0/main.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.utils.data as data
3 | import torch.nn as nn
4 | import torchvision
5 | import torchvision.transforms as transforms
6 | import numpy as np
7 | import time
8 |
9 | from NewResNet import Net
10 | from DataSet import CustomDataset
11 | from multiprocessing import Process
12 | from Test_class import Run_test
13 |
14 | # *** 超参数***
15 | BATC_SIZE = 1
16 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
17 | print("cuda:", torch.cuda.is_available(),"GPUs",torch.cuda.device_count())
18 |
19 | # 保存和恢复模型
20 | # https://www.cnblogs.com/nkh222/p/7656623.html
21 | # https://blog.csdn.net/quincuntial/article/details/78045036
22 | #
23 | # 保存
24 | # torch.save(the_model.state_dict(), PATH)
25 | # 恢复
26 | # the_model = TheModelClass(*args, **kwargs)
27 | # the_model.load_state_dict(torch.load(PATH))
28 |
29 | # # 只保存网络的参数, 官方推荐的方式
30 | # torch.save(net.state_dict(), 'net_params.pkl')
31 | ## 加载网络参数
32 | # net.load_state_dict(torch.load('net_params.pkl'))
33 |
34 | print("Loading the saving Model...")
35 | MyNet = Net(24).to(device)
36 | try:
37 | MyNet.load_state_dict(torch.load('./Model.pkl'))
38 | except:
39 | print("Loading Fail.")
40 | pass
41 |
42 | # 训练集与测试集的路径
43 | train_data_path = "/Users/chenlinwei/Desktop/计算机学习资料/TrainData/RAISE_1K/Train_Data.txt"
44 | test_data_path = "/Users/chenlinwei/Desktop/计算机学习资料/TrainData/RAISE_1K/Test_Data.txt"
45 | all_data_path = "/Users/chenlinwei/Desktop/计算机学习资料/TrainData/RAISE_1K/Data_Read.txt"
46 |
47 | print("Loading the Training data...")
48 | MyData = CustomDataset(train_data_path, random_augment=10, block_size=32)
49 |
50 | # CLASS torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False,
51 | # sampler=None, batch_sampler=None, num_workers=0, collate_fn=,
52 | # pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)
53 |
54 | train_data = data.DataLoader(dataset=MyData,
55 | batch_size=16,
56 | shuffle=True)
57 |
58 | # CLASS torch.optim.Adam(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
59 | Optimizer = torch.optim.Adam(MyNet.parameters(), lr=0.0000001, betas=(0.9, 0.999), eps=1e-08)
60 | # CLASS torch.nn.MSELoss(size_average=None, reduce=None, reduction='mean')
61 | Loss_Func = nn.MSELoss()
62 |
63 | EPOCH = 1000000
64 | to_PIL_image = transforms.ToPILImage()
65 |
66 |
67 | def adjust_learning_rate(optimizer):
68 | for param_group in optimizer.param_groups:
69 | param_group['lr'] = param_group['lr'] / 2
70 |
71 | counter=0
72 | LR=0.0001
73 | # 8 619
74 | print("Start training...")
75 | for epoch in range(EPOCH):
76 | for step, (data, label) in enumerate(train_data):
77 | counter=counter+1
78 | if(counter!=0 and counter%10000==0):
79 | LR=LR/2
80 | Optimizer = torch.optim.Adam(MyNet.parameters(), lr=LR, betas=(0.9, 0.999), eps=1e-08)
81 | # print(type(data), type(label))
82 | # print(data.shape, label.shape)
83 | # print(data, label)
84 | # for i in range(data.size()[0]):
85 | # img1 = to_PIL_image(data)
86 | # img2 = to_PIL_image(label)
87 | # img1.show()
88 | # img2.show()
89 | data, label=data.to(device), label.to(device)
90 | start = time.process_time()
91 | out = MyNet(data)
92 | # print(type(out), out.shape)
93 | loss = Loss_Func(out, label)
94 | Optimizer.zero_grad()
95 | loss.backward()
96 | Optimizer.step()
97 | print(loss)
98 | print(epoch, step)
99 | print("Time:", time.process_time() - start)
100 | if counter != 0 and 0 == counter % 10:
101 | print("Saving the model...")
102 | torch.save(MyNet.state_dict(), './Model.pkl')
103 | if counter != 0 and 0 == counter % 99:
104 | torch.set_default_tensor_type('torch.DoubleTensor')
105 | # def Run_test(net, model_path, test_data_path, save_to, as_name, _block_size=32):
106 | multi_Process = Process(target=Run_test(net=Net(24), model_path='./Model.pkl',
107 | test_data_path="/Users/chenlinwei/Desktop/计算机学习资料/TrainData/RAISE_1K/Fast_Test_Data.txt",
108 | save_to="./test_result/", as_name="test_result_PSNR.txt",
109 | _block_size=64))
110 | multi_Process.start()
111 |
--------------------------------------------------------------------------------
/Code_V1.0/temp.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import numpy as np
3 | import os
4 | import time
5 | import math
6 | from skimage.measure import compare_ssim, compare_psnr
7 |
8 | path = '/Users/chenlinwei/Desktop/计算机学习资料/20181218 Deep Residual Network for Joint Demosaicing and Super-Resolution/Final_test/final'
9 |
10 |
11 | # a = '/Users/chenlinwei/Desktop/计算机学习资料/20181218 Deep Residual Network for Joint Demosaicing and Super-Resolution/Final_test/final_test99_PSNR=28.5774.TIF'
12 | # b = '/Users/chenlinwei/Desktop/计算机学习资料/20181218 Deep Residual Network for Joint Demosaicing and Super-Resolution/Final_test/final_test99_Real.TIF'
13 |
14 |
15 | def file_name(path):
16 | L = []
17 | for root, dirs, files in os.walk(path):
18 | # print("root", root)
19 | # print("dirs", dirs)
20 | # print("files", type(files), files)
21 | l = len(files)
22 | print("len:", l)
23 | files.sort(key=len)
24 | for i in range(l):
25 | file = files[i]
26 | for j in range(i + 1, l):
27 | file_2 = files[j]
28 | if os.path.splitext(file)[1] == '.TIF':
29 | if (file[:2] == '._'):
30 | file = file[2:]
31 | if (file_2[10:file_2[6:].find('_') + 6] == file[10:file[6:].find('_') + 6]):
32 | # print(file, file_2)
33 | # L.append([os.path.join(root, file), os.path.join(root, file_2)])
34 | temp = [file, file_2]
35 | temp.sort(key=len)
36 | L.append(temp)
37 | break
38 | # L.append(root+'/'+file)
39 | # print(L)
40 | return L
41 |
42 |
43 | def sort_key(item):
44 | item = item[0]
45 | return int(item[10:item[6:].find('_') + 6])
46 |
47 |
48 | l = file_name(path)
49 | print(l)
50 | l.sort(key=sort_key, reverse=False)
51 | l = l[:100]
52 | SSIM_SUM = 0.0
53 | SSIM_STR = ''
54 | PSNR_STR = ''
55 | for i in l:
56 | print(i)
57 | a = os.path.join(path, i[0])
58 | b = os.path.join(path, i[1])
59 | a = Image.open(a).convert('RGB')
60 | b = Image.open(b).convert('RGB')
61 | a = np.array(a)
62 | b = np.array(b)
63 | a = np.array(a)
64 | b = np.array(b)
65 | h = min(a.shape[0], b.shape[0])
66 | w = min(a.shape[1], b.shape[1])
67 | a = a[:h, :w]
68 | b = b[:h, :w]
69 | # print(a.shape, b.shape)
70 | SSIM = compare_ssim(X=a, Y=b, full=0, gaussian_weights=0, win_size=11, multichannel=1)
71 | print(SSIM)
72 | SSIM_SUM += SSIM
73 | SSIM_STR += str(SSIM) + '\n'
74 | PSNR_STR += i[1][-11:-4] + '\n'
75 |
76 | print("SSIM_AVG:", SSIM_SUM / len(l))
77 | print(SSIM_STR)
78 | print(PSNR_STR)
79 |
--------------------------------------------------------------------------------
/DataSet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torchvision
4 | import torchvision.transforms as transforms
5 | import torch.utils.data as data
6 | from PIL import Image
7 | import numpy as np
8 | import os
9 |
10 |
11 | def is_tif_file(filename):
12 | return any(filename.endswith(extension) for extension in [".TIF"])
13 |
14 |
15 | def bayer2mono(bayer_img):
16 | bayer_img = np.array(bayer_img)
17 | bayer_mono = np.max(bayer_img, 2)
18 | return bayer_mono[:, :, np.newaxis]
19 |
20 |
21 | # bayer2mono()
22 |
23 |
24 | class CustomDataset(data.Dataset):
25 | def __init__(self, data_dir, file_path_list):
26 | self.imgs = file_path_list
27 | self.data_dir = data_dir
28 | print('Dataset size is : ', len(self.imgs))
29 |
30 | def __getitem__(self, index):
31 | # 注意!!! 读入的Bayer图像最左上为:
32 | # R G
33 | # G B
34 | data_path, label_path = [os.path.join(self.data_dir, i) for i in self.imgs[index]]
35 | # print(index, data_path, label_path)
36 |
37 | data = bayer2mono(Image.open(data_path).convert('RGB'))
38 | label = Image.open(label_path).convert('RGB')
39 |
40 | trans = transforms.Compose([transforms.ToTensor()])
41 |
42 | data_img = trans(data)
43 | label_img = trans(label)
44 |
45 | return data_img, label_img, self.imgs[index]
46 |
47 | def __len__(self):
48 | return len(self.imgs)
49 |
--------------------------------------------------------------------------------
/Final_test/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Linwei-Chen/Deep-Residual-Network-for-JointDemosaicing-and-Super-Resolution/a224f0ea673d70c26ee17aec9f27e1a7c31cbe8e/Final_test/.DS_Store
--------------------------------------------------------------------------------
/Model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torchvision
4 | import torchvision.transforms as transforms
5 | import numpy as np
6 |
7 |
8 | # ResNet
9 | # https://blog.csdn.net/sunqiande88/article/details/80100891
10 | class ResidualBlock(nn.Module):
11 | def __init__(self):
12 | super(ResidualBlock, self).__init__()
13 | self.left = nn.Sequential(
14 | nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, bias=True),
15 | nn.PReLU(),
16 | nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, bias=True),
17 | )
18 | self.shortcut = nn.Sequential()
19 | self.active_f = nn.PReLU()
20 |
21 | def forward(self, x):
22 | out = self.left(x)
23 | out += self.shortcut(x)
24 | out = self.active_f(out)
25 | return out
26 |
27 |
28 | class Net(nn.Module):
29 |
30 | def __init__(self, resnet_level=24):
31 | super(Net, self).__init__()
32 |
33 | # ***Stage1***
34 | # class torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True)
35 | self.stage1_1_conv4x4 = nn.Conv2d(in_channels=1, out_channels=256,
36 | kernel_size=4, stride=2, padding=1, bias=True)
37 | # Reference:
38 | # CLASS torch.nn.PixelShuffle(upscale_factor)
39 | # Examples:
40 | #
41 | # >>> pixel_shuffle = nn.PixelShuffle(3)
42 | # >>> input = torch.randn(1, 9, 4, 4)
43 | # >>> output = pixel_shuffle(input)
44 | # >>> print(output.size())
45 | # torch.Size([1, 1, 12, 12])
46 |
47 | self.stage1_2_SP_conv = nn.PixelShuffle(2)
48 | self.stage1_2_conv4x4 = nn.Conv2d(in_channels=64, out_channels=256,
49 | kernel_size=3, stride=1, padding=1, bias=True)
50 |
51 | # CLASS torch.nn.PReLU(num_parameters=1, init=0.25)
52 | self.stage1_2_PReLU = nn.PReLU()
53 |
54 | # ***Stage2***
55 | self.stage2_ResNetBlock = []
56 | for i in range(resnet_level):
57 | self.stage2_ResNetBlock.append(ResidualBlock())
58 | self.stage2_ResNetBlock = nn.Sequential(*self.stage2_ResNetBlock)
59 |
60 | # ***Stage3***
61 | self.stage3_1_SP_conv = nn.PixelShuffle(2)
62 | self.stage3_2_conv3x3 = nn.Conv2d(in_channels=64, out_channels=256,
63 | kernel_size=3, stride=1, padding=1, bias=True)
64 | self.stage3_2_PReLU = nn.PReLU()
65 | self.stage3_3_conv3x3 = nn.Conv2d(in_channels=256, out_channels=3,
66 | kernel_size=3, stride=1, padding=1, bias=True)
67 |
68 | def forward(self, x):
69 | out = self.stage1_1_conv4x4(x)
70 | out = self.stage1_2_SP_conv(out)
71 | out = self.stage1_2_conv4x4(out)
72 | out = self.stage1_2_PReLU(out)
73 |
74 | out = self.stage2_ResNetBlock(out)
75 |
76 | out = self.stage3_1_SP_conv(out)
77 | out = self.stage3_2_conv3x3(out)
78 | out = self.stage3_2_PReLU(out)
79 | out = self.stage3_3_conv3x3(out)
80 |
81 | return out
82 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Deep-Residual-Network-for-JointDemosaicing-and-Super-Resolution
2 |
3 | check https://zhuanlan.zhihu.com/p/56493507 to see the details
4 |
5 | check https://pan.baidu.com/s/1QtwMmSbTNTfosQ3eEDW2Pg to get the trained model weights
6 |
7 | left is ground truth image, right is output of model
8 |
9 | 
10 |
11 | PSNR: 31.197238996689617 SSIM: 0.9097831587657645
12 |
13 | 
14 |
15 | PSNR: 32.89967806219095 SSIM: 0.9294818208128227
16 |
17 | 
18 |
19 | PSNR: 33.15050503169419 SSIM: 0.9472909901611216
20 |
21 | 
22 |
23 | PSNR: 30.873442524392864 SSIM: 0.9473571002561766
24 |
25 | 
26 |
27 | PSNR: 25.052382881653507 SSIM: 0.9404708529075997
28 |
29 | 
30 |
31 | PSNR: 38.69040333179672 SSIM: 0.9570685066296898
32 |
--------------------------------------------------------------------------------
/Saved_Models/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Linwei-Chen/Deep-Residual-Network-for-JointDemosaicing-and-Super-Resolution/a224f0ea673d70c26ee17aec9f27e1a7c31cbe8e/Saved_Models/.DS_Store
--------------------------------------------------------------------------------
/Saved_Models/20190226Traned_Model/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Linwei-Chen/Deep-Residual-Network-for-JointDemosaicing-and-Super-Resolution/a224f0ea673d70c26ee17aec9f27e1a7c31cbe8e/Saved_Models/20190226Traned_Model/.DS_Store
--------------------------------------------------------------------------------
/TEST_DATA/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Linwei-Chen/Deep-Residual-Network-for-JointDemosaicing-and-Super-Resolution/a224f0ea673d70c26ee17aec9f27e1a7c31cbe8e/TEST_DATA/.DS_Store
--------------------------------------------------------------------------------
/TEST_DATA/0data.TIF:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Linwei-Chen/Deep-Residual-Network-for-JointDemosaicing-and-Super-Resolution/a224f0ea673d70c26ee17aec9f27e1a7c31cbe8e/TEST_DATA/0data.TIF
--------------------------------------------------------------------------------
/TEST_DATA/0label.TIF:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Linwei-Chen/Deep-Residual-Network-for-JointDemosaicing-and-Super-Resolution/a224f0ea673d70c26ee17aec9f27e1a7c31cbe8e/TEST_DATA/0label.TIF
--------------------------------------------------------------------------------
/TEST_DATA/1data.TIF:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Linwei-Chen/Deep-Residual-Network-for-JointDemosaicing-and-Super-Resolution/a224f0ea673d70c26ee17aec9f27e1a7c31cbe8e/TEST_DATA/1data.TIF
--------------------------------------------------------------------------------
/TEST_DATA/1label.TIF:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Linwei-Chen/Deep-Residual-Network-for-JointDemosaicing-and-Super-Resolution/a224f0ea673d70c26ee17aec9f27e1a7c31cbe8e/TEST_DATA/1label.TIF
--------------------------------------------------------------------------------
/TEST_DATA/2data.TIF:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Linwei-Chen/Deep-Residual-Network-for-JointDemosaicing-and-Super-Resolution/a224f0ea673d70c26ee17aec9f27e1a7c31cbe8e/TEST_DATA/2data.TIF
--------------------------------------------------------------------------------
/TEST_DATA/2label.TIF:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Linwei-Chen/Deep-Residual-Network-for-JointDemosaicing-and-Super-Resolution/a224f0ea673d70c26ee17aec9f27e1a7c31cbe8e/TEST_DATA/2label.TIF
--------------------------------------------------------------------------------
/TEST_DATA/3data.TIF:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Linwei-Chen/Deep-Residual-Network-for-JointDemosaicing-and-Super-Resolution/a224f0ea673d70c26ee17aec9f27e1a7c31cbe8e/TEST_DATA/3data.TIF
--------------------------------------------------------------------------------
/TEST_DATA/3label.TIF:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Linwei-Chen/Deep-Residual-Network-for-JointDemosaicing-and-Super-Resolution/a224f0ea673d70c26ee17aec9f27e1a7c31cbe8e/TEST_DATA/3label.TIF
--------------------------------------------------------------------------------
/TEST_DATA/4data.TIF:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Linwei-Chen/Deep-Residual-Network-for-JointDemosaicing-and-Super-Resolution/a224f0ea673d70c26ee17aec9f27e1a7c31cbe8e/TEST_DATA/4data.TIF
--------------------------------------------------------------------------------
/TEST_DATA/4label.TIF:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Linwei-Chen/Deep-Residual-Network-for-JointDemosaicing-and-Super-Resolution/a224f0ea673d70c26ee17aec9f27e1a7c31cbe8e/TEST_DATA/4label.TIF
--------------------------------------------------------------------------------
/TEST_DATA/5data.TIF:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Linwei-Chen/Deep-Residual-Network-for-JointDemosaicing-and-Super-Resolution/a224f0ea673d70c26ee17aec9f27e1a7c31cbe8e/TEST_DATA/5data.TIF
--------------------------------------------------------------------------------
/TEST_DATA/5label.TIF:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Linwei-Chen/Deep-Residual-Network-for-JointDemosaicing-and-Super-Resolution/a224f0ea673d70c26ee17aec9f27e1a7c31cbe8e/TEST_DATA/5label.TIF
--------------------------------------------------------------------------------
/TEST_DATA/6data.TIF:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Linwei-Chen/Deep-Residual-Network-for-JointDemosaicing-and-Super-Resolution/a224f0ea673d70c26ee17aec9f27e1a7c31cbe8e/TEST_DATA/6data.TIF
--------------------------------------------------------------------------------
/TEST_DATA/6label.TIF:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Linwei-Chen/Deep-Residual-Network-for-JointDemosaicing-and-Super-Resolution/a224f0ea673d70c26ee17aec9f27e1a7c31cbe8e/TEST_DATA/6label.TIF
--------------------------------------------------------------------------------
/TEST_DATA/7data.TIF:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Linwei-Chen/Deep-Residual-Network-for-JointDemosaicing-and-Super-Resolution/a224f0ea673d70c26ee17aec9f27e1a7c31cbe8e/TEST_DATA/7data.TIF
--------------------------------------------------------------------------------
/TEST_DATA/7label.TIF:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Linwei-Chen/Deep-Residual-Network-for-JointDemosaicing-and-Super-Resolution/a224f0ea673d70c26ee17aec9f27e1a7c31cbe8e/TEST_DATA/7label.TIF
--------------------------------------------------------------------------------
/TEST_DATA/8data.TIF:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Linwei-Chen/Deep-Residual-Network-for-JointDemosaicing-and-Super-Resolution/a224f0ea673d70c26ee17aec9f27e1a7c31cbe8e/TEST_DATA/8data.TIF
--------------------------------------------------------------------------------
/TEST_DATA/8label.TIF:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Linwei-Chen/Deep-Residual-Network-for-JointDemosaicing-and-Super-Resolution/a224f0ea673d70c26ee17aec9f27e1a7c31cbe8e/TEST_DATA/8label.TIF
--------------------------------------------------------------------------------
/TEST_DATA/9data.TIF:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Linwei-Chen/Deep-Residual-Network-for-JointDemosaicing-and-Super-Resolution/a224f0ea673d70c26ee17aec9f27e1a7c31cbe8e/TEST_DATA/9data.TIF
--------------------------------------------------------------------------------
/TEST_DATA/9label.TIF:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Linwei-Chen/Deep-Residual-Network-for-JointDemosaicing-and-Super-Resolution/a224f0ea673d70c26ee17aec9f27e1a7c31cbe8e/TEST_DATA/9label.TIF
--------------------------------------------------------------------------------
/TEST_DATA/TEST_DATA.txt:
--------------------------------------------------------------------------------
1 | ./TEST_DATA/0data.TIF ./TEST_DATA/0label.TIF
2 | ./TEST_DATA/1data.TIF ./TEST_DATA/1label.TIF
3 | ./TEST_DATA/2data.TIF ./TEST_DATA/2label.TIF
4 | ./TEST_DATA/3data.TIF ./TEST_DATA/3label.TIF
5 | ./TEST_DATA/4data.TIF ./TEST_DATA/4label.TIF
6 | ./TEST_DATA/5data.TIF ./TEST_DATA/5label.TIF
7 | ./TEST_DATA/6data.TIF ./TEST_DATA/6label.TIF
8 | ./TEST_DATA/7data.TIF ./TEST_DATA/7label.TIF
9 | ./TEST_DATA/8data.TIF ./TEST_DATA/8label.TIF
10 | ./TEST_DATA/9data.TIF ./TEST_DATA/9label.TIF
11 | ./TEST_DATA/0data.TIF ./TEST_DATA/0label.TIF
12 | ./TEST_DATA/1data.TIF ./TEST_DATA/1label.TIF
13 | ./TEST_DATA/2data.TIF ./TEST_DATA/2label.TIF
14 | ./TEST_DATA/3data.TIF ./TEST_DATA/3label.TIF
15 | ./TEST_DATA/4data.TIF ./TEST_DATA/4label.TIF
16 | ./TEST_DATA/5data.TIF ./TEST_DATA/5label.TIF
17 | ./TEST_DATA/6data.TIF ./TEST_DATA/6label.TIF
18 | ./TEST_DATA/7data.TIF ./TEST_DATA/7label.TIF
19 | ./TEST_DATA/8data.TIF ./TEST_DATA/8label.TIF
20 | ./TEST_DATA/9data.TIF ./TEST_DATA/9label.TIF
21 | ./TEST_DATA/0data.TIF ./TEST_DATA/0label.TIF
22 | ./TEST_DATA/1data.TIF ./TEST_DATA/1label.TIF
23 | ./TEST_DATA/2data.TIF ./TEST_DATA/2label.TIF
24 | ./TEST_DATA/3data.TIF ./TEST_DATA/3label.TIF
25 | ./TEST_DATA/4data.TIF ./TEST_DATA/4label.TIF
26 | ./TEST_DATA/5data.TIF ./TEST_DATA/5label.TIF
27 | ./TEST_DATA/6data.TIF ./TEST_DATA/6label.TIF
28 | ./TEST_DATA/7data.TIF ./TEST_DATA/7label.TIF
29 | ./TEST_DATA/8data.TIF ./TEST_DATA/8label.TIF
30 | ./TEST_DATA/9data.TIF ./TEST_DATA/9label.TIF
31 | ./TEST_DATA/0data.TIF ./TEST_DATA/0label.TIF
32 | ./TEST_DATA/1data.TIF ./TEST_DATA/1label.TIF
33 | ./TEST_DATA/2data.TIF ./TEST_DATA/2label.TIF
34 | ./TEST_DATA/3data.TIF ./TEST_DATA/3label.TIF
35 | ./TEST_DATA/4data.TIF ./TEST_DATA/4label.TIF
36 | ./TEST_DATA/5data.TIF ./TEST_DATA/5label.TIF
37 | ./TEST_DATA/6data.TIF ./TEST_DATA/6label.TIF
38 | ./TEST_DATA/7data.TIF ./TEST_DATA/7label.TIF
39 | ./TEST_DATA/8data.TIF ./TEST_DATA/8label.TIF
40 | ./TEST_DATA/9data.TIF ./TEST_DATA/9label.TIF
41 |
--------------------------------------------------------------------------------
/Test_class.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 | from PIL import Image
4 | import numpy as np
5 | import time
6 | import math
7 | from skimage.measure import compare_psnr, compare_ssim
8 | import torchvision.transforms as transforms
9 | import torch.utils.data as data
10 | from tqdm import tqdm
11 | # from SSIM_PIL import compare_ssim as SSIM_PIL_compare_ssim
12 |
13 | PADDING = 2
14 | FORWORD_BLOCK_NUM = 1
15 |
16 |
17 | def padding_and_to_blks(bayer_rgb_img, block_size=64):
18 | # *** step1: 图像预处理
19 | # s为padding大小
20 | s = PADDING
21 | w, h = bayer_rgb_img.size
22 | # print(w, h)
23 | raw_data_arr = np.max(np.array(bayer_rgb_img), 2)
24 | print("===> Shape", raw_data_arr.shape)
25 |
26 | if block_size <= 0:
27 | return raw_data_arr[np.newaxis, np.newaxis, :, :], h, w
28 |
29 | # 将图片分割成块,向上取整
30 | row_block = math.ceil(h / block_size)
31 | col_block = math.ceil(w / block_size)
32 | block_total = row_block * col_block
33 |
34 | # 根据取整情况,对原图进行镜像Padding
35 | raw_data_arr = np.pad(raw_data_arr, ((s, row_block * block_size - h + s),
36 | (s, col_block * block_size - w + s)), mode='reflect')
37 |
38 | # *** step2: 将原图切割成块保存于blks
39 | blks = np.zeros((block_total, 1, block_size + 2 * s, block_size + 2 * s))
40 | blk_counter = 0
41 | for i in range(row_block):
42 | for j in range(col_block):
43 | r = i * block_size
44 | c = j * block_size
45 | blks[blk_counter, :, :, :] = raw_data_arr[r: r + block_size + 2 * s, c:c + block_size + 2 * s]
46 | blk_counter = blk_counter + 1
47 | return blks, h, w
48 |
49 |
50 | def run_forward(model, device, blks, h, w, block_size=64):
51 | row_block = 1
52 | col_block = 1
53 | if block_size > 0:
54 | row_block = math.ceil(h / block_size)
55 | col_block = math.ceil(w / block_size)
56 | s = PADDING
57 | blk_num = blks.shape[0]
58 |
59 | res = np.zeros((blk_num, 3, (block_size + 2 * s) * 2, (block_size + 2 * s) * 2))
60 | res = torch.from_numpy(res) # .to(device)
61 | # start = time.perf_counter()
62 | with torch.no_grad():
63 | for i in tqdm(range(math.ceil(blk_num / FORWORD_BLOCK_NUM))):
64 | blk_No = min((i + 1) * FORWORD_BLOCK_NUM, blk_num)
65 | temp = blks[i * FORWORD_BLOCK_NUM:blk_No, :, :, :]
66 | temp = torch.from_numpy(temp / 255.).to(device).float()
67 | temp = model(temp).detach()
68 | res[i * FORWORD_BLOCK_NUM:min((i + 1) * FORWORD_BLOCK_NUM, blk_num), :, :] = temp.cpu()
69 |
70 | # 输出时间
71 | # used_time = time.perf_counter() - start
72 | # print("Processing:", blk_No, '/', blk_num)
73 | # print("Now:", used_time)
74 | # print("Remain:", used_time / blk_No * blk_num - used_time)
75 |
76 | # 计算完成,将图块拼回图像
77 | # print("BLOCKS:", blk_num)
78 | start = time.perf_counter()
79 | res_img = np.zeros((3, 2 * h, 2 * w))
80 | res = res.cpu().numpy() * 255.0
81 | # res = res.clip(0, 255)
82 | # print('Copy use:', time.perf_counter() - start)
83 | # res_img = torch.from_numpy(res_img).double().to(device)
84 | for i in range(row_block):
85 | for j in range(col_block):
86 | r = i * block_size
87 | c = j * block_size
88 | block_h = min(h, r + block_size) - r
89 | block_w = min(w, c + block_size) - c
90 | res_img[:, 2 * r: 2 * (r + block_h), 2 * c:2 * (c + block_w)] = \
91 | res[i * col_block + j, :, 2 * s:2 * (block_h + s), 2 * s:2 * (block_w + s)]
92 | # print('Re-range use:', time.perf_counter() - start)
93 |
94 | # res_img = Image.fromarray((H,W,3))
95 | # 将矩阵形状从(3,H,W)变成(H,W,3)
96 | start = time.perf_counter()
97 | # res_img = res_img.to('cpu').numpy() * 255.
98 | # print('Copy use:', time.perf_counter() - start)
99 |
100 | r = Image.fromarray(res_img[0]).convert('L')
101 | g = Image.fromarray(res_img[1]).convert('L')
102 | b = Image.fromarray(res_img[2]).convert('L')
103 | # print(res_img)
104 | print("===> Output shape:", res_img.shape)
105 | # size = shape[1], shape[0]
106 | # res_img = Image.fromarray(res * 255.,mode='RGB')
107 | # PIL.Image.merge(mode, bands)
108 | res_img = Image.merge('RGB', (r, g, b))
109 |
110 | return res_img.convert('RGB')
111 |
112 |
113 | def to_same_size_ndarray(A: Image, B: Image) -> (np.ndarray, np.ndarray):
114 | A = np.array(A)
115 | B = np.array(B)
116 | h = min(A.shape[0], B.shape[0])
117 | w = min(A.shape[1], B.shape[1])
118 | A = A[:h, :w]
119 | B = B[:h, :w]
120 | # A = np.swapaxes(A, 1, 2)
121 | # A = np.swapaxes(A, 0, 1)
122 | # B = np.swapaxes(B, 1, 2)
123 | # B = np.swapaxes(B, 0, 1)
124 | return A, B
125 |
126 |
127 | def save_res(self, contents, save_path, filename):
128 | filename = save_path + filename + '.txt'
129 | fh = open(filename, 'a')
130 | fh.write(contents)
131 | fh.close()
132 |
133 |
134 | def compare(label, Model_img, ):
135 | label_mono_np, Model_img_mono_np = to_same_size_ndarray(label.convert('L'), Model_img.convert('L'))
136 | label_np, Model_img_np = to_same_size_ndarray(label, Model_img)
137 | PSNR = compare_psnr(im_true=label_np, im_test=Model_img_np)
138 | # http://www.voidcn.com/article/p-auyocqzg-bac.html
139 | SSIM = compare_ssim(X=label_mono_np, Y=Model_img_mono_np, win_size=11, multichannel=False)
140 | # SSIM = compare_ssim(X=label_np, Y=Model_img_np, win_size=11, multichannel=True)
141 | # SSIM = compare_ssim(X=label_np, Y=Model_img_np, win_size=11, data_range=255, K1=0.01, K2=0.03,
142 | # gaussian_weights=True, sigma=1.5, use_sample_covariance=False)
143 | # SSIM = compare_ssim(Image.fromarray(label_np, mode='RGB'), Image.fromarray(Model_img_np, mode='RGB'))
144 | print("PSNR:", PSNR, "SSIM:", SSIM)
145 | # Model_img.save(save_path + filename + str(i + 1) + '_PSNR=' + PSNR[:7] + ".TIF", "TIFF")
146 | # label.save(save_path + filename + str(i + 1) + "_Real" + ".TIF", "TIFF")
147 | return PSNR, SSIM
148 |
149 |
150 | def test(save_path="./test_result/", filename="test_result_PSNR", test_list=range(0, 10)):
151 | # random_pick = random.sample(range(0, 100), 100)
152 | random_pick = range(0, 10)
153 | PSNR_SUM = 0
154 | SSIM_SUM = 0
155 | for i in test_list:
156 | data_path, label_path = imgs[i]
157 | data = Image.open(data_path).convert('L')
158 | label = Image.open(label_path).convert('RGB')
159 | # data.show()
160 | # label.show()
161 | Model_img = run(data)
162 | # Model_img.show()
163 |
164 | label_np, Model_img_np = to_same_size_ndarray(label, Model_img)
165 | PSNR = compare_psnr(im_true=label_np, im_test=Model_img_np)
166 | # http://www.voidcn.com/article/p-auyocqzg-bac.html
167 | # SSIM = compare_ssim(X=label_np, Y=Model_img_np, win_size=11, multichannel=True)
168 | print("PSNR:", PSNR, "SSIM:", SSIM)
169 |
170 | PSNR_SUM += PSNR
171 | SSIM_SUM += SSIM
172 |
173 | PSNR = str(PSNR)
174 | SSIM = str(SSIM)
175 |
176 | str_write = ''
177 | str_write += ("No." + str(i + 1) + ": " + PSNR + "\n")
178 | str_write += ("No." + str(i + 1) + ": " + SSIM + "\n")
179 | save_res(str_write, save_path, filename)
180 |
181 | # random_ID = str(random.randint(0, 100000000))
182 | Model_img.save(save_path + filename + str(i + 1) + '_PSNR=' + PSNR[:7] + ".TIF", "TIFF")
183 |
184 | label.save(save_path + filename + str(i + 1) + "_Real" + ".TIF", "TIFF")
185 |
186 | avg_str = ""
187 | avg_str += "PSNR_AVG=" + str(PSNR_SUM / len(test_list))[:7] + "\n"
188 | avg_str += "SSIM_AVG=" + str(SSIM_SUM / len(test_list))[:7] + "\n"
189 | print("AVG: ", avg_str)
190 | save_res(avg_str, save_path, filename)
191 |
192 |
193 | def PSNR(self, A, B):
194 | A = np.array(A)
195 | B = np.array(B)
196 | h = min(A.shape[0], B.shape[0])
197 | w = min(A.shape[1], B.shape[1])
198 | A = A[:h, :w]
199 | B = B[:h, :w]
200 | mse = ((A.astype(np.float) - B.astype(np.float)) ** 2).mean()
201 | if mse == 0: return 10e4
202 | print("MSE: ", mse)
203 | return 10 * np.log10((255.0 ** 2) / mse)
204 |
205 |
206 | def Run_test(Model, model_path, test_data_path, save_to, as_name, _block_size=32, _test_list=range(0, 10)):
207 | torch.set_default_tensor_type('torch.DoubleTensor')
208 | print("Creating the model...")
209 | MyModel = Model.double().to(device)
210 | print("Loading the model data...")
211 | MyModel.load_state_dict(torch.load(model_path, map_location=device))
212 | print("Init...")
213 | Mytest = Test(test_data_path, MyModel, block_size=_block_size)
214 | Mytest.test(save_path=save_to, filename=as_name, test_list=_test_list)
215 |
216 |
217 | def To_Bayer(img):
218 | w, h = img.size
219 | # img=img.resize((int(w/2),int(h/2)), Image.ANTIALIAS)
220 | # w,h=img.size
221 | # r,g,b=img.split()
222 | data = np.array(img)
223 | """
224 | R G R G
225 | G B G B
226 | R G R G
227 | G B G B
228 | """
229 | bayer_mono = np.zeros((h, w))
230 | for r in range(h):
231 | for c in range(w):
232 | if (0 == r % 2):
233 | if (1 == c % 2):
234 | data[r, c, 0] = 0
235 | data[r, c, 2] = 0
236 |
237 | bayer_mono[r, c] = data[r, c, 1]
238 | else:
239 | data[r, c, 1] = 0
240 | data[r, c, 2] = 0
241 |
242 | bayer_mono[r, c] = data[r, c, 0]
243 | else:
244 | if (0 == c % 2):
245 | data[r, c, 0] = 0
246 | data[r, c, 2] = 0
247 |
248 | bayer_mono[r, c] = data[r, c, 1]
249 | else:
250 | data[r, c, 0] = 0
251 | data[r, c, 1] = 0
252 |
253 | bayer_mono[r, c] = data[r, c, 2]
254 |
255 | # 三通道Bayer图像
256 | bayer = Image.fromarray(data)
257 |
258 | return bayer
259 |
260 | # Bayer_mono=Image.fromarray(bayer_mono)
261 | # Bayer_mono.convert('L')
262 | # Bayer_mono.convert('RGB')
263 | # Bayer_mono.show()
264 | # return Bayer_mono
265 |
266 |
267 | def LR_to_HR(img, model, block_size=64):
268 | pass
269 |
--------------------------------------------------------------------------------
/Train.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | import torch
4 | import tqdm
5 | import torch.utils.data as data
6 | import torch.nn as nn
7 | from torchvision import transforms
8 | import time
9 | import random
10 | from DataSet import CustomDataset
11 | from Model import Net
12 | from torch.utils.data import DataLoader
13 | import matplotlib.pyplot as plt
14 | import numpy as np
15 | import torch.backends.cudnn as cudnn
16 | from Test_class import *
17 |
18 | # *** 超参数*** `
19 |
20 | EPOCH = 1
21 | SUB_EPOCH_SIZE = 200
22 | SUB_EPOCH = 10000
23 | BATCH_COUNTER = 0
24 | LR_HALF = 10000
25 | LR = 0.0001
26 | SEED = 666
27 | BATCH_BLOCK_SIZE = 64
28 | BATCH_SIZE = 12
29 | DATA_SHUFFLE = True
30 | NUM_WORKERS = 2
31 | # 检查GPU是否可用
32 | DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
33 | print("cuda :", torch.cuda.is_available(), "GPU Num : ", torch.cuda.device_count())
34 | if torch.cuda.is_available():
35 | # cudnn.benchmark = True
36 | cudnn.deteministic = True
37 | torch.set_default_tensor_type('torch.cuda.FloatTensor')
38 | else:
39 | torch.set_default_tensor_type('torch.FloatTensor')
40 | # *** 路径 ***
41 | SAVE_PATH = './Saved_Models/20190226Traned_Model'
42 | MODEL_FILENAME = 'Model.pkl'
43 | MODEL_SAVE_PATH = os.path.join(SAVE_PATH, MODEL_FILENAME)
44 | BEST_MODEL_FILENAME = 'Best_Model.pkl'
45 | BEST_MODEL_SAVE_PATH = os.path.join(SAVE_PATH, BEST_MODEL_FILENAME)
46 | SSIM_BEST_MODEL_FILENAME = 'SSIM_Best_Model.pkl'
47 | SSIM_BEST_MODEL_SAVE_PATH = os.path.join(SAVE_PATH, SSIM_BEST_MODEL_FILENAME)
48 | PARA_FILENAME = 'Para.pkl'
49 | PARA_SAVE_PATH = os.path.join(SAVE_PATH, PARA_FILENAME)
50 | TRAIN_DATA_DIR = os.path.join(os.path.expanduser('~'), 'Dataset/RAISE_8K')
51 | TRAIN_DATA_PATH = os.path.join(TRAIN_DATA_DIR, '8K_TRAIN_DATA/8K_TRAIN_DATA.txt')
52 | TEST_DATA_PATH = "./TEST_DATA/TEST_DATA.txt"
53 |
54 |
55 | # *** 模型、数据集 ***
56 |
57 |
58 | def init_para():
59 | start_time = time.perf_counter()
60 | try:
61 | print('===> Find the para_saved file')
62 | open(PARA_SAVE_PATH)
63 | except FileNotFoundError:
64 | print('===> The para_saved file Not exist, creating new one...')
65 | train_dataset_list = txt_to_path_list(TRAIN_DATA_PATH)
66 | random.shuffle(train_dataset_list)
67 | torch.save({
68 | 'epoch': EPOCH,
69 | 'batch_counter': BATCH_COUNTER,
70 | 'lr': LR,
71 | 'optimizer param_groups': torch.optim.Adam(Net().to(DEVICE).parameters(),
72 | lr=LR,
73 | betas=(0.9, 0.999),
74 | eps=1e-08,
75 | amsgrad=True).state_dict()['param_groups'][0],
76 | 'train_dataset_list': train_dataset_list,
77 | 'loss_list': [],
78 | 'result_list': [],
79 | 'hard_cases_list': [],
80 | 'best_result': [0, 0]
81 | }, PARA_SAVE_PATH)
82 | print('==> Done with initialization!')
83 | finally:
84 | print('===> Init_para used time: ', time.perf_counter() - start_time)
85 |
86 |
87 | def txt_to_path_list(txt_path):
88 | with open(txt_path, 'r') as f:
89 | return list(map(lambda line: line.strip().split(' '), f))
90 |
91 |
92 | def get_train_dataset():
93 | start_time = time.perf_counter()
94 | try:
95 | print('===> Try to get train dataset from saved saved file...')
96 | train_dataset_list = para['train_dataset_list']
97 | epoch = para['epoch']
98 | print('===> Pre train_dataset_list : ', len(train_dataset_list))
99 | print('===> Epoch :', epoch)
100 | # random.shuffle(train_dataset_list)
101 | L = min(SUB_EPOCH_SIZE * BATCH_SIZE, len(train_dataset_list))
102 | # print(len(para['hard_cases_list']))
103 | if L <= 0:
104 | if len(para['hard_cases_list']) > 0:
105 | print('===> Loading hard_cases_list...')
106 | global hard_cases_list
107 | train_dataset_list = hard_cases_list # para[' hard_cases_list']
108 | hard_cases_list = []
109 | else:
110 | print('===> Loading TXT...')
111 | train_dataset_list = txt_to_path_list(TRAIN_DATA_PATH)
112 |
113 | epoch = epoch + 1
114 |
115 | L = min(SUB_EPOCH_SIZE * BATCH_SIZE, len(train_dataset_list))
116 | if L <= 0:
117 | raise (FileNotFoundError('Train_data_path.txt File not found'))
118 |
119 | train_dataset_rest_list = train_dataset_list[L:]
120 | train_dataset_list = train_dataset_list[:L]
121 | # para.update({'train_dataset_list': train_dataset_rest_list})
122 | print('===> train_dataset_list now : ', len(train_dataset_list))
123 | print('===> train_dataset_rest_list: ', len(train_dataset_rest_list))
124 | # torch.save(para, PARA_SAVE_PATH)
125 | # print(len(torch.load(PARA_SAVE_PATH)['train_dataset_list']))
126 |
127 | # DataLoader(dataset, batch_size=1,
128 | # shuffle=False, sampler=None,
129 | # batch_sampler=None, num_workers=0,
130 | # collate_fn=,
131 | # pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None
132 | return DataLoader(dataset=CustomDataset(data_dir=TRAIN_DATA_DIR, file_path_list=train_dataset_list),
133 | batch_size=BATCH_SIZE,
134 | shuffle=DATA_SHUFFLE,
135 | num_workers=NUM_WORKERS,
136 | pin_memory=True), {'train_dataset_list': train_dataset_rest_list, 'epoch': epoch}
137 | except FileNotFoundError:
138 | raise FileNotFoundError('File not found')
139 | finally:
140 | print('===> Get_train_dataset used time: ', time.perf_counter() - start_time)
141 |
142 |
143 | def get_test_dataset():
144 | pass
145 |
146 |
147 | def loading_model():
148 | if torch.cuda.is_available():
149 | torch.set_default_tensor_type('torch.cuda.FloatTensor')
150 | else:
151 | torch.set_default_tensor_type('torch.FloatTensor')
152 | model = Net(resnet_level=24).to(DEVICE)
153 | try:
154 | print('===> Loading the saved model...')
155 | model.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=DEVICE))
156 | return model
157 | except FileNotFoundError:
158 | print('===> Loading the saved model fail, create a new one...')
159 | return model
160 | finally:
161 | pass
162 |
163 |
164 | def train(sub_epoch):
165 | sub_epoch_loss = 0
166 | global lr
167 | global batch_counter
168 | sub_epoch_start_time = time.perf_counter()
169 | for iteration, (data, label, path) in enumerate(train_dataset, 1):
170 | batch_counter = batch_counter + 1
171 | if batch_counter != 0 and 0 == batch_counter % LR_HALF:
172 | lr = lr / 2
173 | for param_group in optimizer.param_groups:
174 | param_group['lr'] = lr
175 | print('===> Optimizer update : ', optimizer.state_dict()['param_groups'][0])
176 | start_time = time.perf_counter()
177 | data, label = data.to(DEVICE), label.to(DEVICE)
178 | optimizer.zero_grad()
179 | loss = criterion(model(data), label)
180 | sub_epoch_loss += loss.item()
181 | loss.backward()
182 | optimizer.step()
183 |
184 | global loss_list
185 | if len(loss_list) > 0 and loss.item() > sum(loss_list[max(0, len(loss_list) - 100):]) / 100:
186 | for i in range(BATCH_SIZE):
187 | hard_cases_list.append([path[0][i], path[1][i]])
188 | # print('==> Add to hard_cases_list', hard_cases_list[-8:])
189 | print('hard_cases_list size: ', len(hard_cases_list))
190 |
191 | print("===> Sub_epoch[{}]({}/{}): Loss: {:.12f}".format(sub_epoch, iteration, len(train_dataset), loss.item()))
192 | print('No.{} batches'.format(batch_counter), 'Time used :', time.perf_counter() - start_time)
193 |
194 | print("===> Sub_epoch {} Complete: Avg. Loss: {:.12f}".format(sub_epoch, sub_epoch_loss / len(train_dataset)))
195 | print('{} Batches time used :'.format(len(train_dataset)), time.perf_counter() - sub_epoch_start_time)
196 | return sub_epoch_loss / len(train_dataset)
197 |
198 |
199 | def test(model):
200 | print('===> Testing the performance of model...')
201 | test_model = model
202 | test_list = txt_to_path_list(TEST_DATA_PATH)
203 | PSNR_AVG, SSIM_AVG = 0.0, 0.0
204 | l = 10
205 | img_list = []
206 | for i in range(l):
207 | # print(test_list[i][0])
208 | blks, h, w = padding_and_to_blks(bayer_rgb_img=Image.open(test_list[i][0]).convert('RGB'),
209 | block_size=BATCH_BLOCK_SIZE)
210 | model_img = run_forward(test_model, DEVICE, blks, h, w, block_size=BATCH_BLOCK_SIZE)
211 | img_list.append(model_img)
212 | PSNR, SSIM = compare(label=Image.open(test_list[i][1]).convert('RGB'), Model_img=model_img)
213 | PSNR_AVG += PSNR
214 | SSIM_AVG += SSIM
215 | PSNR_AVG /= l
216 | SSIM_AVG /= l
217 | global best_result
218 | global result_list
219 | result_list.append([PSNR_AVG, SSIM_AVG])
220 | print('PSNR_AVG :', PSNR_AVG, 'SSIM_AVG :', SSIM_AVG)
221 | if PSNR_AVG > best_result[0] and SSIM_AVG > best_result[1]:
222 | best_result = [PSNR_AVG, SSIM_AVG]
223 | print('*** Saving the best model...')
224 | torch.save(model.state_dict(), os.path.join(SAVE_PATH, 'Best_Model_Temp.pkl'))
225 | os.remove(BEST_MODEL_SAVE_PATH)
226 | os.rename(os.path.join(SAVE_PATH, 'Best_Model_Temp.pkl'), BEST_MODEL_SAVE_PATH)
227 |
228 | for i in img_list:
229 | i.show()
230 | '''
231 | elif SSIM_AVG > best_result[1]:
232 | print('*** Saving the SSIM best model...')
233 | torch.save(model.state_dict(), os.path.join(SAVE_PATH, 'SSIM_Best_Model_Temp.pkl'))
234 | os.remove(SSIM_BEST_MODEL_SAVE_PATH)
235 | os.rename(os.path.join(SAVE_PATH, 'SSIM_Best_Model_Temp.pkl'), SSIM_BEST_MODEL_SAVE_PATH)
236 | '''
237 |
238 |
239 | def check_point():
240 | torch.save(model.state_dict(), os.path.join(SAVE_PATH, 'Model_Temp.pkl'))
241 | os.remove(MODEL_SAVE_PATH)
242 | os.rename(os.path.join(SAVE_PATH, 'Model_Temp.pkl'), MODEL_SAVE_PATH)
243 | global lr
244 | global hard_cases_list
245 | global best_result
246 | global result_list
247 | para.update(para_update)
248 | para.update({
249 | 'batch_counter': batch_counter,
250 | 'lr': lr,
251 | 'optimizer param_groups': optimizer.state_dict()['param_groups'][0],
252 | 'hard_cases_list': hard_cases_list,
253 | 'loss_list': loss_list,
254 | 'result_list': result_list,
255 | 'best_result': best_result
256 | })
257 | torch.save(para, os.path.join(SAVE_PATH, 'Para_Temp.pkl'))
258 | os.remove(PARA_SAVE_PATH)
259 | os.rename(os.path.join(SAVE_PATH, 'Para_Temp.pkl'), PARA_SAVE_PATH)
260 | print('Rest list: ', len(para['train_dataset_list']))
261 | print('Loss list: ', len(para['loss_list']))
262 | print('Hard_cases_list', len(para['hard_cases_list']))
263 |
264 |
265 |
266 |
267 | # if __name__ == '__main__':
268 | print('Start training...')
269 | init_para()
270 | model = loading_model()
271 | para = torch.load(PARA_SAVE_PATH)
272 |
273 | batch_counter = para['batch_counter']
274 | lr = para['lr']
275 | loss_list = para['loss_list']
276 | result_list = para['result_list']
277 | best_result = para['best_result']
278 | hard_cases_list = para['hard_cases_list']
279 |
280 | optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.999), eps=1e-08, amsgrad=True)
281 | optimizer.state_dict().update(para['optimizer param_groups'])
282 | print('===> Optimizer param_groups state: ', optimizer.state_dict()['param_groups'][0])
283 | criterion = nn.MSELoss()
284 |
285 | test(loading_model())
286 |
287 | for i in range(1, SUB_EPOCH + 1):
288 | print('===> Batch_counter : ', batch_counter)
289 | train_dataset, para_update = get_train_dataset()
290 | new_avg_loss = train(i)
291 | loss_list.append(new_avg_loss)
292 | check_point()
293 | test(model)
294 |
--------------------------------------------------------------------------------
/Train_control_pannel.py:
--------------------------------------------------------------------------------
1 | import os
2 | # from Train import *
3 | from PIL import Image
4 | import torch
5 | import tqdm
6 | import torch.utils.data as data
7 | import torch.nn as nn
8 | from torchvision import transforms
9 | import time
10 | import random
11 | from DataSet import CustomDataset
12 | from Model import Net
13 | from torch.utils.data import DataLoader
14 | import matplotlib.pyplot as plt
15 | import numpy as np
16 | from Test_class import *
17 |
18 | # 检查GPU是否可用
19 | DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20 | print("cuda :", torch.cuda.is_available(), "GPU Num : ", torch.cuda.device_count())
21 |
22 | BATCH_BLOCK_SIZE = 512
23 |
24 | # *** 路径 ***
25 | SAVE_PATH = './Saved_Models/20190226Traned_Model'
26 | MODEL_FILENAME = 'Model.pkl'
27 | MODEL_SAVE_PATH = os.path.join(SAVE_PATH, MODEL_FILENAME)
28 | BEST_MODEL_FILENAME = 'Best_Model.pkl'
29 | BEST_MODEL_SAVE_PATH = os.path.join(SAVE_PATH, BEST_MODEL_FILENAME)
30 | SSIM_BEST_MODEL_FILENAME = 'SSIM_Best_Model.pkl'
31 | SSIM_BEST_MODEL_SAVE_PATH = os.path.join(SAVE_PATH, SSIM_BEST_MODEL_FILENAME)
32 | PARA_FILENAME = 'Para.pkl'
33 | PARA_SAVE_PATH = os.path.join(SAVE_PATH, PARA_FILENAME)
34 | TRAIN_DATA_DIR = os.path.join(os.path.expanduser('~'), 'Dataset/RAISE_8K')
35 | TRAIN_DATA_PATH = os.path.join(TRAIN_DATA_DIR, '8K_TRAIN_DATA/8K_TRAIN_DATA.txt')
36 | TEST_DATA_PATH = "./TEST_DATA/TEST_DATA.txt"
37 |
38 |
39 | def txt_to_path_list(txt_path):
40 | with open(txt_path, 'r') as f:
41 | return list(map(lambda line: line.strip().split(' '), f))
42 |
43 |
44 | def loading_model(model_path=BEST_MODEL_SAVE_PATH):
45 | model = Net(resnet_level=24).to(DEVICE)
46 | try:
47 | print('===> Loading the saved model...')
48 | model.load_state_dict(torch.load(model_path, map_location=DEVICE))
49 | return model
50 | except FileNotFoundError:
51 | print('===> Loading the saved model fail, create a new one...')
52 | return model
53 | finally:
54 | pass
55 |
56 |
57 | def create_model_imgs(test_list, model_path=BEST_MODEL_SAVE_PATH, save_img=False):
58 | test_model = loading_model(model_path=model_path)
59 | PSNR_AVG, SSIM_AVG = 0.0, 0.0
60 | # img_list = []
61 | l = len(test_list)
62 | content = ''
63 | for i in range(l): # tqdm(range(l)):
64 | print('===> [{}/{}]'.format(i, l))
65 | blks, h, w = padding_and_to_blks(bayer_rgb_img=Image.open(test_list[i][0]).convert('RGB'),
66 | block_size=BATCH_BLOCK_SIZE)
67 | model_img = run_forward(test_model, DEVICE, blks, h, w, block_size=BATCH_BLOCK_SIZE)
68 | # img_list.append(model_img)
69 | PSNR, SSIM = compare(label=Image.open(test_list[i][1]).convert('RGB'), Model_img=model_img)
70 | model_img_file_name = 'No.{}_PSNR={:.4f}_SSIM={:.4f}.TIF'.format(i + 1, PSNR, SSIM)
71 | real_img_file_name = 'No.{}_Real.TIF'.format(i + 1)
72 | with open(file=os.path.join(IMG_SAVE_PATH, 'Result.txt'), mode='a') as f:
73 | f.write(model_img_file_name + '\n')
74 | model_img_save_path = os.path.join(IMG_SAVE_PATH, model_img_file_name)
75 | real_img_save_path = os.path.join(IMG_SAVE_PATH, real_img_file_name)
76 | if save_img:
77 | model_img.save(model_img_save_path, format='TIFF')
78 | Image.open(test_list[i][1]).convert('RGB').save(real_img_save_path, format='TIFF')
79 | PSNR_AVG += PSNR
80 | SSIM_AVG += SSIM
81 | PSNR_AVG /= l
82 | SSIM_AVG /= l
83 | print('PSNR_AVG :', PSNR_AVG, 'SSIM_AVG :', SSIM_AVG)
84 | with open(file=os.path.join(IMG_SAVE_PATH, 'Result.txt'), mode='a') as f:
85 | f.write('{} Pic:\nPSNR_AVG={:.12f}\nSSIM_AVG={:.12f}\n\n'.format(l, PSNR_AVG, SSIM_AVG))
86 |
87 |
88 | def show_graph(show_len=200):
89 | show_len = min(show_len, len(loss_list), len(result_list))
90 | plt.title('Result Analysis')
91 | show_range = range(min(show_len, len(result_list)))
92 | plt.subplot(212)
93 | plt.plot(show_range, loss_list[-len(show_range):], color='green', label='Loss')
94 | plt.xlabel('iteration times')
95 | plt.ylabel('Loss')
96 |
97 | plt.subplot(221)
98 | plt.plot(show_range,
99 | [result_list[i + len(result_list) - show_len][0] for i in show_range],
100 | color='red', label='PSNR')
101 | plt.plot(show_range, [best_result[0] for i in show_range], color='black', label='PSNR_MAX')
102 | plt.xlabel('iteration times')
103 | plt.ylabel('PSNR')
104 |
105 | plt.subplot(222)
106 | plt.plot(show_range, [best_result[1] for i in show_range], color='black', label='SSIM_MAX')
107 | plt.plot(show_range,
108 | [result_list[i + len(result_list) - show_len][1] for i in show_range],
109 | color='blue', label='SSIM')
110 | plt.legend() # 显示图例
111 | plt.xlabel('iteration times')
112 | plt.ylabel('SSIM')
113 | plt.show()
114 |
115 |
116 | para = torch.load(PARA_SAVE_PATH)
117 | batch_counter = para['batch_counter']
118 | lr = para['lr']
119 | loss_list = para['loss_list']
120 | result_list = para['result_list']
121 | best_result = para['best_result']
122 | hard_cases_list = para['hard_cases_list']
123 |
124 | # new_lr = 1e-7
125 | # optimizer = torch.optim.Adam(loading_model().parameters(), lr=new_lr, betas=(0.9, 0.999), eps=1e-08, amsgrad=True)
126 | # optimizer.state_dict().update(para['optimizer param_groups'])
127 | # para.update({'lr': new_lr})
128 | # para.update({'optimizer param_groups': optimizer.state_dict()['param_groups'][0]})
129 | # torch.save(para, PARA_SAVE_PATH)
130 |
131 | # best_result=[32.4,0.93]
132 | # para.update({'train_dataset_list': []})
133 | # para.update({'hard_cases_list': []})
134 | # torch.save(para, PARA_SAVE_PATH)
135 |
136 | # best_result=[32.4,0.93]
137 | # para.update({'best_result': best_result})
138 | # torch.save(para, PARA_SAVE_PATH)
139 |
140 | print('batch_counter', batch_counter)
141 | print('lr:', lr)
142 | print('loss_list', loss_list[-3:])
143 | print('result_list', result_list[-3:])
144 | print('best_result', best_result)
145 | # print('hard_cases_list', hard_cases_list)
146 | show_graph(show_len=350)
147 | # for i in range(42):
148 | # create_model_imgs(test_list=txt_to_path_list(TEST_DATA_PATH)[i * 50:(i + 1) * 50], model_path=BEST_MODEL_SAVE_PATH)
149 | # create_model_imgs(test_list=txt_to_path_list(TEST_DATA_PATH)[:50], model_path=MODEL_SAVE_PATH)
150 | # create_model_imgs(test_list=txt_to_path_list(TEST_DATA_PATH)[:50], model_path=SSIM_BEST_MODEL_SAVE_PATH)
151 | # create_model_imgs(test_list=txt_to_path_list(FAST_TEST_PATH)[:1000], model_path=MODEL_SAVE_PATH)
152 | # print(os.path.join(IMG_SAVE_PATH, 'No.{}_PSNR={:.4f}_SSIM={:.4f}.TIF'.format(1, 40, 1)))
153 |
--------------------------------------------------------------------------------
/__pycache__/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Linwei-Chen/Deep-Residual-Network-for-JointDemosaicing-and-Super-Resolution/a224f0ea673d70c26ee17aec9f27e1a7c31cbe8e/__pycache__/.DS_Store
--------------------------------------------------------------------------------
/__pycache__/DataSet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Linwei-Chen/Deep-Residual-Network-for-JointDemosaicing-and-Super-Resolution/a224f0ea673d70c26ee17aec9f27e1a7c31cbe8e/__pycache__/DataSet.cpython-37.pyc
--------------------------------------------------------------------------------
/__pycache__/Model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Linwei-Chen/Deep-Residual-Network-for-JointDemosaicing-and-Super-Resolution/a224f0ea673d70c26ee17aec9f27e1a7c31cbe8e/__pycache__/Model.cpython-37.pyc
--------------------------------------------------------------------------------
/__pycache__/Test_class.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Linwei-Chen/Deep-Residual-Network-for-JointDemosaicing-and-Super-Resolution/a224f0ea673d70c26ee17aec9f27e1a7c31cbe8e/__pycache__/Test_class.cpython-37.pyc
--------------------------------------------------------------------------------
/__pycache__/Train.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Linwei-Chen/Deep-Residual-Network-for-JointDemosaicing-and-Super-Resolution/a224f0ea673d70c26ee17aec9f27e1a7c31cbe8e/__pycache__/Train.cpython-37.pyc
--------------------------------------------------------------------------------
/img_to_imgblk.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import os
4 | import math
5 | import numpy as np
6 | from PIL import Image
7 | import time
8 |
9 | SOURCE_PATH = './8K_CROSS_DATA' # "/Volumes/750GB/RAISE"
10 | BLOCK_SIZE = 128 # Ground Truth 图片大小
11 | TRAIN_DATA_SAVE_PATH = "./8K_TRAIN_DATA"
12 | CROSS_DATA_SAVE_PATH = "./8K_CROSS_DATA"
13 | TEST_DATA_SAVE_PATH = "./8K_TEST_DATA"
14 |
15 |
16 | def DownSize(img):
17 | for i in range(3):
18 | w, h = img.size
19 | # print(int(w / 1.25), int(h / 1.25))
20 | img = img.resize((int(w / 1.25), int(h / 1.25)), Image.ANTIALIAS)
21 | return img
22 |
23 |
24 | def To_Bayer(img):
25 | w, h = img.size
26 | img = img.resize((int(w / 2), int(h / 2)), Image.ANTIALIAS)
27 | w, h = img.size
28 | # r,g,b=img.split()
29 | data = np.array(img)
30 | """
31 | R G R G
32 | G B G B
33 | R G R G
34 | G B G B
35 | """
36 | bayer_mono = np.zeros((h, w))
37 | for r in range(h):
38 | for c in range(w):
39 | if (0 == r % 2):
40 | if (1 == c % 2):
41 | data[r, c, 0] = 0
42 | data[r, c, 2] = 0
43 |
44 | bayer_mono[r, c] = data[r, c, 1]
45 | else:
46 | data[r, c, 1] = 0
47 | data[r, c, 2] = 0
48 |
49 | bayer_mono[r, c] = data[r, c, 0]
50 | else:
51 | if (0 == c % 2):
52 | data[r, c, 0] = 0
53 | data[r, c, 2] = 0
54 |
55 | bayer_mono[r, c] = data[r, c, 1]
56 | else:
57 | data[r, c, 0] = 0
58 | data[r, c, 1] = 0
59 |
60 | bayer_mono[r, c] = data[r, c, 2]
61 |
62 | # 三通道Bayer图像
63 | bayer = Image.fromarray(data)
64 | # bayer.show()
65 |
66 | return bayer
67 |
68 |
69 | # 获取文件夹下文件的方法
70 | # https://blog.csdn.net/LZGS_4/article/details/50371030
71 | # files = os.listdir(source_path)
72 | # print(files)
73 | # https://blog.csdn.net/lsq2902101015/article/details/51305825
74 |
75 | def file_name(path):
76 | L = []
77 | for root, dirs, files in os.walk(path):
78 | print("root", root)
79 | print("dirs", dirs)
80 | print("files", type(files), files)
81 | for file in files:
82 | if os.path.splitext(file)[1] == '.TIF':
83 | if (file[:2] == '._'):
84 | file = file[2:]
85 | L.append(os.path.join(root, file))
86 | # L.append(root+'/'+file)
87 | return L
88 |
89 |
90 | img_list = file_name(SOURCE_PATH)
91 |
92 | img_list_size = len(img_list)
93 | train_data = img_list[:math.ceil(img_list_size * 0.9)]
94 | test_data = img_list[math.floor(img_list_size * 0.9):]
95 |
96 | print('train data size:', len(train_data))
97 | print('test data size:', len(test_data))
98 |
99 |
100 | def get_name(path):
101 | print(type(path))
102 | size = len(path)
103 | return path[size - 14:size - 4]
104 |
105 |
106 | def mkdir(path):
107 | folder = os.path.exists(path)
108 | if not folder: # 判断是否存在文件夹如果不存在则创建为文件夹
109 | os.makedirs(path)
110 |
111 |
112 | def train_process(img_path, save_path):
113 | txt = open(save_path + '/' + save_path + '.txt', 'a')
114 | time_start = time.perf_counter()
115 | counter = 0
116 | for No, i in enumerate(img_path):
117 | print(i)
118 | img = Image.open(i)
119 | img = DownSize(img)
120 | # img.show()
121 | w, h = img.size
122 | row = int(h / BLOCK_SIZE)
123 | col = int(w / BLOCK_SIZE)
124 | print("No.", counter, "WxH=", w, 'x', h)
125 | counter = counter + 1
126 | img_ID = 0
127 | name = get_name(i)
128 | for r in range(row):
129 | for c in range(col):
130 | # print('img_ID:', img_ID)
131 | # 创建空图块并填充,保存原图并且
132 | # temp = Image.new(mode='RGB', size=(BLOCK_SIZE, BLOCK_SIZE))
133 | temp = img.crop((r * BLOCK_SIZE, c * BLOCK_SIZE, (r + 1) * BLOCK_SIZE, (c + 1) * BLOCK_SIZE)).convert(
134 | 'RGB')
135 | # temp.show()
136 | temp_bayer = To_Bayer(temp)
137 |
138 | mkdir(save_path + '/' + name)
139 | # 生成数据以及标签的路径信息,保存于txt文件中,并按路径保存图像
140 | data_str = save_path + '/' + name + '/' + name + '_' + str(img_ID) + 'data.TIF'
141 | label_str = save_path + '/' + name + '/' + name + '_' + str(img_ID) + 'label.TIF'
142 | txt.write(data_str + ' ' + label_str + '\n')
143 |
144 | temp.save(label_str, 'TIFF')
145 | temp_bayer.save(data_str, 'TIFF')
146 |
147 | img_ID = img_ID + 1
148 | time_used = time.perf_counter() - time_start
149 | print('Time used:', time_used)
150 | print('Time remain:', time_used / (No + 1) * len(img_path) - time_used)
151 |
152 |
153 | L = [2, 4, 5, 7, 9, 10, 12, 15, 17, 18, 19, 20, 21, 25, 29, 30, 32, 33, 34, 36, 37, 38, 42, 43, 45, 48, 52, 53, 54, 56,
154 | 60, 61, 62, 65, 66, 67, 70, 71, 72, 75, 77, 79, 81, 83, 86, 87, 90, 91, 96, 99]
155 |
156 |
157 | def test_process(img_path, save_path, img_ID=870):
158 | txt = open(save_path + '/' + save_path + '.txt', 'a')
159 | time_start = time.perf_counter()
160 | counter = 0
161 | for No, i in enumerate(img_path):
162 | # if No + 1 not in L: continue
163 |
164 | print("No.", counter)
165 | counter = counter + 1
166 |
167 | data_str = save_path + '/' + str(img_ID) + 'data.TIF'
168 | label_str = save_path + '/' + str(img_ID) + 'label.TIF'
169 | txt.write(data_str + ' ' + label_str + '\n')
170 |
171 | temp = Image.open(i)
172 | for j in range(1):
173 | temp = DownSize(temp)
174 | temp_bayer = To_Bayer(temp)
175 | temp.save(label_str, 'TIFF')
176 | temp_bayer.save(data_str, 'TIFF')
177 |
178 | img_ID = img_ID + 1
179 | time_used = time.perf_counter() - time_start
180 | print('Time used:', time_used)
181 | print('Time remain:', time_used / (No + 1) * len(img_path) - time_used)
182 |
183 | txt.close()
184 |
185 |
186 | def resave(img_list, save_path):
187 | time_start = time.perf_counter()
188 | for i ,f in enumerate(img_list,1):
189 | Image.open(f).convert('RGB').save(os.path.join(save_path, os.path.split(f)[1]), format='TIFF')
190 | time_used = time.perf_counter() - time_start
191 | print('===> No.{} Pic used:{}'.format(i, time_used))
192 | print('Time remain:', time_used / i * len(img_list) - time_used)
193 |
194 |
195 | # train_process(train_data[:1], TRAIN_DATA_SAVE_PATH)
196 |
197 | # test_process(test_data[:1], TEST_DATA_SAVE_PATH)
198 |
199 | # train_process(img_list[:6000], TRAIN_DATA_SAVE_PATH)
200 | test_process(img_list[:], TEST_DATA_SAVE_PATH, img_ID=0)
201 | # test_process(img_list[:], CROSS_DATA_SAVE_PATH)
202 | # resave(img_list=img_list[6000:], save_path=CROSS_DATA_SAVE_PATH)
203 |
--------------------------------------------------------------------------------
/测试图片名单.xlsx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Linwei-Chen/Deep-Residual-Network-for-JointDemosaicing-and-Super-Resolution/a224f0ea673d70c26ee17aec9f27e1a7c31cbe8e/测试图片名单.xlsx
--------------------------------------------------------------------------------