├── imgs
├── readme
├── KD.png
├── real_example.png
├── distillationIQA.png
└── synthetic_example.png
├── model_zoo
└── readme.txt
├── tools.py
├── models
├── CNNIQA.py
├── DCNN_NARIQA.py
├── LinearityIQA.py
├── WaDIQaM.py
├── TRIQ.py
├── HyperIQA.py
├── DistillationIQA.py
└── IQT.py
├── LICENSE
├── dataloaders
├── dataloader_LQ_HQ.py
├── dataloader_LQ.py
└── dataloader_LQ_HQ_diff_content_HQ.py
├── test_DistillationIQA_single.py
├── README.md
├── option_train_DistillationIQA_FR.py
├── option_train_DistillationIQA.py
├── test_DistillationIQA.py
├── train_DistillationIQA_FR.py
├── train_DistillationIQA.py
└── folders
├── folders_LQ_HQ.py
├── folders_LQ.py
└── folders_LQ_HQ_diff_content_HQ.py
/imgs/readme:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/model_zoo/readme.txt:
--------------------------------------------------------------------------------
1 | Put trained model here!
2 |
--------------------------------------------------------------------------------
/imgs/KD.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guanghaoyin/CVRKD-IQA/HEAD/imgs/KD.png
--------------------------------------------------------------------------------
/imgs/real_example.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guanghaoyin/CVRKD-IQA/HEAD/imgs/real_example.png
--------------------------------------------------------------------------------
/imgs/distillationIQA.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guanghaoyin/CVRKD-IQA/HEAD/imgs/distillationIQA.png
--------------------------------------------------------------------------------
/imgs/synthetic_example.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guanghaoyin/CVRKD-IQA/HEAD/imgs/synthetic_example.png
--------------------------------------------------------------------------------
/tools.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy.optimize import curve_fit
3 |
4 | def convert_obj_score(ori_obj_score, MOS):
5 | """
6 | func:
7 | fitting the objetive score to the MOS scale.
8 | nonlinear regression fit
9 | """
10 | def logistic_fun(x, a, b, c, d):
11 | return (a - b)/(1 + np.exp(-(x - c)/abs(d))) + b
12 | # nolinear fit the MOSp
13 | param_init = [np.max(MOS), np.min(MOS), np.mean(ori_obj_score), 1]
14 | popt, pcov = curve_fit(logistic_fun, ori_obj_score, MOS,
15 | p0 = param_init, ftol =1e-8, maxfev=500000)
16 | #a, b, c, d = popt[0], popt[1], popt[2], popt[3]
17 |
18 | obj_fit_score = logistic_fun(ori_obj_score, popt[0], popt[1], popt[2], popt[3])
19 |
20 | return obj_fit_score
21 |
--------------------------------------------------------------------------------
/models/CNNIQA.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 | import torch.nn.functional as F
4 | import torch.nn as nn
5 |
6 | class CNNIQAnet(nn.Module):
7 | def __init__(self):
8 | super(CNNIQAnet, self).__init__()
9 |
10 | self.conv = nn.Conv2d(in_channels=1, out_channels=50, kernel_size=7)
11 | self.fc1 = nn.Linear(100, 800)
12 | self.fc2 = nn.Linear(800, 800)
13 | self.fc3 = nn.Linear(800, 1)
14 |
15 | def forward(self, input):
16 | x = input.view(-1, input.size(-3), input.size(-2), input.size(-1))
17 |
18 | x = self.conv(x)
19 |
20 | x1 = F.max_pool2d(x, (x.size(-2), x.size(-1)))
21 | x2 = -F.max_pool2d(-x, (x.size(-2), x.size(-1)))
22 |
23 | h = torch.cat((x1, x2), 1)
24 | h = h.squeeze(3).squeeze(2)
25 |
26 | h = F.relu(self.fc1(h))
27 | h = F.dropout(h)
28 | h = F.relu(self.fc2(h))
29 |
30 | q = self.fc3(h)
31 |
32 | return q
33 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 guanghaoyin
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/dataloaders/dataloader_LQ_HQ.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import folders.folders_LQ_HQ as folders
3 |
4 | class DataLoader(object):
5 | """Dataset class for IQA databases"""
6 |
7 | def __init__(self, dataset, path, img_indx, patch_size, patch_num, batch_size=1, istrain=True, self_patch_num=10):
8 |
9 | self.batch_size = batch_size
10 | self.istrain = istrain
11 |
12 | if dataset == 'live':
13 | self.data = folders.LIVEFolder(
14 | root=path, index=img_indx, patch_num=patch_num, patch_size = patch_size, self_patch_num=self_patch_num)
15 | elif dataset == 'csiq':
16 | self.data = folders.CSIQFolder(
17 | root=path, index=img_indx, patch_num=patch_num, patch_size = patch_size, self_patch_num=self_patch_num)
18 | elif dataset == 'kadid10k':
19 | self.data = folders.Kadid10kFolder(
20 | root=path, index=img_indx, patch_num=patch_num, patch_size = patch_size, self_patch_num=self_patch_num)
21 | elif dataset == 'tid2013':
22 | self.data = folders.TID2013Folder(
23 | root=path, index=img_indx, patch_num=patch_num, patch_size = patch_size, self_patch_num=self_patch_num)
24 |
25 | def get_dataloader(self):
26 | if self.istrain:
27 | dataloader = torch.utils.data.DataLoader(
28 | self.data, batch_size=self.batch_size, shuffle=True)
29 | else:
30 | dataloader = torch.utils.data.DataLoader(
31 | self.data, batch_size=1, shuffle=False)
32 | return dataloader
33 |
--------------------------------------------------------------------------------
/models/DCNN_NARIQA.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 | import torch.nn.functional as F
4 | import torch.nn as nn
5 |
6 | class DCNN_NARIQA(nn.Module):
7 | def __init__(self):
8 | super().__init__()
9 | #ref path
10 | self.block1_ref = nn.Sequential(
11 | nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=4),
12 | nn.ReLU(),
13 | nn.MaxPool2d(kernel_size=5, stride=1))
14 | self.block2_ref = nn.Sequential(
15 | nn.Conv2d(in_channels=64, out_channels=32, kernel_size=9, stride=1),
16 | nn.ReLU())
17 | self.fc3_ref = nn.Linear(in_features=59168, out_features=1024)
18 |
19 | #LQ path
20 | self.block1_lq = nn.Sequential(
21 | nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=4),
22 | nn.ReLU(),
23 | nn.MaxPool2d(kernel_size=5, stride=1))
24 | self.block2_lq = nn.Sequential(
25 | nn.Conv2d(in_channels=64, out_channels=32, kernel_size=9, stride=1),
26 | nn.ReLU())
27 | self.fc3_lq = nn.Linear(in_features=59168, out_features=1024)
28 |
29 | self.fc4 = nn.Linear(in_features=2048, out_features=1024)
30 | self.fc5 = nn.Linear(in_features=1024, out_features=1)
31 |
32 | def forward(self, lq_patches, ref_patches):
33 | feature_lq = self.block1_lq(lq_patches)
34 | feature_lq = self.block2_lq(feature_lq)
35 | feature_lq = self.fc3_lq(feature_lq.view(feature_lq.size(0), -1))
36 |
37 | feature_ref = self.block1_ref(ref_patches)
38 | feature_ref = self.block2_ref(feature_ref)
39 | feature_ref = self.fc3_ref(feature_ref.view(feature_ref.size(0), -1))
40 |
41 | concat_feature = torch.cat((feature_ref, feature_lq), 1)
42 | concat_feature = self.fc4(concat_feature)
43 | pred = self.fc5(concat_feature)
44 | return pred
45 |
46 | if __name__ == "__main__":
47 | x = torch.rand((1,3,224,224))
48 | y = torch.rand((1,3,224,224))
49 | net = DCNN_NARIQA()
50 | pred = net(x, y)
51 |
--------------------------------------------------------------------------------
/test_DistillationIQA_single.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | from option_train_DistillationIQA import set_args, check_args
4 | import numpy as np
5 | from models.DistillationIQA import DistillationIQANet
6 | from PIL import Image
7 | import torchvision
8 |
9 | img_num = {
10 | 'kadid10k': list(range(0,10125)),
11 | 'live': list(range(0, 29)),#ref HR image
12 | 'csiq': list(range(0, 30)),#ref HR image
13 | 'tid2013': list(range(0, 25)),
14 | 'livec': list(range(0, 1162)),# no-ref image
15 | 'koniq-10k': list(range(0, 10073)),# no-ref image
16 | 'bid': list(range(0, 586)),# no-ref image
17 | }
18 | folder_path = {
19 | 'pipal':'./dataset/PIPAL',
20 | 'live': './dataset/LIVE/',
21 | 'csiq': './dataset/CSIQ/',
22 | 'tid2013': './dataset/TID2013/',
23 | 'livec': './dataset/LIVEC/',
24 | 'koniq-10k': './dataset/koniq-10k/',
25 | 'bid': './dataset/BID/',
26 | 'kadid10k':'./dataset/kadid10k/'
27 | }
28 |
29 |
30 | class DistillationIQASolver(object):
31 | def __init__(self, config, lq_path, ref_path):
32 | self.config = config
33 | self.config.teacherNet_model_path = './model_zoo/FR_teacher_cross_dataset.pth'
34 | self.config.studentNet_model_path = './model_zoo/NAR_student_cross_dataset.pth'
35 |
36 | self.device = torch.device('cuda' if config.gpu_ids is not None else 'cpu')
37 | self.txt_log_path = os.path.join(config.log_checkpoint_dir,'log.txt')
38 | with open(self.txt_log_path,"w+") as f:
39 | f.close()
40 |
41 | #model
42 | self.teacherNet = DistillationIQANet(self_patch_num=config.self_patch_num, distillation_layer=config.distillation_layer)
43 | if config.teacherNet_model_path:
44 | self.teacherNet._load_state_dict(torch.load(config.teacherNet_model_path))
45 | self.teacherNet = self.teacherNet.to(self.device)
46 | self.teacherNet.train(False)
47 | self.studentNet = DistillationIQANet(self_patch_num=config.self_patch_num, distillation_layer=config.distillation_layer)
48 | if config.studentNet_model_path:
49 | self.studentNet._load_state_dict(torch.load(config.studentNet_model_path))
50 | self.studentNet = self.studentNet.to(self.device)
51 | self.studentNet.train(True)
52 |
53 | self.transform = torchvision.transforms.Compose([
54 | torchvision.transforms.RandomCrop(size=self.config.patch_size),
55 | torchvision.transforms.ToTensor(),
56 | torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406),
57 | std=(0.229, 0.224, 0.225))
58 | ])
59 | #data
60 | self.LQ_patches = self.preprocess(lq_path)
61 | self.ref_patches = self.preprocess(ref_path)
62 |
63 | def preprocess(self, path):
64 | with open(path, 'rb') as f:
65 | img = Image.open(f)
66 | img= img.convert('RGB')
67 | patches = []
68 | for _ in range(self.config.self_patch_num):
69 | patch = self.transform(img)
70 | patches.append(patch.unsqueeze(0))
71 | patches = torch.cat(patches, 0)
72 | return patches.unsqueeze(0)
73 |
74 | def test(self):
75 | self.studentNet.train(False)
76 | LQ_patches, ref_patches = self.LQ_patches.to(self.device), self.ref_patches.to(self.device)
77 | with torch.no_grad():
78 | _, _, pred = self.studentNet(LQ_patches, ref_patches)
79 | return float(pred.item())
80 |
81 | if __name__ == "__main__":
82 | config = set_args()
83 | config = check_args(config)
84 |
85 | lq_path = './dataset/koniq-10k/1024x768/28311109.jpg'
86 | ref_path = './dataset/DIV2K_ref/val_HR/0801.png'
87 | label = 1.15686274509804
88 | solver = DistillationIQASolver(config=config, lq_path=lq_path, ref_path=ref_path)
89 | scores = []
90 | for _ in range(10):
91 | scores.append(solver.test())
92 | print(np.mean(scores))
93 | # result 1.2577123641967773
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # CVRKD-IQA(DistillationIQA)
2 | This repository is for CVRKD-IQA introduced in the following paper
3 |
4 | Guanghao Yin, Wei Wang, Zehuan Yuan, Chuchu Han, Wei Ji, Shouqian Sun and Changhu Wang, "Content-Variant Reference Image Quality Assessment via Knowledge Distillation", AAAI Oral, 2022 [arxiv](https://arxiv.org/abs/2202.13123)
5 |
6 | ## Introduction
7 | CVRKD-IQA is the first content-variant reference IQA method via knowledge distillation. The practicability of previous FR-IQAs is affected by the requirement for pixel-level aligned reference images. And NR-IQAs still have the potential to achieve better performance since HQ image information is not fully exploited. Hence, we use non-aligned reference (NAR) images to introduce various prior distributions of high-quality images. Moreover, the comparisons of distribution differences between HQ and LQ images can help our model better assess the image quality. Further, the knowledge distillation transfers more HQ-LQ distribution difference information from the FR-teacher to the NAR-student and stabilizing CVRKD-IQA performance. Since the content-variant and non-aligned reference HQ images are easy to obtain, our model can support more IQA applications with its relative robustness to content variations.
8 |
9 |

10 |
11 | ## Prepare data
12 | ### Training datasets
13 | Download synthetic [Kaddid-10K](http://database.mmsp-kn.de/kadid-10k-database.html) dataset. And download the training HQ images of [DIV2K](https://data.vision.ee.ethz.ch/cvl/DIV2K/) as the reference dataset.
14 |
15 | ### Testing datasets
16 | Download synthetic [LIVE](http://live.ece.utexas.edu/index.php), [CSIQ](https://qualinet.github.io/databases/image/categorical_image_quality_csiq_database/) [TID2013](http://www.ponomarenko.info/tid2013.htm) and authentic [KonIQ-10K](http://database.mmsp-kn.de/koniq-10k-database.html) datasets. And download the testing HQ images of [DIV2K](https://data.vision.ee.ethz.ch/cvl/DIV2K/) as the reference dataset.
17 |
18 | Place those unzipped data in ./dataset file
19 | ## Train
20 | ### 1.Train the FR-teacher
21 | (1) (optional) Download models for our paper and place it in './model_zoo/'
22 | The models for FR-teacher can be downloaded from [Google Cloud](https://drive.google.com/file/d/1niFBV-ysJeVoaXUPQp-08ovrjS9tPhGW/view?usp=sharing)
23 |
24 | (2) Quick start (you can change the options in option_train_DistillationIQA_FR.py)
25 | ```
26 | Python train_DistillationIQA_FR.py --self_patch_num 10 --patch_size 224
27 | ```
28 | ### 2.Fix pretrained FR-teacher and train the NAR-student
29 | (1) (optional) Download models for our paper and place it in './model_zoo/'
30 | The models for FR-teacher and NAR-student can be downloaded from [Google Cloud](https://drive.google.com/file/d/107TI1pa0TDxs3V8tO2KhmhKJmfc9ZOl4/view?usp=sharing)
31 |
32 | (2) Quick start (you can change the options in option_train_DistillationIQA.py)
33 | ```
34 | Python train_DistillationIQA.py --self_patch_num 10 --patch_size 224
35 | ```
36 |
37 | ## Test
38 | (1) Make sure the trained models are placed in './model_zoo/FR_teacher_cross_dataset.pth' and './model_zoo/NAR_student_cross_dataset.pth'
39 |
40 | (2) Quick start (you can change the options in option_train_DistillationIQA.py)
41 | ```
42 | Python test_DistillationIQA.py
43 | ```
44 | (3) test single image
45 | ```
46 | Python test_DistillationIQA_single.py
47 | ```
48 | ## More visual results
49 | Synthetic examples of IQA scores predicted by our NAR-student when gradually increasing the distortion levels.
50 | 
51 |
52 | Real-data examples of IQA scores predicted by our NAR-student.
53 | 
54 |
55 | ## T-SNE visual visualization of HQ-LQ difference-aware features of NAR-student w/ and w/o KD
56 | After distilled with FR-teacher, the HQ-LQ features in Fig(b) are clusterd. It proves that the HQ-LQ distribution difference prior from the FR-teacher can indeed
57 | help the NAR-student extract quality-sensitive discriminative features for more accurate and consistent performance.
58 |
59 | 
60 |
61 | ## Citation
62 |
63 | If you find the code helpful in your resarch or work, please cite the following papers.
64 |
65 | ```
66 | @article{yin2022content,
67 | title={Content-Variant Reference Image Quality Assessment via Knowledge Distillation},
68 | author={Yin, Guanghao and Wang, Wei and Yuan, Zehuan and Han, Chuchu and Ji, Wei and Sun, Shouqian and Wang, Changhu},
69 | journal={arXiv preprint arXiv:2202.13123},
70 | year={2022}
71 | }
72 | ```
73 | ## Acknowledgements
74 | Part of our code is built on [HyperIQA](https://github.com/SSL92/hyperIQA). We thank the authors for sharing their codes. Also thanks for the support of [Bytedance.Inc](https://github.com/bytedance)
75 |
--------------------------------------------------------------------------------
/models/LinearityIQA.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torchvision import models
5 | import numpy as np
6 |
7 |
8 | def SPSP(x, P=1, method='avg'):
9 | batch_size = x.size(0)
10 | map_size = x.size()[-2:]
11 | pool_features = []
12 | for p in range(1, P+1):
13 | pool_size = [np.int(d / p) for d in map_size]
14 | if method == 'maxmin':
15 | M = F.max_pool2d(x, pool_size)
16 | m = -F.max_pool2d(-x, pool_size)
17 | pool_features.append(torch.cat((M, m), 1).view(batch_size, -1)) # max & min pooling
18 | elif method == 'max':
19 | M = F.max_pool2d(x, pool_size)
20 | pool_features.append(M.view(batch_size, -1)) # max pooling
21 | elif method == 'min':
22 | m = -F.max_pool2d(-x, pool_size)
23 | pool_features.append(m.view(batch_size, -1)) # min pooling
24 | elif method == 'avg':
25 | a = F.avg_pool2d(x, pool_size)
26 | pool_features.append(a.view(batch_size, -1)) # average pooling
27 | else:
28 | m1 = F.avg_pool2d(x, pool_size)
29 | rm2 = torch.sqrt(F.relu(F.avg_pool2d(torch.pow(x, 2), pool_size) - torch.pow(m1, 2)))
30 | if method == 'std':
31 | pool_features.append(rm2.view(batch_size, -1)) # std pooling
32 | else:
33 | pool_features.append(torch.cat((m1, rm2), 1).view(batch_size, -1)) # statistical pooling: mean & std
34 | return torch.cat(pool_features, dim=1)
35 |
36 |
37 | class LinearityIQA(nn.Module):
38 | def __init__(self, arch='resnext101_32x8d', pool='avg', use_bn_end=False, P6=1, P7=1):
39 | super(LinearityIQA, self).__init__()
40 | self.pool = pool
41 | self.use_bn_end = use_bn_end
42 | if pool in ['max', 'min', 'avg', 'std']:
43 | c = 1
44 | else:
45 | c = 2
46 | self.P6 = P6 #
47 | self.P7 = P7 #
48 | features = list(models.__dict__[arch](pretrained=True).children())[:-2]
49 | if arch == 'alexnet':
50 | in_features = [256, 256]
51 | self.id1 = 9
52 | self.id2 = 12
53 | features = features[0]
54 | elif arch == 'vgg16':
55 | in_features = [512, 512]
56 | self.id1 = 23
57 | self.id2 = 30
58 | features = features[0]
59 | elif 'res' in arch:
60 | self.id1 = 6
61 | self.id2 = 7
62 | if arch == 'resnet18' or arch == 'resnet34':
63 | in_features = [256, 512]
64 | else:
65 | in_features = [1024, 2048]
66 | else:
67 | print('The arch is not implemented!')
68 | self.features = nn.Sequential(*features)
69 | self.dr6 = nn.Sequential(nn.Linear(in_features[0] * c * sum([p * p for p in range(1, self.P6+1)]), 1024),
70 | nn.BatchNorm1d(1024),
71 | nn.Linear(1024, 256),
72 | nn.BatchNorm1d(256),
73 | nn.Linear(256, 64),
74 | nn.BatchNorm1d(64), nn.ReLU())
75 | self.dr7 = nn.Sequential(nn.Linear(in_features[1] * c * sum([p * p for p in range(1, self.P7+1)]), 1024),
76 | nn.BatchNorm1d(1024),
77 | nn.Linear(1024, 256),
78 | nn.BatchNorm1d(256),
79 | nn.Linear(256, 64),
80 | nn.BatchNorm1d(64), nn.ReLU())
81 |
82 | if self.use_bn_end:
83 | self.regr6 = nn.Sequential(nn.Linear(64, 1), nn.BatchNorm1d(1))
84 | self.regr7 = nn.Sequential(nn.Linear(64, 1), nn.BatchNorm1d(1))
85 | self.regression = nn.Sequential(nn.Linear(64 * 2, 1), nn.BatchNorm1d(1))
86 | else:
87 | self.regr6 = nn.Linear(64, 1)
88 | self.regr7 = nn.Linear(64, 1)
89 | self.regression = nn.Linear(64 * 2, 1)
90 |
91 | def extract_features(self, x):
92 | f, pq = [], []
93 | for ii, model in enumerate(self.features):
94 | x = model(x)
95 | if ii == self.id1:
96 | x6 = SPSP(x, P=self.P6, method=self.pool)
97 | x6 = self.dr6(x6)
98 | f.append(x6)
99 | pq.append(self.regr6(x6))
100 | if ii == self.id2:
101 | x7 = SPSP(x, P=self.P7, method=self.pool)
102 | x7 = self.dr7(x7)
103 | f.append(x7)
104 | pq.append(self.regr7(x7))
105 |
106 | f = torch.cat(f, dim=1)
107 |
108 | return f, pq
109 |
110 | def forward(self, x):
111 | f, pq = self.extract_features(x)
112 | s = self.regression(f)
113 | pq.append(s)
114 |
115 | return pq, s
116 |
117 | if __name__ == "__main__":
118 | x = torch.rand((1,3,224,224))
119 | net = LinearityIQA()
120 | net.train(False)
121 | # print(net.dr6)
122 | # print(net.dr7)
123 | y, pred = net(x)
124 | print(pred)
125 |
--------------------------------------------------------------------------------
/dataloaders/dataloader_LQ.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 | import folders.folders_LQ as folders
4 |
5 | class DataLoader(object):
6 | """Dataset class for IQA databases"""
7 |
8 | def __init__(self, dataset, path, img_indx, patch_size, patch_num, batch_size=1, istrain=True, self_patch_num=1):
9 |
10 | self.batch_size = batch_size
11 | self.istrain = istrain
12 |
13 | if (dataset == 'live') | (dataset == 'csiq') | (dataset == 'tid2013') | (dataset == 'livec') | (dataset == 'kadid10k'):
14 | # Train transforms
15 | if istrain:
16 | transforms = torchvision.transforms.Compose([
17 | torchvision.transforms.RandomCrop(size=patch_size),
18 | torchvision.transforms.RandomHorizontalFlip(),
19 | torchvision.transforms.RandomVerticalFlip(),
20 | torchvision.transforms.RandomRotation(degrees=180),
21 | torchvision.transforms.ToTensor(),
22 | torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406),
23 | std=(0.229, 0.224, 0.225))
24 | ])
25 | # Test transforms
26 | else:
27 | transforms = torchvision.transforms.Compose([
28 | torchvision.transforms.RandomCrop(size=patch_size),
29 | torchvision.transforms.ToTensor(),
30 | torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406),
31 | std=(0.229, 0.224, 0.225))
32 | ])
33 | elif dataset == 'koniq-10k':
34 | if istrain:
35 | transforms = torchvision.transforms.Compose([
36 | torchvision.transforms.RandomCrop(size=patch_size),
37 | torchvision.transforms.RandomHorizontalFlip(),
38 | torchvision.transforms.RandomVerticalFlip(),
39 | torchvision.transforms.RandomRotation(degrees=180),
40 | torchvision.transforms.ToTensor(),
41 | torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406),
42 | std=(0.229, 0.224, 0.225))])
43 | else:
44 | transforms = torchvision.transforms.Compose([
45 | torchvision.transforms.RandomCrop(size=patch_size),
46 | torchvision.transforms.ToTensor(),
47 | torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406),
48 | std=(0.229, 0.224, 0.225))])
49 | elif dataset == 'bid':
50 | if istrain:
51 | transforms = torchvision.transforms.Compose([
52 | torchvision.transforms.Resize((512, 512)),
53 | torchvision.transforms.RandomCrop(size=patch_size),
54 | torchvision.transforms.RandomHorizontalFlip(),
55 | torchvision.transforms.RandomVerticalFlip(),
56 | torchvision.transforms.RandomRotation(degrees=180),
57 | torchvision.transforms.ToTensor(),
58 | torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406),
59 | std=(0.229, 0.224, 0.225))])
60 | else:
61 | transforms = torchvision.transforms.Compose([
62 | torchvision.transforms.Resize((512, 512)),
63 | torchvision.transforms.RandomCrop(size=patch_size),
64 | torchvision.transforms.ToTensor(),
65 | torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406),
66 | std=(0.229, 0.224, 0.225))])
67 | else:
68 | transforms = torchvision.transforms.Compose([
69 | torchvision.transforms.ToTensor(),
70 | torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406),
71 | std=(0.229, 0.224, 0.225))])
72 |
73 | if dataset == 'live':
74 | self.data = folders.LIVEFolder(
75 | root=path, index=img_indx, transform=transforms, patch_num=patch_num, patch_size = patch_size, self_patch_num=self_patch_num)
76 | elif dataset == 'csiq':
77 | self.data = folders.CSIQFolder(
78 | root=path, index=img_indx, transform=transforms, patch_num=patch_num, patch_size = patch_size, self_patch_num=self_patch_num)
79 | elif dataset == 'kadid10k':
80 | self.data = folders.Kadid10kFolder(
81 | root=path, index=img_indx, transform=transforms, patch_num=patch_num, patch_size = patch_size, self_patch_num=self_patch_num)
82 | elif dataset == 'tid2013':
83 | self.data = folders.TID2013Folder(
84 | root=path, index=img_indx, transform=transforms, patch_num=patch_num, patch_size = patch_size, self_patch_num=self_patch_num)
85 | elif dataset == 'koniq-10k':
86 | self.data = folders.Koniq_10kFolder(
87 | root=path, index=img_indx, transform=transforms, patch_num=patch_num, patch_size = patch_size, self_patch_num=self_patch_num)
88 | elif dataset == 'livec':
89 | self.data = folders.LIVEChallengeFolder(
90 | root=path, index=img_indx, transform=transforms, patch_num=patch_num, patch_size = patch_size, self_patch_num=self_patch_num)
91 |
92 | def get_dataloader(self):
93 | if self.istrain:
94 | dataloader = torch.utils.data.DataLoader(
95 | self.data, batch_size=self.batch_size, shuffle=True)
96 | else:
97 | dataloader = torch.utils.data.DataLoader(
98 | self.data, batch_size=1, shuffle=False)
99 | return dataloader
100 |
--------------------------------------------------------------------------------
/option_train_DistillationIQA_FR.py:
--------------------------------------------------------------------------------
1 | # import template
2 | import argparse
3 | import os
4 |
5 | """
6 | Configuration file
7 | """
8 | def check_args(args, rank=0):
9 | if rank == 0:
10 | with open(args.setting_file, 'w') as opt_file:
11 | opt_file.write('------------ Options -------------\n')
12 | print('------------ Options -------------')
13 | for k in args.__dict__:
14 | v = args.__dict__[k]
15 | opt_file.write('%s: %s\n' % (str(k), str(v)))
16 | print('%s: %s' % (str(k), str(v)))
17 | opt_file.write('-------------- End ----------------\n')
18 | print('------------ End -------------')
19 |
20 | return args
21 |
22 | def str2bool(v):
23 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
24 | return True
25 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
26 | return False
27 | else:
28 | raise argparse.ArgumentTypeError('Unsupported value encountered.')
29 |
30 | def set_args():
31 | parser = argparse.ArgumentParser()
32 | parser.add_argument('--gpu_ids', type=str, default='0')
33 | parser.add_argument('--test_dataset', type=str, default='live', help='Support datasets: pipal|livec|koniq-10k|bid|live|csiq|tid2013|kadid10k')
34 | parser.add_argument('--train_dataset', type=str, default='kadid10k', help='Support datasets: pipal|livec|koniq-10k|bid|live|csiq|tid2013|kadid10k')
35 | parser.add_argument('--train_patch_num', type=int, default=1, help='Number of sample patches from training image')
36 | parser.add_argument('--test_patch_num', type=int, default=1, help='Number of sample patches from testing image')
37 | parser.add_argument('--lr', dest='lr', type=float, default=2e-5, help='Learning rate')
38 | parser.add_argument('--weight_decay', type=float, default=5e-4, help='Weight decay')
39 | parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
40 | parser.add_argument('--epochs', type=int, default=100, help='Epochs for training')
41 | parser.add_argument('--patch_size', type=int, default=224, help='Crop size for training & testing image patches')
42 | parser.add_argument('--self_patch_num', type=int, default=10, help='number of training & testing image self patches')
43 | parser.add_argument('--train_test_num', type=int, default=1, help='Train-test times')
44 | parser.add_argument('--update_opt_epoch', type=int, default=300)
45 |
46 | #Ref Img
47 | parser.add_argument('--use_refHQ', type=str2bool, default=True)
48 | parser.add_argument('--ref_folder_paths', type=str, default='./dataset/DIV2K_ref.tar.gz')
49 | parser.add_argument('--distillation_layer', type=int, default=18, help='last xth layers of HQ-MLP for distillation')
50 |
51 | parser.add_argument('--net_print', type=int, default=2000)
52 | parser.add_argument('--setting_file', type=str, default='setting.txt')
53 | parser.add_argument('--checkpoint_dir', type=str, default='./checkpoint_FRIQA_teacher/')
54 |
55 | parser.add_argument('--use_fitting_prcc_srcc', type=str2bool, default=True)
56 | parser.add_argument('--print_netC', type=str2bool, default=False)
57 |
58 | parser.add_argument('--teacherNet_model_path', type=str, default=None, help='./model_zoo/FR_teacher_cross_dataset.pth')
59 |
60 | args = parser.parse_args()
61 | #Dataset
62 | args.setting_file = os.path.join(args.checkpoint_dir, args.setting_file)
63 | if not os.path.exists('./dataset/'):
64 | os.mkdir('./dataset/')
65 |
66 | folder_path = {
67 | 'live': './dataset/LIVE/',
68 | 'csiq': './dataset/CSIQ/',
69 | 'tid2013': './dataset/TID2013/',
70 | 'koniq-10k': './dataset/koniq-10k/',
71 | }
72 | ref_dataset_path = './dataset/DIV2K_ref/'
73 | args.ref_train_dataset_path = ref_dataset_path + 'train_HR/'
74 | args.ref_test_dataset_path = ref_dataset_path + 'val_HR/'
75 |
76 | #checkpoint files
77 | args.model_checkpoint_dir = args.checkpoint_dir + 'models/'
78 | args.result_checkpoint_dir = args.checkpoint_dir + 'results/'
79 | args.log_checkpoint_dir = args.checkpoint_dir + 'log/'
80 |
81 | if os.path.exists(args.checkpoint_dir) and os.path.isfile(args.checkpoint_dir):
82 | raise IOError('Required dst path {} as a directory for checkpoint saving, got a file'.format(
83 | args.checkpoint_dir))
84 | elif not os.path.exists(args.checkpoint_dir):
85 | os.makedirs(args.checkpoint_dir)
86 | print('%s created successfully!'%args.checkpoint_dir)
87 |
88 | if os.path.exists(args.model_checkpoint_dir) and os.path.isfile(args.model_checkpoint_dir):
89 | raise IOError('Required dst path {} as a directory for checkpoint model saving, got a file'.format(
90 | args.model_checkpoint_dir))
91 | elif not os.path.exists(args.model_checkpoint_dir):
92 | os.makedirs(args.model_checkpoint_dir)
93 | print('%s created successfully!'%args.model_checkpoint_dir)
94 |
95 | if os.path.exists(args.result_checkpoint_dir) and os.path.isfile(args.result_checkpoint_dir):
96 | raise IOError('Required dst path {} as a directory for checkpoint results saving, got a file'.format(
97 | args.result_checkpoint_dir))
98 | elif not os.path.exists(args.result_checkpoint_dir):
99 | os.makedirs(args.result_checkpoint_dir)
100 | print('%s created successfully!'%args.result_checkpoint_dir)
101 |
102 | if os.path.exists(args.log_checkpoint_dir) and os.path.isfile(args.log_checkpoint_dir):
103 | raise IOError('Required dst path {} as a directory for checkpoint log saving, got a file'.format(
104 | args.log_checkpoint_dir))
105 | elif not os.path.exists(args.log_checkpoint_dir):
106 | os.makedirs(args.log_checkpoint_dir)
107 | print('%s created successfully!'%args.log_checkpoint_dir)
108 |
109 | return args
110 |
111 | if __name__ == "__main__":
112 | args = set_args()
113 |
114 |
--------------------------------------------------------------------------------
/option_train_DistillationIQA.py:
--------------------------------------------------------------------------------
1 | # import template
2 | import argparse
3 | import os
4 | """
5 | Configuration file
6 | """
7 | def check_args(args, rank=0):
8 | if rank == 0:
9 | with open(args.setting_file, 'w') as opt_file:
10 | opt_file.write('------------ Options -------------\n')
11 | print('------------ Options -------------')
12 | for k in args.__dict__:
13 | v = args.__dict__[k]
14 | opt_file.write('%s: %s\n' % (str(k), str(v)))
15 | print('%s: %s' % (str(k), str(v)))
16 | opt_file.write('-------------- End ----------------\n')
17 | print('------------ End -------------')
18 |
19 | return args
20 |
21 | def str2bool(v):
22 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
23 | return True
24 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
25 | return False
26 | else:
27 | raise argparse.ArgumentTypeError('Unsupported value encountered.')
28 |
29 | def set_args():
30 | parser = argparse.ArgumentParser()
31 | parser.add_argument('--gpu_ids', type=str, default='0')
32 | parser.add_argument('--test_dataset', type=str, default='live', help='Support datasets: pipal|livec|koniq-10k|bid|live|csiq|tid2013|kadid10k')
33 | parser.add_argument('--train_dataset', type=str, default='kadid10k', help='Support datasets: pipal|livec|koniq-10k|bid|live|csiq|tid2013|kadid10k')
34 | parser.add_argument('--train_patch_num', type=int, default=1, help='Number of sample patches from training image')
35 | parser.add_argument('--test_patch_num', type=int, default=1, help='Number of sample patches from testing image')
36 | parser.add_argument('--lr', dest='lr', type=float, default=2e-5, help='Learning rate')
37 | parser.add_argument('--weight_decay', type=float, default=5e-4, help='Weight decay')
38 | parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
39 | parser.add_argument('--epochs', type=int, default=100, help='Epochs for training')
40 | parser.add_argument('--patch_size', type=int, default=224, help='Crop size for training & testing image patches')
41 | parser.add_argument('--self_patch_num', type=int, default=10, help='number of training & testing image self patches')
42 | parser.add_argument('--train_test_num', type=int, default=1, help='Train-test times')
43 | parser.add_argument('--update_opt_epoch', type=int, default=30)
44 |
45 | #Ref Img
46 | parser.add_argument('--use_refHQ', type=str2bool, default=True)
47 | parser.add_argument('--distillation_layer', type=int, default=18, help='last xth layers of HQ-MLP for distillation')
48 |
49 | parser.add_argument('--net_print', type=int, default=2000)
50 | parser.add_argument('--setting_file', type=str, default='setting.txt')
51 | parser.add_argument('--checkpoint_dir', type=str, default='./checkpoint_DistillationIQA/')
52 |
53 | parser.add_argument('--use_fitting_prcc_srcc', type=str2bool, default=True)
54 | parser.add_argument('--print_netC', type=str2bool, default=False)
55 |
56 | parser.add_argument('--teacherNet_model_path', type=str, default='./model_zoo/FR_teacher_cross_dataset.pth')
57 | parser.add_argument('--studentNet_model_path', type=str, default=None, help='./model_zoo/NAR_student_cross_dataset.pth')
58 |
59 | #distillation
60 | parser.add_argument('--distillation_loss', type=str, default='l1', help='mse|l1|kldiv')
61 |
62 | args = parser.parse_args()
63 | #Dataset
64 | args.setting_file = os.path.join(args.checkpoint_dir, args.setting_file)
65 | if not os.path.exists('./dataset/'):
66 | os.mkdir('./dataset/')
67 |
68 | folder_path = {
69 | 'live': './dataset/LIVE/',
70 | 'csiq': './dataset/CSIQ/',
71 | 'tid2013': './dataset/TID2013/',
72 | 'koniq-10k': './dataset/koniq-10k/',
73 | }
74 |
75 | ref_dataset_path = './dataset/DIV2K_ref/'
76 | args.ref_train_dataset_path = ref_dataset_path + 'train_HR/'
77 | args.ref_test_dataset_path = ref_dataset_path + 'val_HR/'
78 |
79 | #checkpoint files
80 | args.model_checkpoint_dir = args.checkpoint_dir + 'models/'
81 | args.result_checkpoint_dir = args.checkpoint_dir + 'results/'
82 | args.log_checkpoint_dir = args.checkpoint_dir + 'log/'
83 |
84 | if os.path.exists(args.checkpoint_dir) and os.path.isfile(args.checkpoint_dir):
85 | raise IOError('Required dst path {} as a directory for checkpoint saving, got a file'.format(
86 | args.checkpoint_dir))
87 | elif not os.path.exists(args.checkpoint_dir):
88 | os.makedirs(args.checkpoint_dir)
89 | print('%s created successfully!'%args.checkpoint_dir)
90 |
91 | if os.path.exists(args.model_checkpoint_dir) and os.path.isfile(args.model_checkpoint_dir):
92 | raise IOError('Required dst path {} as a directory for checkpoint model saving, got a file'.format(
93 | args.model_checkpoint_dir))
94 | elif not os.path.exists(args.model_checkpoint_dir):
95 | os.makedirs(args.model_checkpoint_dir)
96 | print('%s created successfully!'%args.model_checkpoint_dir)
97 |
98 | if os.path.exists(args.result_checkpoint_dir) and os.path.isfile(args.result_checkpoint_dir):
99 | raise IOError('Required dst path {} as a directory for checkpoint results saving, got a file'.format(
100 | args.result_checkpoint_dir))
101 | elif not os.path.exists(args.result_checkpoint_dir):
102 | os.makedirs(args.result_checkpoint_dir)
103 | print('%s created successfully!'%args.result_checkpoint_dir)
104 |
105 | if os.path.exists(args.log_checkpoint_dir) and os.path.isfile(args.log_checkpoint_dir):
106 | raise IOError('Required dst path {} as a directory for checkpoint log saving, got a file'.format(
107 | args.log_checkpoint_dir))
108 | elif not os.path.exists(args.log_checkpoint_dir):
109 | os.makedirs(args.log_checkpoint_dir)
110 | print('%s created successfully!'%args.log_checkpoint_dir)
111 |
112 | return args
113 |
114 | if __name__ == "__main__":
115 | args = set_args()
116 |
117 |
--------------------------------------------------------------------------------
/models/WaDIQaM.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import torch.nn as nn
4 |
5 | class WaDIQaM_FR(nn.Module):
6 | """
7 | (Wa)DIQaM-FR Model
8 | """
9 | def __init__(self, weighted_average=True):
10 | """
11 | :param weighted_average: weighted average or not?
12 | """
13 | super(WaDIQaM_FR, self).__init__()
14 | self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
15 | self.conv2 = nn.Conv2d(32, 32, 3, padding=1)
16 | self.conv3 = nn.Conv2d(32, 64, 3, padding=1)
17 | self.conv4 = nn.Conv2d(64, 64, 3, padding=1)
18 | self.conv5 = nn.Conv2d(64, 128, 3, padding=1)
19 | self.conv6 = nn.Conv2d(128, 128, 3, padding=1)
20 | self.conv7 = nn.Conv2d(128, 256, 3, padding=1)
21 | self.conv8 = nn.Conv2d(256, 256, 3, padding=1)
22 | self.conv9 = nn.Conv2d(256, 512, 3, padding=1)
23 | self.conv10 = nn.Conv2d(512, 512, 3, padding=1)
24 | self.fc1_q = nn.Linear(512*3, 512)
25 | self.fc2_q = nn.Linear(512, 1)
26 | self.fc1_w = nn.Linear(512*3, 512)
27 | self.fc2_w = nn.Linear(512, 1)
28 | self.dropout = nn.Dropout()
29 | self.weighted_average = weighted_average
30 |
31 | def extract_features(self, x):
32 | """
33 | feature extraction
34 | :param x: the input image
35 | :return: the output feature
36 | """
37 | h = F.relu(self.conv1(x))
38 | h = F.relu(self.conv2(h))
39 | h = F.max_pool2d(h, 2)
40 |
41 | h = F.relu(self.conv3(h))
42 | h = F.relu(self.conv4(h))
43 | h = F.max_pool2d(h, 2)
44 |
45 | h = F.relu(self.conv5(h))
46 | h = F.relu(self.conv6(h))
47 | h = F.max_pool2d(h, 2)
48 |
49 | h = F.relu(self.conv7(h))
50 | h = F.relu(self.conv8(h))
51 | h = F.max_pool2d(h, 2)
52 |
53 | h = F.relu(self.conv9(h))
54 | h = F.relu(self.conv10(h))
55 | h = F.max_pool2d(h, 2)
56 |
57 | h = h.view(-1, 512)
58 |
59 | return h
60 |
61 | def forward(self, x, x_ref):
62 | """
63 | :param x: distorted patches of images
64 | :param x_ref: reference patches of images
65 | :return: quality of images/patches
66 | """
67 | batch_size = x.size(0)
68 | n_patches = x.size(1)
69 | if self.weighted_average:
70 | q = torch.ones((batch_size, 1), device=x.device)
71 | else:
72 | q = torch.ones((batch_size * n_patches, 1), device=x.device)
73 |
74 | for i in range(batch_size):
75 |
76 | h = self.extract_features(x[i])
77 | h_ref = self.extract_features(x_ref[i])
78 | h = torch.cat((h - h_ref, h, h_ref), 1)
79 |
80 | h_ = h # save intermediate features
81 |
82 | h = F.relu(self.fc1_q(h_))
83 | h = self.dropout(h)
84 | h = self.fc2_q(h)
85 |
86 | if self.weighted_average:
87 | w = F.relu(self.fc1_w(h_))
88 | w = self.dropout(w)
89 | w = F.relu(self.fc2_w(w)) + 0.000001 # small constant
90 | q[i] = torch.sum(h * w) / torch.sum(w)
91 | else:
92 | q[i*n_patches:(i+1)*n_patches] = h
93 |
94 | return q
95 |
96 | class WaDIQaM_NR(nn.Module):
97 | """
98 | (Wa)DIQaM-NR-NR Model
99 | """
100 | def __init__(self, weighted_average=True):
101 | """
102 | :param weighted_average: weighted average or not?
103 | """
104 | super(WaDIQaM_NR, self).__init__()
105 | self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
106 | self.conv2 = nn.Conv2d(32, 32, 3, padding=1)
107 | self.conv3 = nn.Conv2d(32, 64, 3, padding=1)
108 | self.conv4 = nn.Conv2d(64, 64, 3, padding=1)
109 | self.conv5 = nn.Conv2d(64, 128, 3, padding=1)
110 | self.conv6 = nn.Conv2d(128, 128, 3, padding=1)
111 | self.conv7 = nn.Conv2d(128, 256, 3, padding=1)
112 | self.conv8 = nn.Conv2d(256, 256, 3, padding=1)
113 | self.conv9 = nn.Conv2d(256, 512, 3, padding=1)
114 | self.conv10 = nn.Conv2d(512, 512, 3, padding=1)
115 | self.fc1q_nr = nn.Linear(512, 512)
116 | self.fc2q_nr = nn.Linear(512, 1)
117 | self.fc1w_nr = nn.Linear(512, 512)
118 | self.fc2w_nr = nn.Linear(512, 1)
119 | self.dropout = nn.Dropout()
120 | self.weighted_average = weighted_average
121 |
122 | def extract_features(self, x):
123 | """
124 | feature extraction
125 | :param x: the input image
126 | :return: the output feature
127 | """
128 | h = F.relu(self.conv1(x))
129 | h = F.relu(self.conv2(h))
130 | h = F.max_pool2d(h, 2)
131 |
132 | h = F.relu(self.conv3(h))
133 | h = F.relu(self.conv4(h))
134 | h = F.max_pool2d(h, 2)
135 |
136 | h = F.relu(self.conv5(h))
137 | h = F.relu(self.conv6(h))
138 | h = F.max_pool2d(h, 2)
139 |
140 | h = F.relu(self.conv7(h))
141 | h = F.relu(self.conv8(h))
142 | h = F.max_pool2d(h, 2)
143 |
144 | h = F.relu(self.conv9(h))
145 | h = F.relu(self.conv10(h))
146 | h = F.max_pool2d(h, 2)
147 |
148 | h = h.view(-1,512)
149 |
150 | return h
151 |
152 | def forward(self, x):
153 | """
154 | :param x: distorted patches of images
155 | :return: quality of images/patches
156 | """
157 | batch_size = x.size(0)
158 | n_patches = x.size(1)
159 | if self.weighted_average:
160 | q = torch.ones((batch_size, 1), device=x.device)
161 | else:
162 | q = torch.ones((batch_size * n_patches, 1), device=x.device)
163 |
164 | for i in range(batch_size):
165 |
166 | h = self.extract_features(x[i])
167 |
168 | h_ = h # save intermediate features
169 |
170 | h = F.relu(self.fc1q_nr(h_))
171 | h = self.dropout(h)
172 | h = self.fc2q_nr(h)
173 |
174 | if self.weighted_average:
175 | w = F.relu(self.fc1w_nr(h_))
176 | w = self.dropout(w)
177 | w = F.relu(self.fc2w_nr(w)) + 0.000001 # small constant
178 | q[i] = torch.sum(h * w) / torch.sum(w)
179 | else:
180 | q[i * n_patches:(i + 1) * n_patches] = h
181 |
182 | return q
183 |
--------------------------------------------------------------------------------
/dataloaders/dataloader_LQ_HQ_diff_content_HQ.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 | import folders.folders_LQ_HQ_diff_content_HQ as folders
4 |
5 | class DataLoader(object):
6 | """Dataset class for IQA databases"""
7 |
8 | def __init__(self, dataset, path, ref_path, img_indx, patch_size, patch_num, batch_size=1, istrain=True, self_patch_num=10, use_HQref = True):
9 |
10 | self.batch_size = batch_size
11 | self.istrain = istrain
12 |
13 | if (dataset == 'live') | (dataset == 'csiq') | (dataset == 'tid2013') | (dataset == 'livec') | (dataset == 'kadid10k'):
14 | # Train transforms
15 | if istrain:
16 | HQ_diff_content_transform = torchvision.transforms.Compose([
17 | torchvision.transforms.RandomCrop(size=patch_size),
18 | torchvision.transforms.RandomHorizontalFlip(),
19 | torchvision.transforms.RandomVerticalFlip(),
20 | torchvision.transforms.RandomRotation(degrees=180),
21 | torchvision.transforms.ToTensor(),
22 | torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406),
23 | std=(0.229, 0.224, 0.225))
24 | ])
25 | # Test transforms
26 | else:
27 | HQ_diff_content_transform = torchvision.transforms.Compose([
28 | torchvision.transforms.RandomCrop(size=patch_size),
29 | torchvision.transforms.ToTensor(),
30 | torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406),
31 | std=(0.229, 0.224, 0.225))
32 | ])
33 | elif dataset == 'koniq-10k':
34 | if istrain:
35 | HQ_diff_content_transform = torchvision.transforms.Compose([
36 | torchvision.transforms.RandomCrop(size=patch_size),
37 | torchvision.transforms.RandomHorizontalFlip(),
38 | torchvision.transforms.RandomVerticalFlip(),
39 | torchvision.transforms.RandomRotation(degrees=180),
40 | torchvision.transforms.ToTensor(),
41 | torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406),
42 | std=(0.229, 0.224, 0.225))])
43 | else:
44 | HQ_diff_content_transform = torchvision.transforms.Compose([
45 | torchvision.transforms.RandomCrop(size=patch_size),
46 | torchvision.transforms.ToTensor(),
47 | torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406),
48 | std=(0.229, 0.224, 0.225))])
49 | elif dataset == 'bid':
50 | if istrain:
51 | HQ_diff_content_transform = torchvision.transforms.Compose([
52 | torchvision.transforms.Resize((512, 512)),
53 | torchvision.transforms.RandomCrop(size=patch_size),
54 | torchvision.transforms.RandomHorizontalFlip(),
55 | torchvision.transforms.RandomVerticalFlip(),
56 | torchvision.transforms.RandomRotation(degrees=180),
57 | torchvision.transforms.ToTensor(),
58 | torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406),
59 | std=(0.229, 0.224, 0.225))])
60 | else:
61 | HQ_diff_content_transform = torchvision.transforms.Compose([
62 | torchvision.transforms.Resize((512, 512)),
63 | torchvision.transforms.RandomCrop(size=patch_size),
64 | torchvision.transforms.ToTensor(),
65 | torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406),
66 | std=(0.229, 0.224, 0.225))])
67 | else:
68 | HQ_diff_content_transform = torchvision.transforms.Compose([
69 | torchvision.transforms.ToTensor(),
70 | torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406),
71 | std=(0.229, 0.224, 0.225))])
72 | transforms = torchvision.transforms.Compose([
73 | torchvision.transforms.ToTensor(),
74 | torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406),
75 | std=(0.229, 0.224, 0.225))])
76 |
77 | if dataset == 'live':
78 | self.data = folders.LIVEFolder(
79 | root=path, HQ_diff_content_root=ref_path, index=img_indx, transform=transforms, HQ_diff_content_transform=HQ_diff_content_transform, patch_num=patch_num, patch_size = patch_size, self_patch_num=self_patch_num)
80 | elif dataset == 'csiq':
81 | self.data = folders.CSIQFolder(
82 | root=path, HQ_diff_content_root=ref_path, index=img_indx, transform=transforms, HQ_diff_content_transform=HQ_diff_content_transform, patch_num=patch_num, patch_size = patch_size, self_patch_num=self_patch_num)
83 | elif dataset == 'kadid10k':
84 | self.data = folders.Kadid10kFolder(
85 | root=path, HQ_diff_content_root=ref_path, index=img_indx, transform=transforms, HQ_diff_content_transform=HQ_diff_content_transform, patch_num=patch_num, patch_size = patch_size, self_patch_num=self_patch_num)
86 | elif dataset == 'tid2013':
87 | self.data = folders.TID2013Folder(
88 | root=path, HQ_diff_content_root=ref_path, index=img_indx, transform=transforms, HQ_diff_content_transform=HQ_diff_content_transform, patch_num=patch_num, patch_size = patch_size, self_patch_num=self_patch_num)
89 | elif dataset == 'koniq-10k':
90 | self.data = folders.Koniq_10kFolder(
91 | root=path, HQ_diff_content_root=ref_path, index=img_indx, transform=transforms, HQ_diff_content_transform=HQ_diff_content_transform, patch_num=patch_num, patch_size = patch_size, self_patch_num=self_patch_num)
92 | elif dataset == 'livec':
93 | self.data = folders.LIVEChallengeFolder(
94 | root=path, HQ_diff_content_root=ref_path, index=img_indx, transform=transforms, HQ_diff_content_transform=HQ_diff_content_transform, patch_num=patch_num, patch_size = patch_size, self_patch_num=self_patch_num)
95 |
96 | def get_dataloader(self):
97 | if self.istrain:
98 | dataloader = torch.utils.data.DataLoader(
99 | self.data, batch_size=self.batch_size, shuffle=True)
100 | else:
101 | dataloader = torch.utils.data.DataLoader(
102 | self.data, batch_size=self.batch_size, shuffle=False)
103 | return dataloader
104 |
--------------------------------------------------------------------------------
/test_DistillationIQA.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | from dataloaders.dataloader_LQ_HQ_diff_content_HQ import DataLoader
4 | from option_train_DistillationIQA import set_args, check_args
5 | from scipy import stats
6 | import numpy as np
7 | from tools.nonlinear_convert import convert_obj_score
8 | from models.DistillationIQA import DistillationIQANet
9 |
10 | img_num = {
11 | 'kadid10k': list(range(0,10125)),
12 | 'live': list(range(0, 29)),#ref HR image
13 | 'csiq': list(range(0, 30)),#ref HR image
14 | 'tid2013': list(range(0, 25)),
15 | 'livec': list(range(0, 1162)),# no-ref image
16 | 'koniq-10k': list(range(0, 10073)),# no-ref image
17 | 'bid': list(range(0, 586)),# no-ref image
18 | }
19 | folder_path = {
20 | 'pipal':'./dataset/PIPAL',
21 | 'live': './dataset/LIVE/',
22 | 'csiq': './dataset/CSIQ/',
23 | 'tid2013': './dataset/TID2013/',
24 | 'livec': './dataset/LIVEC/',
25 | 'koniq-10k': './dataset/koniq-10k/',
26 | 'bid': './dataset/BID/',
27 | 'kadid10k':'./dataset/kadid10k/'
28 | }
29 |
30 |
31 | class DistillationIQASolver(object):
32 | def __init__(self, config):
33 | self.config = config
34 | self.device = torch.device('cuda' if config.gpu_ids is not None else 'cpu')
35 | self.txt_log_path = os.path.join(config.log_checkpoint_dir,'log.txt')
36 | self.config.teacherNet_model_path = './model_zoo/FR_teacher_cross_dataset.pth'
37 | self.config.studentNet_model_path = './model_zoo/NAR_student_cross_dataset.pth'
38 | with open(self.txt_log_path,"w+") as f:
39 | f.close()
40 |
41 | #model
42 | self.teacherNet = DistillationIQANet(self_patch_num=config.self_patch_num, distillation_layer=config.distillation_layer)
43 | if config.teacherNet_model_path:
44 | self.teacherNet._load_state_dict(torch.load(config.teacherNet_model_path))
45 | self.teacherNet = self.teacherNet.to(self.device)
46 | self.teacherNet.train(False)
47 | self.studentNet = DistillationIQANet(self_patch_num=config.self_patch_num, distillation_layer=config.distillation_layer)
48 | if config.studentNet_model_path:
49 | self.studentNet._load_state_dict(torch.load(config.studentNet_model_path))
50 | self.studentNet = self.studentNet.to(self.device)
51 | self.studentNet.train(True)
52 |
53 | #data
54 | test_loader_LIVE = DataLoader('live', folder_path['live'], config.ref_test_dataset_path, img_num['live'], config.patch_size, config.test_patch_num, istrain=False, self_patch_num=config.self_patch_num)
55 | test_loader_CSIQ = DataLoader('csiq', folder_path['csiq'], config.ref_test_dataset_path, img_num['csiq'], config.patch_size, config.test_patch_num, istrain=False, self_patch_num=config.self_patch_num)
56 | test_loader_TID = DataLoader('tid2013', folder_path['tid2013'], config.ref_test_dataset_path, img_num['tid2013'], config.patch_size, config.test_patch_num, istrain=False, self_patch_num=config.self_patch_num)
57 | test_loader_Koniq = DataLoader('koniq-10k', folder_path['koniq-10k'], config.ref_test_dataset_path, img_num['koniq-10k'], config.patch_size, config.test_patch_num, istrain=False, self_patch_num=config.self_patch_num)
58 |
59 | self.test_data_LIVE = test_loader_LIVE.get_dataloader()
60 | self.test_data_CSIQ = test_loader_CSIQ.get_dataloader()
61 | self.test_data_TID = test_loader_TID.get_dataloader()
62 | self.test_data_Koniq = test_loader_Koniq.get_dataloader()
63 |
64 |
65 | def test(self, test_data):
66 | self.studentNet.train(False)
67 | test_pred_scores, test_gt_scores = [], []
68 | for LQ_patches, _, ref_patches, label in test_data:
69 | LQ_patches, ref_patches, label = LQ_patches.to(self.device), ref_patches.to(self.device), label.to(self.device)
70 | with torch.no_grad():
71 | _, _, pred = self.studentNet(LQ_patches, ref_patches)
72 | test_pred_scores.append(float(pred.item()))
73 | test_gt_scores = test_gt_scores + label.cpu().tolist()
74 | if self.config.use_fitting_prcc_srcc:
75 | fitting_pred_scores = convert_obj_score(test_pred_scores, test_gt_scores)
76 | test_pred_scores = np.mean(np.reshape(np.array(test_pred_scores), (-1, self.config.test_patch_num)), axis=1)
77 | test_gt_scores = np.mean(np.reshape(np.array(test_gt_scores), (-1, self.config.test_patch_num)), axis=1)
78 | test_srcc, _ = stats.spearmanr(test_pred_scores, test_gt_scores)
79 | if self.config.use_fitting_prcc_srcc:
80 | test_plcc, _ = stats.pearsonr(fitting_pred_scores, test_gt_scores)
81 | else:
82 | test_plcc, _ = stats.pearsonr(test_pred_scores, test_gt_scores)
83 | test_krcc, _ = stats.stats.kendalltau(test_pred_scores, test_gt_scores)
84 | test_srcc, test_plcc, test_krcc = abs(test_srcc), abs(test_plcc), abs(test_krcc)
85 | self.studentNet.train(True)
86 | return test_srcc, test_plcc, test_krcc
87 |
88 | if __name__ == "__main__":
89 | config = set_args()
90 | config = check_args(config)
91 | solver = DistillationIQASolver(config=config)
92 | fold_10_test_LIVE_srcc, fold_10_test_LIVE_plcc, fold_10_test_LIVE_krcc = [], [], []
93 | fold_10_test_CSIQ_srcc, fold_10_test_CSIQ_plcc, fold_10_test_CSIQ_krcc = [], [], []
94 | fold_10_test_TID_srcc, fold_10_test_TID_plcc, fold_10_test_TID_krcc = [], [], []
95 | fold_10_test_Koniq_srcc, fold_10_test_Koniq_plcc, fold_10_test_Koniq_krcc = [], [], []
96 |
97 | for i in range(10):
98 |
99 | test_LIVE_srcc, test_LIVE_plcc, test_LIVE_krcc = solver.test(solver.test_data_LIVE)
100 | print('round{} Dataset:LIVE Test_SRCC:{} Test_PLCC:{} TEST_KRCC:{}\n'.format(i, test_LIVE_srcc, test_LIVE_plcc, test_LIVE_krcc))
101 | fold_10_test_LIVE_srcc.append(test_LIVE_srcc)
102 | fold_10_test_LIVE_plcc.append(test_LIVE_plcc)
103 | fold_10_test_LIVE_krcc.append(test_LIVE_krcc)
104 |
105 | test_CSIQ_srcc, test_CSIQ_plcc, test_CSIQ_krcc = solver.test(solver.test_data_CSIQ)
106 | print('round{} Dataset:CSIQ Test_SRCC:{} Test_PLCC:{} TEST_KRCC:{}\n'.format(i, test_CSIQ_srcc, test_CSIQ_plcc, test_CSIQ_krcc))
107 | fold_10_test_CSIQ_srcc.append(test_CSIQ_srcc)
108 | fold_10_test_CSIQ_plcc.append(test_CSIQ_plcc)
109 | fold_10_test_CSIQ_krcc.append(test_CSIQ_krcc)
110 |
111 | test_TID_srcc, test_TID_plcc, test_TID_krcc = solver.test(solver.test_data_TID)
112 | print('round{} Dataset:TID Test_SRCC:{} Test_PLCC:{} TEST_KRCC:{}\n'.format(i, test_TID_srcc, test_TID_plcc, test_TID_krcc))
113 | fold_10_test_TID_srcc.append(test_TID_srcc)
114 | fold_10_test_TID_plcc.append(test_TID_plcc)
115 | fold_10_test_TID_krcc.append(test_TID_krcc)
116 |
117 | test_Koniq_srcc, test_Koniq_plcc, test_Koniq_krcc = solver.test(solver.test_data_Koniq)
118 | print('round{} Dataset:Koniq Test_SRCC:{} Test_PLCC:{} TEST_KRCC:{}\n'.format(i, test_Koniq_srcc, test_Koniq_plcc, test_Koniq_krcc))
119 | fold_10_test_Koniq_srcc.append(test_Koniq_srcc)
120 | fold_10_test_Koniq_plcc.append(test_Koniq_plcc)
121 | fold_10_test_Koniq_krcc.append(test_Koniq_krcc)
122 |
123 | print('Dataset:LIVE Test_SRCC:{} Test_PLCC:{} TEST_KRCC:{}\n'.format(np.mean(fold_10_test_LIVE_srcc), np.mean(fold_10_test_LIVE_plcc), np.mean(fold_10_test_LIVE_krcc)))
124 | print('Dataset:CSIQ Test_SRCC:{} Test_PLCC:{} TEST_KRCC:{}\n'.format(np.mean(fold_10_test_CSIQ_srcc), np.mean(fold_10_test_CSIQ_plcc), np.mean(fold_10_test_CSIQ_krcc)))
125 | print('Dataset:TID Test_SRCC:{} Test_PLCC:{} TEST_KRCC:{}\n'.format(np.mean(fold_10_test_TID_srcc), np.mean(fold_10_test_TID_plcc), np.mean(fold_10_test_TID_krcc)))
126 | print('Dataset:Koniq Test_SRCC:{} Test_PLCC:{} TEST_KRCC:{}\n'.format(np.mean(fold_10_test_Koniq_srcc), np.mean(fold_10_test_Koniq_plcc), np.mean(fold_10_test_Koniq_krcc)))
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
--------------------------------------------------------------------------------
/train_DistillationIQA_FR.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import random
4 | from dataloaders.dataloader_LQ_HQ import DataLoader
5 | from option_train_DistillationIQA_FR import set_args, check_args
6 | from scipy import stats
7 | import numpy as np
8 | from tools.nonlinear_convert import convert_obj_score
9 | from models.DistillationIQA import DistillationIQANet
10 |
11 | img_num = {
12 | 'kadid10k': list(range(0,10125)),
13 | 'live': list(range(0, 29)),#ref HR image
14 | 'csiq': list(range(0, 30)),#ref HR image
15 | 'tid2013': list(range(0, 25)),
16 | 'livec': list(range(0, 1162)),# no-ref image
17 | 'koniq-10k': list(range(0, 10073)),# no-ref image
18 | 'bid': list(range(0, 586)),# no-ref image
19 | }
20 | folder_path = {
21 | 'pipal':'./dataset/PIPAL',
22 | 'live': './dataset/LIVE/',
23 | 'csiq': './dataset/CSIQ/',
24 | 'tid2013': './dataset/TID2013/',
25 | 'livec': './dataset/LIVEC/',
26 | 'koniq-10k': './dataset/koniq-10k/',
27 | 'bid': './dataset/BID/',
28 | 'kadid10k':'./dataset/kadid10k/'
29 | }
30 |
31 |
32 | class DistillationFRIQASolver(object):
33 | def __init__(self, config):
34 | self.config = config
35 | self.device = torch.device('cuda' if config.gpu_ids is not None else 'cpu')
36 | self.txt_log_path = os.path.join(config.log_checkpoint_dir,'log.txt')
37 | with open(self.txt_log_path,"w+") as f:
38 | f.close()
39 |
40 | #model
41 | self.teacherNet = DistillationIQANet(self_patch_num=config.self_patch_num, distillation_layer=config.distillation_layer)
42 | if config.teacherNet_model_path:
43 | self.teacherNet._load_state_dict(torch.load(config.teacherNet_pretrained_path))
44 | self.teacherNet = self.teacherNet.to(self.device)
45 | self.teacherNet.train(True)
46 | #lr,opt,loss,epoch
47 | self.lr = config.lr
48 | self.lr_ratio = 10
49 | self.feature_loss_ratio = 0.1
50 | resnet_params = list(map(id, self.teacherNet.feature_extractor.parameters()))
51 | res_params = filter(lambda p: id(p) not in resnet_params, self.teacherNet.parameters())
52 | paras = [{'params': res_params, 'lr': self.lr * self.lr_ratio },
53 | {'params': self.teacherNet.feature_extractor.parameters(), 'lr': self.lr}
54 | ]
55 | self.optimizer = torch.optim.Adam(paras, weight_decay=config.weight_decay)
56 | self.mse_loss = torch.nn.MSELoss()
57 | self.l1_loss = torch.nn.L1Loss()
58 | self.epochs = config.epochs
59 |
60 | #data
61 | config.train_index = img_num[config.train_dataset]
62 | random.shuffle(config.train_index)
63 | train_loader = DataLoader(config.train_dataset, folder_path[config.train_dataset], config.train_index, config.patch_size, config.train_patch_num, batch_size=config.batch_size, istrain=True, self_patch_num=config.self_patch_num)
64 | test_loader_LIVE = DataLoader('live', folder_path['live'], img_num['live'], config.patch_size, config.test_patch_num, istrain=False, self_patch_num=config.self_patch_num)
65 | test_loader_CSIQ = DataLoader('csiq', folder_path['csiq'], img_num['csiq'], config.patch_size, config.test_patch_num, istrain=False, self_patch_num=config.self_patch_num)
66 | test_loader_TID = DataLoader('tid2013', folder_path['tid2013'], img_num['tid2013'], config.patch_size, config.test_patch_num, istrain=False, self_patch_num=config.self_patch_num)
67 |
68 | self.train_data = train_loader.get_dataloader()
69 | self.test_data_LIVE = test_loader_LIVE.get_dataloader()
70 | self.test_data_CSIQ = test_loader_CSIQ.get_dataloader()
71 | self.test_data_TID = test_loader_TID.get_dataloader()
72 |
73 | def train(self):
74 | best_srcc_LIVE, best_srcc_CSIQ, best_srcc_TID = 0.0, 0.0, 0.0
75 | best_plcc_LIVE, best_plcc_CSIQ, best_plcc_TID = 0.0, 0.0, 0.0
76 | best_krcc_LIVE, best_krcc_CSIQ, best_krcc_TID = 0.0, 0.0, 0.0
77 |
78 | print('Epoch\tTrain_Loss\tTrain_SRCC\tTest_SRCC\tTest_PLCC\tTest_KRCC')
79 | # NEW
80 | scaler = torch.cuda.amp.GradScaler()
81 |
82 | for t in range(self.epochs):
83 | epoch_loss = []
84 | pred_scores = []
85 | gt_scores = []
86 |
87 | for LQ_patches, refHQ_patches, label in self.train_data:
88 | LQ_patches, refHQ_patches, label = LQ_patches.to(self.device), refHQ_patches.to(self.device), label.to(self.device)
89 | self.optimizer.zero_grad()
90 |
91 | with torch.cuda.amp.autocast():
92 | _, _, pred = self.teacherNet(LQ_patches, refHQ_patches)
93 |
94 | pred_scores = pred_scores + pred.cpu().tolist()
95 | gt_scores = gt_scores + label.cpu().tolist()
96 | loss = self.l1_loss(pred.squeeze(), label.float().detach())
97 |
98 | epoch_loss.append(loss.item())
99 | scaler.scale(loss).backward()
100 | scaler.step(self.optimizer)
101 | scaler.update()
102 |
103 | train_srcc, _ = stats.spearmanr(pred_scores, gt_scores)
104 | test_LIVE_srcc, test_LIVE_plcc, test_LIVE_krcc = self.test(self.test_data_LIVE)
105 | test_CSIQ_srcc, test_CSIQ_plcc, test_CSIQ_krcc = self.test(self.test_data_CSIQ)
106 | test_TID_srcc, test_TID_plcc, test_TID_krcc = self.test(self.test_data_TID)
107 |
108 | if test_LIVE_srcc + test_LIVE_plcc + test_LIVE_krcc > best_srcc_LIVE + best_plcc_LIVE + best_krcc_LIVE:
109 | best_srcc_LIVE, best_srcc_CSIQ, best_srcc_TID = test_LIVE_srcc, test_CSIQ_srcc, test_TID_srcc
110 | print('%d:live\t%4.3f\t\t%4.4f\t\t%4.4f\t\t%4.4f\t\t%4.4f \n' %
111 | (t, sum(epoch_loss) / len(epoch_loss), train_srcc, test_LIVE_srcc, test_LIVE_plcc, test_LIVE_krcc))
112 |
113 | if test_CSIQ_srcc + test_CSIQ_plcc + test_CSIQ_krcc > best_srcc_CSIQ + best_plcc_CSIQ + best_krcc_CSIQ:
114 | best_plcc_LIVE, best_plcc_CSIQ, best_plcc_TID = test_LIVE_plcc, test_CSIQ_plcc, test_TID_plcc
115 | print('%d:csiq\t%4.3f\t\t%4.4f\t\t%4.4f\t\t%4.4f\t\t%4.4f \n' %
116 | (t, sum(epoch_loss) / len(epoch_loss), train_srcc, test_CSIQ_srcc, test_CSIQ_plcc, test_CSIQ_krcc))
117 |
118 | if test_TID_srcc + test_TID_plcc + test_TID_krcc > best_srcc_TID + best_plcc_TID + best_krcc_TID:
119 | best_krcc_LIVE, best_krcc_CSIQ, best_krcc_TID = test_LIVE_krcc, test_CSIQ_krcc, test_TID_krcc
120 | print('%d:tid\t%4.3f\t\t%4.4f\t\t%4.4f\t\t%4.4f\t\t%4.4f \n' %
121 | (t, sum(epoch_loss) / len(epoch_loss), train_srcc, test_TID_srcc, test_TID_plcc, test_TID_krcc))
122 |
123 | torch.save(self.teacherNet.state_dict(), os.path.join(self.config.model_checkpoint_dir, 'FRIQA_{}_saved_model.pth'.format(t)))
124 |
125 | self.lr = self.lr / pow(10, (t // self.config.update_opt_epoch))
126 | if t > 20:
127 | self.lr_ratio = 1
128 | resnet_params = list(map(id, self.teacherNet.feature_extractor.parameters()))
129 | rest_params = filter(lambda p: id(p) not in resnet_params, self.teacherNet.parameters())
130 | paras = [{'params': rest_params, 'lr': self.lr * self.lr_ratio },
131 | {'params': self.teacherNet.feature_extractor.parameters(), 'lr': self.lr}
132 | ]
133 | self.optimizer = torch.optim.Adam(paras, weight_decay=self.config.weight_decay)
134 |
135 | print('Best live test SRCC %f, PLCC %f, KRCC %f\n' % (best_srcc_LIVE, best_plcc_LIVE, best_krcc_LIVE))
136 | print('Best csiq test SRCC %f, PLCC %f, KRCC %f\n' % (best_srcc_CSIQ, best_plcc_CSIQ, best_krcc_CSIQ))
137 | print('Best tid2013 test SRCC %f, PLCC %f, KRCC %f\n' % (best_srcc_TID, best_plcc_TID, best_krcc_TID))
138 |
139 |
140 | def test(self, test_data):
141 | self.teacherNet.train(False)
142 | test_pred_scores, test_gt_scores = [], []
143 | for LQ_patches, refHQ_patches, label in test_data:
144 | LQ_patches, refHQ_patches, label = LQ_patches.to(self.device), refHQ_patches.to(self.device), label.to(self.device)
145 | with torch.no_grad():
146 | _, _, pred = self.teacherNet(LQ_patches, refHQ_patches)
147 | test_pred_scores.append(float(pred.item()))
148 | test_gt_scores = test_gt_scores + label.cpu().tolist()
149 | if self.config.use_fitting_prcc_srcc:
150 | fitting_pred_scores = convert_obj_score(test_pred_scores, test_gt_scores)
151 | test_pred_scores = np.mean(np.reshape(np.array(test_pred_scores), (-1, self.config.test_patch_num)), axis=1)
152 | test_gt_scores = np.mean(np.reshape(np.array(test_gt_scores), (-1, self.config.test_patch_num)), axis=1)
153 | test_srcc, _ = stats.spearmanr(test_pred_scores, test_gt_scores)
154 | if self.config.use_fitting_prcc_srcc:
155 | test_plcc, _ = stats.pearsonr(fitting_pred_scores, test_gt_scores)
156 | else:
157 | test_plcc, _ = stats.pearsonr(test_pred_scores, test_gt_scores)
158 | test_krcc, _ = stats.stats.kendalltau(test_pred_scores, test_gt_scores)
159 | test_srcc, test_plcc, test_krcc = abs(test_srcc), abs(test_plcc), abs(test_krcc)
160 | self.teacherNet.train(True)
161 | return test_srcc, test_plcc, test_krcc
162 |
163 | if __name__ == "__main__":
164 | config = set_args()
165 | config = check_args(config)
166 | solver = DistillationFRIQASolver(config=config)
167 | solver.train()
168 |
169 |
170 |
171 |
172 |
173 |
--------------------------------------------------------------------------------
/train_DistillationIQA.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import random
4 | from dataloaders.dataloader_LQ_HQ_diff_content_HQ import DataLoader
5 | from option_train_DistillationIQA import set_args, check_args
6 | from scipy import stats
7 | import numpy as np
8 | from tools.nonlinear_convert import convert_obj_score
9 | from models.DistillationIQA import DistillationIQANet
10 |
11 | img_num = {
12 | 'kadid10k': list(range(0,10125)),
13 | 'live': list(range(0, 29)),#ref HR image
14 | 'csiq': list(range(0, 30)),#ref HR image
15 | 'tid2013': list(range(0, 25)),
16 | 'livec': list(range(0, 1162)),# no-ref image
17 | 'koniq-10k': list(range(0, 10073)),# no-ref image
18 | 'bid': list(range(0, 586)),# no-ref image
19 | }
20 | folder_path = {
21 | 'pipal':'./dataset/PIPAL',
22 | 'live': './dataset/LIVE/',
23 | 'csiq': './dataset/CSIQ/',
24 | 'tid2013': './dataset/TID2013/',
25 | 'livec': './dataset/LIVEC/',
26 | 'koniq-10k': './dataset/koniq-10k/',
27 | 'bid': './dataset/BID/',
28 | 'kadid10k':'./dataset/kadid10k/'
29 | }
30 |
31 |
32 | class DistillationIQASolver(object):
33 | def __init__(self, config):
34 | self.config = config
35 | self.device = torch.device('cuda' if config.gpu_ids is not None else 'cpu')
36 | self.txt_log_path = os.path.join(config.log_checkpoint_dir,'log.txt')
37 | with open(self.txt_log_path,"w+") as f:
38 | f.close()
39 |
40 | #model
41 | self.teacherNet = DistillationIQANet(self_patch_num=config.self_patch_num, distillation_layer=config.distillation_layer)
42 | if config.teacherNet_model_path:
43 | self.teacherNet._load_state_dict(torch.load(config.teacherNet_model_path))
44 | self.teacherNet = self.teacherNet.to(self.device)
45 | self.teacherNet.train(False)
46 | self.studentNet = DistillationIQANet(self_patch_num=config.self_patch_num, distillation_layer=config.distillation_layer)
47 | if config.studentNet_model_path:
48 | self.studentNet._load_state_dict(torch.load(config.studentNet_model_path))
49 | self.studentNet = self.studentNet.to(self.device)
50 | self.studentNet.train(True)
51 |
52 | #lr,opt,loss,epoch
53 | self.lr = config.lr
54 | self.lr_ratio = 1
55 | self.feature_loss_ratio = 1
56 | resnet_params = list(map(id, self.studentNet.feature_extractor.parameters()))
57 | res_params = filter(lambda p: id(p) not in resnet_params, self.studentNet.parameters())
58 | paras = [{'params': res_params, 'lr': self.lr * self.lr_ratio },
59 | {'params': self.studentNet.feature_extractor.parameters(), 'lr': self.lr}
60 | ]
61 | self.optimizer = torch.optim.Adam(paras, weight_decay=config.weight_decay)
62 | self.mse_loss = torch.nn.MSELoss()
63 | self.l1_loss = torch.nn.L1Loss()
64 | self.epochs = config.epochs
65 |
66 | #data
67 | config.train_index = img_num[config.train_dataset]
68 | random.shuffle(config.train_index)
69 | train_loader = DataLoader(config.train_dataset, folder_path[config.train_dataset], config.ref_train_dataset_path, config.train_index, config.patch_size, config.train_patch_num, batch_size=config.batch_size, istrain=True, self_patch_num=config.self_patch_num)
70 | test_loader_LIVE = DataLoader('live', folder_path['live'], config.ref_test_dataset_path, img_num['live'], config.patch_size, config.test_patch_num, istrain=False, self_patch_num=config.self_patch_num)
71 | test_loader_CSIQ = DataLoader('csiq', folder_path['csiq'], config.ref_test_dataset_path, img_num['csiq'], config.patch_size, config.test_patch_num, istrain=False, self_patch_num=config.self_patch_num)
72 | test_loader_TID = DataLoader('tid2013', folder_path['tid2013'], config.ref_test_dataset_path, img_num['tid2013'], config.patch_size, config.test_patch_num, istrain=False, self_patch_num=config.self_patch_num)
73 | test_loader_Koniq = DataLoader('koniq-10k', folder_path['koniq-10k'], config.ref_test_dataset_path, img_num['koniq-10k'], config.patch_size, config.test_patch_num, istrain=False, self_patch_num=config.self_patch_num)
74 |
75 | self.train_data = train_loader.get_dataloader()
76 | self.test_data_LIVE = test_loader_LIVE.get_dataloader()
77 | self.test_data_CSIQ = test_loader_CSIQ.get_dataloader()
78 | self.test_data_TID = test_loader_TID.get_dataloader()
79 | self.test_data_Koniq = test_loader_Koniq.get_dataloader()
80 |
81 |
82 | def train(self):
83 | best_srcc_LIVE, best_srcc_CSIQ, best_srcc_TID, best_srcc_Koniq = 0.0, 0.0, 0.0, 0.0
84 | best_plcc_LIVE, best_plcc_CSIQ, best_plcc_TID, best_plcc_Koniq = 0.0, 0.0, 0.0, 0.0
85 | best_krcc_LIVE, best_krcc_CSIQ, best_krcc_TID, best_krcc_Koniq = 0.0, 0.0, 0.0, 0.0
86 |
87 | print('Epoch\tTrain_Loss\tTrain_SRCC\tTest_SRCC\tTest_PLCC\tTest_KRCC')
88 |
89 | # NEW
90 | scaler = torch.cuda.amp.GradScaler()
91 |
92 | for t in range(self.epochs):
93 | epoch_loss = []
94 | pred_scores = []
95 | gt_scores = []
96 |
97 | for LQ_patches, refHQ_patches, ref_patches, label in self.train_data:
98 | LQ_patches, refHQ_patches, ref_patches, label = LQ_patches.to(self.device), refHQ_patches.to(self.device), ref_patches.to(self.device), label.to(self.device)
99 | self.optimizer.zero_grad()
100 |
101 | with torch.cuda.amp.autocast():
102 | t_encode_diff_inner_feature, t_decode_inner_feature, _ = self.teacherNet(LQ_patches, refHQ_patches)
103 | s_encode_diff_inner_feature, s_decode_inner_feature, pred = self.studentNet(LQ_patches, ref_patches)
104 |
105 | pred_scores = pred_scores + pred.cpu().tolist()
106 | gt_scores = gt_scores + label.cpu().tolist()
107 | pred_loss = self.l1_loss(pred.squeeze(), label.float().detach())
108 |
109 | encode_diff_loss, decode_loss = 0.0, 0.0
110 | for t_encode_diff_feature, s_encode_diff_feature, t_decode_feature, s_decode_feature in zip(t_encode_diff_inner_feature, s_encode_diff_inner_feature, t_decode_inner_feature, s_decode_inner_feature):
111 | #mse_loss
112 | feature_loss += self.mse_loss(t_encode_diff_feature, s_encode_diff_feature)
113 | # encode_diff_loss += self.mse_loss(t_encode_diff_feature, s_encode_diff_feature)
114 | # decode_loss += self.mse_loss(t_decode_feature, s_decode_feature)
115 | # feature_loss = encode_diff_loss + decode_loss
116 | loss = pred_loss + feature_loss*self.feature_loss_ratio
117 | epoch_loss.append(loss.item())
118 | scaler.scale(loss).backward()
119 | scaler.step(self.optimizer)
120 | scaler.update()
121 |
122 | train_srcc, _ = stats.spearmanr(pred_scores, gt_scores)
123 | test_LIVE_srcc, test_LIVE_plcc, test_LIVE_krcc = self.test(self.test_data_LIVE)
124 | test_CSIQ_srcc, test_CSIQ_plcc, test_CSIQ_krcc = self.test(self.test_data_CSIQ)
125 | test_TID_srcc, test_TID_plcc, test_TID_krcc = self.test(self.test_data_TID)
126 | test_Koniq_srcc, test_Koniq_plcc, test_Koniq_krcc = solver.test(solver.test_data_Koniq)
127 |
128 | if test_LIVE_srcc + test_LIVE_plcc + test_LIVE_krcc > best_srcc_LIVE + best_plcc_LIVE + best_krcc_LIVE:
129 | best_srcc_LIVE, best_srcc_CSIQ, best_srcc_TID = test_LIVE_srcc, test_CSIQ_srcc, test_TID_srcc
130 | print('%d:live\t%4.3f\t\t%4.4f\t\t%4.4f\t\t%4.4f\t\t%4.4f \n' %
131 | (t, sum(epoch_loss) / len(epoch_loss), train_srcc, test_LIVE_srcc, test_LIVE_plcc, test_LIVE_krcc))
132 |
133 | if test_CSIQ_srcc + test_CSIQ_plcc + test_CSIQ_krcc > best_srcc_CSIQ + best_plcc_CSIQ + best_krcc_CSIQ:
134 | best_plcc_LIVE, best_plcc_CSIQ, best_plcc_TID = test_LIVE_plcc, test_CSIQ_plcc, test_TID_plcc
135 | print('%d:csiq\t%4.3f\t\t%4.4f\t\t%4.4f\t\t%4.4f\t\t%4.4f \n' %
136 | (t, sum(epoch_loss) / len(epoch_loss), train_srcc, test_CSIQ_srcc, test_CSIQ_plcc, test_CSIQ_krcc))
137 |
138 | if test_TID_srcc + test_TID_plcc + test_TID_krcc > best_srcc_TID + best_plcc_TID + best_krcc_TID:
139 | best_krcc_LIVE, best_krcc_CSIQ, best_krcc_TID = test_LIVE_krcc, test_CSIQ_krcc, test_TID_krcc
140 | print('%d:tid\t%4.3f\t\t%4.4f\t\t%4.4f\t\t%4.4f\t\t%4.4f \n' %
141 | (t, sum(epoch_loss) / len(epoch_loss), train_srcc, test_TID_srcc, test_TID_plcc, test_TID_krcc))
142 |
143 | if test_Koniq_srcc + test_Koniq_plcc + test_Koniq_krcc > best_srcc_Koniq + best_plcc_Koniq + best_krcc_Koniq:
144 | print('%d:koniq-10k\t%4.3f\t\t%4.4f\t\t%4.4f\t\t%4.4f\t\t%4.4f \n' %
145 | (t, sum(epoch_loss) / len(epoch_loss), train_srcc, test_Koniq_srcc, test_Koniq_plcc, test_Koniq_krcc))
146 | best_srcc_Koniq, best_plcc_Koniq, best_krcc_Koniq = test_Koniq_srcc, test_Koniq_plcc, test_Koniq_krcc
147 |
148 | torch.save(self.studentNet.state_dict(), os.path.join(self.config.model_checkpoint_dir, 'Distillation_inner_{}_saved_model.pth'.format(t)))
149 |
150 | self.lr = self.lr / pow(10, (t // self.config.update_opt_epoch))
151 | if t > 20:
152 | self.lr_ratio = 1
153 | resnet_params = list(map(id, self.studentNet.feature_extractor.parameters()))
154 | rest_params = filter(lambda p: id(p) not in resnet_params, self.studentNet.parameters())
155 | paras = [{'params': rest_params, 'lr': self.lr * self.lr_ratio },
156 | {'params': self.studentNet.feature_extractor.parameters(), 'lr': self.lr}
157 | ]
158 | self.optimizer = torch.optim.Adam(paras, weight_decay=self.config.weight_decay)
159 | print('Best live test SRCC %f, PLCC %f, KRCC %f\n' % (best_srcc_LIVE, best_plcc_LIVE, best_krcc_LIVE))
160 | print('Best csiq test SRCC %f, PLCC %f, KRCC %f\n' % (best_srcc_CSIQ, best_plcc_CSIQ, best_krcc_CSIQ))
161 | print('Best tid2013 test SRCC %f, PLCC %f, KRCC %f\n' % (best_srcc_TID, best_plcc_TID, best_krcc_TID))
162 | print('Best koniq-10k test SRCC %f, PLCC %f, KRCC %f\n' % (best_srcc_Koniq, best_plcc_Koniq, best_krcc_Koniq))
163 |
164 |
165 | def test(self, test_data):
166 | self.studentNet.train(False)
167 | test_pred_scores, test_gt_scores = [], []
168 | for LQ_patches, _, ref_patches, label in test_data:
169 | LQ_patches, ref_patches, label = LQ_patches.to(self.device), ref_patches.to(self.device), label.to(self.device)
170 | with torch.no_grad():
171 | _, _, pred = self.studentNet(LQ_patches, ref_patches)
172 | test_pred_scores.append(float(pred.item()))
173 | test_gt_scores = test_gt_scores + label.cpu().tolist()
174 | if self.config.use_fitting_prcc_srcc:
175 | fitting_pred_scores = convert_obj_score(test_pred_scores, test_gt_scores)
176 | test_pred_scores = np.mean(np.reshape(np.array(test_pred_scores), (-1, self.config.test_patch_num)), axis=1)
177 | test_gt_scores = np.mean(np.reshape(np.array(test_gt_scores), (-1, self.config.test_patch_num)), axis=1)
178 | test_srcc, _ = stats.spearmanr(test_pred_scores, test_gt_scores)
179 | if self.config.use_fitting_prcc_srcc:
180 | test_plcc, _ = stats.pearsonr(fitting_pred_scores, test_gt_scores)
181 | else:
182 | test_plcc, _ = stats.pearsonr(test_pred_scores, test_gt_scores)
183 | test_krcc, _ = stats.stats.kendalltau(test_pred_scores, test_gt_scores)
184 | test_srcc, test_plcc, test_krcc = abs(test_srcc), abs(test_plcc), abs(test_krcc)
185 | self.studentNet.train(True)
186 | return test_srcc, test_plcc, test_krcc
187 |
188 | if __name__ == "__main__":
189 | config = set_args()
190 | config = check_args(config)
191 | solver = DistillationIQASolver(config=config)
192 | solver.train()
193 |
194 |
195 |
196 |
197 |
198 |
--------------------------------------------------------------------------------
/models/TRIQ.py:
--------------------------------------------------------------------------------
1 | import torch as torch
2 | from torch._C import device
3 | import torch.nn as nn
4 | import math
5 | from torch.nn.modules.conv import Conv2d
6 | import torch.utils.model_zoo as model_zoo
7 | from torch.nn import Dropout, Softmax, Linear, LayerNorm
8 | # from SemanticResNet50 import ResNetBackbone, Bottleneck
9 | # from models.ResNet50_MLP import ResNetBackbone
10 | # from models.MLP_return_inner_feature import MLPMixer
11 |
12 |
13 | #ResNet
14 | model_urls = {
15 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
16 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
17 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
18 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
19 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
20 | }
21 |
22 | class Bottleneck(nn.Module):
23 | expansion = 4
24 |
25 | def __init__(self, inplanes, planes, stride=1, downsample=None):
26 | super(Bottleneck, self).__init__()
27 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
28 | self.bn1 = nn.BatchNorm2d(planes)
29 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
30 | padding=1, bias=False)
31 | self.bn2 = nn.BatchNorm2d(planes)
32 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
33 | self.bn3 = nn.BatchNorm2d(planes * 4)
34 | self.relu = nn.ReLU(inplace=True)
35 | self.downsample = downsample
36 | self.stride = stride
37 |
38 | def forward(self, x):
39 | residual = x
40 |
41 | out = self.conv1(x)
42 | out = self.bn1(out)
43 | out = self.relu(out)
44 |
45 | out = self.conv2(out)
46 | out = self.bn2(out)
47 | out = self.relu(out)
48 |
49 | out = self.conv3(out)
50 | out = self.bn3(out)
51 |
52 | if self.downsample is not None:
53 | residual = self.downsample(x)
54 |
55 | out += residual
56 | out = self.relu(out)
57 |
58 | return out
59 |
60 | class ResNetBackbone(nn.Module):
61 |
62 | # def __init__(self, outc=176, block=Bottleneck, layers=[3, 4, 6, 3], num_classes=1000, pretrained=True):
63 | def __init__(self, outc=2048, block=Bottleneck, layers=[3, 4, 6, 3], pretrained=True):
64 | super(ResNetBackbone, self).__init__()
65 | self.pretrained = pretrained
66 | self.inplanes = 64
67 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
68 | self.bn1 = nn.BatchNorm2d(64)
69 | self.relu = nn.ReLU(inplace=True)
70 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
71 | self.layer1 = self._make_layer(block, 64, layers[0])
72 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
73 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
74 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
75 |
76 | self.lda_out_channels = int(outc // 4)
77 |
78 | # local distortion aware module
79 | self.lda1_pool = nn.Sequential(
80 | nn.Conv2d(256, 16, kernel_size=1, stride=1, padding=0, bias=False),
81 | nn.AvgPool2d(7, stride=7),
82 | )
83 | self.lda1_fc = nn.Linear(16 * 64, self.lda_out_channels)
84 |
85 | self.lda2_pool = nn.Sequential(
86 | nn.Conv2d(512, 32, kernel_size=1, stride=1, padding=0, bias=False),
87 | nn.AvgPool2d(7, stride=7),
88 | )
89 | self.lda2_fc = nn.Linear(32 * 16, self.lda_out_channels)
90 |
91 | self.lda3_pool = nn.Sequential(
92 | nn.Conv2d(1024, 64, kernel_size=1, stride=1, padding=0, bias=False),
93 | nn.AvgPool2d(7, stride=7),
94 | )
95 | self.lda3_fc = nn.Linear(64 * 4, self.lda_out_channels)
96 |
97 | self.lda4_pool = nn.AvgPool2d(7, stride=7)
98 | self.lda4_fc = nn.Linear(2048, self.lda_out_channels)
99 |
100 | for m in self.modules():
101 | if isinstance(m, nn.Conv2d):
102 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
103 | m.weight.data.normal_(0, math.sqrt(2. / n))
104 | elif isinstance(m, nn.BatchNorm2d):
105 | m.weight.data.fill_(1)
106 | m.bias.data.zero_()
107 | if self.pretrained:
108 | self.load_resnet50_backbone()
109 |
110 | def _make_layer(self, block, planes, blocks, stride=1):
111 | downsample = None
112 | if stride != 1 or self.inplanes != planes * block.expansion:
113 | downsample = nn.Sequential(
114 | nn.Conv2d(self.inplanes, planes * block.expansion,
115 | kernel_size=1, stride=stride, bias=False),
116 | nn.BatchNorm2d(planes * block.expansion),
117 | )
118 |
119 | layers = []
120 | layers.append(block(self.inplanes, planes, stride, downsample))
121 | self.inplanes = planes * block.expansion
122 | for i in range(1, blocks):
123 | layers.append(block(self.inplanes, planes))
124 |
125 | return nn.Sequential(*layers)
126 |
127 |
128 | def forward(self, x):
129 |
130 | x = self.conv1(x)
131 | x = self.bn1(x)
132 | x = self.relu(x)
133 | x = self.maxpool(x)
134 | x = self.layer1(x)
135 | x = self.layer2(x)
136 | x = self.layer3(x)
137 | x = self.layer4(x)
138 | return x
139 |
140 | def load_resnet50_backbone(self):
141 | """Constructs a ResNet-50 model_hyper.
142 | Args:
143 | pretrained (bool): If True, returns a model_hyper pre-trained on ImageNet
144 | """
145 | save_model = model_zoo.load_url(model_urls['resnet50'])
146 | model_dict = self.state_dict()
147 | state_dict = {k: v for k, v in save_model.items() if k in model_dict.keys()}
148 | model_dict.update(state_dict)
149 | self.load_state_dict(model_dict)
150 |
151 | class Mlp(nn.Module):
152 | def __init__(self):
153 | super(Mlp, self).__init__()
154 | self.hidden_size = 32
155 | self.mlp_size = 64
156 | self.dropout_rate = 0.1
157 | self.fc1 = Linear(self.hidden_size, self.mlp_size)
158 | self.fc2 = Linear(self.mlp_size, self.hidden_size)
159 | self.act_fn = nn.ReLU()
160 | self.dropout = Dropout(self.dropout_rate)
161 |
162 | self._init_weights()
163 |
164 | def _init_weights(self):
165 | nn.init.xavier_uniform_(self.fc1.weight)
166 | nn.init.xavier_uniform_(self.fc2.weight)
167 | nn.init.normal_(self.fc1.bias, std=1e-6)
168 | nn.init.normal_(self.fc2.bias, std=1e-6)
169 |
170 | def forward(self, x):
171 | x = self.fc1(x)
172 | x = self.act_fn(x)
173 | x = self.dropout(x)
174 | x = self.fc2(x)
175 | x = self.dropout(x)
176 | return x
177 |
178 | class MHA(nn.Module):
179 | def __init__(self):
180 | super(MHA, self).__init__()
181 | self.vis = True
182 | self.num_attention_heads = 8
183 | self.hidden_size = 32
184 | self.attention_head_size = int(self.hidden_size / self.num_attention_heads)
185 | self.all_head_size = self.num_attention_heads * self.attention_head_size
186 |
187 | self.query = Linear(self.hidden_size, self.all_head_size)
188 | self.key = Linear(self.hidden_size, self.all_head_size)
189 | self.value = Linear(self.hidden_size, self.all_head_size)
190 |
191 | self.out = Linear(self.hidden_size, self.hidden_size)
192 | self.attn_dropout = Dropout(0.0)
193 | self.proj_dropout = Dropout(0.0)
194 |
195 | self.softmax = Softmax(dim=-1)
196 |
197 | def transpose_for_scores(self, x):
198 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
199 | x = x.view(*new_x_shape)
200 | return x.permute(0, 2, 1, 3)
201 |
202 | def forward(self, x):
203 | mixed_key_layer = self.key(x)
204 | mixed_value_layer = self.value(x)
205 | mixed_query_layer = self.query(x)
206 |
207 | query_layer = self.transpose_for_scores(mixed_query_layer)
208 | key_layer = self.transpose_for_scores(mixed_key_layer)
209 | value_layer = self.transpose_for_scores(mixed_value_layer)
210 |
211 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
212 | attention_scores = attention_scores / math.sqrt(self.attention_head_size)
213 | attention_probs = self.softmax(attention_scores)
214 | weights = attention_probs
215 | attention_probs = self.attn_dropout(attention_probs)
216 |
217 | context_layer = torch.matmul(attention_probs, value_layer)
218 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
219 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
220 | context_layer = context_layer.view(*new_context_layer_shape)
221 | attention_output = self.out(context_layer)
222 | attention_output = self.proj_dropout(attention_output)
223 | return attention_output
224 |
225 | class TransformerBlock(nn.Module):
226 | def __init__(self):
227 | super(TransformerBlock, self).__init__()
228 | self.hidden_size = 32
229 | self.ffn_norm = LayerNorm(self.hidden_size, eps=1e-6)
230 | self.ffn = Mlp()
231 | self.attention_norm = LayerNorm(self.hidden_size, eps=1e-6)
232 | self.attn = MHA()
233 |
234 | self.dropout_rate = 0.1
235 | self.dropout1 = Dropout(self.dropout_rate)
236 | self.dropout2 = Dropout(self.dropout_rate)
237 |
238 | def forward(self, x):
239 | x1 = self.attn(x)
240 | x1 = self.dropout1(x1)
241 | x1 += x
242 | x1 = self.attention_norm(x1)
243 |
244 | x2 = self.ffn(x1)
245 | x2 = self.dropout2(x2)
246 | x2 += x1
247 | x2 = self.attention_norm(x2)
248 |
249 | return x2
250 |
251 | class RegressionFCNet(nn.Module):
252 | def __init__(self):
253 | super(RegressionFCNet, self).__init__()
254 | self.target_in_size=32
255 | self.target_fc1_size=64
256 |
257 | self.l1 = nn.Linear(self.target_in_size, self.target_fc1_size)
258 | self.relu = nn.ReLU()
259 | self.l2 = nn.Linear(self.target_fc1_size, 1)
260 |
261 |
262 | def forward(self, x):
263 | q = self.l1(x)
264 | q = self.relu(q)
265 | q = self.l2(q).squeeze()
266 | return q
267 |
268 | class TRIQ(nn.Module):
269 | def __init__(self):
270 | super(TRIQ, self).__init__()
271 | self.feature_extractor = ResNetBackbone()
272 | # for param in self.feature_extractor.parameters():
273 | # param.requires_grad = False
274 |
275 | self.conv = Conv2d(2048,32,kernel_size=1, stride=1, padding=0)
276 | self.position_embeddings = nn.Parameter(torch.zeros(1, 49+1, 32))
277 | self.quality_token = nn.Parameter(torch.zeros(1, 1, 32))
278 |
279 | self.transformer_block1 = TransformerBlock()
280 | self.transformer_block2 = TransformerBlock()
281 |
282 | self.regressor = RegressionFCNet()
283 |
284 | def cal_params(self):
285 | params = list(self.parameters())
286 | k = 0
287 | for i in params:
288 | l = 1
289 | for j in i.size():
290 | l *= j
291 | k = k + l
292 | print("Total parameters is :" + str(k))
293 |
294 | def forward(self, LQ):
295 | B = LQ.shape[0]
296 | feature_LQ = self.feature_extractor(LQ)
297 | feature_LQ = self.conv(feature_LQ)
298 |
299 | quality_tokens = self.quality_token.expand(B,-1,-1)
300 |
301 | flat_feature_LQ = feature_LQ.flatten(2).transpose(-1,-2)
302 | flat_feature_LQ = torch.cat((quality_tokens,flat_feature_LQ), dim=1) + self.position_embeddings
303 |
304 | f = self.transformer_block1(flat_feature_LQ)
305 | f = self.transformer_block2(f)
306 |
307 | y = self.regressor(f)
308 |
309 | return y
310 |
311 | '''
312 | TEST
313 | Run this code with:
314 | ```
315 | cd $HOME/pretrained-models.pytorch
316 | python -m pretrainedmodels.inceptionresnetv2
317 | ```
318 | '''
319 | if __name__ == '__main__':
320 | import time
321 | device = torch.device('cuda')
322 | LQ = torch.rand((1,3,224, 224)).to(device)
323 | net = TRIQ().to(device)
324 | net.cal_params()
325 |
326 | torch.cuda.synchronize()
327 | start = time.time()
328 | a = net(LQ)
329 | torch.cuda.synchronize()
330 | end = time.time()
331 | print("run time is :" + str(end-start))
332 |
--------------------------------------------------------------------------------
/models/HyperIQA.py:
--------------------------------------------------------------------------------
1 | import torch as torch
2 | import torch.nn as nn
3 | from torch.nn import functional as F
4 | from torch.nn import init
5 | import math
6 | import torch.utils.model_zoo as model_zoo
7 |
8 | model_urls = {
9 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
10 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
11 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
12 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
13 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
14 | }
15 |
16 |
17 | class HyperNet(nn.Module):
18 | """
19 | Hyper network for learning perceptual rules.
20 |
21 | Args:
22 | lda_out_channels: local distortion aware module output size.
23 | hyper_in_channels: input feature channels for hyper network.
24 | target_in_size: input vector size for target network.
25 | target_fc(i)_size: fully connection layer size of target network.
26 | feature_size: input feature map width/height for hyper network.
27 |
28 | Note:
29 | For size match, input args must satisfy: 'target_fc(i)_size * target_fc(i+1)_size' is divisible by 'feature_size ^ 2'.
30 |
31 | """
32 | def __init__(self, lda_out_channels=16, hyper_in_channels=112, target_in_size=224, target_fc1_size=112, target_fc2_size=56, target_fc3_size=28, target_fc4_size=14, feature_size=7):
33 | super(HyperNet, self).__init__()
34 |
35 | self.hyperInChn = hyper_in_channels
36 | self.target_in_size = target_in_size
37 | self.f1 = target_fc1_size
38 | self.f2 = target_fc2_size
39 | self.f3 = target_fc3_size
40 | self.f4 = target_fc4_size
41 | self.feature_size = feature_size
42 |
43 | self.res = resnet50_backbone(lda_out_channels, target_in_size, pretrained=True)
44 |
45 | self.pool = nn.AdaptiveAvgPool2d((1, 1))
46 |
47 | # Conv layers for resnet output features
48 | self.conv1 = nn.Sequential(
49 | nn.Conv2d(2048, 1024, 1, padding=(0, 0)),
50 | nn.ReLU(inplace=True),
51 | nn.Conv2d(1024, 512, 1, padding=(0, 0)),
52 | nn.ReLU(inplace=True),
53 | nn.Conv2d(512, self.hyperInChn, 1, padding=(0, 0)),
54 | nn.ReLU(inplace=True)
55 | )
56 |
57 | # Hyper network part, conv for generating target fc weights, fc for generating target fc biases
58 | self.fc1w_conv = nn.Conv2d(self.hyperInChn, int(self.target_in_size * self.f1 / feature_size ** 2), 3, padding=(1, 1))
59 | self.fc1b_fc = nn.Linear(self.hyperInChn, self.f1)
60 |
61 | self.fc2w_conv = nn.Conv2d(self.hyperInChn, int(self.f1 * self.f2 / feature_size ** 2), 3, padding=(1, 1))
62 | self.fc2b_fc = nn.Linear(self.hyperInChn, self.f2)
63 |
64 | self.fc3w_conv = nn.Conv2d(self.hyperInChn, int(self.f2 * self.f3 / feature_size ** 2), 3, padding=(1, 1))
65 | self.fc3b_fc = nn.Linear(self.hyperInChn, self.f3)
66 |
67 | self.fc4w_conv = nn.Conv2d(self.hyperInChn, int(self.f3 * self.f4 / feature_size ** 2), 3, padding=(1, 1))
68 | self.fc4b_fc = nn.Linear(self.hyperInChn, self.f4)
69 |
70 | self.fc5w_fc = nn.Linear(self.hyperInChn, self.f4)
71 | self.fc5b_fc = nn.Linear(self.hyperInChn, 1)
72 |
73 | # initialize
74 | for i, m_name in enumerate(self._modules):
75 | if i > 2:
76 | nn.init.kaiming_normal_(self._modules[m_name].weight.data)
77 |
78 | def forward(self, img):
79 | feature_size = self.feature_size
80 |
81 | res_out = self.res(img)
82 |
83 | # input vector for target net
84 | target_in_vec = res_out['target_in_vec'].view(-1, self.target_in_size, 1, 1)
85 |
86 | # input features for hyper net
87 | hyper_in_feat = self.conv1(res_out['hyper_in_feat']).view(-1, self.hyperInChn, feature_size, feature_size)
88 |
89 | # generating target net weights & biases
90 | target_fc1w = self.fc1w_conv(hyper_in_feat).view(-1, self.f1, self.target_in_size, 1, 1)
91 | target_fc1b = self.fc1b_fc(self.pool(hyper_in_feat).squeeze()).view(-1, self.f1)
92 |
93 | target_fc2w = self.fc2w_conv(hyper_in_feat).view(-1, self.f2, self.f1, 1, 1)
94 | target_fc2b = self.fc2b_fc(self.pool(hyper_in_feat).squeeze()).view(-1, self.f2)
95 |
96 | target_fc3w = self.fc3w_conv(hyper_in_feat).view(-1, self.f3, self.f2, 1, 1)
97 | target_fc3b = self.fc3b_fc(self.pool(hyper_in_feat).squeeze()).view(-1, self.f3)
98 |
99 | target_fc4w = self.fc4w_conv(hyper_in_feat).view(-1, self.f4, self.f3, 1, 1)
100 | target_fc4b = self.fc4b_fc(self.pool(hyper_in_feat).squeeze()).view(-1, self.f4)
101 |
102 | target_fc5w = self.fc5w_fc(self.pool(hyper_in_feat).squeeze()).view(-1, 1, self.f4, 1, 1)
103 | target_fc5b = self.fc5b_fc(self.pool(hyper_in_feat).squeeze()).view(-1, 1)
104 |
105 | out = {}
106 | out['target_in_vec'] = target_in_vec
107 | out['target_fc1w'] = target_fc1w
108 | out['target_fc1b'] = target_fc1b
109 | out['target_fc2w'] = target_fc2w
110 | out['target_fc2b'] = target_fc2b
111 | out['target_fc3w'] = target_fc3w
112 | out['target_fc3b'] = target_fc3b
113 | out['target_fc4w'] = target_fc4w
114 | out['target_fc4b'] = target_fc4b
115 | out['target_fc5w'] = target_fc5w
116 | out['target_fc5b'] = target_fc5b
117 |
118 | return out
119 |
120 |
121 | class TargetNet(nn.Module):
122 | """
123 | Target network for quality prediction.
124 | """
125 | def __init__(self, paras):
126 | super(TargetNet, self).__init__()
127 | self.l1 = nn.Sequential(
128 | TargetFC(paras['target_fc1w'], paras['target_fc1b']),
129 | nn.Sigmoid(),
130 | )
131 | self.l2 = nn.Sequential(
132 | TargetFC(paras['target_fc2w'], paras['target_fc2b']),
133 | nn.Sigmoid(),
134 | )
135 |
136 | self.l3 = nn.Sequential(
137 | TargetFC(paras['target_fc3w'], paras['target_fc3b']),
138 | nn.Sigmoid(),
139 | )
140 |
141 | self.l4 = nn.Sequential(
142 | TargetFC(paras['target_fc4w'], paras['target_fc4b']),
143 | nn.Sigmoid(),
144 | TargetFC(paras['target_fc5w'], paras['target_fc5b']),
145 | )
146 |
147 | def forward(self, x):
148 | q = self.l1(x)
149 | # q = F.dropout(q)
150 | q = self.l2(q)
151 | q = self.l3(q)
152 | q = self.l4(q).squeeze()
153 | return q
154 |
155 |
156 | class TargetFC(nn.Module):
157 | """
158 | Fully connection operations for target net
159 |
160 | Note:
161 | Weights & biases are different for different images in a batch,
162 | thus here we use group convolution for calculating images in a batch with individual weights & biases.
163 | """
164 | def __init__(self, weight, bias):
165 | super(TargetFC, self).__init__()
166 | self.weight = weight
167 | self.bias = bias
168 |
169 | def forward(self, input_):
170 |
171 | input_re = input_.view(-1, input_.shape[0] * input_.shape[1], input_.shape[2], input_.shape[3])
172 | weight_re = self.weight.view(self.weight.shape[0] * self.weight.shape[1], self.weight.shape[2], self.weight.shape[3], self.weight.shape[4])
173 | bias_re = self.bias.view(self.bias.shape[0] * self.bias.shape[1])
174 | out = F.conv2d(input=input_re, weight=weight_re, bias=bias_re, groups=self.weight.shape[0])
175 |
176 | return out.view(input_.shape[0], self.weight.shape[1], input_.shape[2], input_.shape[3])
177 |
178 |
179 | class Bottleneck(nn.Module):
180 | expansion = 4
181 |
182 | def __init__(self, inplanes, planes, stride=1, downsample=None):
183 | super(Bottleneck, self).__init__()
184 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
185 | self.bn1 = nn.BatchNorm2d(planes)
186 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
187 | padding=1, bias=False)
188 | self.bn2 = nn.BatchNorm2d(planes)
189 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
190 | self.bn3 = nn.BatchNorm2d(planes * 4)
191 | self.relu = nn.ReLU(inplace=True)
192 | self.downsample = downsample
193 | self.stride = stride
194 |
195 | def forward(self, x):
196 | residual = x
197 |
198 | out = self.conv1(x)
199 | out = self.bn1(out)
200 | out = self.relu(out)
201 |
202 | out = self.conv2(out)
203 | out = self.bn2(out)
204 | out = self.relu(out)
205 |
206 | out = self.conv3(out)
207 | out = self.bn3(out)
208 |
209 | if self.downsample is not None:
210 | residual = self.downsample(x)
211 |
212 | out += residual
213 | out = self.relu(out)
214 |
215 | return out
216 |
217 |
218 | class ResNetBackbone(nn.Module):
219 |
220 | def __init__(self, lda_out_channels, in_chn, block, layers, num_classes=1000):
221 | super(ResNetBackbone, self).__init__()
222 | self.inplanes = 64
223 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
224 | self.bn1 = nn.BatchNorm2d(64)
225 | self.relu = nn.ReLU(inplace=True)
226 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
227 | self.layer1 = self._make_layer(block, 64, layers[0])
228 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
229 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
230 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
231 |
232 | # local distortion aware module
233 | self.lda1_pool = nn.Sequential(
234 | nn.Conv2d(256, 16, kernel_size=1, stride=1, padding=0, bias=False),
235 | nn.AvgPool2d(7, stride=7),
236 | )
237 | self.lda1_fc = nn.Linear(16 * 64, lda_out_channels)
238 |
239 | self.lda2_pool = nn.Sequential(
240 | nn.Conv2d(512, 32, kernel_size=1, stride=1, padding=0, bias=False),
241 | nn.AvgPool2d(7, stride=7),
242 | )
243 | self.lda2_fc = nn.Linear(32 * 16, lda_out_channels)
244 |
245 | self.lda3_pool = nn.Sequential(
246 | nn.Conv2d(1024, 64, kernel_size=1, stride=1, padding=0, bias=False),
247 | nn.AvgPool2d(7, stride=7),
248 | )
249 | self.lda3_fc = nn.Linear(64 * 4, lda_out_channels)
250 |
251 | self.lda4_pool = nn.AvgPool2d(7, stride=7)
252 | self.lda4_fc = nn.Linear(2048, in_chn - lda_out_channels * 3)
253 |
254 | for m in self.modules():
255 | if isinstance(m, nn.Conv2d):
256 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
257 | m.weight.data.normal_(0, math.sqrt(2. / n))
258 | elif isinstance(m, nn.BatchNorm2d):
259 | m.weight.data.fill_(1)
260 | m.bias.data.zero_()
261 |
262 | # initialize
263 | nn.init.kaiming_normal_(self.lda1_pool._modules['0'].weight.data)
264 | nn.init.kaiming_normal_(self.lda2_pool._modules['0'].weight.data)
265 | nn.init.kaiming_normal_(self.lda3_pool._modules['0'].weight.data)
266 | nn.init.kaiming_normal_(self.lda1_fc.weight.data)
267 | nn.init.kaiming_normal_(self.lda2_fc.weight.data)
268 | nn.init.kaiming_normal_(self.lda3_fc.weight.data)
269 | nn.init.kaiming_normal_(self.lda4_fc.weight.data)
270 |
271 | def _make_layer(self, block, planes, blocks, stride=1):
272 | downsample = None
273 | if stride != 1 or self.inplanes != planes * block.expansion:
274 | downsample = nn.Sequential(
275 | nn.Conv2d(self.inplanes, planes * block.expansion,
276 | kernel_size=1, stride=stride, bias=False),
277 | nn.BatchNorm2d(planes * block.expansion),
278 | )
279 |
280 | layers = []
281 | layers.append(block(self.inplanes, planes, stride, downsample))
282 | self.inplanes = planes * block.expansion
283 | for i in range(1, blocks):
284 | layers.append(block(self.inplanes, planes))
285 |
286 | return nn.Sequential(*layers)
287 |
288 | def forward(self, x):
289 | x = self.conv1(x)
290 | x = self.bn1(x)
291 | x = self.relu(x)
292 | x = self.maxpool(x)
293 | x = self.layer1(x)
294 |
295 | # the same effect as lda operation in the paper, but save much more memory
296 | lda_1 = self.lda1_fc(self.lda1_pool(x).view(x.size(0), -1))
297 | x = self.layer2(x)
298 | lda_2 = self.lda2_fc(self.lda2_pool(x).view(x.size(0), -1))
299 | x = self.layer3(x)
300 | lda_3 = self.lda3_fc(self.lda3_pool(x).view(x.size(0), -1))
301 | x = self.layer4(x)
302 | lda_4 = self.lda4_fc(self.lda4_pool(x).view(x.size(0), -1))
303 |
304 | vec = torch.cat((lda_1, lda_2, lda_3, lda_4), 1)
305 |
306 | out = {}
307 | out['hyper_in_feat'] = x
308 | out['target_in_vec'] = vec
309 |
310 | return out
311 |
312 |
313 | def resnet50_backbone(lda_out_channels, in_chn, pretrained=False, **kwargs):
314 | """Constructs a ResNet-50 model_hyper.
315 |
316 | Args:
317 | pretrained (bool): If True, returns a model_hyper pre-trained on ImageNet
318 | """
319 | model = ResNetBackbone(lda_out_channels, in_chn, Bottleneck, [3, 4, 6, 3], **kwargs)
320 | if pretrained:
321 | save_model = model_zoo.load_url(model_urls['resnet50'])
322 | model_dict = model.state_dict()
323 | state_dict = {k: v for k, v in save_model.items() if k in model_dict.keys()}
324 | model_dict.update(state_dict)
325 | model.load_state_dict(model_dict)
326 | else:
327 | model.apply(weights_init_xavier)
328 | return model
329 |
330 |
331 | def weights_init_xavier(m):
332 | classname = m.__class__.__name__
333 | # print(classname)
334 | # if isinstance(m, nn.Conv2d):
335 | if classname.find('Conv') != -1:
336 | init.kaiming_normal_(m.weight.data)
337 | elif classname.find('Linear') != -1:
338 | init.kaiming_normal_(m.weight.data)
339 | elif classname.find('BatchNorm2d') != -1:
340 | init.uniform_(m.weight.data, 1.0, 0.02)
341 | init.constant_(m.bias.data, 0.0)
342 |
--------------------------------------------------------------------------------
/folders/folders_LQ_HQ.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.utils.data as data
3 | import torchvision
4 | from PIL import Image
5 | import os
6 | import os.path
7 | import scipy.io
8 | import numpy as np
9 | import csv
10 | import random
11 | from openpyxl import load_workbook
12 | import cv2
13 | from torchvision import transforms
14 |
15 | class Kadid10kFolder(data.Dataset):
16 |
17 | def __init__(self, root, index, patch_num, patch_size=224, self_patch_num=1):
18 | self.patch_size = patch_size
19 | self.self_patch_num = self_patch_num
20 |
21 | imgname = []
22 | refimgname = []
23 | mos_all = []
24 | csv_file = os.path.join(root, 'dmos.csv')
25 | with open(csv_file) as f:
26 | reader = csv.DictReader(f)
27 | for row in reader:
28 | imgname.append(row['dist_img'])
29 | refimgname.append(row['ref_img'])
30 | mos = np.array(float(row['dmos'])).astype(np.float32)
31 | mos_all.append(mos)
32 |
33 | sample = []
34 | for i, item in enumerate(index):
35 | for aug in range(patch_num):
36 | sample.append((os.path.join(root, 'images', imgname[item]),os.path.join(root, 'images', refimgname[item]), mos_all[item]))
37 |
38 | self.samples = sample
39 | self.transform = torchvision.transforms.Compose([
40 | torchvision.transforms.ToTensor(),
41 | torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406),
42 | std=(0.229, 0.224, 0.225))
43 | ])
44 |
45 | def __getitem__(self, index):
46 | """
47 | Args:
48 | index (int): Index
49 | Returns:
50 | tuple: (LQ, HQ, target) where target is IQA values of the target LQ.
51 | """
52 | LQ_path, HQ_path, target = self.samples[index]
53 | LQ = pil_loader(LQ_path)
54 | HQ = pil_loader(HQ_path)
55 | LQ_patches, HQ_patches = [], []
56 | for _ in range(self.self_patch_num):
57 | LQ_patch, HQ_patch = getPairRandomPatch(LQ,HQ, crop_size=self.patch_size)
58 | LQ_patch, HQ_patch = getPairAugment(LQ_patch, HQ_patch)
59 |
60 | LQ_patch = self.transform(LQ_patch)
61 | HQ_patch = self.transform(HQ_patch)
62 |
63 | LQ_patches.append(LQ_patch.unsqueeze(0))
64 | HQ_patches.append(HQ_patch.unsqueeze(0))
65 | #[self_patch_num, 3, patch_size, patch_size]
66 | LQ_patches = torch.cat(LQ_patches, 0)
67 | HQ_patches = torch.cat(HQ_patches, 0)
68 |
69 | return LQ_patches, HQ_patches, target
70 |
71 | def __len__(self):
72 | length = len(self.samples)
73 | return length
74 |
75 | class LIVEFolder(data.Dataset):
76 |
77 | def __init__(self, root, index, patch_num, patch_size=224, self_patch_num=1):
78 | self.patch_size =patch_size
79 | self.self_patch_num = self_patch_num
80 |
81 | refpath = os.path.join(root, 'refimgs')
82 | refname = getFileName(refpath, '.bmp')
83 |
84 | jp2kroot = os.path.join(root, 'jp2k')
85 | jp2kname = self.getDistortionTypeFileName(jp2kroot, 227)
86 |
87 | jpegroot = os.path.join(root, 'jpeg')
88 | jpegname = self.getDistortionTypeFileName(jpegroot, 233)
89 |
90 | wnroot = os.path.join(root, 'wn')
91 | wnname = self.getDistortionTypeFileName(wnroot, 174)
92 |
93 | gblurroot = os.path.join(root, 'gblur')
94 | gblurname = self.getDistortionTypeFileName(gblurroot, 174)
95 |
96 | fastfadingroot = os.path.join(root, 'fastfading')
97 | fastfadingname = self.getDistortionTypeFileName(fastfadingroot, 174)
98 |
99 | imgpath = jp2kname + jpegname + wnname + gblurname + fastfadingname
100 |
101 | dmos = scipy.io.loadmat(os.path.join(root, 'dmos_realigned.mat'))
102 | labels = dmos['dmos_new'].astype(np.float32)
103 |
104 | orgs = dmos['orgs']
105 | refnames_all = scipy.io.loadmat(os.path.join(root, 'refnames_all.mat'))
106 | refnames_all = refnames_all['refnames_all']
107 |
108 | sample = []
109 | for i in range(0, len(index)):
110 | train_sel = (refname[index[i]] == refnames_all)
111 | train_sel = train_sel * ~orgs.astype(np.bool_)
112 | train_sel = np.where(train_sel == True)
113 | train_sel = train_sel[1].tolist()
114 | for j, item in enumerate(train_sel):
115 | for aug in range(patch_num):
116 | LQ_path = imgpath[item]
117 | HQ_path = os.path.join(root, 'refimgs', refnames_all[0][item][0])
118 | label = labels[0][item]
119 | sample.append((LQ_path, HQ_path, label))
120 |
121 | self.samples = sample
122 | self.transform = torchvision.transforms.Compose([
123 | torchvision.transforms.ToTensor(),
124 | torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406),
125 | std=(0.229, 0.224, 0.225))
126 | ])
127 |
128 | def __getitem__(self, index):
129 | """
130 | Args:
131 | index (int): Index
132 | Returns:
133 | tuple: (LQ, HQ, target) where target is IQA values of the target LQ.
134 | """
135 | LQ_path, HQ_path, target = self.samples[index]
136 | LQ = pil_loader(LQ_path)
137 | HQ = pil_loader(HQ_path)
138 | LQ_patches, HQ_patches = [], []
139 | for _ in range(self.self_patch_num):
140 | LQ_patch, HQ_patch = getPairRandomPatch(LQ,HQ, crop_size=self.patch_size)
141 | LQ_patch, HQ_patch = getPairAugment(LQ_patch, HQ_patch)
142 |
143 | LQ_patch = self.transform(LQ_patch)
144 | HQ_patch = self.transform(HQ_patch)
145 |
146 | LQ_patches.append(LQ_patch.unsqueeze(0))
147 | HQ_patches.append(HQ_patch.unsqueeze(0))
148 | LQ_patches = torch.cat(LQ_patches, 0)
149 | HQ_patches = torch.cat(HQ_patches, 0)
150 |
151 | return LQ_patches, HQ_patches, target
152 |
153 | def __len__(self):
154 | length = len(self.samples)
155 | return length
156 |
157 | def getDistortionTypeFileName(self, path, num):
158 | filename = []
159 | index = 1
160 | for i in range(0, num):
161 | name = '{:0>3d}{}'.format(index, '.bmp')
162 | filename.append(os.path.join(path, name))
163 | index = index + 1
164 | return filename
165 |
166 | class CSIQFolder(data.Dataset):
167 |
168 | def __init__(self, root, index, patch_num, patch_size =224, self_patch_num=1):
169 | self.patch_size =patch_size
170 | self.self_patch_num = self_patch_num
171 |
172 | refpath = os.path.join(root, 'src_imgs')
173 | refname = getFileName(refpath,'.png')
174 | txtpath = os.path.join(root, 'csiq_label.txt')
175 | fh = open(txtpath, 'r')
176 | imgnames = []
177 | target = []
178 | refnames_all = []
179 | for line in fh:
180 | line = line.split('\n')
181 | words = line[0].split()
182 | imgnames.append((words[0]))
183 | target.append(words[1])
184 | ref_temp = words[0].split(".")
185 | refnames_all.append(ref_temp[0] + '.' + ref_temp[-1])
186 |
187 | labels = np.array(target).astype(np.float32)
188 | refnames_all = np.array(refnames_all)
189 |
190 | sample = []
191 |
192 | for i, item in enumerate(index):
193 | train_sel = (refname[index[i]] == refnames_all)
194 | train_sel = np.where(train_sel == True)
195 | train_sel = train_sel[0].tolist()
196 | for j, item in enumerate(train_sel):
197 | for aug in range(patch_num):
198 | LQ_path = os.path.join(root, 'dst_imgs_all', imgnames[item])
199 | HQ_path = os.path.join(root, 'src_imgs', refnames_all[item])
200 | label = labels[item]
201 | sample.append((LQ_path, HQ_path, label))
202 | self.samples = sample
203 | self.transform = torchvision.transforms.Compose([
204 | torchvision.transforms.ToTensor(),
205 | torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406),
206 | std=(0.229, 0.224, 0.225))
207 | ])
208 |
209 | def __getitem__(self, index):
210 | """
211 | Args:
212 | index (int): Index
213 | Returns:
214 | tuple: (LQ, HQ, target) where target is IQA values of the target LQ.
215 | """
216 | LQ_path, HQ_path, target = self.samples[index]
217 | LQ = pil_loader(LQ_path)
218 | HQ = pil_loader(HQ_path)
219 | LQ_patches, HQ_patches = [], []
220 | for _ in range(self.self_patch_num):
221 | LQ_patch, HQ_patch = getPairRandomPatch(LQ,HQ, crop_size=self.patch_size)
222 | LQ_patch, HQ_patch = getPairAugment(LQ_patch, HQ_patch)
223 |
224 | LQ_patch = self.transform(LQ_patch)
225 | HQ_patch = self.transform(HQ_patch)
226 |
227 | LQ_patches.append(LQ_patch.unsqueeze(0))
228 | HQ_patches.append(HQ_patch.unsqueeze(0))
229 | LQ_patches = torch.cat(LQ_patches, 0)
230 | HQ_patches = torch.cat(HQ_patches, 0)
231 |
232 | return LQ_patches, HQ_patches, target
233 |
234 | def __len__(self):
235 | length = len(self.samples)
236 | return length
237 |
238 | class TID2013Folder(data.Dataset):
239 |
240 | def __init__(self, root, index, patch_num, patch_size=224, self_patch_num=1):
241 | self.patch_size =patch_size
242 | self.self_patch_num = self_patch_num
243 |
244 | refpath = os.path.join(root, 'reference_images')
245 | refname = self._getTIDFileName(refpath,'.bmp.BMP')
246 | txtpath = os.path.join(root, 'mos_with_names.txt')
247 | fh = open(txtpath, 'r')
248 | imgnames = []
249 | target = []
250 | refnames_all = []
251 | for line in fh:
252 | line = line.split('\n')
253 | words = line[0].split()
254 | imgnames.append((words[1]))
255 | target.append(words[0])
256 | ref_temp = words[1].split("_")
257 | refnames_all.append(ref_temp[0][1:])
258 | labels = np.array(target).astype(np.float32)
259 | refnames_all = np.array(refnames_all)
260 |
261 | sample = []
262 | for i, item in enumerate(index):
263 | train_sel = (refname[index[i]] == refnames_all)
264 | train_sel = np.where(train_sel == True)
265 | train_sel = train_sel[0].tolist()
266 | for j, item in enumerate(train_sel):
267 | for aug in range(patch_num):
268 | LQ_path = os.path.join(root, 'distorted_images', imgnames[item])
269 | HQ_name = 'I' + imgnames[item].split("_")[0][1:] + '.BMP'
270 | HQ_path = os.path.join(refpath, HQ_name)
271 | label = labels[item]
272 | sample.append((LQ_path, HQ_path, label))
273 | self.samples = sample
274 | self.transform = self.transform = torchvision.transforms.Compose([
275 | torchvision.transforms.ToTensor(),
276 | torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406),
277 | std=(0.229, 0.224, 0.225))
278 | ])
279 |
280 | def _getTIDFileName(self, path, suffix):
281 | filename = []
282 | f_list = os.listdir(path)
283 | for i in f_list:
284 | if suffix.find(os.path.splitext(i)[1]) != -1:
285 | filename.append(i[1:3])
286 | return filename
287 |
288 | def __getitem__(self, index):
289 | """
290 | Args:
291 | index (int): Index
292 | Returns:
293 | tuple: (LQ, HQ, target) where target is IQA values of the target LQ.
294 | """
295 | LQ_path, HQ_path, target = self.samples[index]
296 | LQ = pil_loader(LQ_path)
297 | HQ = pil_loader(HQ_path)
298 | LQ_patches, HQ_patches = [], []
299 | for _ in range(self.self_patch_num):
300 | LQ_patch, HQ_patch = getPairRandomPatch(LQ,HQ, crop_size=self.patch_size)
301 | LQ_patch, HQ_patch = getPairAugment(LQ_patch, HQ_patch)
302 |
303 | LQ_patch = self.transform(LQ_patch)
304 | HQ_patch = self.transform(HQ_patch)
305 |
306 | LQ_patches.append(LQ_patch.unsqueeze(0))
307 | HQ_patches.append(HQ_patch.unsqueeze(0))
308 | LQ_patches = torch.cat(LQ_patches, 0)
309 | HQ_patches = torch.cat(HQ_patches, 0)
310 |
311 | return LQ_patches, HQ_patches, target
312 |
313 | def __len__(self):
314 | length = len(self.samples)
315 | return length
316 |
317 |
318 | def getFileName(path, suffix):
319 | filename = []
320 | f_list = os.listdir(path)
321 | for i in f_list:
322 | if os.path.splitext(i)[1] == suffix:
323 | filename.append(i)
324 | return filename
325 |
326 |
327 | def getPairRandomPatch(img1, img2, crop_size=512):
328 | (iw,ih) = img1.size
329 | # print(ih,iw)
330 |
331 | ip = int(crop_size)
332 |
333 | ix = random.randrange(0, iw - ip + 1)
334 | iy = random.randrange(0, ih - ip + 1)
335 |
336 |
337 | img1_patch = img1.crop((ix, iy, ix+ip, iy+ip))#左上右下
338 | img2_patch = img2.crop((ix, iy, ix+ip, iy+ip))#左上右下
339 |
340 | return img1_patch, img2_patch
341 |
342 | def getPairAugment(img1, img2, hflip=True, vflip=True, rot=True):
343 | hflip = hflip and random.random() < 0.5
344 | vflip = vflip and random.random() < 0.5
345 | rot180 = rot and random.random() < 0.5
346 |
347 | if hflip:
348 | img1 = img1.transpose(Image.FLIP_TOP_BOTTOM)
349 | img2 = img2.transpose(Image.FLIP_TOP_BOTTOM)
350 | if vflip:
351 | img1 = img1.transpose(Image.FLIP_LEFT_RIGHT)
352 | img2 = img2.transpose(Image.FLIP_LEFT_RIGHT)
353 | if rot180:
354 | img1 = img1.transpose(Image.ROTATE_180)
355 | img2 = img2.transpose(Image.ROTATE_180)
356 |
357 | return img1, img2
358 |
359 |
360 | def getSelfPatch(img, patch_size, patch_num, is_random=True):
361 | (iw,ih) = img.size
362 | patches = []
363 | for i in range(patch_num):
364 | if is_random:
365 | ix = random.randrange(0, iw - patch_size + 1)
366 | iy = random.randrange(0, ih - patch_size + 1)
367 | else:ix,iy=(iw - patch_size + 1)//2,(ih - patch_size + 1)//2
368 |
369 | # patch = img[iy:iy + lr_size, ix:ix + lr_size, :]#上下左右
370 | patch = img.crop((ix, iy, ix+patch_size, iy+patch_size))#左上右下
371 | patches.append(patch)
372 |
373 | return patches
374 |
375 |
376 | def pil_loader(path):
377 | with open(path, 'rb') as f:
378 | img = Image.open(f)
379 | return img.convert('RGB')
380 |
--------------------------------------------------------------------------------
/models/DistillationIQA.py:
--------------------------------------------------------------------------------
1 | import torch as torch
2 | import torch.nn as nn
3 | from torch.nn import functional as F
4 | from torch.nn import init
5 | import math
6 | import torch.utils.model_zoo as model_zoo
7 | # from SemanticResNet50 import ResNetBackbone, Bottleneck
8 | # from models.ResNet50_MLP import ResNetBackbone
9 | # from models.MLP_return_inner_feature import MLPMixer
10 |
11 | from functools import partial
12 | from einops.layers.torch import Rearrange, Reduce
13 |
14 | #ResNet
15 | model_urls = {
16 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
17 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
18 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
19 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
20 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
21 | }
22 |
23 | class PatchEmbed(nn.Module):
24 | """ Feature to Patch Embedding
25 | input : N C H W
26 | output: N num_patch P^2*C
27 | """
28 |
29 | def __init__(self, patch_size=7, in_channels=2048):
30 | super().__init__()
31 | self.patch_size = patch_size
32 | self.dim = self.patch_size ** 2 * in_channels
33 |
34 | def forward(self, x):
35 | N, C, H, W = ori_shape = x.shape
36 |
37 | p = self.patch_size
38 | num_patches = (H // p) * (W // p)
39 |
40 | fold_out = torch.nn.functional.unfold(x, (p, p), stride=p) # B, num_dim, num_patch
41 | out = fold_out.permute(0, 2, 1 )# B, num_patch, num_dim
42 |
43 | return out
44 |
45 | class Bottleneck(nn.Module):
46 | expansion = 4
47 |
48 | def __init__(self, inplanes, planes, stride=1, downsample=None):
49 | super(Bottleneck, self).__init__()
50 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
51 | self.bn1 = nn.BatchNorm2d(planes)
52 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
53 | padding=1, bias=False)
54 | self.bn2 = nn.BatchNorm2d(planes)
55 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
56 | self.bn3 = nn.BatchNorm2d(planes * 4)
57 | self.relu = nn.ReLU(inplace=True)
58 | self.downsample = downsample
59 | self.stride = stride
60 |
61 | def forward(self, x):
62 | residual = x
63 |
64 | out = self.conv1(x)
65 | out = self.bn1(out)
66 | out = self.relu(out)
67 |
68 | out = self.conv2(out)
69 | out = self.bn2(out)
70 | out = self.relu(out)
71 |
72 | out = self.conv3(out)
73 | out = self.bn3(out)
74 |
75 | if self.downsample is not None:
76 | residual = self.downsample(x)
77 |
78 | out += residual
79 | out = self.relu(out)
80 |
81 | return out
82 |
83 | class ResNetBackbone(nn.Module):
84 |
85 | # def __init__(self, outc=176, block=Bottleneck, layers=[3, 4, 6, 3], num_classes=1000, pretrained=True):
86 | def __init__(self, outc=2048, block=Bottleneck, layers=[3, 4, 6, 3], pretrained=True):
87 | super(ResNetBackbone, self).__init__()
88 | self.pretrained = pretrained
89 | self.inplanes = 64
90 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
91 | self.bn1 = nn.BatchNorm2d(64)
92 | self.relu = nn.ReLU(inplace=True)
93 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
94 | self.layer1 = self._make_layer(block, 64, layers[0])
95 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
96 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
97 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
98 |
99 | self.lda_out_channels = int(outc // 4)
100 |
101 | # local distortion aware module
102 | self.lda1_pool = nn.Sequential(
103 | nn.Conv2d(256, 16, kernel_size=1, stride=1, padding=0, bias=False),
104 | nn.AvgPool2d(7, stride=7),
105 | )
106 | self.lda1_fc = nn.Linear(16 * 64, self.lda_out_channels)
107 |
108 | self.lda2_pool = nn.Sequential(
109 | nn.Conv2d(512, 32, kernel_size=1, stride=1, padding=0, bias=False),
110 | nn.AvgPool2d(7, stride=7),
111 | )
112 | self.lda2_fc = nn.Linear(32 * 16, self.lda_out_channels)
113 |
114 | self.lda3_pool = nn.Sequential(
115 | nn.Conv2d(1024, 64, kernel_size=1, stride=1, padding=0, bias=False),
116 | nn.AvgPool2d(7, stride=7),
117 | )
118 | self.lda3_fc = nn.Linear(64 * 4, self.lda_out_channels)
119 |
120 | self.lda4_pool = nn.AvgPool2d(7, stride=7)
121 | self.lda4_fc = nn.Linear(2048, self.lda_out_channels)
122 |
123 | for m in self.modules():
124 | if isinstance(m, nn.Conv2d):
125 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
126 | m.weight.data.normal_(0, math.sqrt(2. / n))
127 | elif isinstance(m, nn.BatchNorm2d):
128 | m.weight.data.fill_(1)
129 | m.bias.data.zero_()
130 | if self.pretrained:
131 | self.load_resnet50_backbone()
132 |
133 | def _make_layer(self, block, planes, blocks, stride=1):
134 | downsample = None
135 | if stride != 1 or self.inplanes != planes * block.expansion:
136 | downsample = nn.Sequential(
137 | nn.Conv2d(self.inplanes, planes * block.expansion,
138 | kernel_size=1, stride=stride, bias=False),
139 | nn.BatchNorm2d(planes * block.expansion),
140 | )
141 |
142 | layers = []
143 | layers.append(block(self.inplanes, planes, stride, downsample))
144 | self.inplanes = planes * block.expansion
145 | for i in range(1, blocks):
146 | layers.append(block(self.inplanes, planes))
147 |
148 | return nn.Sequential(*layers)
149 |
150 | def cal_params(self):
151 | params = list(self.parameters())
152 | k = 0
153 | for i in params:
154 | l = 1
155 | for j in i.size():
156 | l *= j
157 | k = k + l
158 | print("Total parameters is :" + str(k))
159 |
160 | def forward(self, x):
161 |
162 | x = self.conv1(x)
163 | x = self.bn1(x)
164 | x = self.relu(x)
165 | x = self.maxpool(x)
166 | x = self.layer1(x)
167 | # lda_1 = self.lda1_fc(self.lda1_pool(x).view(x.size(0), -1))
168 | lda_1 = x #[b, 256, 56, 56]
169 | x = self.layer2(x)
170 | # lda_2 = self.lda2_fc(self.lda2_pool(x).view(x.size(0), -1))
171 | lda_2 = x #[b, 512, 28, 28]
172 | x = self.layer3(x)
173 | # lda_3 = self.lda3_fc(self.lda3_pool(x).view(x.size(0), -1))
174 | lda_3 = x #[b, 1024, 14, 14]
175 | x = self.layer4(x)
176 | # lda_4 = self.lda4_fc(self.lda4_pool(x).view(x.size(0), -1))
177 | lda_4 = x #[b, 2048, 7, 7]
178 | # return x #[b, 2048, 7, 7]
179 | return [lda_1, lda_2, lda_3, lda_4]
180 |
181 | def load_resnet50_backbone(self):
182 | """Constructs a ResNet-50 model_hyper.
183 | Args:
184 | pretrained (bool): If True, returns a model_hyper pre-trained on ImageNet
185 | """
186 | save_model = model_zoo.load_url(model_urls['resnet50'])
187 | model_dict = self.state_dict()
188 | state_dict = {k: v for k, v in save_model.items() if k in model_dict.keys()}
189 | model_dict.update(state_dict)
190 | self.load_state_dict(model_dict)
191 |
192 | def weights_init_xavier(m):
193 | classname = m.__class__.__name__
194 | # print(classname)
195 | # if isinstance(m, nn.Conv2d):
196 | if classname.find('Conv') != -1:
197 | init.kaiming_normal_(m.weight.data)
198 | elif classname.find('Linear') != -1:
199 | init.kaiming_normal_(m.weight.data)
200 | elif classname.find('BatchNorm2d') != -1:
201 | init.uniform_(m.weight.data, 1.0, 0.02)
202 | init.constant_(m.bias.data, 0.0)
203 |
204 | #MLP
205 | class PreNormResidual(nn.Module):
206 | def __init__(self, dim, fn):
207 | super().__init__()
208 | self.fn = fn
209 | self.norm = nn.LayerNorm(dim)
210 |
211 | def forward(self, x):
212 | return self.fn(self.norm(x)) + x
213 |
214 | class MLPMixer(nn.Module):
215 | def __init__(self, image_size, channels, patch_size, dim, depth, expansion_factor = 4, dropout = 0.):
216 | super().__init__()
217 | assert (image_size % patch_size) == 0, 'image must be divisible by patch size'
218 | self.num_patches = (image_size // patch_size) ** 2
219 | self.chan_first, self.chan_last = partial(nn.Conv1d, kernel_size = 1), nn.Linear
220 |
221 | self.mlp = nn.Sequential(
222 | Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
223 | nn.Linear((patch_size ** 2) * channels, dim),
224 | *[nn.Sequential(
225 | PreNormResidual(dim, self.FeedForward(self.num_patches, expansion_factor, dropout, self.chan_first)),
226 | PreNormResidual(dim, self.FeedForward(dim, expansion_factor, dropout, self.chan_last))
227 | ) for _ in range(depth)],
228 | nn.LayerNorm(dim),
229 | Reduce('b n c -> b c', 'mean'),
230 | # nn.Linear(dim, num_classes)
231 | )
232 | # print(self.mlp)
233 |
234 | def FeedForward(self, dim, expansion_factor = 4, dropout = 0., dense = nn.Linear):
235 | return nn.Sequential(
236 | dense(dim, dim * expansion_factor),
237 | nn.GELU(),
238 | nn.Dropout(dropout),
239 | dense(dim * expansion_factor, dim),
240 | nn.Dropout(dropout)
241 | )
242 |
243 | def forward(self, x, distillation_layer_num=None):
244 | # [3, 256*self_patch_num, 7, 7]
245 | mlp_inner_feature = []
246 | layer_idx = 0
247 | for mlp_single in self.mlp:
248 | x = mlp_single(x)
249 | mlp_inner_feature.append(x)
250 | if distillation_layer_num:
251 | return x, mlp_inner_feature[-distillation_layer_num-2:-2]
252 | else:
253 | return x, mlp_inner_feature
254 |
255 |
256 | def initialize_weights(net_l, scale=1):
257 | if not isinstance(net_l, list):
258 | net_l = [net_l]
259 | for net in net_l:
260 | for m in net.modules():
261 | if isinstance(m, nn.Conv2d):
262 | init.kaiming_normal_(m.weight, a=0, mode='fan_in')
263 | m.weight.data *= scale # for residual block
264 | if m.bias is not None:
265 | m.bias.data.zero_()
266 | elif isinstance(m, nn.Linear):
267 | init.kaiming_normal_(m.weight, a=0, mode='fan_in')
268 | m.weight.data *= scale
269 | if m.bias is not None:
270 | m.bias.data.zero_()
271 | elif isinstance(m, nn.BatchNorm2d):
272 | init.constant_(m.weight, 1)
273 | init.constant_(m.bias.data, 0.0)
274 | else:
275 | pass
276 |
277 | class RegressionFCNet(nn.Module):
278 | """
279 | Target network for quality prediction.
280 | """
281 | def __init__(self):
282 | super(RegressionFCNet, self).__init__()
283 | self.target_in_size=512
284 | self.target_fc1_size=256
285 |
286 | self.sigmoid = nn.Sigmoid()
287 | self.l1 = nn.Linear(self.target_in_size, self.target_fc1_size)
288 | self.relu1 = nn.PReLU()
289 | self.drop1 = nn.Dropout(0.5)
290 | self.bn1 = nn.BatchNorm1d(256)
291 |
292 | self.l2 = nn.Linear(self.target_fc1_size, 1)
293 |
294 |
295 | def forward(self, x):
296 | q = self.l1(x)
297 | q = self.l2(q).squeeze()
298 | return q
299 |
300 | class DistillationIQANet(nn.Module):
301 | def __init__(self, self_patch_num=10, lda_channel=64, encode_decode_channel=64, MLP_depth=9, distillation_layer=9):
302 | super(DistillationIQANet, self).__init__()
303 |
304 | self.self_patch_num = self_patch_num
305 | self.lda_channel = lda_channel
306 | self.encode_decode_channel = encode_decode_channel
307 | self.MLP_depth = MLP_depth
308 | self.distillation_layer_num = distillation_layer
309 |
310 | self.feature_extractor = ResNetBackbone()
311 | for param in self.feature_extractor.parameters():
312 | param.requires_grad = False
313 |
314 | self.lda1_process = nn.Sequential(nn.Conv2d(256, self.lda_channel, kernel_size=1, stride=1, padding=0), nn.AdaptiveAvgPool2d((7, 7)))
315 | self.lda2_process = nn.Sequential(nn.Conv2d(512, self.lda_channel, kernel_size=1, stride=1, padding=0), nn.AdaptiveAvgPool2d((7, 7)))
316 | self.lda3_process = nn.Sequential(nn.Conv2d(1024, self.lda_channel, kernel_size=1, stride=1, padding=0), nn.AdaptiveAvgPool2d((7, 7)))
317 | self.lda4_process = nn.Sequential(nn.Conv2d(2048, self.lda_channel, kernel_size=1, stride=1, padding=0), nn.AdaptiveAvgPool2d((7, 7)))
318 |
319 | self.lda_process = [self.lda1_process, self.lda2_process, self.lda3_process, self.lda4_process]
320 |
321 | self.MLP_encoder_diff = MLPMixer(image_size = 7, channels = self.self_patch_num*self.lda_channel*4, patch_size = 1, dim = self.encode_decode_channel*4, depth = self.MLP_depth*2)
322 | self.MLP_encoder_lq = MLPMixer(image_size = 7, channels = self.self_patch_num*self.lda_channel*4, patch_size = 1, dim = self.encode_decode_channel*4, depth = self.MLP_depth)
323 |
324 | self.regressor = RegressionFCNet()
325 |
326 | initialize_weights(self.MLP_encoder_diff,0.1)
327 | initialize_weights(self.MLP_encoder_lq,0.1)
328 | initialize_weights(self.regressor,0.1)
329 |
330 | initialize_weights(self.lda1_process,0.1)
331 | initialize_weights(self.lda2_process,0.1)
332 | initialize_weights(self.lda3_process,0.1)
333 | initialize_weights(self.lda4_process,0.1)
334 |
335 | def forward(self, LQ_patches, refHQ_patches):
336 | device = LQ_patches.device
337 | b, p, c, h, w = LQ_patches.shape
338 | LQ_patches_reshape = LQ_patches.view(b*p, c, h, w)
339 | refHQ_patches_reshape = refHQ_patches.view(b*p, c, h, w)
340 |
341 | # [b*p, 256, 56, 56], [b*p, 512, 28, 28], [b*p, 1024, 14, 14], [b*p, 2048, 7, 7]
342 | lq_lda_features = self.feature_extractor(LQ_patches_reshape)
343 | refHQ_lda_features = self.feature_extractor(refHQ_patches_reshape)
344 |
345 | # encode_diff_feature, encode_lq_feature, feature = [], [], []
346 | multi_scale_diff_feature, multi_scale_lq_feature, feature = [], [], []
347 | for lq_lda_feature, refHQ_lda_feature, lda_process in zip(lq_lda_features, refHQ_lda_features, self.lda_process):
348 | # [b, p, 64, 7, 7]
349 | lq_lda_feature = lda_process(lq_lda_feature).view(b, -1, 7, 7)
350 | refHQ_lda_feature = lda_process(refHQ_lda_feature).view(b, -1, 7, 7)
351 | diff_lda_feature = refHQ_lda_feature - lq_lda_feature
352 |
353 |
354 | multi_scale_diff_feature.append(diff_lda_feature)
355 | multi_scale_lq_feature.append(lq_lda_feature)
356 |
357 | multi_scale_lq_feature = torch.cat(multi_scale_lq_feature, 1).to(device)
358 | multi_scale_diff_feature = torch.cat(multi_scale_diff_feature, 1).to(device)
359 | encode_lq_feature, encode_lq_inner_feature = self.MLP_encoder_lq(multi_scale_lq_feature, self.distillation_layer_num)
360 | encode_diff_feature, encode_diff_inner_feature = self.MLP_encoder_diff(multi_scale_diff_feature, self.distillation_layer_num)
361 | feature = torch.cat((encode_lq_feature, encode_diff_feature), 1)
362 |
363 | pred = self.regressor(feature)
364 | return encode_diff_inner_feature, encode_lq_inner_feature, pred
365 |
366 | def _load_state_dict(self, state_dict, strict=True):
367 | own_state = self.state_dict()
368 | for name, param in state_dict.items():
369 | if name in own_state:
370 | if isinstance(param, nn.Parameter):
371 | param = param.data
372 | try:
373 | own_state[name].copy_(param)
374 | except Exception:
375 | if name.find('tail') >= 0:
376 | print('Replace pre-trained upsampler to new one...')
377 | else:
378 | raise RuntimeError('While copying the parameter named {}, '
379 | 'whose dimensions in the model are {} and '
380 | 'whose dimensions in the checkpoint are {}.'
381 | .format(name, own_state[name].size(), param.size()))
382 | elif strict:
383 | if name.find('tail') == -1:
384 | raise KeyError('unexpected key "{}" in state_dict'
385 | .format(name))
386 |
387 | if strict:
388 | missing = set(own_state.keys()) - set(state_dict.keys())
389 | if len(missing) > 0:
390 | raise KeyError('missing keys in state_dict: "{}"'.format(missing))
391 |
392 |
393 |
394 | if __name__ == "__main__":
395 | net = ResNetBackbone()
396 | x = torch.rand((1,3,224,224))
397 | y = net(x)
398 | print(y.shape)
399 |
400 | model = MLPMixer(image_size = 7, channels = 1280, patch_size = 1, dim = 512, depth = 12)
401 | img = torch.randn(96, 256, 7, 7)
402 | pred = model(img) # (1, 1000)
403 | print(pred.shape)
404 |
405 | m = DistillationIQANet()
406 | lq = torch.rand((3,10,3,224,224))
407 | hq = torch.rand((3,10,3,224,224))
408 | encode_diff_feature, encode_lq_feature, pred = m(lq, hq)
409 | print(pred.shape)
410 |
--------------------------------------------------------------------------------
/folders/folders_LQ.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.utils.data as data
3 | import torchvision
4 | from PIL import Image
5 | import os
6 | import os.path
7 | import scipy.io
8 | import numpy as np
9 | import csv
10 | import random
11 | from openpyxl import load_workbook
12 | import cv2
13 | from torchvision import transforms
14 |
15 | class Kadid10kFolder(data.Dataset):
16 |
17 | def __init__(self, root, index, transform, patch_num, patch_size=224, self_patch_num=1, use_L=False):
18 | self.patch_size = patch_size
19 | self.self_patch_num = self_patch_num
20 | self.use_L = use_L
21 | self.transform = transform
22 |
23 | imgname = []
24 | refimgname = []
25 | mos_all = []
26 | csv_file = os.path.join(root, 'dmos.csv')
27 | with open(csv_file) as f:
28 | reader = csv.DictReader(f)
29 | for row in reader:
30 | imgname.append(row['dist_img'])
31 | refimgname.append(row['ref_img'])
32 | mos = np.array(float(row['dmos'])).astype(np.float32)
33 | mos_all.append(mos)
34 |
35 | sample = []
36 | for i, item in enumerate(index):
37 | for aug in range(patch_num):
38 | sample.append((os.path.join(root, 'images', imgname[item]),os.path.join(root, 'images', refimgname[item]), mos_all[item]))
39 |
40 | self.samples = sample
41 |
42 |
43 | def __getitem__(self, index):
44 | """
45 | Args:
46 | index (int): Index
47 | Returns:
48 | tuple: (LQ, HQ, target) where target is IQA values of the target LQ.
49 | """
50 | LQ_path, HQ_path, target = self.samples[index]
51 |
52 | LQ = pil_loader(LQ_path, self.use_L)
53 | LQ_patches = []
54 | for _ in range(self.self_patch_num):
55 | LQ_patch = self.transform(LQ)
56 | LQ_patches.append(LQ_patch.unsqueeze(0))
57 | #[self_patch_num, 3, patch_size, patch_size]
58 | LQ_patches = torch.cat(LQ_patches, 0)
59 |
60 | return LQ_patches, target
61 |
62 | def __len__(self):
63 | length = len(self.samples)
64 | return length
65 |
66 | class LIVEFolder(data.Dataset):
67 |
68 | def __init__(self, root, index, transform, patch_num, patch_size=224, self_patch_num=1, use_L=False):
69 | self.patch_size =patch_size
70 | self.self_patch_num = self_patch_num
71 | self.transform = transform
72 | self.use_L = use_L
73 |
74 | refpath = os.path.join(root, 'refimgs')
75 | refname = getFileName(refpath, '.bmp')
76 |
77 | jp2kroot = os.path.join(root, 'jp2k')
78 | jp2kname = self.getDistortionTypeFileName(jp2kroot, 227)
79 |
80 | jpegroot = os.path.join(root, 'jpeg')
81 | jpegname = self.getDistortionTypeFileName(jpegroot, 233)
82 |
83 | wnroot = os.path.join(root, 'wn')
84 | wnname = self.getDistortionTypeFileName(wnroot, 174)
85 |
86 | gblurroot = os.path.join(root, 'gblur')
87 | gblurname = self.getDistortionTypeFileName(gblurroot, 174)
88 |
89 | fastfadingroot = os.path.join(root, 'fastfading')
90 | fastfadingname = self.getDistortionTypeFileName(fastfadingroot, 174)
91 |
92 | imgpath = jp2kname + jpegname + wnname + gblurname + fastfadingname
93 |
94 | dmos = scipy.io.loadmat(os.path.join(root, 'dmos_realigned.mat'))
95 | labels = dmos['dmos_new'].astype(np.float32)
96 |
97 | orgs = dmos['orgs']
98 | refnames_all = scipy.io.loadmat(os.path.join(root, 'refnames_all.mat'))
99 | refnames_all = refnames_all['refnames_all']
100 |
101 | sample = []
102 | for i in range(0, len(index)):
103 | train_sel = (refname[index[i]] == refnames_all)
104 | train_sel = train_sel * ~orgs.astype(np.bool_)
105 | train_sel = np.where(train_sel == True)
106 | train_sel = train_sel[1].tolist()
107 | for j, item in enumerate(train_sel):
108 | for aug in range(patch_num):
109 | LQ_path = imgpath[item]
110 | HQ_path = os.path.join(root, 'refimgs', refnames_all[0][item][0])
111 | label = labels[0][item]
112 | sample.append((LQ_path, HQ_path, label))
113 |
114 | self.samples = sample
115 |
116 |
117 | def __getitem__(self, index):
118 | """
119 | Args:
120 | index (int): Index
121 | Returns:
122 | tuple: (LQ, HQ, target) where target is IQA values of the target LQ.
123 | """
124 | LQ_path, HQ_path, target = self.samples[index]
125 | LQ = pil_loader(LQ_path, self.use_L)
126 | LQ_patches = []
127 | for _ in range(self.self_patch_num):
128 | LQ_patch = self.transform(LQ)
129 | LQ_patches.append(LQ_patch.unsqueeze(0))
130 | LQ_patches = torch.cat(LQ_patches, 0)
131 |
132 | return LQ_patches, target
133 |
134 | def __len__(self):
135 | length = len(self.samples)
136 | return length
137 |
138 | def getDistortionTypeFileName(self, path, num):
139 | filename = []
140 | index = 1
141 | for i in range(0, num):
142 | name = '{:0>3d}{}'.format(index, '.bmp')
143 | filename.append(os.path.join(path, name))
144 | index = index + 1
145 | return filename
146 |
147 | class CSIQFolder(data.Dataset):
148 |
149 | def __init__(self, root, index, transform, patch_num, patch_size =224, self_patch_num=1, use_L=False):
150 | self.patch_size =patch_size
151 | self.self_patch_num = self_patch_num
152 | self.transform = transform
153 | self.use_L = use_L
154 |
155 | refpath = os.path.join(root, 'src_imgs')
156 | refname = getFileName(refpath,'.png')
157 | txtpath = os.path.join(root, 'csiq_label.txt')
158 | fh = open(txtpath, 'r')
159 | imgnames = []
160 | target = []
161 | refnames_all = []
162 | for line in fh:
163 | line = line.split('\n')
164 | words = line[0].split()
165 | imgnames.append((words[0]))
166 | target.append(words[1])
167 | ref_temp = words[0].split(".")
168 | refnames_all.append(ref_temp[0] + '.' + ref_temp[-1])
169 |
170 | labels = np.array(target).astype(np.float32)
171 | refnames_all = np.array(refnames_all)
172 |
173 | sample = []
174 |
175 | for i, item in enumerate(index):
176 | train_sel = (refname[index[i]] == refnames_all)
177 | train_sel = np.where(train_sel == True)
178 | train_sel = train_sel[0].tolist()
179 | for j, item in enumerate(train_sel):
180 | for aug in range(patch_num):
181 | LQ_path = os.path.join(root, 'dst_imgs_all', imgnames[item])
182 | HQ_path = os.path.join(root, 'src_imgs', refnames_all[item])
183 | label = labels[item]
184 | sample.append((LQ_path, HQ_path, label))
185 | self.samples = sample
186 |
187 |
188 | def __getitem__(self, index):
189 | """
190 | Args:
191 | index (int): Index
192 | Returns:
193 | tuple: (LQ, HQ, target) where target is IQA values of the target LQ.
194 | """
195 | LQ_path, HQ_path, target = self.samples[index]
196 | LQ = pil_loader(LQ_path, self.use_L)
197 | LQ_patches = []
198 | for _ in range(self.self_patch_num):
199 | LQ_patch = self.transform(LQ)
200 | LQ_patches.append(LQ_patch.unsqueeze(0))
201 | LQ_patches = torch.cat(LQ_patches, 0)
202 |
203 | return LQ_patches, target
204 |
205 | def __len__(self):
206 | length = len(self.samples)
207 | return length
208 |
209 | class TID2013Folder(data.Dataset):
210 |
211 | def __init__(self, root, index, transform, patch_num, patch_size=224, self_patch_num=1, use_L=False):
212 | self.patch_size =patch_size
213 | self.self_patch_num = self_patch_num
214 | self.transform = transform
215 | self.use_L = use_L
216 |
217 | refpath = os.path.join(root, 'reference_images')
218 | refname = self._getTIDFileName(refpath,'.bmp.BMP')
219 | txtpath = os.path.join(root, 'mos_with_names.txt')
220 | fh = open(txtpath, 'r')
221 | imgnames = []
222 | target = []
223 | refnames_all = []
224 | for line in fh:
225 | line = line.split('\n')
226 | words = line[0].split()
227 | imgnames.append((words[1]))
228 | target.append(words[0])
229 | ref_temp = words[1].split("_")
230 | refnames_all.append(ref_temp[0][1:])
231 | labels = np.array(target).astype(np.float32)
232 | refnames_all = np.array(refnames_all)
233 |
234 | sample = []
235 | for i, item in enumerate(index):
236 | train_sel = (refname[index[i]] == refnames_all)
237 | train_sel = np.where(train_sel == True)
238 | train_sel = train_sel[0].tolist()
239 | for j, item in enumerate(train_sel):
240 | for aug in range(patch_num):
241 | LQ_path = os.path.join(root, 'distorted_images', imgnames[item])
242 | HQ_name = 'I' + imgnames[item].split("_")[0][1:] + '.BMP'
243 | HQ_path = os.path.join(refpath, HQ_name)
244 | label = labels[item]
245 | sample.append((LQ_path, HQ_path, label))
246 | self.samples = sample
247 |
248 |
249 | def _getTIDFileName(self, path, suffix):
250 | filename = []
251 | f_list = os.listdir(path)
252 | for i in f_list:
253 | if suffix.find(os.path.splitext(i)[1]) != -1:
254 | filename.append(i[1:3])
255 | return filename
256 |
257 | def __getitem__(self, index):
258 | """
259 | Args:
260 | index (int): Index
261 | Returns:
262 | tuple: (LQ, HQ, target) where target is IQA values of the target LQ.
263 | """
264 | LQ_path, HQ_path, target = self.samples[index]
265 | LQ = pil_loader(LQ_path, self.use_L)
266 | LQ_patches = []
267 | for _ in range(self.self_patch_num):
268 | LQ_patch = self.transform(LQ)
269 | LQ_patches.append(LQ_patch.unsqueeze(0))
270 | LQ_patches = torch.cat(LQ_patches, 0)
271 |
272 | return LQ_patches, target
273 |
274 | def __len__(self):
275 | length = len(self.samples)
276 | return length
277 |
278 | class LIVEChallengeFolder(data.Dataset):
279 | def __init__(self, root, index, transform, patch_num, patch_size=224, self_patch_num=1, use_L=False):
280 | self.patch_size =patch_size
281 | self.self_patch_num = self_patch_num
282 | self.transform = transform
283 | self.use_L = use_L
284 |
285 | LQ_pathes = scipy.io.loadmat(os.path.join(root, 'Data', 'AllImages_release.mat'))
286 | LQ_pathes = LQ_pathes['AllImages_release']
287 | LQ_pathes = LQ_pathes[7:1169]
288 | mos = scipy.io.loadmat(os.path.join(root, 'Data', 'AllMOS_release.mat'))
289 | labels = mos['AllMOS_release'].astype(np.float32)
290 | labels = labels[0][7:1169]
291 |
292 | sample = []
293 | for _, item in enumerate(index):
294 | for _ in range(patch_num):
295 | sample.append((os.path.join(root, 'Images', LQ_pathes[item][0][0]), labels[item]))
296 | self.samples = sample
297 |
298 | def __getitem__(self, index):
299 | """
300 | Args:
301 | index (int): Index
302 | Returns:
303 | tuple: (LQ, target) where target is IQA values of the target LQ.
304 | """
305 | LQ_path, target = self.samples[index]
306 | LQ = pil_loader(LQ_path, self.use_L)
307 | LQ_patches = []
308 | for _ in range(self.self_patch_num):
309 | LQ_patch = self.transform(LQ)
310 |
311 | LQ_patches.append(LQ_patch.unsqueeze(0))
312 | #[self_patch_num, 3, patch_size, patch_size]
313 | LQ_patches = torch.cat(LQ_patches, 0)
314 |
315 | return LQ_patches, target
316 |
317 | def __len__(self):
318 | length = len(self.samples)
319 | return length
320 |
321 | class BIDChallengeFolder(data.Dataset):
322 | def __init__(self, root, index, transform, patch_num, patch_size=224, self_patch_num=1, use_L=False):
323 | self.patch_size =patch_size
324 | self.self_patch_num = self_patch_num
325 | self.transform = transform
326 | self.use_L = use_L
327 |
328 | LQ_pathes = []
329 | labels = []
330 |
331 | xls_file = os.path.join(root, 'DatabaseGrades.xlsx')
332 | workbook = load_workbook(xls_file)
333 | booksheet = workbook.active
334 | rows = booksheet.rows
335 | count = 1
336 | for _ in rows:
337 | count += 1
338 | img_num = (booksheet.cell(row=count, column=1).value)
339 | img_name = "DatabaseImage%04d.JPG" % (img_num)
340 | LQ_pathes.append(img_name)
341 | mos = (booksheet.cell(row=count, column=2).value)
342 | mos = np.array(mos)
343 | mos = mos.astype(np.float32)
344 | labels.append(mos)
345 | if count == 587:
346 | break
347 |
348 | sample = []
349 | for _, item in enumerate(index):
350 | for _ in range(patch_num):
351 | sample.append((os.path.join(root, LQ_pathes[item]), labels[item]))
352 | self.samples = sample
353 |
354 | def __getitem__(self, index):
355 | """
356 | Args:
357 | index (int): Index
358 | Returns:
359 | tuple: (LQ, target) where target is IQA values of the target LQ.
360 | """
361 | LQ_path, target = self.samples[index]
362 | LQ = pil_loader(LQ_path, self.use_L)
363 | LQ_patches = []
364 | for _ in range(self.self_patch_num):
365 | LQ_patch = self.transform(LQ)
366 |
367 | LQ_patches.append(LQ_patch.unsqueeze(0))
368 | #[self_patch_num, 3, patch_size, patch_size]
369 | LQ_patches = torch.cat(LQ_patches, 0)
370 |
371 | return LQ_patches, target
372 |
373 | def __len__(self):
374 | length = len(self.samples)
375 | return length
376 |
377 | class Koniq_10kFolder(data.Dataset):
378 | def __init__(self, root, index, transform, patch_num, patch_size=224, self_patch_num=1, use_L=False):
379 | self.patch_size =patch_size
380 | self.self_patch_num = self_patch_num
381 | self.transform = transform
382 | self.use_L = use_L
383 |
384 | imgname = []
385 | mos_all = []
386 | csv_file = os.path.join(root, 'koniq10k_scores_and_distributions.csv')
387 | with open(csv_file) as f:
388 | reader = csv.DictReader(f)
389 | for row in reader:
390 | imgname.append(row['image_name'])
391 | mos = np.array(float(row['MOS_zscore'])).astype(np.float32)
392 | mos_all.append(mos)
393 |
394 | sample = []
395 | for _, item in enumerate(index):
396 | for _ in range(patch_num):
397 | sample.append((os.path.join(root, '1024x768', imgname[item]), mos_all[item]))
398 |
399 | self.samples = sample
400 |
401 | def __getitem__(self, index):
402 | """
403 | Args:
404 | index (int): Index
405 | Returns:
406 | tuple: (LQ, target) where target is IQA values of the target LQ.
407 | """
408 | LQ_path, target = self.samples[index]
409 | LQ = pil_loader(LQ_path, self.use_L)
410 | LQ_patches = []
411 | for _ in range(self.self_patch_num):
412 | LQ_patch = self.transform(LQ)
413 |
414 | LQ_patches.append(LQ_patch.unsqueeze(0))
415 | #[self_patch_num, 3, patch_size, patch_size]
416 | LQ_patches = torch.cat(LQ_patches, 0)
417 |
418 | return LQ_patches, target
419 |
420 | def __len__(self):
421 | length = len(self.samples)
422 | return length
423 |
424 |
425 | def getFileName(path, suffix):
426 | filename = []
427 | f_list = os.listdir(path)
428 | for i in f_list:
429 | if os.path.splitext(i)[1] == suffix:
430 | filename.append(i)
431 | return filename
432 |
433 |
434 | def getPairRandomPatch(img1, img2, crop_size=512):
435 | (iw,ih) = img1.size
436 | # print(ih,iw)
437 |
438 | ip = int(crop_size)
439 |
440 | ix = random.randrange(0, iw - ip + 1)
441 | iy = random.randrange(0, ih - ip + 1)
442 |
443 |
444 | img1_patch = img1.crop((ix, iy, ix+ip, iy+ip))#左上右下
445 | img2_patch = img2.crop((ix, iy, ix+ip, iy+ip))#左上右下
446 |
447 | return img1_patch, img2_patch
448 |
449 | def getPairAugment(img1, img2, hflip=True, vflip=True, rot=True):
450 | hflip = hflip and random.random() < 0.5
451 | vflip = vflip and random.random() < 0.5
452 | rot180 = rot and random.random() < 0.5
453 |
454 | if hflip:
455 | img1 = img1.transpose(Image.FLIP_TOP_BOTTOM)
456 | img2 = img2.transpose(Image.FLIP_TOP_BOTTOM)
457 | if vflip:
458 | img1 = img1.transpose(Image.FLIP_LEFT_RIGHT)
459 | img2 = img2.transpose(Image.FLIP_LEFT_RIGHT)
460 | if rot180:
461 | img1 = img1.transpose(Image.ROTATE_180)
462 | img2 = img2.transpose(Image.ROTATE_180)
463 |
464 | return img1, img2
465 |
466 |
467 | def getSelfPatch(img, patch_size, patch_num, is_random=True):
468 | (iw,ih) = img.size
469 | patches = []
470 | for i in range(patch_num):
471 | if is_random:
472 | ix = random.randrange(0, iw - patch_size + 1)
473 | iy = random.randrange(0, ih - patch_size + 1)
474 | else:ix,iy=(iw - patch_size + 1)//2,(ih - patch_size + 1)//2
475 |
476 | # patch = img[iy:iy + lr_size, ix:ix + lr_size, :]#上下左右
477 | patch = img.crop((ix, iy, ix+patch_size, iy+patch_size))#左上右下
478 | patches.append(patch)
479 |
480 | return patches
481 |
482 |
483 | def pil_loader(path, use_L=False):
484 | if use_L:
485 | return Image.open(path).convert('L')
486 | else:
487 | return Image.open(path).convert('RGB')
488 |
--------------------------------------------------------------------------------
/models/IQT.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch._C import device
3 | import torch.nn as nn
4 | from torch.nn import Dropout, Softmax, Linear, LayerNorm
5 | import torch.utils.model_zoo as model_zoo
6 | import math
7 | import copy
8 |
9 | __all__ = ['InceptionResNetV2', 'inceptionresnetv2']
10 | ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu}
11 | hidden_size = 256
12 |
13 | pretrained_settings = {
14 | 'inceptionresnetv2': {
15 | 'imagenet': {
16 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/inceptionresnetv2-520b38e4.pth',
17 | 'input_space': 'RGB',
18 | 'input_size': [3, 299, 299],
19 | 'input_range': [0, 1],
20 | 'mean': [0.5, 0.5, 0.5],
21 | 'std': [0.5, 0.5, 0.5],
22 | 'num_classes': 1000
23 | },
24 | 'imagenet+background': {
25 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/inceptionresnetv2-520b38e4.pth',
26 | 'input_space': 'RGB',
27 | 'input_size': [3, 299, 299],
28 | 'input_range': [0, 1],
29 | 'mean': [0.5, 0.5, 0.5],
30 | 'std': [0.5, 0.5, 0.5],
31 | 'num_classes': 1001
32 | }
33 | }
34 | }
35 |
36 |
37 | class BasicConv2d(nn.Module):
38 |
39 | def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0):
40 | super(BasicConv2d, self).__init__()
41 | self.conv = nn.Conv2d(in_planes, out_planes,
42 | kernel_size=kernel_size, stride=stride,
43 | padding=padding, bias=False) # verify bias false
44 | self.bn = nn.BatchNorm2d(out_planes,
45 | eps=0.001, # value found in tensorflow
46 | momentum=0.1, # default pytorch value
47 | affine=True)
48 | self.relu = nn.ReLU(inplace=False)
49 |
50 | def forward(self, x):
51 | x = self.conv(x)
52 | x = self.bn(x)
53 | x = self.relu(x)
54 | return x
55 |
56 |
57 | class Mixed_5b(nn.Module):
58 |
59 | def __init__(self):
60 | super(Mixed_5b, self).__init__()
61 |
62 | self.branch0 = BasicConv2d(192, 96, kernel_size=1, stride=1)
63 |
64 | self.branch1 = nn.Sequential(
65 | BasicConv2d(192, 48, kernel_size=1, stride=1),
66 | BasicConv2d(48, 64, kernel_size=5, stride=1, padding=2)
67 | )
68 |
69 | self.branch2 = nn.Sequential(
70 | BasicConv2d(192, 64, kernel_size=1, stride=1),
71 | BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1),
72 | BasicConv2d(96, 96, kernel_size=3, stride=1, padding=1)
73 | )
74 |
75 | self.branch3 = nn.Sequential(
76 | nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False),
77 | BasicConv2d(192, 64, kernel_size=1, stride=1)
78 | )
79 |
80 | def forward(self, x):
81 | x0 = self.branch0(x)
82 | x1 = self.branch1(x)
83 | x2 = self.branch2(x)
84 | x3 = self.branch3(x)
85 | out = torch.cat((x0, x1, x2, x3), 1)
86 | return out
87 |
88 |
89 | class Block35(nn.Module):
90 |
91 | def __init__(self, scale=1.0):
92 | super(Block35, self).__init__()
93 |
94 | self.scale = scale
95 |
96 | self.branch0 = BasicConv2d(320, 32, kernel_size=1, stride=1)
97 |
98 | self.branch1 = nn.Sequential(
99 | BasicConv2d(320, 32, kernel_size=1, stride=1),
100 | BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1)
101 | )
102 |
103 | self.branch2 = nn.Sequential(
104 | BasicConv2d(320, 32, kernel_size=1, stride=1),
105 | BasicConv2d(32, 48, kernel_size=3, stride=1, padding=1),
106 | BasicConv2d(48, 64, kernel_size=3, stride=1, padding=1)
107 | )
108 |
109 | self.conv2d = nn.Conv2d(128, 320, kernel_size=1, stride=1)
110 | self.relu = nn.ReLU(inplace=False)
111 |
112 | def forward(self, x):
113 | x0 = self.branch0(x)
114 | x1 = self.branch1(x)
115 | x2 = self.branch2(x)
116 | out = torch.cat((x0, x1, x2), 1)
117 | out = self.conv2d(out)
118 | out = out * self.scale + x
119 | out = self.relu(out)
120 | return out
121 |
122 |
123 | class Mixed_6a(nn.Module):
124 |
125 | def __init__(self):
126 | super(Mixed_6a, self).__init__()
127 |
128 | self.branch0 = BasicConv2d(320, 384, kernel_size=3, stride=2)
129 |
130 | self.branch1 = nn.Sequential(
131 | BasicConv2d(320, 256, kernel_size=1, stride=1),
132 | BasicConv2d(256, 256, kernel_size=3, stride=1, padding=1),
133 | BasicConv2d(256, 384, kernel_size=3, stride=2)
134 | )
135 |
136 | self.branch2 = nn.MaxPool2d(3, stride=2)
137 |
138 | def forward(self, x):
139 | x0 = self.branch0(x)
140 | x1 = self.branch1(x)
141 | x2 = self.branch2(x)
142 | out = torch.cat((x0, x1, x2), 1)
143 | return out
144 |
145 |
146 | class Block17(nn.Module):
147 |
148 | def __init__(self, scale=1.0):
149 | super(Block17, self).__init__()
150 |
151 | self.scale = scale
152 |
153 | self.branch0 = BasicConv2d(1088, 192, kernel_size=1, stride=1)
154 |
155 | self.branch1 = nn.Sequential(
156 | BasicConv2d(1088, 128, kernel_size=1, stride=1),
157 | BasicConv2d(128, 160, kernel_size=(1,7), stride=1, padding=(0,3)),
158 | BasicConv2d(160, 192, kernel_size=(7,1), stride=1, padding=(3,0))
159 | )
160 |
161 | self.conv2d = nn.Conv2d(384, 1088, kernel_size=1, stride=1)
162 | self.relu = nn.ReLU(inplace=False)
163 |
164 | def forward(self, x):
165 | x0 = self.branch0(x)
166 | x1 = self.branch1(x)
167 | out = torch.cat((x0, x1), 1)
168 | out = self.conv2d(out)
169 | out = out * self.scale + x
170 | out = self.relu(out)
171 | return out
172 |
173 |
174 | class Mixed_7a(nn.Module):
175 |
176 | def __init__(self):
177 | super(Mixed_7a, self).__init__()
178 |
179 | self.branch0 = nn.Sequential(
180 | BasicConv2d(1088, 256, kernel_size=1, stride=1),
181 | BasicConv2d(256, 384, kernel_size=3, stride=2)
182 | )
183 |
184 | self.branch1 = nn.Sequential(
185 | BasicConv2d(1088, 256, kernel_size=1, stride=1),
186 | BasicConv2d(256, 288, kernel_size=3, stride=2)
187 | )
188 |
189 | self.branch2 = nn.Sequential(
190 | BasicConv2d(1088, 256, kernel_size=1, stride=1),
191 | BasicConv2d(256, 288, kernel_size=3, stride=1, padding=1),
192 | BasicConv2d(288, 320, kernel_size=3, stride=2)
193 | )
194 |
195 | self.branch3 = nn.MaxPool2d(3, stride=2)
196 |
197 | def forward(self, x):
198 | x0 = self.branch0(x)
199 | x1 = self.branch1(x)
200 | x2 = self.branch2(x)
201 | x3 = self.branch3(x)
202 | out = torch.cat((x0, x1, x2, x3), 1)
203 | return out
204 |
205 |
206 | class Block8(nn.Module):
207 |
208 | def __init__(self, scale=1.0, noReLU=False):
209 | super(Block8, self).__init__()
210 |
211 | self.scale = scale
212 | self.noReLU = noReLU
213 |
214 | self.branch0 = BasicConv2d(2080, 192, kernel_size=1, stride=1)
215 |
216 | self.branch1 = nn.Sequential(
217 | BasicConv2d(2080, 192, kernel_size=1, stride=1),
218 | BasicConv2d(192, 224, kernel_size=(1,3), stride=1, padding=(0,1)),
219 | BasicConv2d(224, 256, kernel_size=(3,1), stride=1, padding=(1,0))
220 | )
221 |
222 | self.conv2d = nn.Conv2d(448, 2080, kernel_size=1, stride=1)
223 | if not self.noReLU:
224 | self.relu = nn.ReLU(inplace=False)
225 |
226 | def forward(self, x):
227 | x0 = self.branch0(x)
228 | x1 = self.branch1(x)
229 | out = torch.cat((x0, x1), 1)
230 | out = self.conv2d(out)
231 | out = out * self.scale + x
232 | if not self.noReLU:
233 | out = self.relu(out)
234 | return out
235 |
236 |
237 | class InceptionResNetV2(nn.Module):
238 |
239 | def __init__(self, num_classes=1001):
240 | super(InceptionResNetV2, self).__init__()
241 | # Special attributs
242 | self.input_space = None
243 | self.input_size = (299, 299, 3)
244 | self.mean = None
245 | self.std = None
246 | # Modules
247 | self.conv2d_1a = BasicConv2d(3, 32, kernel_size=3, stride=2)
248 | self.conv2d_2a = BasicConv2d(32, 32, kernel_size=3, stride=1)
249 | self.conv2d_2b = BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1)
250 | self.maxpool_3a = nn.MaxPool2d(3, stride=2)
251 | self.conv2d_3b = BasicConv2d(64, 80, kernel_size=1, stride=1)
252 | self.conv2d_4a = BasicConv2d(80, 192, kernel_size=3, stride=1)
253 | self.maxpool_5a = nn.MaxPool2d(3, stride=2)
254 | self.mixed_5b = Mixed_5b()
255 | self.repeat = nn.Sequential(
256 | Block35(scale=0.17),
257 | Block35(scale=0.17),
258 | Block35(scale=0.17),
259 | Block35(scale=0.17),
260 | Block35(scale=0.17),
261 | Block35(scale=0.17),
262 | Block35(scale=0.17),
263 | Block35(scale=0.17),
264 | Block35(scale=0.17),
265 | Block35(scale=0.17)
266 | )
267 | self.mixed_6a = Mixed_6a()
268 | self.repeat_1 = nn.Sequential(
269 | Block17(scale=0.10),
270 | Block17(scale=0.10),
271 | Block17(scale=0.10),
272 | Block17(scale=0.10),
273 | Block17(scale=0.10),
274 | Block17(scale=0.10),
275 | Block17(scale=0.10),
276 | Block17(scale=0.10),
277 | Block17(scale=0.10),
278 | Block17(scale=0.10),
279 | Block17(scale=0.10),
280 | Block17(scale=0.10),
281 | Block17(scale=0.10),
282 | Block17(scale=0.10),
283 | Block17(scale=0.10),
284 | Block17(scale=0.10),
285 | Block17(scale=0.10),
286 | Block17(scale=0.10),
287 | Block17(scale=0.10),
288 | Block17(scale=0.10)
289 | )
290 | self.mixed_7a = Mixed_7a()
291 | self.repeat_2 = nn.Sequential(
292 | Block8(scale=0.20),
293 | Block8(scale=0.20),
294 | Block8(scale=0.20),
295 | Block8(scale=0.20),
296 | Block8(scale=0.20),
297 | Block8(scale=0.20),
298 | Block8(scale=0.20),
299 | Block8(scale=0.20),
300 | Block8(scale=0.20)
301 | )
302 | self.block8 = Block8(noReLU=True)
303 | self.conv2d_7b = BasicConv2d(2080, 1536, kernel_size=1, stride=1)
304 | self.avgpool_1a = nn.AvgPool2d(8, count_include_pad=False)
305 | self.last_linear = nn.Linear(1536, num_classes)
306 |
307 | def features(self, input):
308 | device = input.device
309 | concat_x = []
310 | x = self.conv2d_1a(input)
311 | x = self.conv2d_2a(x)
312 | x = self.conv2d_2b(x)
313 | x = self.maxpool_3a(x)
314 | x = self.conv2d_3b(x)
315 | x = self.conv2d_4a(x)
316 | x = self.maxpool_5a(x)
317 | x = self.mixed_5b(x)
318 | concat_x.append(x)
319 | for i, block in enumerate(self.repeat):
320 | x = block(x)
321 | if (i+1)%2 == 0:
322 | concat_x.append(x)
323 | concat_x = torch.cat(concat_x, 1).to(device)
324 | return concat_x
325 |
326 | def logits(self, features):
327 | x = self.avgpool_1a(features)
328 | x = x.view(x.size(0), -1)
329 | x = self.last_linear(x)
330 | return x
331 |
332 | def forward(self, input):
333 | x = self.features(input)
334 | return x
335 |
336 | def inceptionresnetv2_feature_extractor(num_classes=1000, pretrained='imagenet'):
337 | r"""InceptionResNetV2 model architecture from the
338 | `"InceptionV4, Inception-ResNet..." `_ paper.
339 | """
340 | if pretrained:
341 | settings = pretrained_settings['inceptionresnetv2'][pretrained]
342 | assert num_classes == settings['num_classes'], \
343 | "num_classes should be {}, but is {}".format(settings['num_classes'], num_classes)
344 |
345 | # both 'imagenet'&'imagenet+background' are loaded from same parameters
346 | model = InceptionResNetV2(num_classes=1001)
347 | model.load_state_dict(model_zoo.load_url(settings['url']))
348 |
349 | if pretrained == 'imagenet':
350 | new_last_linear = nn.Linear(1536, 1000)
351 | new_last_linear.weight.data = model.last_linear.weight.data[1:]
352 | new_last_linear.bias.data = model.last_linear.bias.data[1:]
353 | model.last_linear = new_last_linear
354 |
355 | model.input_space = settings['input_space']
356 | model.input_size = settings['input_size']
357 | model.input_range = settings['input_range']
358 |
359 | model.mean = settings['mean']
360 | model.std = settings['std']
361 | else:
362 | model = InceptionResNetV2(num_classes=num_classes)
363 | return model
364 |
365 | class Mlp(nn.Module):
366 | def __init__(self):
367 | super(Mlp, self).__init__()
368 | self.hidden_size = 256
369 | self.mlp_size = 512
370 | self.dropout_rate = 0.1
371 | self.fc1 = Linear(self.hidden_size, self.mlp_size)
372 | self.fc2 = Linear(self.mlp_size, self.hidden_size)
373 | self.act_fn = ACT2FN["gelu"]
374 | self.dropout = Dropout(self.dropout_rate)
375 |
376 | self._init_weights()
377 |
378 | def _init_weights(self):
379 | nn.init.xavier_uniform_(self.fc1.weight)
380 | nn.init.xavier_uniform_(self.fc2.weight)
381 | nn.init.normal_(self.fc1.bias, std=1e-6)
382 | nn.init.normal_(self.fc2.bias, std=1e-6)
383 |
384 | def forward(self, x):
385 | x = self.fc1(x)
386 | x = self.act_fn(x)
387 | x = self.dropout(x)
388 | x = self.fc2(x)
389 | x = self.dropout(x)
390 | return x
391 |
392 | class MHA(nn.Module):
393 | def __init__(self):
394 | super(MHA, self).__init__()
395 | self.vis = True
396 | self.num_attention_heads = 4
397 | self.hidden_size = 256
398 | self.attention_head_size = int(self.hidden_size / self.num_attention_heads)
399 | self.all_head_size = self.num_attention_heads * self.attention_head_size
400 |
401 | self.query = Linear(self.hidden_size, self.all_head_size)
402 | self.key = Linear(self.hidden_size, self.all_head_size)
403 | self.value = Linear(self.hidden_size, self.all_head_size)
404 |
405 | self.out = Linear(self.hidden_size, self.hidden_size)
406 | self.attn_dropout = Dropout(0.0)
407 | self.proj_dropout = Dropout(0.0)
408 |
409 | self.softmax = Softmax(dim=-1)
410 |
411 | def transpose_for_scores(self, x):
412 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
413 | x = x.view(*new_x_shape)
414 | return x.permute(0, 2, 1, 3)
415 |
416 | def forward(self, k, v, q):
417 | mixed_key_layer = self.key(k)
418 | mixed_value_layer = self.value(v)
419 | mixed_query_layer = self.query(q)
420 |
421 | query_layer = self.transpose_for_scores(mixed_query_layer)
422 | key_layer = self.transpose_for_scores(mixed_key_layer)
423 | value_layer = self.transpose_for_scores(mixed_value_layer)
424 |
425 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
426 | attention_scores = attention_scores / math.sqrt(self.attention_head_size)
427 | attention_probs = self.softmax(attention_scores)
428 | weights = attention_probs
429 | attention_probs = self.attn_dropout(attention_probs)
430 |
431 | context_layer = torch.matmul(attention_probs, value_layer)
432 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
433 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
434 | context_layer = context_layer.view(*new_context_layer_shape)
435 | attention_output = self.out(context_layer)
436 | attention_output = self.proj_dropout(attention_output)
437 | return attention_output
438 |
439 | class IQT_Encoder_Block(nn.Module):
440 | def __init__(self):
441 | super(IQT_Encoder_Block, self).__init__()
442 | self.hidden_size = 256
443 | self.ffn_norm = LayerNorm(self.hidden_size, eps=1e-6)
444 | self.ffn = Mlp()
445 | self.attention_norm = LayerNorm(self.hidden_size, eps=1e-6)
446 | self.attn = MHA()
447 |
448 | def forward(self, k, v, q):
449 | x = self.attn(k, v, q)
450 | x += q
451 | x1 = self.attention_norm(x)
452 | x = self.ffn(x1)
453 | x += x1
454 | y = self.ffn_norm(x)
455 | return y
456 |
457 | class IQT_Encoder(nn.Module):
458 | def __init__(self):
459 | super(IQT_Encoder, self).__init__()
460 | self.layer = nn.ModuleList()
461 | for _ in range(2):
462 | layer = IQT_Encoder_Block()
463 | self.layer.append(copy.deepcopy(layer))
464 |
465 | def forward(self, x_diff):
466 | for encoder_block in self.layer:
467 | x_diff = encoder_block(x_diff, x_diff, x_diff)
468 | return x_diff
469 |
470 | class IQT_Decoder_Block(nn.Module):
471 | def __init__(self):
472 | super(IQT_Decoder_Block, self).__init__()
473 | self.hidden_size = 256
474 | self.ffn_norm = LayerNorm(self.hidden_size, eps=1e-6)
475 | self.ffn = Mlp()
476 | self.attention_norm = LayerNorm(self.hidden_size, eps=1e-6)
477 | self.attn = MHA()
478 | self.self_attention_norm = LayerNorm(self.hidden_size, eps=1e-6)
479 | self.self_attn = MHA()
480 |
481 | def forward(self, k, v, q, k1, v1):
482 | x = self.self_attn(k, v, q)
483 | x += q
484 | x1 = self.self_attention_norm(x)
485 | x = self.attn(k1, v1, x1)
486 | x += x1
487 | x2 = self.attention_norm(x)
488 | x = self.ffn(x2)
489 | x += x2
490 | y = self.ffn_norm(x)
491 | return y
492 |
493 | class IQT_Decoder(nn.Module):
494 | def __init__(self):
495 | super(IQT_Decoder, self).__init__()
496 | self.layer = nn.ModuleList()
497 | for _ in range(2):
498 | layer = IQT_Decoder_Block()
499 | self.layer.append(copy.deepcopy(layer))
500 |
501 | def forward(self, x_HQ, x_diff):
502 | for decoder_block in self.layer:
503 | x_HQ = decoder_block(x_HQ, x_HQ, x_HQ, x_diff, x_diff)
504 | return x_HQ
505 |
506 | class RegressionFCNet(nn.Module):
507 | def __init__(self):
508 | super(RegressionFCNet, self).__init__()
509 | self.target_in_size=256
510 | self.target_fc1_size=512
511 |
512 | self.l1 = nn.Linear(self.target_in_size, self.target_fc1_size)
513 | self.relu = nn.ReLU()
514 | self.l2 = nn.Linear(self.target_fc1_size, 1)
515 |
516 |
517 | def forward(self, x):
518 | q = self.l1(x)
519 | q = self.relu(q)
520 | q = self.l2(q).squeeze()
521 | return q
522 |
523 | class IQT(nn.Module):
524 |
525 | def __init__(self):
526 | super(IQT, self).__init__()
527 | self.feature_extractor = inceptionresnetv2_feature_extractor(num_classes=1000, pretrained='imagenet')
528 | for param in self.feature_extractor.parameters():
529 | param.requires_grad = False
530 |
531 | self.conv = nn.Conv2d(1920, 256, kernel_size=1, stride=1, padding=0)
532 |
533 | self.position_embeddings = nn.Parameter(torch.zeros(1, 625+1, 256))
534 | self.quality_token = nn.Parameter(torch.zeros(1, 1, 256))
535 |
536 | self.encoder = IQT_Encoder()
537 | self.decoder = IQT_Decoder()
538 |
539 | self.regressor = RegressionFCNet()
540 |
541 | def cal_params(self):
542 | params = list(self.parameters())
543 | k = 0
544 | for i in params:
545 | l = 1
546 | for j in i.size():
547 | l *= j
548 | k = k + l
549 | print("Total parameters is :" + str(k))
550 |
551 | def forward(self, LQ_patch, HQ_patch):
552 | B = LQ.shape[0]
553 | feature_LQ = self.feature_extractor(LQ_patch)
554 | feature_HQ = self.feature_extractor(HQ_patch)
555 |
556 | feature_diff = self.conv(feature_HQ-feature_LQ)
557 | feature_HQ = self.conv(feature_HQ)
558 |
559 | quality_tokens = self.quality_token.expand(B,-1,-1)
560 |
561 | flat_feature_HQ = feature_HQ.flatten(2).transpose(-1,-2)
562 | flat_feature_HQ = torch.cat((quality_tokens,flat_feature_HQ), dim=1) + self.position_embeddings
563 |
564 | flat_feature_diff = feature_diff.flatten(2).transpose(-1,-2)
565 | flat_feature_diff = torch.cat((quality_tokens,flat_feature_diff), dim=1) + self.position_embeddings
566 |
567 | flat_feature_diff = self.encoder(flat_feature_diff)
568 | f = self.decoder(flat_feature_diff, flat_feature_HQ)
569 | y = self.regressor(f[:,0])
570 |
571 | return y
572 |
573 | '''
574 | TEST
575 | Run this code with:
576 | ```
577 | cd $HOME/pretrained-models.pytorch
578 | python -m pretrainedmodels.inceptionresnetv2
579 | ```
580 | '''
581 | if __name__ == '__main__':
582 | import time
583 | device = torch.device('cuda')
584 | LQ = torch.rand((1,3,224, 224)).to(device)
585 | HQ = torch.rand((1,3,224, 224)).to(device)
586 | net = IQT().to(device)
587 | net.cal_params()
588 |
589 | torch.cuda.synchronize()
590 | start = time.time()
591 | a = net(LQ, HQ)
592 | torch.cuda.synchronize()
593 | end = time.time()
594 | print("run time is :" + str(end-start))
595 |
--------------------------------------------------------------------------------
/folders/folders_LQ_HQ_diff_content_HQ.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.utils.data as data
3 | import torchvision
4 | from PIL import Image
5 | import os
6 | import os.path
7 | import scipy.io
8 | import numpy as np
9 | import csv
10 | import random
11 | from openpyxl import load_workbook
12 | import cv2
13 | from torchvision import transforms
14 |
15 | class Kadid10kFolder(data.Dataset):
16 | def __init__(self, root, HQ_diff_content_root, index, transform, HQ_diff_content_transform, patch_num, patch_size=224, self_patch_num=10):
17 | self.patch_size = patch_size
18 | self.self_patch_num = self_patch_num
19 | self.HQ_diff_content_root = HQ_diff_content_root
20 |
21 | LQ_paths = []
22 | HQ_paths = []
23 | mos_all = []
24 | csv_file = os.path.join(root, 'dmos.csv')
25 | with open(csv_file) as f:
26 | reader = csv.DictReader(f)
27 | for row in reader:
28 | LQ_paths.append(row['dist_img'])
29 | HQ_paths.append(row['ref_img'])
30 | mos = np.array(float(row['dmos'])).astype(np.float32)
31 | mos_all.append(mos)
32 |
33 | sample = []
34 | for _, item in enumerate(index):
35 | for _ in range(patch_num):
36 | sample.append((os.path.join(root, 'images', LQ_paths[item]),os.path.join(root, 'images', HQ_paths[item]), mos_all[item]))
37 |
38 | self.HQ_diff_content_paths = []
39 | for HQ_diff_content_img_path in os.listdir(HQ_diff_content_root):
40 | if HQ_diff_content_img_path[-3:] == 'png' or HQ_diff_content_img_path[-3:] == 'jpg' or HQ_diff_content_img_path[-3:] == 'bmp':
41 | self.HQ_diff_content_paths.append(os.path.join(HQ_diff_content_root, HQ_diff_content_img_path))
42 |
43 | self.samples = sample
44 | self.transform = transform
45 | self.HQ_diff_content_transform = HQ_diff_content_transform
46 |
47 | def __getitem__(self, index):
48 | """
49 | Args:
50 | index (int): Index
51 | Returns:
52 | tuple: (LQ, HQ, HQ_diff_content, target) where target is IQA values of the target LQ.
53 | """
54 | LQ_path, HQ_path, target = self.samples[index]
55 | HQ_diff_content_path = self.HQ_diff_content_paths[random.randint(0, len(self.HQ_diff_content_paths)-1)]
56 | LQ = pil_loader(LQ_path)
57 | HQ_diff_content = pil_loader(HQ_diff_content_path)
58 | HQ = pil_loader(HQ_path)
59 | LQ_patches, HQ_patches, HQ_diff_content_patches = [], [], []
60 | for _ in range(self.self_patch_num):
61 | LQ_patch, HQ_patch = getPairRandomPatch(LQ,HQ, crop_size=self.patch_size)
62 |
63 | LQ_patch = self.transform(LQ_patch)
64 | HQ_patch = self.transform(HQ_patch)
65 | HQ_diff_content_patch = self.HQ_diff_content_transform(HQ_diff_content)
66 |
67 | LQ_patches.append(LQ_patch.unsqueeze(0))
68 | HQ_patches.append(HQ_patch.unsqueeze(0))
69 | HQ_diff_content_patches.append(HQ_diff_content_patch.unsqueeze(0))
70 | #[self_patch_num, 3, patch_size, patch_size]
71 | LQ_patches = torch.cat(LQ_patches, 0)
72 | HQ_patches = torch.cat(HQ_patches, 0)
73 | HQ_diff_content_patches = torch.cat(HQ_diff_content_patches, 0)
74 |
75 | return LQ_patches, HQ_patches, HQ_diff_content_patches, target
76 |
77 | def __len__(self):
78 | length = len(self.samples)
79 | return length
80 |
81 | class LIVEFolder(data.Dataset):
82 |
83 | def __init__(self, root, HQ_diff_content_root, index, transform, HQ_diff_content_transform, patch_num, patch_size=224, self_patch_num=10):
84 | self.patch_size =patch_size
85 | self.self_patch_num = self_patch_num
86 | self.root = root
87 | self.HQ_diff_content_root = HQ_diff_content_root
88 |
89 | refpath = os.path.join(root, 'refimgs')
90 | refname = getFileName(refpath, '.bmp')
91 |
92 | jp2kroot = os.path.join(root, 'jp2k')
93 | jp2kname = self.getDistortionTypeFileName(jp2kroot, 227)
94 |
95 | jpegroot = os.path.join(root, 'jpeg')
96 | jpegname = self.getDistortionTypeFileName(jpegroot, 233)
97 |
98 | wnroot = os.path.join(root, 'wn')
99 | wnname = self.getDistortionTypeFileName(wnroot, 174)
100 |
101 | gblurroot = os.path.join(root, 'gblur')
102 | gblurname = self.getDistortionTypeFileName(gblurroot, 174)
103 |
104 | fastfadingroot = os.path.join(root, 'fastfading')
105 | fastfadingname = self.getDistortionTypeFileName(fastfadingroot, 174)
106 |
107 | imgpath = jp2kname + jpegname + wnname + gblurname + fastfadingname
108 |
109 | dmos = scipy.io.loadmat(os.path.join(root, 'dmos_realigned.mat'))
110 | labels = dmos['dmos_new'].astype(np.float32)
111 |
112 | orgs = dmos['orgs']
113 | refpaths_all = scipy.io.loadmat(os.path.join(root, 'refnames_all.mat'))
114 | refpaths_all = refpaths_all['refnames_all']
115 |
116 | sample = []
117 | for i in range(0, len(index)):
118 | train_sel = (refname[index[i]] == refpaths_all)
119 | train_sel = train_sel * ~orgs.astype(np.bool_)
120 | train_sel = np.where(train_sel == True)
121 | train_sel = train_sel[1].tolist()
122 | for j, item in enumerate(train_sel):
123 | for aug in range(patch_num):
124 | LQ_path = imgpath[item]
125 | HQ_path = os.path.join(root, 'refimgs', refpaths_all[0][item][0])
126 | label = labels[0][item]
127 | sample.append((LQ_path, HQ_path, label))
128 | # print(self.imgpath[item])
129 |
130 | self.HQ_diff_content_path = []
131 | for HQ_diff_content_img_name in os.listdir(HQ_diff_content_root):
132 | if HQ_diff_content_img_name[-3:] == 'png' or HQ_diff_content_img_name[-3:] == 'jpg' or HQ_diff_content_img_name[-3:] == 'bmp':
133 | self.HQ_diff_content_path.append(os.path.join(HQ_diff_content_root, HQ_diff_content_img_name))
134 |
135 | self.samples = sample
136 | self.transform = transform
137 | self.HQ_diff_content_transform = HQ_diff_content_transform
138 |
139 | def __getitem__(self, index):
140 | """
141 | Args:
142 | index (int): Index
143 | Returns:
144 | tuple: (LQ, HQ, HQ_diff_content, target) where target is IQA values of the target LQ.
145 | """
146 | LQ_path, HQ_path, target = self.samples[index]
147 | HQ_diff_content_path = self.HQ_diff_content_path[random.randint(0, len(self.HQ_diff_content_path)-1)]
148 | LQ = pil_loader(LQ_path)
149 | HQ = pil_loader(HQ_path)
150 | HQ_diff_content = pil_loader(HQ_diff_content_path)
151 | LQ_patches, HQ_patches, HQ_diff_content_patches = [], [], []
152 | for _ in range(self.self_patch_num):
153 | LQ_patch, HQ_patch = getPairRandomPatch(LQ, HQ, crop_size=self.patch_size)
154 |
155 | LQ_patch = self.transform(LQ_patch)
156 | HQ_patch = self.transform(HQ_patch)
157 | HQ_diff_content_patch = self.HQ_diff_content_transform(HQ_diff_content)
158 |
159 | LQ_patches.append(LQ_patch.unsqueeze(0))
160 | HQ_patches.append(HQ_patch.unsqueeze(0))
161 | HQ_diff_content_patches.append(HQ_diff_content_patch.unsqueeze(0))
162 | #[self_patch_num, 3, patch_size, patch_size]
163 | LQ_patches = torch.cat(LQ_patches, 0)
164 | HQ_patches = torch.cat(HQ_patches, 0)
165 | HQ_diff_content_patches = torch.cat(HQ_diff_content_patches, 0)
166 |
167 | return LQ_patches, HQ_patches, HQ_diff_content_patches, target
168 |
169 | def __len__(self):
170 | length = len(self.samples)
171 | return length
172 |
173 | def getDistortionTypeFileName(self, path, num):
174 | filename = []
175 | index = 1
176 | for i in range(0, num):
177 | name = '{:0>3d}{}'.format(index, '.bmp')
178 | filename.append(os.path.join(path, name))
179 | index = index + 1
180 | return filename
181 |
182 | class CSIQFolder(data.Dataset):
183 |
184 | def __init__(self, root, HQ_diff_content_root, index, transform, HQ_diff_content_transform, patch_num, patch_size =224, self_patch_num=10):
185 | self.patch_size =patch_size
186 | self.self_patch_num = self_patch_num
187 |
188 | refpath = os.path.join(root, 'src_imgs')
189 | refname = getFileName(refpath,'.png')
190 | txtpath = os.path.join(root, 'csiq_label.txt')
191 | fh = open(txtpath, 'r')
192 | LQ_pathes = []
193 | target = []
194 | refpaths_all = []
195 | for line in fh:
196 | line = line.split('\n')
197 | words = line[0].split()
198 | LQ_pathes.append((words[0]))
199 | target.append(words[1])
200 | ref_temp = words[0].split(".")
201 | refpaths_all.append(ref_temp[0] + '.' + ref_temp[-1])
202 |
203 | labels = np.array(target).astype(np.float32)
204 | refpaths_all = np.array(refpaths_all)
205 |
206 | sample = []
207 |
208 | for i, item in enumerate(index):
209 | train_sel = (refname[index[i]] == refpaths_all)
210 | train_sel = np.where(train_sel == True)
211 | train_sel = train_sel[0].tolist()
212 | for j, item in enumerate(train_sel):
213 | for aug in range(patch_num):
214 | LQ_path = os.path.join(root, 'dst_imgs_all', LQ_pathes[item])
215 | HQ_path = os.path.join(root, 'src_imgs', refpaths_all[item])
216 | label = labels[item]
217 | sample.append((LQ_path, HQ_path, label))
218 |
219 | self.HQ_diff_content = []
220 | for HQ_diff_content_img_name in os.listdir(HQ_diff_content_root):
221 | if HQ_diff_content_img_name[-3:] == 'png' or HQ_diff_content_img_name[-3:] == 'jpg' or HQ_diff_content_img_name[-3:] == 'bmp':
222 | self.HQ_diff_content.append(os.path.join(HQ_diff_content_root, HQ_diff_content_img_name))
223 |
224 | self.samples = sample
225 | self.transform = transform
226 | self.HQ_diff_content_transform = HQ_diff_content_transform
227 |
228 | def __getitem__(self, index):
229 | """
230 | Args:
231 | index (int): Index
232 | Returns:
233 | tuple: (LQ, HQ, HQ_diff_content, target) where target is IQA values of the target LQ.
234 | """
235 | LQ_path, HQ_path, target = self.samples[index]
236 | HQ_diff_content_path = self.HQ_diff_content[random.randint(0, len(self.HQ_diff_content)-1)]
237 | LQ = pil_loader(LQ_path)
238 | HQ = pil_loader(HQ_path)
239 | HQ_diff_content = pil_loader(HQ_diff_content_path)
240 | LQ_patches, HQ_patches, HQ_diff_content_patches = [], [], []
241 | for _ in range(self.self_patch_num):
242 | LQ_patch, HQ_patch = getPairRandomPatch(LQ, HQ, crop_size=self.patch_size)
243 |
244 | LQ_patch = self.transform(LQ_patch)
245 | HQ_patch = self.transform(HQ_patch)
246 | HQ_diff_content_patch = self.HQ_diff_content_transform(HQ_diff_content)
247 |
248 | LQ_patches.append(LQ_patch.unsqueeze(0))
249 | HQ_patches.append(HQ_patch.unsqueeze(0))
250 | HQ_diff_content_patches.append(HQ_diff_content_patch.unsqueeze(0))
251 | #[self_patch_num, 3, patch_size, patch_size]
252 | LQ_patches = torch.cat(LQ_patches, 0)
253 | HQ_patches = torch.cat(HQ_patches, 0)
254 | HQ_diff_content_patches = torch.cat(HQ_diff_content_patches, 0)
255 |
256 | return LQ_patches, HQ_patches, HQ_diff_content_patches, target
257 |
258 | def __len__(self):
259 | length = len(self.samples)
260 | return length
261 |
262 | class TID2013Folder(data.Dataset):
263 |
264 | def __init__(self, root, HQ_diff_content_root, index, transform, HQ_diff_content_transform, patch_num, patch_size =224, self_patch_num=10):
265 | self.patch_size =patch_size
266 | self.self_patch_num = self_patch_num
267 |
268 | refpath = os.path.join(root, 'reference_images')
269 | refname = self._getTIDFileName(refpath,'.bmp.BMP')
270 | txtpath = os.path.join(root, 'mos_with_names.txt')
271 | fh = open(txtpath, 'r')
272 | LQ_pathes = []
273 | target = []
274 | refpaths_all = []
275 | for line in fh:
276 | line = line.split('\n')
277 | words = line[0].split()
278 | LQ_pathes.append((words[1]))
279 | target.append(words[0])
280 | ref_temp = words[1].split("_")
281 | refpaths_all.append(ref_temp[0][1:])
282 | labels = np.array(target).astype(np.float32)
283 | refpaths_all = np.array(refpaths_all)
284 |
285 | sample = []
286 | for i, item in enumerate(index):
287 | train_sel = (refname[index[i]] == refpaths_all)
288 | train_sel = np.where(train_sel == True)
289 | train_sel = train_sel[0].tolist()
290 | for j, item in enumerate(train_sel):
291 | for aug in range(patch_num):
292 | LQ_path = os.path.join(root, 'distorted_images', LQ_pathes[item])
293 | refHQ_name = 'I' + LQ_pathes[item].split("_")[0][1:] + '.BMP'
294 | HQ_path = os.path.join(refpath, refHQ_name)
295 | label = labels[item]
296 | sample.append((LQ_path, HQ_path, label))
297 | self.HQ_diff_content = []
298 | for HQ_diff_content_img_name in os.listdir(HQ_diff_content_root):
299 | if HQ_diff_content_img_name[-3:] == 'png' or HQ_diff_content_img_name[-3:] == 'jpg' or HQ_diff_content_img_name[-3:] == 'bmp':
300 | self.HQ_diff_content.append(os.path.join(HQ_diff_content_root, HQ_diff_content_img_name))
301 |
302 | self.samples = sample
303 | self.transform = transform
304 | self.HQ_diff_content_transform = HQ_diff_content_transform
305 |
306 | def _getTIDFileName(self, path, suffix):
307 | filename = []
308 | f_list = os.listdir(path)
309 | for i in f_list:
310 | if suffix.find(os.path.splitext(i)[1]) != -1:
311 | filename.append(i[1:3])
312 | return filename
313 |
314 | def __getitem__(self, index):
315 | """
316 | Args:
317 | index (int): Index
318 | Returns:
319 | tuple: (LQ, HQ, HQ_diff_content, target) where target is IQA values of the target LQ.
320 | """
321 | LQ_path, HQ_path, target = self.samples[index]
322 | HQ_diff_content_path = self.HQ_diff_content[random.randint(0, len(self.HQ_diff_content)-1)]
323 | LQ = pil_loader(LQ_path)
324 | HQ = pil_loader(HQ_path)
325 | HQ_diff_content = pil_loader(HQ_diff_content_path)
326 | LQ_patches, HQ_patches, HQ_diff_content_patches = [], [], []
327 | for _ in range(self.self_patch_num):
328 | LQ_patch, HQ_patch = getPairRandomPatch(LQ, HQ, crop_size=self.patch_size)
329 |
330 | LQ_patch = self.transform(LQ_patch)
331 | HQ_patch = self.transform(HQ_patch)
332 | HQ_diff_content_patch = self.HQ_diff_content_transform(HQ_diff_content)
333 |
334 | LQ_patches.append(LQ_patch.unsqueeze(0))
335 | HQ_patches.append(HQ_patch.unsqueeze(0))
336 | HQ_diff_content_patches.append(HQ_diff_content_patch.unsqueeze(0))
337 | #[self_patch_num, 3, patch_size, patch_size]
338 | LQ_patches = torch.cat(LQ_patches, 0)
339 | HQ_patches = torch.cat(HQ_patches, 0)
340 | HQ_diff_content_patches = torch.cat(HQ_diff_content_patches, 0)
341 |
342 | return LQ_patches, HQ_patches, HQ_diff_content_patches, target
343 |
344 | def __len__(self):
345 | length = len(self.samples)
346 | return length
347 |
348 | class LIVEChallengeFolder(data.Dataset):
349 | def __init__(self, root, HQ_diff_content_root, index, transform, HQ_diff_content_transform, patch_num, patch_size =224, self_patch_num=10):
350 | self.patch_size =patch_size
351 | self.self_patch_num = self_patch_num
352 |
353 | LQ_pathes = scipy.io.loadmat(os.path.join(root, 'Data', 'AllImages_release.mat'))
354 | LQ_pathes = LQ_pathes['AllImages_release']
355 | LQ_pathes = LQ_pathes[7:1169]
356 | mos = scipy.io.loadmat(os.path.join(root, 'Data', 'AllMOS_release.mat'))
357 | labels = mos['AllMOS_release'].astype(np.float32)
358 | labels = labels[0][7:1169]
359 |
360 | sample = []
361 | for _, item in enumerate(index):
362 | for _ in range(patch_num):
363 | sample.append((os.path.join(root, 'Images', LQ_pathes[item][0][0]), labels[item]))
364 |
365 | self.HQ_diff_content_paths = []
366 | for HQ_diff_content_img_path in os.listdir(HQ_diff_content_root):
367 | if HQ_diff_content_img_path[-3:] == 'png' or HQ_diff_content_img_path[-3:] == 'jpg' or HQ_diff_content_img_path[-3:] == 'bmp':
368 | self.HQ_diff_content_paths.append(os.path.join(HQ_diff_content_root, HQ_diff_content_img_path))
369 |
370 | self.samples = sample
371 | self.transform = transform
372 | self.HQ_diff_content_transform = HQ_diff_content_transform
373 |
374 | def __getitem__(self, index):
375 | """
376 | Args:
377 | index (int): Index
378 | Returns:
379 | tuple: (LQ, _, HQ_diff_content, target) where target is IQA values of the target LQ.
380 | """
381 | LQ_path, target = self.samples[index]
382 | HQ_diff_content_path = self.HQ_diff_content_paths[random.randint(0, len(self.HQ_diff_content_paths)-1)]
383 | LQ = pil_loader(LQ_path)
384 | HQ_diff_content = pil_loader(HQ_diff_content_path)
385 | LQ_patches, HQ_diff_content_patches = [], []
386 | for _ in range(self.self_patch_num):
387 | LQ_patch = self.HQ_diff_content_transform(LQ)
388 | HQ_diff_content_patch = self.HQ_diff_content_transform(HQ_diff_content)
389 |
390 | LQ_patches.append(LQ_patch.unsqueeze(0))
391 | HQ_diff_content_patches.append(HQ_diff_content_patch.unsqueeze(0))
392 | #[self_patch_num, 3, patch_size, patch_size]
393 | LQ_patches = torch.cat(LQ_patches, 0)
394 | HQ_diff_content_patches = torch.cat(HQ_diff_content_patches, 0)
395 |
396 | return LQ_patches, _, HQ_diff_content_patches, target
397 |
398 | def __len__(self):
399 | length = len(self.samples)
400 | return length
401 |
402 | class BIDChallengeFolder(data.Dataset):
403 | def __init__(self, root, HQ_diff_content_root, index, transform, HQ_diff_content_transform, patch_num, patch_size =224, self_patch_num=10):
404 | self.patch_size =patch_size
405 | self.self_patch_num = self_patch_num
406 |
407 | LQ_pathes = []
408 | labels = []
409 |
410 | xls_file = os.path.join(root, 'DatabaseGrades.xlsx')
411 | workbook = load_workbook(xls_file)
412 | booksheet = workbook.active
413 | rows = booksheet.rows
414 | count = 1
415 | for _ in rows:
416 | count += 1
417 | img_num = (booksheet.cell(row=count, column=1).value)
418 | img_name = "DatabaseImage%04d.JPG" % (img_num)
419 | LQ_pathes.append(img_name)
420 | mos = (booksheet.cell(row=count, column=2).value)
421 | mos = np.array(mos)
422 | mos = mos.astype(np.float32)
423 | labels.append(mos)
424 | if count == 587:
425 | break
426 |
427 | sample = []
428 | for _, item in enumerate(index):
429 | for _ in range(patch_num):
430 | sample.append((os.path.join(root, LQ_pathes[item]), labels[item]))
431 |
432 | self.HQ_diff_content_paths = []
433 | for HQ_diff_content_img_path in os.listdir(HQ_diff_content_root):
434 | if HQ_diff_content_img_path[-3:] == 'png' or HQ_diff_content_img_path[-3:] == 'jpg' or HQ_diff_content_img_path[-3:] == 'bmp':
435 | self.HQ_diff_content_paths.append(os.path.join(HQ_diff_content_root, HQ_diff_content_img_path))
436 |
437 | self.samples = sample
438 | self.transform = transform
439 | self.HQ_diff_content_transform = HQ_diff_content_transform
440 |
441 | def __getitem__(self, index):
442 | """
443 | Args:
444 | index (int): Index
445 | Returns:
446 | tuple: (LQ, _, HQ_diff_content, target) where target is IQA values of the target LQ.
447 | """
448 | LQ_path, target = self.samples[index]
449 | HQ_diff_content_path = self.HQ_diff_content_paths[random.randint(0, len(self.HQ_diff_content_paths)-1)]
450 | LQ = pil_loader(LQ_path)
451 | HQ_diff_content = pil_loader(HQ_diff_content_path)
452 | LQ_patches, HQ_diff_content_patches = [], []
453 | for _ in range(self.self_patch_num):
454 | LQ_patch = self.HQ_diff_content_transform(LQ)
455 | HQ_diff_content_patch = self.HQ_diff_content_transform(HQ_diff_content)
456 |
457 | LQ_patches.append(LQ_patch.unsqueeze(0))
458 | HQ_diff_content_patches.append(HQ_diff_content_patch.unsqueeze(0))
459 | #[self_patch_num, 3, patch_size, patch_size]
460 | LQ_patches = torch.cat(LQ_patches, 0)
461 | HQ_diff_content_patches = torch.cat(HQ_diff_content_patches, 0)
462 |
463 | return LQ_patches, _, HQ_diff_content_patches, target
464 |
465 | def __len__(self):
466 | length = len(self.samples)
467 | return length
468 |
469 | class Koniq_10kFolder(data.Dataset):
470 | def __init__(self, root, HQ_diff_content_root, index, transform, HQ_diff_content_transform, patch_num, patch_size =224, self_patch_num=10):
471 | self.patch_size =patch_size
472 | self.self_patch_num = self_patch_num
473 |
474 | imgname = []
475 | mos_all = []
476 | csv_file = os.path.join(root, 'koniq10k_scores_and_distributions.csv')
477 | with open(csv_file) as f:
478 | reader = csv.DictReader(f)
479 | for row in reader:
480 | imgname.append(row['image_name'])
481 | mos = np.array(float(row['MOS_zscore'])).astype(np.float32)
482 | mos_all.append(mos)
483 |
484 | sample = []
485 | for _, item in enumerate(index):
486 | for _ in range(patch_num):
487 | sample.append((os.path.join(root, '1024x768', imgname[item]), mos_all[item]))
488 |
489 | self.HQ_diff_content_paths = []
490 | for HQ_diff_content_img_path in os.listdir(HQ_diff_content_root):
491 | if HQ_diff_content_img_path[-3:] == 'png' or HQ_diff_content_img_path[-3:] == 'jpg' or HQ_diff_content_img_path[-3:] == 'bmp':
492 | self.HQ_diff_content_paths.append(os.path.join(HQ_diff_content_root, HQ_diff_content_img_path))
493 |
494 |
495 | self.samples = sample
496 | self.transform = transform
497 | self.HQ_diff_content_transform = HQ_diff_content_transform
498 |
499 | def __getitem__(self, index):
500 | """
501 | Args:
502 | index (int): Index
503 | Returns:
504 | tuple: (LQ, _, HQ_diff_content, target) where target is IQA values of the target LQ.
505 | """
506 | LQ_path, target = self.samples[index]
507 | HQ_diff_content_path = self.HQ_diff_content_paths[random.randint(0, len(self.HQ_diff_content_paths)-1)]
508 | LQ = pil_loader(LQ_path)
509 | HQ_diff_content = pil_loader(HQ_diff_content_path)
510 | LQ_patches, HQ_diff_content_patches = [], []
511 | for _ in range(self.self_patch_num):
512 | LQ_patch = self.HQ_diff_content_transform(LQ)
513 | HQ_diff_content_patch = self.HQ_diff_content_transform(HQ_diff_content)
514 |
515 | LQ_patches.append(LQ_patch.unsqueeze(0))
516 | HQ_diff_content_patches.append(HQ_diff_content_patch.unsqueeze(0))
517 | #[self_patch_num, 3, patch_size, patch_size]
518 | LQ_patches = torch.cat(LQ_patches, 0)
519 | HQ_diff_content_patches = torch.cat(HQ_diff_content_patches, 0)
520 |
521 | return LQ_patches, _, HQ_diff_content_patches, target
522 |
523 | def __len__(self):
524 | length = len(self.samples)
525 | return length
526 |
527 | def getFileName(path, suffix):
528 | filename = []
529 | f_list = os.listdir(path)
530 | for i in f_list:
531 | if os.path.splitext(i)[1] == suffix:
532 | filename.append(i)
533 | return filename
534 |
535 |
536 | def getPairRandomPatch(img1, img2, crop_size=512):
537 | (iw,ih) = img1.size
538 | # print(ih,iw)
539 |
540 | ip = int(crop_size)
541 |
542 | ix = random.randrange(0, iw - ip + 1)
543 | iy = random.randrange(0, ih - ip + 1)
544 |
545 |
546 | img1_patch = img1.crop((ix, iy, ix+ip, iy+ip))#左上右下
547 | img2_patch = img2.crop((ix, iy, ix+ip, iy+ip))#左上右下
548 |
549 | return img1_patch, img2_patch
550 |
551 | def getPairAugment(img1, img2, hflip=True, vflip=False, rot=False):
552 | hflip = hflip and random.random() < 0.5
553 | vflip = vflip and random.random() < 0.5
554 | rot180 = rot and random.random() < 0.5
555 |
556 | if hflip:
557 | img1 = img1.transpose(Image.FLIP_TOP_BOTTOM)
558 | img2 = img2.transpose(Image.FLIP_TOP_BOTTOM)
559 | if vflip:
560 | img1 = img1.transpose(Image.FLIP_LEFT_RIGHT)
561 | img2 = img2.transpose(Image.FLIP_LEFT_RIGHT)
562 | if rot180:
563 | img1 = img1.transpose(Image.ROTATE_180)
564 | img2 = img2.transpose(Image.ROTATE_180)
565 |
566 | return img1, img2
567 |
568 |
569 | def getSelfPatch(img, patch_size, patch_num, is_random=True):
570 | (iw,ih) = img.size
571 | patches = []
572 | for i in range(patch_num):
573 | if is_random:
574 | ix = random.randrange(0, iw - patch_size + 1)
575 | iy = random.randrange(0, ih - patch_size + 1)
576 | else:ix,iy=(iw - patch_size + 1)//2,(ih - patch_size + 1)//2
577 |
578 | # patch = img[iy:iy + lr_size, ix:ix + lr_size, :]#上下左右
579 | patch = img.crop((ix, iy, ix+patch_size, iy+patch_size))#左上右下
580 | patches.append(patch)
581 |
582 | return patches
583 |
584 |
585 | def pil_loader(path):
586 | with open(path, 'rb') as f:
587 | img = Image.open(f)
588 | return img.convert('RGB')
589 |
590 |
--------------------------------------------------------------------------------