├── imgs ├── readme ├── KD.png ├── real_example.png ├── distillationIQA.png └── synthetic_example.png ├── model_zoo └── readme.txt ├── tools.py ├── models ├── CNNIQA.py ├── DCNN_NARIQA.py ├── LinearityIQA.py ├── WaDIQaM.py ├── TRIQ.py ├── HyperIQA.py ├── DistillationIQA.py └── IQT.py ├── LICENSE ├── dataloaders ├── dataloader_LQ_HQ.py ├── dataloader_LQ.py └── dataloader_LQ_HQ_diff_content_HQ.py ├── test_DistillationIQA_single.py ├── README.md ├── option_train_DistillationIQA_FR.py ├── option_train_DistillationIQA.py ├── test_DistillationIQA.py ├── train_DistillationIQA_FR.py ├── train_DistillationIQA.py └── folders ├── folders_LQ_HQ.py ├── folders_LQ.py └── folders_LQ_HQ_diff_content_HQ.py /imgs/readme: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /model_zoo/readme.txt: -------------------------------------------------------------------------------- 1 | Put trained model here! 2 | -------------------------------------------------------------------------------- /imgs/KD.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guanghaoyin/CVRKD-IQA/HEAD/imgs/KD.png -------------------------------------------------------------------------------- /imgs/real_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guanghaoyin/CVRKD-IQA/HEAD/imgs/real_example.png -------------------------------------------------------------------------------- /imgs/distillationIQA.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guanghaoyin/CVRKD-IQA/HEAD/imgs/distillationIQA.png -------------------------------------------------------------------------------- /imgs/synthetic_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guanghaoyin/CVRKD-IQA/HEAD/imgs/synthetic_example.png -------------------------------------------------------------------------------- /tools.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.optimize import curve_fit 3 | 4 | def convert_obj_score(ori_obj_score, MOS): 5 | """ 6 | func: 7 | fitting the objetive score to the MOS scale. 8 | nonlinear regression fit 9 | """ 10 | def logistic_fun(x, a, b, c, d): 11 | return (a - b)/(1 + np.exp(-(x - c)/abs(d))) + b 12 | # nolinear fit the MOSp 13 | param_init = [np.max(MOS), np.min(MOS), np.mean(ori_obj_score), 1] 14 | popt, pcov = curve_fit(logistic_fun, ori_obj_score, MOS, 15 | p0 = param_init, ftol =1e-8, maxfev=500000) 16 | #a, b, c, d = popt[0], popt[1], popt[2], popt[3] 17 | 18 | obj_fit_score = logistic_fun(ori_obj_score, popt[0], popt[1], popt[2], popt[3]) 19 | 20 | return obj_fit_score 21 | -------------------------------------------------------------------------------- /models/CNNIQA.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import torch.nn.functional as F 4 | import torch.nn as nn 5 | 6 | class CNNIQAnet(nn.Module): 7 | def __init__(self): 8 | super(CNNIQAnet, self).__init__() 9 | 10 | self.conv = nn.Conv2d(in_channels=1, out_channels=50, kernel_size=7) 11 | self.fc1 = nn.Linear(100, 800) 12 | self.fc2 = nn.Linear(800, 800) 13 | self.fc3 = nn.Linear(800, 1) 14 | 15 | def forward(self, input): 16 | x = input.view(-1, input.size(-3), input.size(-2), input.size(-1)) 17 | 18 | x = self.conv(x) 19 | 20 | x1 = F.max_pool2d(x, (x.size(-2), x.size(-1))) 21 | x2 = -F.max_pool2d(-x, (x.size(-2), x.size(-1))) 22 | 23 | h = torch.cat((x1, x2), 1) 24 | h = h.squeeze(3).squeeze(2) 25 | 26 | h = F.relu(self.fc1(h)) 27 | h = F.dropout(h) 28 | h = F.relu(self.fc2(h)) 29 | 30 | q = self.fc3(h) 31 | 32 | return q 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 guanghaoyin 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /dataloaders/dataloader_LQ_HQ.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import folders.folders_LQ_HQ as folders 3 | 4 | class DataLoader(object): 5 | """Dataset class for IQA databases""" 6 | 7 | def __init__(self, dataset, path, img_indx, patch_size, patch_num, batch_size=1, istrain=True, self_patch_num=10): 8 | 9 | self.batch_size = batch_size 10 | self.istrain = istrain 11 | 12 | if dataset == 'live': 13 | self.data = folders.LIVEFolder( 14 | root=path, index=img_indx, patch_num=patch_num, patch_size = patch_size, self_patch_num=self_patch_num) 15 | elif dataset == 'csiq': 16 | self.data = folders.CSIQFolder( 17 | root=path, index=img_indx, patch_num=patch_num, patch_size = patch_size, self_patch_num=self_patch_num) 18 | elif dataset == 'kadid10k': 19 | self.data = folders.Kadid10kFolder( 20 | root=path, index=img_indx, patch_num=patch_num, patch_size = patch_size, self_patch_num=self_patch_num) 21 | elif dataset == 'tid2013': 22 | self.data = folders.TID2013Folder( 23 | root=path, index=img_indx, patch_num=patch_num, patch_size = patch_size, self_patch_num=self_patch_num) 24 | 25 | def get_dataloader(self): 26 | if self.istrain: 27 | dataloader = torch.utils.data.DataLoader( 28 | self.data, batch_size=self.batch_size, shuffle=True) 29 | else: 30 | dataloader = torch.utils.data.DataLoader( 31 | self.data, batch_size=1, shuffle=False) 32 | return dataloader 33 | -------------------------------------------------------------------------------- /models/DCNN_NARIQA.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import torch.nn.functional as F 4 | import torch.nn as nn 5 | 6 | class DCNN_NARIQA(nn.Module): 7 | def __init__(self): 8 | super().__init__() 9 | #ref path 10 | self.block1_ref = nn.Sequential( 11 | nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=4), 12 | nn.ReLU(), 13 | nn.MaxPool2d(kernel_size=5, stride=1)) 14 | self.block2_ref = nn.Sequential( 15 | nn.Conv2d(in_channels=64, out_channels=32, kernel_size=9, stride=1), 16 | nn.ReLU()) 17 | self.fc3_ref = nn.Linear(in_features=59168, out_features=1024) 18 | 19 | #LQ path 20 | self.block1_lq = nn.Sequential( 21 | nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=4), 22 | nn.ReLU(), 23 | nn.MaxPool2d(kernel_size=5, stride=1)) 24 | self.block2_lq = nn.Sequential( 25 | nn.Conv2d(in_channels=64, out_channels=32, kernel_size=9, stride=1), 26 | nn.ReLU()) 27 | self.fc3_lq = nn.Linear(in_features=59168, out_features=1024) 28 | 29 | self.fc4 = nn.Linear(in_features=2048, out_features=1024) 30 | self.fc5 = nn.Linear(in_features=1024, out_features=1) 31 | 32 | def forward(self, lq_patches, ref_patches): 33 | feature_lq = self.block1_lq(lq_patches) 34 | feature_lq = self.block2_lq(feature_lq) 35 | feature_lq = self.fc3_lq(feature_lq.view(feature_lq.size(0), -1)) 36 | 37 | feature_ref = self.block1_ref(ref_patches) 38 | feature_ref = self.block2_ref(feature_ref) 39 | feature_ref = self.fc3_ref(feature_ref.view(feature_ref.size(0), -1)) 40 | 41 | concat_feature = torch.cat((feature_ref, feature_lq), 1) 42 | concat_feature = self.fc4(concat_feature) 43 | pred = self.fc5(concat_feature) 44 | return pred 45 | 46 | if __name__ == "__main__": 47 | x = torch.rand((1,3,224,224)) 48 | y = torch.rand((1,3,224,224)) 49 | net = DCNN_NARIQA() 50 | pred = net(x, y) 51 | -------------------------------------------------------------------------------- /test_DistillationIQA_single.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from option_train_DistillationIQA import set_args, check_args 4 | import numpy as np 5 | from models.DistillationIQA import DistillationIQANet 6 | from PIL import Image 7 | import torchvision 8 | 9 | img_num = { 10 | 'kadid10k': list(range(0,10125)), 11 | 'live': list(range(0, 29)),#ref HR image 12 | 'csiq': list(range(0, 30)),#ref HR image 13 | 'tid2013': list(range(0, 25)), 14 | 'livec': list(range(0, 1162)),# no-ref image 15 | 'koniq-10k': list(range(0, 10073)),# no-ref image 16 | 'bid': list(range(0, 586)),# no-ref image 17 | } 18 | folder_path = { 19 | 'pipal':'./dataset/PIPAL', 20 | 'live': './dataset/LIVE/', 21 | 'csiq': './dataset/CSIQ/', 22 | 'tid2013': './dataset/TID2013/', 23 | 'livec': './dataset/LIVEC/', 24 | 'koniq-10k': './dataset/koniq-10k/', 25 | 'bid': './dataset/BID/', 26 | 'kadid10k':'./dataset/kadid10k/' 27 | } 28 | 29 | 30 | class DistillationIQASolver(object): 31 | def __init__(self, config, lq_path, ref_path): 32 | self.config = config 33 | self.config.teacherNet_model_path = './model_zoo/FR_teacher_cross_dataset.pth' 34 | self.config.studentNet_model_path = './model_zoo/NAR_student_cross_dataset.pth' 35 | 36 | self.device = torch.device('cuda' if config.gpu_ids is not None else 'cpu') 37 | self.txt_log_path = os.path.join(config.log_checkpoint_dir,'log.txt') 38 | with open(self.txt_log_path,"w+") as f: 39 | f.close() 40 | 41 | #model 42 | self.teacherNet = DistillationIQANet(self_patch_num=config.self_patch_num, distillation_layer=config.distillation_layer) 43 | if config.teacherNet_model_path: 44 | self.teacherNet._load_state_dict(torch.load(config.teacherNet_model_path)) 45 | self.teacherNet = self.teacherNet.to(self.device) 46 | self.teacherNet.train(False) 47 | self.studentNet = DistillationIQANet(self_patch_num=config.self_patch_num, distillation_layer=config.distillation_layer) 48 | if config.studentNet_model_path: 49 | self.studentNet._load_state_dict(torch.load(config.studentNet_model_path)) 50 | self.studentNet = self.studentNet.to(self.device) 51 | self.studentNet.train(True) 52 | 53 | self.transform = torchvision.transforms.Compose([ 54 | torchvision.transforms.RandomCrop(size=self.config.patch_size), 55 | torchvision.transforms.ToTensor(), 56 | torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406), 57 | std=(0.229, 0.224, 0.225)) 58 | ]) 59 | #data 60 | self.LQ_patches = self.preprocess(lq_path) 61 | self.ref_patches = self.preprocess(ref_path) 62 | 63 | def preprocess(self, path): 64 | with open(path, 'rb') as f: 65 | img = Image.open(f) 66 | img= img.convert('RGB') 67 | patches = [] 68 | for _ in range(self.config.self_patch_num): 69 | patch = self.transform(img) 70 | patches.append(patch.unsqueeze(0)) 71 | patches = torch.cat(patches, 0) 72 | return patches.unsqueeze(0) 73 | 74 | def test(self): 75 | self.studentNet.train(False) 76 | LQ_patches, ref_patches = self.LQ_patches.to(self.device), self.ref_patches.to(self.device) 77 | with torch.no_grad(): 78 | _, _, pred = self.studentNet(LQ_patches, ref_patches) 79 | return float(pred.item()) 80 | 81 | if __name__ == "__main__": 82 | config = set_args() 83 | config = check_args(config) 84 | 85 | lq_path = './dataset/koniq-10k/1024x768/28311109.jpg' 86 | ref_path = './dataset/DIV2K_ref/val_HR/0801.png' 87 | label = 1.15686274509804 88 | solver = DistillationIQASolver(config=config, lq_path=lq_path, ref_path=ref_path) 89 | scores = [] 90 | for _ in range(10): 91 | scores.append(solver.test()) 92 | print(np.mean(scores)) 93 | # result 1.2577123641967773 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CVRKD-IQA(DistillationIQA) 2 | This repository is for CVRKD-IQA introduced in the following paper 3 | 4 | Guanghao Yin, Wei Wang, Zehuan Yuan, Chuchu Han, Wei Ji, Shouqian Sun and Changhu Wang, "Content-Variant Reference Image Quality Assessment via Knowledge Distillation", AAAI Oral, 2022 [arxiv](https://arxiv.org/abs/2202.13123) 5 | 6 | ## Introduction 7 | CVRKD-IQA is the first content-variant reference IQA method via knowledge distillation. The practicability of previous FR-IQAs is affected by the requirement for pixel-level aligned reference images. And NR-IQAs still have the potential to achieve better performance since HQ image information is not fully exploited. Hence, we use non-aligned reference (NAR) images to introduce various prior distributions of high-quality images. Moreover, the comparisons of distribution differences between HQ and LQ images can help our model better assess the image quality. Further, the knowledge distillation transfers more HQ-LQ distribution difference information from the FR-teacher to the NAR-student and stabilizing CVRKD-IQA performance. Since the content-variant and non-aligned reference HQ images are easy to obtain, our model can support more IQA applications with its relative robustness to content variations. 8 | 9 |
Distillation
10 | 11 | ## Prepare data 12 | ### Training datasets 13 | Download synthetic [Kaddid-10K](http://database.mmsp-kn.de/kadid-10k-database.html) dataset. And download the training HQ images of [DIV2K](https://data.vision.ee.ethz.ch/cvl/DIV2K/) as the reference dataset. 14 | 15 | ### Testing datasets 16 | Download synthetic [LIVE](http://live.ece.utexas.edu/index.php), [CSIQ](https://qualinet.github.io/databases/image/categorical_image_quality_csiq_database/) [TID2013](http://www.ponomarenko.info/tid2013.htm) and authentic [KonIQ-10K](http://database.mmsp-kn.de/koniq-10k-database.html) datasets. And download the testing HQ images of [DIV2K](https://data.vision.ee.ethz.ch/cvl/DIV2K/) as the reference dataset. 17 | 18 | Place those unzipped data in ./dataset file 19 | ## Train 20 | ### 1.Train the FR-teacher 21 | (1) (optional) Download models for our paper and place it in './model_zoo/' 22 | The models for FR-teacher can be downloaded from [Google Cloud](https://drive.google.com/file/d/1niFBV-ysJeVoaXUPQp-08ovrjS9tPhGW/view?usp=sharing) 23 | 24 | (2) Quick start (you can change the options in option_train_DistillationIQA_FR.py) 25 | ``` 26 | Python train_DistillationIQA_FR.py --self_patch_num 10 --patch_size 224 27 | ``` 28 | ### 2.Fix pretrained FR-teacher and train the NAR-student 29 | (1) (optional) Download models for our paper and place it in './model_zoo/' 30 | The models for FR-teacher and NAR-student can be downloaded from [Google Cloud](https://drive.google.com/file/d/107TI1pa0TDxs3V8tO2KhmhKJmfc9ZOl4/view?usp=sharing) 31 | 32 | (2) Quick start (you can change the options in option_train_DistillationIQA.py) 33 | ``` 34 | Python train_DistillationIQA.py --self_patch_num 10 --patch_size 224 35 | ``` 36 | 37 | ## Test 38 | (1) Make sure the trained models are placed in './model_zoo/FR_teacher_cross_dataset.pth' and './model_zoo/NAR_student_cross_dataset.pth' 39 | 40 | (2) Quick start (you can change the options in option_train_DistillationIQA.py) 41 | ``` 42 | Python test_DistillationIQA.py 43 | ``` 44 | (3) test single image 45 | ``` 46 | Python test_DistillationIQA_single.py 47 | ``` 48 | ## More visual results 49 | Synthetic examples of IQA scores predicted by our NAR-student when gradually increasing the distortion levels. 50 |
Distillation
51 | 52 | Real-data examples of IQA scores predicted by our NAR-student. 53 |
Distillation
54 | 55 | ## T-SNE visual visualization of HQ-LQ difference-aware features of NAR-student w/ and w/o KD 56 | After distilled with FR-teacher, the HQ-LQ features in Fig(b) are clusterd. It proves that the HQ-LQ distribution difference prior from the FR-teacher can indeed 57 | help the NAR-student extract quality-sensitive discriminative features for more accurate and consistent performance. 58 | 59 |
Distillation
60 | 61 | ## Citation 62 | 63 | If you find the code helpful in your resarch or work, please cite the following papers. 64 | 65 | ``` 66 | @article{yin2022content, 67 | title={Content-Variant Reference Image Quality Assessment via Knowledge Distillation}, 68 | author={Yin, Guanghao and Wang, Wei and Yuan, Zehuan and Han, Chuchu and Ji, Wei and Sun, Shouqian and Wang, Changhu}, 69 | journal={arXiv preprint arXiv:2202.13123}, 70 | year={2022} 71 | } 72 | ``` 73 | ## Acknowledgements 74 | Part of our code is built on [HyperIQA](https://github.com/SSL92/hyperIQA). We thank the authors for sharing their codes. Also thanks for the support of [Bytedance.Inc](https://github.com/bytedance) 75 | -------------------------------------------------------------------------------- /models/LinearityIQA.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchvision import models 5 | import numpy as np 6 | 7 | 8 | def SPSP(x, P=1, method='avg'): 9 | batch_size = x.size(0) 10 | map_size = x.size()[-2:] 11 | pool_features = [] 12 | for p in range(1, P+1): 13 | pool_size = [np.int(d / p) for d in map_size] 14 | if method == 'maxmin': 15 | M = F.max_pool2d(x, pool_size) 16 | m = -F.max_pool2d(-x, pool_size) 17 | pool_features.append(torch.cat((M, m), 1).view(batch_size, -1)) # max & min pooling 18 | elif method == 'max': 19 | M = F.max_pool2d(x, pool_size) 20 | pool_features.append(M.view(batch_size, -1)) # max pooling 21 | elif method == 'min': 22 | m = -F.max_pool2d(-x, pool_size) 23 | pool_features.append(m.view(batch_size, -1)) # min pooling 24 | elif method == 'avg': 25 | a = F.avg_pool2d(x, pool_size) 26 | pool_features.append(a.view(batch_size, -1)) # average pooling 27 | else: 28 | m1 = F.avg_pool2d(x, pool_size) 29 | rm2 = torch.sqrt(F.relu(F.avg_pool2d(torch.pow(x, 2), pool_size) - torch.pow(m1, 2))) 30 | if method == 'std': 31 | pool_features.append(rm2.view(batch_size, -1)) # std pooling 32 | else: 33 | pool_features.append(torch.cat((m1, rm2), 1).view(batch_size, -1)) # statistical pooling: mean & std 34 | return torch.cat(pool_features, dim=1) 35 | 36 | 37 | class LinearityIQA(nn.Module): 38 | def __init__(self, arch='resnext101_32x8d', pool='avg', use_bn_end=False, P6=1, P7=1): 39 | super(LinearityIQA, self).__init__() 40 | self.pool = pool 41 | self.use_bn_end = use_bn_end 42 | if pool in ['max', 'min', 'avg', 'std']: 43 | c = 1 44 | else: 45 | c = 2 46 | self.P6 = P6 # 47 | self.P7 = P7 # 48 | features = list(models.__dict__[arch](pretrained=True).children())[:-2] 49 | if arch == 'alexnet': 50 | in_features = [256, 256] 51 | self.id1 = 9 52 | self.id2 = 12 53 | features = features[0] 54 | elif arch == 'vgg16': 55 | in_features = [512, 512] 56 | self.id1 = 23 57 | self.id2 = 30 58 | features = features[0] 59 | elif 'res' in arch: 60 | self.id1 = 6 61 | self.id2 = 7 62 | if arch == 'resnet18' or arch == 'resnet34': 63 | in_features = [256, 512] 64 | else: 65 | in_features = [1024, 2048] 66 | else: 67 | print('The arch is not implemented!') 68 | self.features = nn.Sequential(*features) 69 | self.dr6 = nn.Sequential(nn.Linear(in_features[0] * c * sum([p * p for p in range(1, self.P6+1)]), 1024), 70 | nn.BatchNorm1d(1024), 71 | nn.Linear(1024, 256), 72 | nn.BatchNorm1d(256), 73 | nn.Linear(256, 64), 74 | nn.BatchNorm1d(64), nn.ReLU()) 75 | self.dr7 = nn.Sequential(nn.Linear(in_features[1] * c * sum([p * p for p in range(1, self.P7+1)]), 1024), 76 | nn.BatchNorm1d(1024), 77 | nn.Linear(1024, 256), 78 | nn.BatchNorm1d(256), 79 | nn.Linear(256, 64), 80 | nn.BatchNorm1d(64), nn.ReLU()) 81 | 82 | if self.use_bn_end: 83 | self.regr6 = nn.Sequential(nn.Linear(64, 1), nn.BatchNorm1d(1)) 84 | self.regr7 = nn.Sequential(nn.Linear(64, 1), nn.BatchNorm1d(1)) 85 | self.regression = nn.Sequential(nn.Linear(64 * 2, 1), nn.BatchNorm1d(1)) 86 | else: 87 | self.regr6 = nn.Linear(64, 1) 88 | self.regr7 = nn.Linear(64, 1) 89 | self.regression = nn.Linear(64 * 2, 1) 90 | 91 | def extract_features(self, x): 92 | f, pq = [], [] 93 | for ii, model in enumerate(self.features): 94 | x = model(x) 95 | if ii == self.id1: 96 | x6 = SPSP(x, P=self.P6, method=self.pool) 97 | x6 = self.dr6(x6) 98 | f.append(x6) 99 | pq.append(self.regr6(x6)) 100 | if ii == self.id2: 101 | x7 = SPSP(x, P=self.P7, method=self.pool) 102 | x7 = self.dr7(x7) 103 | f.append(x7) 104 | pq.append(self.regr7(x7)) 105 | 106 | f = torch.cat(f, dim=1) 107 | 108 | return f, pq 109 | 110 | def forward(self, x): 111 | f, pq = self.extract_features(x) 112 | s = self.regression(f) 113 | pq.append(s) 114 | 115 | return pq, s 116 | 117 | if __name__ == "__main__": 118 | x = torch.rand((1,3,224,224)) 119 | net = LinearityIQA() 120 | net.train(False) 121 | # print(net.dr6) 122 | # print(net.dr7) 123 | y, pred = net(x) 124 | print(pred) 125 | -------------------------------------------------------------------------------- /dataloaders/dataloader_LQ.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import folders.folders_LQ as folders 4 | 5 | class DataLoader(object): 6 | """Dataset class for IQA databases""" 7 | 8 | def __init__(self, dataset, path, img_indx, patch_size, patch_num, batch_size=1, istrain=True, self_patch_num=1): 9 | 10 | self.batch_size = batch_size 11 | self.istrain = istrain 12 | 13 | if (dataset == 'live') | (dataset == 'csiq') | (dataset == 'tid2013') | (dataset == 'livec') | (dataset == 'kadid10k'): 14 | # Train transforms 15 | if istrain: 16 | transforms = torchvision.transforms.Compose([ 17 | torchvision.transforms.RandomCrop(size=patch_size), 18 | torchvision.transforms.RandomHorizontalFlip(), 19 | torchvision.transforms.RandomVerticalFlip(), 20 | torchvision.transforms.RandomRotation(degrees=180), 21 | torchvision.transforms.ToTensor(), 22 | torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406), 23 | std=(0.229, 0.224, 0.225)) 24 | ]) 25 | # Test transforms 26 | else: 27 | transforms = torchvision.transforms.Compose([ 28 | torchvision.transforms.RandomCrop(size=patch_size), 29 | torchvision.transforms.ToTensor(), 30 | torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406), 31 | std=(0.229, 0.224, 0.225)) 32 | ]) 33 | elif dataset == 'koniq-10k': 34 | if istrain: 35 | transforms = torchvision.transforms.Compose([ 36 | torchvision.transforms.RandomCrop(size=patch_size), 37 | torchvision.transforms.RandomHorizontalFlip(), 38 | torchvision.transforms.RandomVerticalFlip(), 39 | torchvision.transforms.RandomRotation(degrees=180), 40 | torchvision.transforms.ToTensor(), 41 | torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406), 42 | std=(0.229, 0.224, 0.225))]) 43 | else: 44 | transforms = torchvision.transforms.Compose([ 45 | torchvision.transforms.RandomCrop(size=patch_size), 46 | torchvision.transforms.ToTensor(), 47 | torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406), 48 | std=(0.229, 0.224, 0.225))]) 49 | elif dataset == 'bid': 50 | if istrain: 51 | transforms = torchvision.transforms.Compose([ 52 | torchvision.transforms.Resize((512, 512)), 53 | torchvision.transforms.RandomCrop(size=patch_size), 54 | torchvision.transforms.RandomHorizontalFlip(), 55 | torchvision.transforms.RandomVerticalFlip(), 56 | torchvision.transforms.RandomRotation(degrees=180), 57 | torchvision.transforms.ToTensor(), 58 | torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406), 59 | std=(0.229, 0.224, 0.225))]) 60 | else: 61 | transforms = torchvision.transforms.Compose([ 62 | torchvision.transforms.Resize((512, 512)), 63 | torchvision.transforms.RandomCrop(size=patch_size), 64 | torchvision.transforms.ToTensor(), 65 | torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406), 66 | std=(0.229, 0.224, 0.225))]) 67 | else: 68 | transforms = torchvision.transforms.Compose([ 69 | torchvision.transforms.ToTensor(), 70 | torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406), 71 | std=(0.229, 0.224, 0.225))]) 72 | 73 | if dataset == 'live': 74 | self.data = folders.LIVEFolder( 75 | root=path, index=img_indx, transform=transforms, patch_num=patch_num, patch_size = patch_size, self_patch_num=self_patch_num) 76 | elif dataset == 'csiq': 77 | self.data = folders.CSIQFolder( 78 | root=path, index=img_indx, transform=transforms, patch_num=patch_num, patch_size = patch_size, self_patch_num=self_patch_num) 79 | elif dataset == 'kadid10k': 80 | self.data = folders.Kadid10kFolder( 81 | root=path, index=img_indx, transform=transforms, patch_num=patch_num, patch_size = patch_size, self_patch_num=self_patch_num) 82 | elif dataset == 'tid2013': 83 | self.data = folders.TID2013Folder( 84 | root=path, index=img_indx, transform=transforms, patch_num=patch_num, patch_size = patch_size, self_patch_num=self_patch_num) 85 | elif dataset == 'koniq-10k': 86 | self.data = folders.Koniq_10kFolder( 87 | root=path, index=img_indx, transform=transforms, patch_num=patch_num, patch_size = patch_size, self_patch_num=self_patch_num) 88 | elif dataset == 'livec': 89 | self.data = folders.LIVEChallengeFolder( 90 | root=path, index=img_indx, transform=transforms, patch_num=patch_num, patch_size = patch_size, self_patch_num=self_patch_num) 91 | 92 | def get_dataloader(self): 93 | if self.istrain: 94 | dataloader = torch.utils.data.DataLoader( 95 | self.data, batch_size=self.batch_size, shuffle=True) 96 | else: 97 | dataloader = torch.utils.data.DataLoader( 98 | self.data, batch_size=1, shuffle=False) 99 | return dataloader 100 | -------------------------------------------------------------------------------- /option_train_DistillationIQA_FR.py: -------------------------------------------------------------------------------- 1 | # import template 2 | import argparse 3 | import os 4 | 5 | """ 6 | Configuration file 7 | """ 8 | def check_args(args, rank=0): 9 | if rank == 0: 10 | with open(args.setting_file, 'w') as opt_file: 11 | opt_file.write('------------ Options -------------\n') 12 | print('------------ Options -------------') 13 | for k in args.__dict__: 14 | v = args.__dict__[k] 15 | opt_file.write('%s: %s\n' % (str(k), str(v))) 16 | print('%s: %s' % (str(k), str(v))) 17 | opt_file.write('-------------- End ----------------\n') 18 | print('------------ End -------------') 19 | 20 | return args 21 | 22 | def str2bool(v): 23 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 24 | return True 25 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 26 | return False 27 | else: 28 | raise argparse.ArgumentTypeError('Unsupported value encountered.') 29 | 30 | def set_args(): 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument('--gpu_ids', type=str, default='0') 33 | parser.add_argument('--test_dataset', type=str, default='live', help='Support datasets: pipal|livec|koniq-10k|bid|live|csiq|tid2013|kadid10k') 34 | parser.add_argument('--train_dataset', type=str, default='kadid10k', help='Support datasets: pipal|livec|koniq-10k|bid|live|csiq|tid2013|kadid10k') 35 | parser.add_argument('--train_patch_num', type=int, default=1, help='Number of sample patches from training image') 36 | parser.add_argument('--test_patch_num', type=int, default=1, help='Number of sample patches from testing image') 37 | parser.add_argument('--lr', dest='lr', type=float, default=2e-5, help='Learning rate') 38 | parser.add_argument('--weight_decay', type=float, default=5e-4, help='Weight decay') 39 | parser.add_argument('--batch_size', type=int, default=32, help='Batch size') 40 | parser.add_argument('--epochs', type=int, default=100, help='Epochs for training') 41 | parser.add_argument('--patch_size', type=int, default=224, help='Crop size for training & testing image patches') 42 | parser.add_argument('--self_patch_num', type=int, default=10, help='number of training & testing image self patches') 43 | parser.add_argument('--train_test_num', type=int, default=1, help='Train-test times') 44 | parser.add_argument('--update_opt_epoch', type=int, default=300) 45 | 46 | #Ref Img 47 | parser.add_argument('--use_refHQ', type=str2bool, default=True) 48 | parser.add_argument('--ref_folder_paths', type=str, default='./dataset/DIV2K_ref.tar.gz') 49 | parser.add_argument('--distillation_layer', type=int, default=18, help='last xth layers of HQ-MLP for distillation') 50 | 51 | parser.add_argument('--net_print', type=int, default=2000) 52 | parser.add_argument('--setting_file', type=str, default='setting.txt') 53 | parser.add_argument('--checkpoint_dir', type=str, default='./checkpoint_FRIQA_teacher/') 54 | 55 | parser.add_argument('--use_fitting_prcc_srcc', type=str2bool, default=True) 56 | parser.add_argument('--print_netC', type=str2bool, default=False) 57 | 58 | parser.add_argument('--teacherNet_model_path', type=str, default=None, help='./model_zoo/FR_teacher_cross_dataset.pth') 59 | 60 | args = parser.parse_args() 61 | #Dataset 62 | args.setting_file = os.path.join(args.checkpoint_dir, args.setting_file) 63 | if not os.path.exists('./dataset/'): 64 | os.mkdir('./dataset/') 65 | 66 | folder_path = { 67 | 'live': './dataset/LIVE/', 68 | 'csiq': './dataset/CSIQ/', 69 | 'tid2013': './dataset/TID2013/', 70 | 'koniq-10k': './dataset/koniq-10k/', 71 | } 72 | ref_dataset_path = './dataset/DIV2K_ref/' 73 | args.ref_train_dataset_path = ref_dataset_path + 'train_HR/' 74 | args.ref_test_dataset_path = ref_dataset_path + 'val_HR/' 75 | 76 | #checkpoint files 77 | args.model_checkpoint_dir = args.checkpoint_dir + 'models/' 78 | args.result_checkpoint_dir = args.checkpoint_dir + 'results/' 79 | args.log_checkpoint_dir = args.checkpoint_dir + 'log/' 80 | 81 | if os.path.exists(args.checkpoint_dir) and os.path.isfile(args.checkpoint_dir): 82 | raise IOError('Required dst path {} as a directory for checkpoint saving, got a file'.format( 83 | args.checkpoint_dir)) 84 | elif not os.path.exists(args.checkpoint_dir): 85 | os.makedirs(args.checkpoint_dir) 86 | print('%s created successfully!'%args.checkpoint_dir) 87 | 88 | if os.path.exists(args.model_checkpoint_dir) and os.path.isfile(args.model_checkpoint_dir): 89 | raise IOError('Required dst path {} as a directory for checkpoint model saving, got a file'.format( 90 | args.model_checkpoint_dir)) 91 | elif not os.path.exists(args.model_checkpoint_dir): 92 | os.makedirs(args.model_checkpoint_dir) 93 | print('%s created successfully!'%args.model_checkpoint_dir) 94 | 95 | if os.path.exists(args.result_checkpoint_dir) and os.path.isfile(args.result_checkpoint_dir): 96 | raise IOError('Required dst path {} as a directory for checkpoint results saving, got a file'.format( 97 | args.result_checkpoint_dir)) 98 | elif not os.path.exists(args.result_checkpoint_dir): 99 | os.makedirs(args.result_checkpoint_dir) 100 | print('%s created successfully!'%args.result_checkpoint_dir) 101 | 102 | if os.path.exists(args.log_checkpoint_dir) and os.path.isfile(args.log_checkpoint_dir): 103 | raise IOError('Required dst path {} as a directory for checkpoint log saving, got a file'.format( 104 | args.log_checkpoint_dir)) 105 | elif not os.path.exists(args.log_checkpoint_dir): 106 | os.makedirs(args.log_checkpoint_dir) 107 | print('%s created successfully!'%args.log_checkpoint_dir) 108 | 109 | return args 110 | 111 | if __name__ == "__main__": 112 | args = set_args() 113 | 114 | -------------------------------------------------------------------------------- /option_train_DistillationIQA.py: -------------------------------------------------------------------------------- 1 | # import template 2 | import argparse 3 | import os 4 | """ 5 | Configuration file 6 | """ 7 | def check_args(args, rank=0): 8 | if rank == 0: 9 | with open(args.setting_file, 'w') as opt_file: 10 | opt_file.write('------------ Options -------------\n') 11 | print('------------ Options -------------') 12 | for k in args.__dict__: 13 | v = args.__dict__[k] 14 | opt_file.write('%s: %s\n' % (str(k), str(v))) 15 | print('%s: %s' % (str(k), str(v))) 16 | opt_file.write('-------------- End ----------------\n') 17 | print('------------ End -------------') 18 | 19 | return args 20 | 21 | def str2bool(v): 22 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 23 | return True 24 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 25 | return False 26 | else: 27 | raise argparse.ArgumentTypeError('Unsupported value encountered.') 28 | 29 | def set_args(): 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument('--gpu_ids', type=str, default='0') 32 | parser.add_argument('--test_dataset', type=str, default='live', help='Support datasets: pipal|livec|koniq-10k|bid|live|csiq|tid2013|kadid10k') 33 | parser.add_argument('--train_dataset', type=str, default='kadid10k', help='Support datasets: pipal|livec|koniq-10k|bid|live|csiq|tid2013|kadid10k') 34 | parser.add_argument('--train_patch_num', type=int, default=1, help='Number of sample patches from training image') 35 | parser.add_argument('--test_patch_num', type=int, default=1, help='Number of sample patches from testing image') 36 | parser.add_argument('--lr', dest='lr', type=float, default=2e-5, help='Learning rate') 37 | parser.add_argument('--weight_decay', type=float, default=5e-4, help='Weight decay') 38 | parser.add_argument('--batch_size', type=int, default=32, help='Batch size') 39 | parser.add_argument('--epochs', type=int, default=100, help='Epochs for training') 40 | parser.add_argument('--patch_size', type=int, default=224, help='Crop size for training & testing image patches') 41 | parser.add_argument('--self_patch_num', type=int, default=10, help='number of training & testing image self patches') 42 | parser.add_argument('--train_test_num', type=int, default=1, help='Train-test times') 43 | parser.add_argument('--update_opt_epoch', type=int, default=30) 44 | 45 | #Ref Img 46 | parser.add_argument('--use_refHQ', type=str2bool, default=True) 47 | parser.add_argument('--distillation_layer', type=int, default=18, help='last xth layers of HQ-MLP for distillation') 48 | 49 | parser.add_argument('--net_print', type=int, default=2000) 50 | parser.add_argument('--setting_file', type=str, default='setting.txt') 51 | parser.add_argument('--checkpoint_dir', type=str, default='./checkpoint_DistillationIQA/') 52 | 53 | parser.add_argument('--use_fitting_prcc_srcc', type=str2bool, default=True) 54 | parser.add_argument('--print_netC', type=str2bool, default=False) 55 | 56 | parser.add_argument('--teacherNet_model_path', type=str, default='./model_zoo/FR_teacher_cross_dataset.pth') 57 | parser.add_argument('--studentNet_model_path', type=str, default=None, help='./model_zoo/NAR_student_cross_dataset.pth') 58 | 59 | #distillation 60 | parser.add_argument('--distillation_loss', type=str, default='l1', help='mse|l1|kldiv') 61 | 62 | args = parser.parse_args() 63 | #Dataset 64 | args.setting_file = os.path.join(args.checkpoint_dir, args.setting_file) 65 | if not os.path.exists('./dataset/'): 66 | os.mkdir('./dataset/') 67 | 68 | folder_path = { 69 | 'live': './dataset/LIVE/', 70 | 'csiq': './dataset/CSIQ/', 71 | 'tid2013': './dataset/TID2013/', 72 | 'koniq-10k': './dataset/koniq-10k/', 73 | } 74 | 75 | ref_dataset_path = './dataset/DIV2K_ref/' 76 | args.ref_train_dataset_path = ref_dataset_path + 'train_HR/' 77 | args.ref_test_dataset_path = ref_dataset_path + 'val_HR/' 78 | 79 | #checkpoint files 80 | args.model_checkpoint_dir = args.checkpoint_dir + 'models/' 81 | args.result_checkpoint_dir = args.checkpoint_dir + 'results/' 82 | args.log_checkpoint_dir = args.checkpoint_dir + 'log/' 83 | 84 | if os.path.exists(args.checkpoint_dir) and os.path.isfile(args.checkpoint_dir): 85 | raise IOError('Required dst path {} as a directory for checkpoint saving, got a file'.format( 86 | args.checkpoint_dir)) 87 | elif not os.path.exists(args.checkpoint_dir): 88 | os.makedirs(args.checkpoint_dir) 89 | print('%s created successfully!'%args.checkpoint_dir) 90 | 91 | if os.path.exists(args.model_checkpoint_dir) and os.path.isfile(args.model_checkpoint_dir): 92 | raise IOError('Required dst path {} as a directory for checkpoint model saving, got a file'.format( 93 | args.model_checkpoint_dir)) 94 | elif not os.path.exists(args.model_checkpoint_dir): 95 | os.makedirs(args.model_checkpoint_dir) 96 | print('%s created successfully!'%args.model_checkpoint_dir) 97 | 98 | if os.path.exists(args.result_checkpoint_dir) and os.path.isfile(args.result_checkpoint_dir): 99 | raise IOError('Required dst path {} as a directory for checkpoint results saving, got a file'.format( 100 | args.result_checkpoint_dir)) 101 | elif not os.path.exists(args.result_checkpoint_dir): 102 | os.makedirs(args.result_checkpoint_dir) 103 | print('%s created successfully!'%args.result_checkpoint_dir) 104 | 105 | if os.path.exists(args.log_checkpoint_dir) and os.path.isfile(args.log_checkpoint_dir): 106 | raise IOError('Required dst path {} as a directory for checkpoint log saving, got a file'.format( 107 | args.log_checkpoint_dir)) 108 | elif not os.path.exists(args.log_checkpoint_dir): 109 | os.makedirs(args.log_checkpoint_dir) 110 | print('%s created successfully!'%args.log_checkpoint_dir) 111 | 112 | return args 113 | 114 | if __name__ == "__main__": 115 | args = set_args() 116 | 117 | -------------------------------------------------------------------------------- /models/WaDIQaM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | 5 | class WaDIQaM_FR(nn.Module): 6 | """ 7 | (Wa)DIQaM-FR Model 8 | """ 9 | def __init__(self, weighted_average=True): 10 | """ 11 | :param weighted_average: weighted average or not? 12 | """ 13 | super(WaDIQaM_FR, self).__init__() 14 | self.conv1 = nn.Conv2d(3, 32, 3, padding=1) 15 | self.conv2 = nn.Conv2d(32, 32, 3, padding=1) 16 | self.conv3 = nn.Conv2d(32, 64, 3, padding=1) 17 | self.conv4 = nn.Conv2d(64, 64, 3, padding=1) 18 | self.conv5 = nn.Conv2d(64, 128, 3, padding=1) 19 | self.conv6 = nn.Conv2d(128, 128, 3, padding=1) 20 | self.conv7 = nn.Conv2d(128, 256, 3, padding=1) 21 | self.conv8 = nn.Conv2d(256, 256, 3, padding=1) 22 | self.conv9 = nn.Conv2d(256, 512, 3, padding=1) 23 | self.conv10 = nn.Conv2d(512, 512, 3, padding=1) 24 | self.fc1_q = nn.Linear(512*3, 512) 25 | self.fc2_q = nn.Linear(512, 1) 26 | self.fc1_w = nn.Linear(512*3, 512) 27 | self.fc2_w = nn.Linear(512, 1) 28 | self.dropout = nn.Dropout() 29 | self.weighted_average = weighted_average 30 | 31 | def extract_features(self, x): 32 | """ 33 | feature extraction 34 | :param x: the input image 35 | :return: the output feature 36 | """ 37 | h = F.relu(self.conv1(x)) 38 | h = F.relu(self.conv2(h)) 39 | h = F.max_pool2d(h, 2) 40 | 41 | h = F.relu(self.conv3(h)) 42 | h = F.relu(self.conv4(h)) 43 | h = F.max_pool2d(h, 2) 44 | 45 | h = F.relu(self.conv5(h)) 46 | h = F.relu(self.conv6(h)) 47 | h = F.max_pool2d(h, 2) 48 | 49 | h = F.relu(self.conv7(h)) 50 | h = F.relu(self.conv8(h)) 51 | h = F.max_pool2d(h, 2) 52 | 53 | h = F.relu(self.conv9(h)) 54 | h = F.relu(self.conv10(h)) 55 | h = F.max_pool2d(h, 2) 56 | 57 | h = h.view(-1, 512) 58 | 59 | return h 60 | 61 | def forward(self, x, x_ref): 62 | """ 63 | :param x: distorted patches of images 64 | :param x_ref: reference patches of images 65 | :return: quality of images/patches 66 | """ 67 | batch_size = x.size(0) 68 | n_patches = x.size(1) 69 | if self.weighted_average: 70 | q = torch.ones((batch_size, 1), device=x.device) 71 | else: 72 | q = torch.ones((batch_size * n_patches, 1), device=x.device) 73 | 74 | for i in range(batch_size): 75 | 76 | h = self.extract_features(x[i]) 77 | h_ref = self.extract_features(x_ref[i]) 78 | h = torch.cat((h - h_ref, h, h_ref), 1) 79 | 80 | h_ = h # save intermediate features 81 | 82 | h = F.relu(self.fc1_q(h_)) 83 | h = self.dropout(h) 84 | h = self.fc2_q(h) 85 | 86 | if self.weighted_average: 87 | w = F.relu(self.fc1_w(h_)) 88 | w = self.dropout(w) 89 | w = F.relu(self.fc2_w(w)) + 0.000001 # small constant 90 | q[i] = torch.sum(h * w) / torch.sum(w) 91 | else: 92 | q[i*n_patches:(i+1)*n_patches] = h 93 | 94 | return q 95 | 96 | class WaDIQaM_NR(nn.Module): 97 | """ 98 | (Wa)DIQaM-NR-NR Model 99 | """ 100 | def __init__(self, weighted_average=True): 101 | """ 102 | :param weighted_average: weighted average or not? 103 | """ 104 | super(WaDIQaM_NR, self).__init__() 105 | self.conv1 = nn.Conv2d(3, 32, 3, padding=1) 106 | self.conv2 = nn.Conv2d(32, 32, 3, padding=1) 107 | self.conv3 = nn.Conv2d(32, 64, 3, padding=1) 108 | self.conv4 = nn.Conv2d(64, 64, 3, padding=1) 109 | self.conv5 = nn.Conv2d(64, 128, 3, padding=1) 110 | self.conv6 = nn.Conv2d(128, 128, 3, padding=1) 111 | self.conv7 = nn.Conv2d(128, 256, 3, padding=1) 112 | self.conv8 = nn.Conv2d(256, 256, 3, padding=1) 113 | self.conv9 = nn.Conv2d(256, 512, 3, padding=1) 114 | self.conv10 = nn.Conv2d(512, 512, 3, padding=1) 115 | self.fc1q_nr = nn.Linear(512, 512) 116 | self.fc2q_nr = nn.Linear(512, 1) 117 | self.fc1w_nr = nn.Linear(512, 512) 118 | self.fc2w_nr = nn.Linear(512, 1) 119 | self.dropout = nn.Dropout() 120 | self.weighted_average = weighted_average 121 | 122 | def extract_features(self, x): 123 | """ 124 | feature extraction 125 | :param x: the input image 126 | :return: the output feature 127 | """ 128 | h = F.relu(self.conv1(x)) 129 | h = F.relu(self.conv2(h)) 130 | h = F.max_pool2d(h, 2) 131 | 132 | h = F.relu(self.conv3(h)) 133 | h = F.relu(self.conv4(h)) 134 | h = F.max_pool2d(h, 2) 135 | 136 | h = F.relu(self.conv5(h)) 137 | h = F.relu(self.conv6(h)) 138 | h = F.max_pool2d(h, 2) 139 | 140 | h = F.relu(self.conv7(h)) 141 | h = F.relu(self.conv8(h)) 142 | h = F.max_pool2d(h, 2) 143 | 144 | h = F.relu(self.conv9(h)) 145 | h = F.relu(self.conv10(h)) 146 | h = F.max_pool2d(h, 2) 147 | 148 | h = h.view(-1,512) 149 | 150 | return h 151 | 152 | def forward(self, x): 153 | """ 154 | :param x: distorted patches of images 155 | :return: quality of images/patches 156 | """ 157 | batch_size = x.size(0) 158 | n_patches = x.size(1) 159 | if self.weighted_average: 160 | q = torch.ones((batch_size, 1), device=x.device) 161 | else: 162 | q = torch.ones((batch_size * n_patches, 1), device=x.device) 163 | 164 | for i in range(batch_size): 165 | 166 | h = self.extract_features(x[i]) 167 | 168 | h_ = h # save intermediate features 169 | 170 | h = F.relu(self.fc1q_nr(h_)) 171 | h = self.dropout(h) 172 | h = self.fc2q_nr(h) 173 | 174 | if self.weighted_average: 175 | w = F.relu(self.fc1w_nr(h_)) 176 | w = self.dropout(w) 177 | w = F.relu(self.fc2w_nr(w)) + 0.000001 # small constant 178 | q[i] = torch.sum(h * w) / torch.sum(w) 179 | else: 180 | q[i * n_patches:(i + 1) * n_patches] = h 181 | 182 | return q 183 | -------------------------------------------------------------------------------- /dataloaders/dataloader_LQ_HQ_diff_content_HQ.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import folders.folders_LQ_HQ_diff_content_HQ as folders 4 | 5 | class DataLoader(object): 6 | """Dataset class for IQA databases""" 7 | 8 | def __init__(self, dataset, path, ref_path, img_indx, patch_size, patch_num, batch_size=1, istrain=True, self_patch_num=10, use_HQref = True): 9 | 10 | self.batch_size = batch_size 11 | self.istrain = istrain 12 | 13 | if (dataset == 'live') | (dataset == 'csiq') | (dataset == 'tid2013') | (dataset == 'livec') | (dataset == 'kadid10k'): 14 | # Train transforms 15 | if istrain: 16 | HQ_diff_content_transform = torchvision.transforms.Compose([ 17 | torchvision.transforms.RandomCrop(size=patch_size), 18 | torchvision.transforms.RandomHorizontalFlip(), 19 | torchvision.transforms.RandomVerticalFlip(), 20 | torchvision.transforms.RandomRotation(degrees=180), 21 | torchvision.transforms.ToTensor(), 22 | torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406), 23 | std=(0.229, 0.224, 0.225)) 24 | ]) 25 | # Test transforms 26 | else: 27 | HQ_diff_content_transform = torchvision.transforms.Compose([ 28 | torchvision.transforms.RandomCrop(size=patch_size), 29 | torchvision.transforms.ToTensor(), 30 | torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406), 31 | std=(0.229, 0.224, 0.225)) 32 | ]) 33 | elif dataset == 'koniq-10k': 34 | if istrain: 35 | HQ_diff_content_transform = torchvision.transforms.Compose([ 36 | torchvision.transforms.RandomCrop(size=patch_size), 37 | torchvision.transforms.RandomHorizontalFlip(), 38 | torchvision.transforms.RandomVerticalFlip(), 39 | torchvision.transforms.RandomRotation(degrees=180), 40 | torchvision.transforms.ToTensor(), 41 | torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406), 42 | std=(0.229, 0.224, 0.225))]) 43 | else: 44 | HQ_diff_content_transform = torchvision.transforms.Compose([ 45 | torchvision.transforms.RandomCrop(size=patch_size), 46 | torchvision.transforms.ToTensor(), 47 | torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406), 48 | std=(0.229, 0.224, 0.225))]) 49 | elif dataset == 'bid': 50 | if istrain: 51 | HQ_diff_content_transform = torchvision.transforms.Compose([ 52 | torchvision.transforms.Resize((512, 512)), 53 | torchvision.transforms.RandomCrop(size=patch_size), 54 | torchvision.transforms.RandomHorizontalFlip(), 55 | torchvision.transforms.RandomVerticalFlip(), 56 | torchvision.transforms.RandomRotation(degrees=180), 57 | torchvision.transforms.ToTensor(), 58 | torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406), 59 | std=(0.229, 0.224, 0.225))]) 60 | else: 61 | HQ_diff_content_transform = torchvision.transforms.Compose([ 62 | torchvision.transforms.Resize((512, 512)), 63 | torchvision.transforms.RandomCrop(size=patch_size), 64 | torchvision.transforms.ToTensor(), 65 | torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406), 66 | std=(0.229, 0.224, 0.225))]) 67 | else: 68 | HQ_diff_content_transform = torchvision.transforms.Compose([ 69 | torchvision.transforms.ToTensor(), 70 | torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406), 71 | std=(0.229, 0.224, 0.225))]) 72 | transforms = torchvision.transforms.Compose([ 73 | torchvision.transforms.ToTensor(), 74 | torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406), 75 | std=(0.229, 0.224, 0.225))]) 76 | 77 | if dataset == 'live': 78 | self.data = folders.LIVEFolder( 79 | root=path, HQ_diff_content_root=ref_path, index=img_indx, transform=transforms, HQ_diff_content_transform=HQ_diff_content_transform, patch_num=patch_num, patch_size = patch_size, self_patch_num=self_patch_num) 80 | elif dataset == 'csiq': 81 | self.data = folders.CSIQFolder( 82 | root=path, HQ_diff_content_root=ref_path, index=img_indx, transform=transforms, HQ_diff_content_transform=HQ_diff_content_transform, patch_num=patch_num, patch_size = patch_size, self_patch_num=self_patch_num) 83 | elif dataset == 'kadid10k': 84 | self.data = folders.Kadid10kFolder( 85 | root=path, HQ_diff_content_root=ref_path, index=img_indx, transform=transforms, HQ_diff_content_transform=HQ_diff_content_transform, patch_num=patch_num, patch_size = patch_size, self_patch_num=self_patch_num) 86 | elif dataset == 'tid2013': 87 | self.data = folders.TID2013Folder( 88 | root=path, HQ_diff_content_root=ref_path, index=img_indx, transform=transforms, HQ_diff_content_transform=HQ_diff_content_transform, patch_num=patch_num, patch_size = patch_size, self_patch_num=self_patch_num) 89 | elif dataset == 'koniq-10k': 90 | self.data = folders.Koniq_10kFolder( 91 | root=path, HQ_diff_content_root=ref_path, index=img_indx, transform=transforms, HQ_diff_content_transform=HQ_diff_content_transform, patch_num=patch_num, patch_size = patch_size, self_patch_num=self_patch_num) 92 | elif dataset == 'livec': 93 | self.data = folders.LIVEChallengeFolder( 94 | root=path, HQ_diff_content_root=ref_path, index=img_indx, transform=transforms, HQ_diff_content_transform=HQ_diff_content_transform, patch_num=patch_num, patch_size = patch_size, self_patch_num=self_patch_num) 95 | 96 | def get_dataloader(self): 97 | if self.istrain: 98 | dataloader = torch.utils.data.DataLoader( 99 | self.data, batch_size=self.batch_size, shuffle=True) 100 | else: 101 | dataloader = torch.utils.data.DataLoader( 102 | self.data, batch_size=self.batch_size, shuffle=False) 103 | return dataloader 104 | -------------------------------------------------------------------------------- /test_DistillationIQA.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from dataloaders.dataloader_LQ_HQ_diff_content_HQ import DataLoader 4 | from option_train_DistillationIQA import set_args, check_args 5 | from scipy import stats 6 | import numpy as np 7 | from tools.nonlinear_convert import convert_obj_score 8 | from models.DistillationIQA import DistillationIQANet 9 | 10 | img_num = { 11 | 'kadid10k': list(range(0,10125)), 12 | 'live': list(range(0, 29)),#ref HR image 13 | 'csiq': list(range(0, 30)),#ref HR image 14 | 'tid2013': list(range(0, 25)), 15 | 'livec': list(range(0, 1162)),# no-ref image 16 | 'koniq-10k': list(range(0, 10073)),# no-ref image 17 | 'bid': list(range(0, 586)),# no-ref image 18 | } 19 | folder_path = { 20 | 'pipal':'./dataset/PIPAL', 21 | 'live': './dataset/LIVE/', 22 | 'csiq': './dataset/CSIQ/', 23 | 'tid2013': './dataset/TID2013/', 24 | 'livec': './dataset/LIVEC/', 25 | 'koniq-10k': './dataset/koniq-10k/', 26 | 'bid': './dataset/BID/', 27 | 'kadid10k':'./dataset/kadid10k/' 28 | } 29 | 30 | 31 | class DistillationIQASolver(object): 32 | def __init__(self, config): 33 | self.config = config 34 | self.device = torch.device('cuda' if config.gpu_ids is not None else 'cpu') 35 | self.txt_log_path = os.path.join(config.log_checkpoint_dir,'log.txt') 36 | self.config.teacherNet_model_path = './model_zoo/FR_teacher_cross_dataset.pth' 37 | self.config.studentNet_model_path = './model_zoo/NAR_student_cross_dataset.pth' 38 | with open(self.txt_log_path,"w+") as f: 39 | f.close() 40 | 41 | #model 42 | self.teacherNet = DistillationIQANet(self_patch_num=config.self_patch_num, distillation_layer=config.distillation_layer) 43 | if config.teacherNet_model_path: 44 | self.teacherNet._load_state_dict(torch.load(config.teacherNet_model_path)) 45 | self.teacherNet = self.teacherNet.to(self.device) 46 | self.teacherNet.train(False) 47 | self.studentNet = DistillationIQANet(self_patch_num=config.self_patch_num, distillation_layer=config.distillation_layer) 48 | if config.studentNet_model_path: 49 | self.studentNet._load_state_dict(torch.load(config.studentNet_model_path)) 50 | self.studentNet = self.studentNet.to(self.device) 51 | self.studentNet.train(True) 52 | 53 | #data 54 | test_loader_LIVE = DataLoader('live', folder_path['live'], config.ref_test_dataset_path, img_num['live'], config.patch_size, config.test_patch_num, istrain=False, self_patch_num=config.self_patch_num) 55 | test_loader_CSIQ = DataLoader('csiq', folder_path['csiq'], config.ref_test_dataset_path, img_num['csiq'], config.patch_size, config.test_patch_num, istrain=False, self_patch_num=config.self_patch_num) 56 | test_loader_TID = DataLoader('tid2013', folder_path['tid2013'], config.ref_test_dataset_path, img_num['tid2013'], config.patch_size, config.test_patch_num, istrain=False, self_patch_num=config.self_patch_num) 57 | test_loader_Koniq = DataLoader('koniq-10k', folder_path['koniq-10k'], config.ref_test_dataset_path, img_num['koniq-10k'], config.patch_size, config.test_patch_num, istrain=False, self_patch_num=config.self_patch_num) 58 | 59 | self.test_data_LIVE = test_loader_LIVE.get_dataloader() 60 | self.test_data_CSIQ = test_loader_CSIQ.get_dataloader() 61 | self.test_data_TID = test_loader_TID.get_dataloader() 62 | self.test_data_Koniq = test_loader_Koniq.get_dataloader() 63 | 64 | 65 | def test(self, test_data): 66 | self.studentNet.train(False) 67 | test_pred_scores, test_gt_scores = [], [] 68 | for LQ_patches, _, ref_patches, label in test_data: 69 | LQ_patches, ref_patches, label = LQ_patches.to(self.device), ref_patches.to(self.device), label.to(self.device) 70 | with torch.no_grad(): 71 | _, _, pred = self.studentNet(LQ_patches, ref_patches) 72 | test_pred_scores.append(float(pred.item())) 73 | test_gt_scores = test_gt_scores + label.cpu().tolist() 74 | if self.config.use_fitting_prcc_srcc: 75 | fitting_pred_scores = convert_obj_score(test_pred_scores, test_gt_scores) 76 | test_pred_scores = np.mean(np.reshape(np.array(test_pred_scores), (-1, self.config.test_patch_num)), axis=1) 77 | test_gt_scores = np.mean(np.reshape(np.array(test_gt_scores), (-1, self.config.test_patch_num)), axis=1) 78 | test_srcc, _ = stats.spearmanr(test_pred_scores, test_gt_scores) 79 | if self.config.use_fitting_prcc_srcc: 80 | test_plcc, _ = stats.pearsonr(fitting_pred_scores, test_gt_scores) 81 | else: 82 | test_plcc, _ = stats.pearsonr(test_pred_scores, test_gt_scores) 83 | test_krcc, _ = stats.stats.kendalltau(test_pred_scores, test_gt_scores) 84 | test_srcc, test_plcc, test_krcc = abs(test_srcc), abs(test_plcc), abs(test_krcc) 85 | self.studentNet.train(True) 86 | return test_srcc, test_plcc, test_krcc 87 | 88 | if __name__ == "__main__": 89 | config = set_args() 90 | config = check_args(config) 91 | solver = DistillationIQASolver(config=config) 92 | fold_10_test_LIVE_srcc, fold_10_test_LIVE_plcc, fold_10_test_LIVE_krcc = [], [], [] 93 | fold_10_test_CSIQ_srcc, fold_10_test_CSIQ_plcc, fold_10_test_CSIQ_krcc = [], [], [] 94 | fold_10_test_TID_srcc, fold_10_test_TID_plcc, fold_10_test_TID_krcc = [], [], [] 95 | fold_10_test_Koniq_srcc, fold_10_test_Koniq_plcc, fold_10_test_Koniq_krcc = [], [], [] 96 | 97 | for i in range(10): 98 | 99 | test_LIVE_srcc, test_LIVE_plcc, test_LIVE_krcc = solver.test(solver.test_data_LIVE) 100 | print('round{} Dataset:LIVE Test_SRCC:{} Test_PLCC:{} TEST_KRCC:{}\n'.format(i, test_LIVE_srcc, test_LIVE_plcc, test_LIVE_krcc)) 101 | fold_10_test_LIVE_srcc.append(test_LIVE_srcc) 102 | fold_10_test_LIVE_plcc.append(test_LIVE_plcc) 103 | fold_10_test_LIVE_krcc.append(test_LIVE_krcc) 104 | 105 | test_CSIQ_srcc, test_CSIQ_plcc, test_CSIQ_krcc = solver.test(solver.test_data_CSIQ) 106 | print('round{} Dataset:CSIQ Test_SRCC:{} Test_PLCC:{} TEST_KRCC:{}\n'.format(i, test_CSIQ_srcc, test_CSIQ_plcc, test_CSIQ_krcc)) 107 | fold_10_test_CSIQ_srcc.append(test_CSIQ_srcc) 108 | fold_10_test_CSIQ_plcc.append(test_CSIQ_plcc) 109 | fold_10_test_CSIQ_krcc.append(test_CSIQ_krcc) 110 | 111 | test_TID_srcc, test_TID_plcc, test_TID_krcc = solver.test(solver.test_data_TID) 112 | print('round{} Dataset:TID Test_SRCC:{} Test_PLCC:{} TEST_KRCC:{}\n'.format(i, test_TID_srcc, test_TID_plcc, test_TID_krcc)) 113 | fold_10_test_TID_srcc.append(test_TID_srcc) 114 | fold_10_test_TID_plcc.append(test_TID_plcc) 115 | fold_10_test_TID_krcc.append(test_TID_krcc) 116 | 117 | test_Koniq_srcc, test_Koniq_plcc, test_Koniq_krcc = solver.test(solver.test_data_Koniq) 118 | print('round{} Dataset:Koniq Test_SRCC:{} Test_PLCC:{} TEST_KRCC:{}\n'.format(i, test_Koniq_srcc, test_Koniq_plcc, test_Koniq_krcc)) 119 | fold_10_test_Koniq_srcc.append(test_Koniq_srcc) 120 | fold_10_test_Koniq_plcc.append(test_Koniq_plcc) 121 | fold_10_test_Koniq_krcc.append(test_Koniq_krcc) 122 | 123 | print('Dataset:LIVE Test_SRCC:{} Test_PLCC:{} TEST_KRCC:{}\n'.format(np.mean(fold_10_test_LIVE_srcc), np.mean(fold_10_test_LIVE_plcc), np.mean(fold_10_test_LIVE_krcc))) 124 | print('Dataset:CSIQ Test_SRCC:{} Test_PLCC:{} TEST_KRCC:{}\n'.format(np.mean(fold_10_test_CSIQ_srcc), np.mean(fold_10_test_CSIQ_plcc), np.mean(fold_10_test_CSIQ_krcc))) 125 | print('Dataset:TID Test_SRCC:{} Test_PLCC:{} TEST_KRCC:{}\n'.format(np.mean(fold_10_test_TID_srcc), np.mean(fold_10_test_TID_plcc), np.mean(fold_10_test_TID_krcc))) 126 | print('Dataset:Koniq Test_SRCC:{} Test_PLCC:{} TEST_KRCC:{}\n'.format(np.mean(fold_10_test_Koniq_srcc), np.mean(fold_10_test_Koniq_plcc), np.mean(fold_10_test_Koniq_krcc))) 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | -------------------------------------------------------------------------------- /train_DistillationIQA_FR.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import random 4 | from dataloaders.dataloader_LQ_HQ import DataLoader 5 | from option_train_DistillationIQA_FR import set_args, check_args 6 | from scipy import stats 7 | import numpy as np 8 | from tools.nonlinear_convert import convert_obj_score 9 | from models.DistillationIQA import DistillationIQANet 10 | 11 | img_num = { 12 | 'kadid10k': list(range(0,10125)), 13 | 'live': list(range(0, 29)),#ref HR image 14 | 'csiq': list(range(0, 30)),#ref HR image 15 | 'tid2013': list(range(0, 25)), 16 | 'livec': list(range(0, 1162)),# no-ref image 17 | 'koniq-10k': list(range(0, 10073)),# no-ref image 18 | 'bid': list(range(0, 586)),# no-ref image 19 | } 20 | folder_path = { 21 | 'pipal':'./dataset/PIPAL', 22 | 'live': './dataset/LIVE/', 23 | 'csiq': './dataset/CSIQ/', 24 | 'tid2013': './dataset/TID2013/', 25 | 'livec': './dataset/LIVEC/', 26 | 'koniq-10k': './dataset/koniq-10k/', 27 | 'bid': './dataset/BID/', 28 | 'kadid10k':'./dataset/kadid10k/' 29 | } 30 | 31 | 32 | class DistillationFRIQASolver(object): 33 | def __init__(self, config): 34 | self.config = config 35 | self.device = torch.device('cuda' if config.gpu_ids is not None else 'cpu') 36 | self.txt_log_path = os.path.join(config.log_checkpoint_dir,'log.txt') 37 | with open(self.txt_log_path,"w+") as f: 38 | f.close() 39 | 40 | #model 41 | self.teacherNet = DistillationIQANet(self_patch_num=config.self_patch_num, distillation_layer=config.distillation_layer) 42 | if config.teacherNet_model_path: 43 | self.teacherNet._load_state_dict(torch.load(config.teacherNet_pretrained_path)) 44 | self.teacherNet = self.teacherNet.to(self.device) 45 | self.teacherNet.train(True) 46 | #lr,opt,loss,epoch 47 | self.lr = config.lr 48 | self.lr_ratio = 10 49 | self.feature_loss_ratio = 0.1 50 | resnet_params = list(map(id, self.teacherNet.feature_extractor.parameters())) 51 | res_params = filter(lambda p: id(p) not in resnet_params, self.teacherNet.parameters()) 52 | paras = [{'params': res_params, 'lr': self.lr * self.lr_ratio }, 53 | {'params': self.teacherNet.feature_extractor.parameters(), 'lr': self.lr} 54 | ] 55 | self.optimizer = torch.optim.Adam(paras, weight_decay=config.weight_decay) 56 | self.mse_loss = torch.nn.MSELoss() 57 | self.l1_loss = torch.nn.L1Loss() 58 | self.epochs = config.epochs 59 | 60 | #data 61 | config.train_index = img_num[config.train_dataset] 62 | random.shuffle(config.train_index) 63 | train_loader = DataLoader(config.train_dataset, folder_path[config.train_dataset], config.train_index, config.patch_size, config.train_patch_num, batch_size=config.batch_size, istrain=True, self_patch_num=config.self_patch_num) 64 | test_loader_LIVE = DataLoader('live', folder_path['live'], img_num['live'], config.patch_size, config.test_patch_num, istrain=False, self_patch_num=config.self_patch_num) 65 | test_loader_CSIQ = DataLoader('csiq', folder_path['csiq'], img_num['csiq'], config.patch_size, config.test_patch_num, istrain=False, self_patch_num=config.self_patch_num) 66 | test_loader_TID = DataLoader('tid2013', folder_path['tid2013'], img_num['tid2013'], config.patch_size, config.test_patch_num, istrain=False, self_patch_num=config.self_patch_num) 67 | 68 | self.train_data = train_loader.get_dataloader() 69 | self.test_data_LIVE = test_loader_LIVE.get_dataloader() 70 | self.test_data_CSIQ = test_loader_CSIQ.get_dataloader() 71 | self.test_data_TID = test_loader_TID.get_dataloader() 72 | 73 | def train(self): 74 | best_srcc_LIVE, best_srcc_CSIQ, best_srcc_TID = 0.0, 0.0, 0.0 75 | best_plcc_LIVE, best_plcc_CSIQ, best_plcc_TID = 0.0, 0.0, 0.0 76 | best_krcc_LIVE, best_krcc_CSIQ, best_krcc_TID = 0.0, 0.0, 0.0 77 | 78 | print('Epoch\tTrain_Loss\tTrain_SRCC\tTest_SRCC\tTest_PLCC\tTest_KRCC') 79 | # NEW 80 | scaler = torch.cuda.amp.GradScaler() 81 | 82 | for t in range(self.epochs): 83 | epoch_loss = [] 84 | pred_scores = [] 85 | gt_scores = [] 86 | 87 | for LQ_patches, refHQ_patches, label in self.train_data: 88 | LQ_patches, refHQ_patches, label = LQ_patches.to(self.device), refHQ_patches.to(self.device), label.to(self.device) 89 | self.optimizer.zero_grad() 90 | 91 | with torch.cuda.amp.autocast(): 92 | _, _, pred = self.teacherNet(LQ_patches, refHQ_patches) 93 | 94 | pred_scores = pred_scores + pred.cpu().tolist() 95 | gt_scores = gt_scores + label.cpu().tolist() 96 | loss = self.l1_loss(pred.squeeze(), label.float().detach()) 97 | 98 | epoch_loss.append(loss.item()) 99 | scaler.scale(loss).backward() 100 | scaler.step(self.optimizer) 101 | scaler.update() 102 | 103 | train_srcc, _ = stats.spearmanr(pred_scores, gt_scores) 104 | test_LIVE_srcc, test_LIVE_plcc, test_LIVE_krcc = self.test(self.test_data_LIVE) 105 | test_CSIQ_srcc, test_CSIQ_plcc, test_CSIQ_krcc = self.test(self.test_data_CSIQ) 106 | test_TID_srcc, test_TID_plcc, test_TID_krcc = self.test(self.test_data_TID) 107 | 108 | if test_LIVE_srcc + test_LIVE_plcc + test_LIVE_krcc > best_srcc_LIVE + best_plcc_LIVE + best_krcc_LIVE: 109 | best_srcc_LIVE, best_srcc_CSIQ, best_srcc_TID = test_LIVE_srcc, test_CSIQ_srcc, test_TID_srcc 110 | print('%d:live\t%4.3f\t\t%4.4f\t\t%4.4f\t\t%4.4f\t\t%4.4f \n' % 111 | (t, sum(epoch_loss) / len(epoch_loss), train_srcc, test_LIVE_srcc, test_LIVE_plcc, test_LIVE_krcc)) 112 | 113 | if test_CSIQ_srcc + test_CSIQ_plcc + test_CSIQ_krcc > best_srcc_CSIQ + best_plcc_CSIQ + best_krcc_CSIQ: 114 | best_plcc_LIVE, best_plcc_CSIQ, best_plcc_TID = test_LIVE_plcc, test_CSIQ_plcc, test_TID_plcc 115 | print('%d:csiq\t%4.3f\t\t%4.4f\t\t%4.4f\t\t%4.4f\t\t%4.4f \n' % 116 | (t, sum(epoch_loss) / len(epoch_loss), train_srcc, test_CSIQ_srcc, test_CSIQ_plcc, test_CSIQ_krcc)) 117 | 118 | if test_TID_srcc + test_TID_plcc + test_TID_krcc > best_srcc_TID + best_plcc_TID + best_krcc_TID: 119 | best_krcc_LIVE, best_krcc_CSIQ, best_krcc_TID = test_LIVE_krcc, test_CSIQ_krcc, test_TID_krcc 120 | print('%d:tid\t%4.3f\t\t%4.4f\t\t%4.4f\t\t%4.4f\t\t%4.4f \n' % 121 | (t, sum(epoch_loss) / len(epoch_loss), train_srcc, test_TID_srcc, test_TID_plcc, test_TID_krcc)) 122 | 123 | torch.save(self.teacherNet.state_dict(), os.path.join(self.config.model_checkpoint_dir, 'FRIQA_{}_saved_model.pth'.format(t))) 124 | 125 | self.lr = self.lr / pow(10, (t // self.config.update_opt_epoch)) 126 | if t > 20: 127 | self.lr_ratio = 1 128 | resnet_params = list(map(id, self.teacherNet.feature_extractor.parameters())) 129 | rest_params = filter(lambda p: id(p) not in resnet_params, self.teacherNet.parameters()) 130 | paras = [{'params': rest_params, 'lr': self.lr * self.lr_ratio }, 131 | {'params': self.teacherNet.feature_extractor.parameters(), 'lr': self.lr} 132 | ] 133 | self.optimizer = torch.optim.Adam(paras, weight_decay=self.config.weight_decay) 134 | 135 | print('Best live test SRCC %f, PLCC %f, KRCC %f\n' % (best_srcc_LIVE, best_plcc_LIVE, best_krcc_LIVE)) 136 | print('Best csiq test SRCC %f, PLCC %f, KRCC %f\n' % (best_srcc_CSIQ, best_plcc_CSIQ, best_krcc_CSIQ)) 137 | print('Best tid2013 test SRCC %f, PLCC %f, KRCC %f\n' % (best_srcc_TID, best_plcc_TID, best_krcc_TID)) 138 | 139 | 140 | def test(self, test_data): 141 | self.teacherNet.train(False) 142 | test_pred_scores, test_gt_scores = [], [] 143 | for LQ_patches, refHQ_patches, label in test_data: 144 | LQ_patches, refHQ_patches, label = LQ_patches.to(self.device), refHQ_patches.to(self.device), label.to(self.device) 145 | with torch.no_grad(): 146 | _, _, pred = self.teacherNet(LQ_patches, refHQ_patches) 147 | test_pred_scores.append(float(pred.item())) 148 | test_gt_scores = test_gt_scores + label.cpu().tolist() 149 | if self.config.use_fitting_prcc_srcc: 150 | fitting_pred_scores = convert_obj_score(test_pred_scores, test_gt_scores) 151 | test_pred_scores = np.mean(np.reshape(np.array(test_pred_scores), (-1, self.config.test_patch_num)), axis=1) 152 | test_gt_scores = np.mean(np.reshape(np.array(test_gt_scores), (-1, self.config.test_patch_num)), axis=1) 153 | test_srcc, _ = stats.spearmanr(test_pred_scores, test_gt_scores) 154 | if self.config.use_fitting_prcc_srcc: 155 | test_plcc, _ = stats.pearsonr(fitting_pred_scores, test_gt_scores) 156 | else: 157 | test_plcc, _ = stats.pearsonr(test_pred_scores, test_gt_scores) 158 | test_krcc, _ = stats.stats.kendalltau(test_pred_scores, test_gt_scores) 159 | test_srcc, test_plcc, test_krcc = abs(test_srcc), abs(test_plcc), abs(test_krcc) 160 | self.teacherNet.train(True) 161 | return test_srcc, test_plcc, test_krcc 162 | 163 | if __name__ == "__main__": 164 | config = set_args() 165 | config = check_args(config) 166 | solver = DistillationFRIQASolver(config=config) 167 | solver.train() 168 | 169 | 170 | 171 | 172 | 173 | -------------------------------------------------------------------------------- /train_DistillationIQA.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import random 4 | from dataloaders.dataloader_LQ_HQ_diff_content_HQ import DataLoader 5 | from option_train_DistillationIQA import set_args, check_args 6 | from scipy import stats 7 | import numpy as np 8 | from tools.nonlinear_convert import convert_obj_score 9 | from models.DistillationIQA import DistillationIQANet 10 | 11 | img_num = { 12 | 'kadid10k': list(range(0,10125)), 13 | 'live': list(range(0, 29)),#ref HR image 14 | 'csiq': list(range(0, 30)),#ref HR image 15 | 'tid2013': list(range(0, 25)), 16 | 'livec': list(range(0, 1162)),# no-ref image 17 | 'koniq-10k': list(range(0, 10073)),# no-ref image 18 | 'bid': list(range(0, 586)),# no-ref image 19 | } 20 | folder_path = { 21 | 'pipal':'./dataset/PIPAL', 22 | 'live': './dataset/LIVE/', 23 | 'csiq': './dataset/CSIQ/', 24 | 'tid2013': './dataset/TID2013/', 25 | 'livec': './dataset/LIVEC/', 26 | 'koniq-10k': './dataset/koniq-10k/', 27 | 'bid': './dataset/BID/', 28 | 'kadid10k':'./dataset/kadid10k/' 29 | } 30 | 31 | 32 | class DistillationIQASolver(object): 33 | def __init__(self, config): 34 | self.config = config 35 | self.device = torch.device('cuda' if config.gpu_ids is not None else 'cpu') 36 | self.txt_log_path = os.path.join(config.log_checkpoint_dir,'log.txt') 37 | with open(self.txt_log_path,"w+") as f: 38 | f.close() 39 | 40 | #model 41 | self.teacherNet = DistillationIQANet(self_patch_num=config.self_patch_num, distillation_layer=config.distillation_layer) 42 | if config.teacherNet_model_path: 43 | self.teacherNet._load_state_dict(torch.load(config.teacherNet_model_path)) 44 | self.teacherNet = self.teacherNet.to(self.device) 45 | self.teacherNet.train(False) 46 | self.studentNet = DistillationIQANet(self_patch_num=config.self_patch_num, distillation_layer=config.distillation_layer) 47 | if config.studentNet_model_path: 48 | self.studentNet._load_state_dict(torch.load(config.studentNet_model_path)) 49 | self.studentNet = self.studentNet.to(self.device) 50 | self.studentNet.train(True) 51 | 52 | #lr,opt,loss,epoch 53 | self.lr = config.lr 54 | self.lr_ratio = 1 55 | self.feature_loss_ratio = 1 56 | resnet_params = list(map(id, self.studentNet.feature_extractor.parameters())) 57 | res_params = filter(lambda p: id(p) not in resnet_params, self.studentNet.parameters()) 58 | paras = [{'params': res_params, 'lr': self.lr * self.lr_ratio }, 59 | {'params': self.studentNet.feature_extractor.parameters(), 'lr': self.lr} 60 | ] 61 | self.optimizer = torch.optim.Adam(paras, weight_decay=config.weight_decay) 62 | self.mse_loss = torch.nn.MSELoss() 63 | self.l1_loss = torch.nn.L1Loss() 64 | self.epochs = config.epochs 65 | 66 | #data 67 | config.train_index = img_num[config.train_dataset] 68 | random.shuffle(config.train_index) 69 | train_loader = DataLoader(config.train_dataset, folder_path[config.train_dataset], config.ref_train_dataset_path, config.train_index, config.patch_size, config.train_patch_num, batch_size=config.batch_size, istrain=True, self_patch_num=config.self_patch_num) 70 | test_loader_LIVE = DataLoader('live', folder_path['live'], config.ref_test_dataset_path, img_num['live'], config.patch_size, config.test_patch_num, istrain=False, self_patch_num=config.self_patch_num) 71 | test_loader_CSIQ = DataLoader('csiq', folder_path['csiq'], config.ref_test_dataset_path, img_num['csiq'], config.patch_size, config.test_patch_num, istrain=False, self_patch_num=config.self_patch_num) 72 | test_loader_TID = DataLoader('tid2013', folder_path['tid2013'], config.ref_test_dataset_path, img_num['tid2013'], config.patch_size, config.test_patch_num, istrain=False, self_patch_num=config.self_patch_num) 73 | test_loader_Koniq = DataLoader('koniq-10k', folder_path['koniq-10k'], config.ref_test_dataset_path, img_num['koniq-10k'], config.patch_size, config.test_patch_num, istrain=False, self_patch_num=config.self_patch_num) 74 | 75 | self.train_data = train_loader.get_dataloader() 76 | self.test_data_LIVE = test_loader_LIVE.get_dataloader() 77 | self.test_data_CSIQ = test_loader_CSIQ.get_dataloader() 78 | self.test_data_TID = test_loader_TID.get_dataloader() 79 | self.test_data_Koniq = test_loader_Koniq.get_dataloader() 80 | 81 | 82 | def train(self): 83 | best_srcc_LIVE, best_srcc_CSIQ, best_srcc_TID, best_srcc_Koniq = 0.0, 0.0, 0.0, 0.0 84 | best_plcc_LIVE, best_plcc_CSIQ, best_plcc_TID, best_plcc_Koniq = 0.0, 0.0, 0.0, 0.0 85 | best_krcc_LIVE, best_krcc_CSIQ, best_krcc_TID, best_krcc_Koniq = 0.0, 0.0, 0.0, 0.0 86 | 87 | print('Epoch\tTrain_Loss\tTrain_SRCC\tTest_SRCC\tTest_PLCC\tTest_KRCC') 88 | 89 | # NEW 90 | scaler = torch.cuda.amp.GradScaler() 91 | 92 | for t in range(self.epochs): 93 | epoch_loss = [] 94 | pred_scores = [] 95 | gt_scores = [] 96 | 97 | for LQ_patches, refHQ_patches, ref_patches, label in self.train_data: 98 | LQ_patches, refHQ_patches, ref_patches, label = LQ_patches.to(self.device), refHQ_patches.to(self.device), ref_patches.to(self.device), label.to(self.device) 99 | self.optimizer.zero_grad() 100 | 101 | with torch.cuda.amp.autocast(): 102 | t_encode_diff_inner_feature, t_decode_inner_feature, _ = self.teacherNet(LQ_patches, refHQ_patches) 103 | s_encode_diff_inner_feature, s_decode_inner_feature, pred = self.studentNet(LQ_patches, ref_patches) 104 | 105 | pred_scores = pred_scores + pred.cpu().tolist() 106 | gt_scores = gt_scores + label.cpu().tolist() 107 | pred_loss = self.l1_loss(pred.squeeze(), label.float().detach()) 108 | 109 | encode_diff_loss, decode_loss = 0.0, 0.0 110 | for t_encode_diff_feature, s_encode_diff_feature, t_decode_feature, s_decode_feature in zip(t_encode_diff_inner_feature, s_encode_diff_inner_feature, t_decode_inner_feature, s_decode_inner_feature): 111 | #mse_loss 112 | feature_loss += self.mse_loss(t_encode_diff_feature, s_encode_diff_feature) 113 | # encode_diff_loss += self.mse_loss(t_encode_diff_feature, s_encode_diff_feature) 114 | # decode_loss += self.mse_loss(t_decode_feature, s_decode_feature) 115 | # feature_loss = encode_diff_loss + decode_loss 116 | loss = pred_loss + feature_loss*self.feature_loss_ratio 117 | epoch_loss.append(loss.item()) 118 | scaler.scale(loss).backward() 119 | scaler.step(self.optimizer) 120 | scaler.update() 121 | 122 | train_srcc, _ = stats.spearmanr(pred_scores, gt_scores) 123 | test_LIVE_srcc, test_LIVE_plcc, test_LIVE_krcc = self.test(self.test_data_LIVE) 124 | test_CSIQ_srcc, test_CSIQ_plcc, test_CSIQ_krcc = self.test(self.test_data_CSIQ) 125 | test_TID_srcc, test_TID_plcc, test_TID_krcc = self.test(self.test_data_TID) 126 | test_Koniq_srcc, test_Koniq_plcc, test_Koniq_krcc = solver.test(solver.test_data_Koniq) 127 | 128 | if test_LIVE_srcc + test_LIVE_plcc + test_LIVE_krcc > best_srcc_LIVE + best_plcc_LIVE + best_krcc_LIVE: 129 | best_srcc_LIVE, best_srcc_CSIQ, best_srcc_TID = test_LIVE_srcc, test_CSIQ_srcc, test_TID_srcc 130 | print('%d:live\t%4.3f\t\t%4.4f\t\t%4.4f\t\t%4.4f\t\t%4.4f \n' % 131 | (t, sum(epoch_loss) / len(epoch_loss), train_srcc, test_LIVE_srcc, test_LIVE_plcc, test_LIVE_krcc)) 132 | 133 | if test_CSIQ_srcc + test_CSIQ_plcc + test_CSIQ_krcc > best_srcc_CSIQ + best_plcc_CSIQ + best_krcc_CSIQ: 134 | best_plcc_LIVE, best_plcc_CSIQ, best_plcc_TID = test_LIVE_plcc, test_CSIQ_plcc, test_TID_plcc 135 | print('%d:csiq\t%4.3f\t\t%4.4f\t\t%4.4f\t\t%4.4f\t\t%4.4f \n' % 136 | (t, sum(epoch_loss) / len(epoch_loss), train_srcc, test_CSIQ_srcc, test_CSIQ_plcc, test_CSIQ_krcc)) 137 | 138 | if test_TID_srcc + test_TID_plcc + test_TID_krcc > best_srcc_TID + best_plcc_TID + best_krcc_TID: 139 | best_krcc_LIVE, best_krcc_CSIQ, best_krcc_TID = test_LIVE_krcc, test_CSIQ_krcc, test_TID_krcc 140 | print('%d:tid\t%4.3f\t\t%4.4f\t\t%4.4f\t\t%4.4f\t\t%4.4f \n' % 141 | (t, sum(epoch_loss) / len(epoch_loss), train_srcc, test_TID_srcc, test_TID_plcc, test_TID_krcc)) 142 | 143 | if test_Koniq_srcc + test_Koniq_plcc + test_Koniq_krcc > best_srcc_Koniq + best_plcc_Koniq + best_krcc_Koniq: 144 | print('%d:koniq-10k\t%4.3f\t\t%4.4f\t\t%4.4f\t\t%4.4f\t\t%4.4f \n' % 145 | (t, sum(epoch_loss) / len(epoch_loss), train_srcc, test_Koniq_srcc, test_Koniq_plcc, test_Koniq_krcc)) 146 | best_srcc_Koniq, best_plcc_Koniq, best_krcc_Koniq = test_Koniq_srcc, test_Koniq_plcc, test_Koniq_krcc 147 | 148 | torch.save(self.studentNet.state_dict(), os.path.join(self.config.model_checkpoint_dir, 'Distillation_inner_{}_saved_model.pth'.format(t))) 149 | 150 | self.lr = self.lr / pow(10, (t // self.config.update_opt_epoch)) 151 | if t > 20: 152 | self.lr_ratio = 1 153 | resnet_params = list(map(id, self.studentNet.feature_extractor.parameters())) 154 | rest_params = filter(lambda p: id(p) not in resnet_params, self.studentNet.parameters()) 155 | paras = [{'params': rest_params, 'lr': self.lr * self.lr_ratio }, 156 | {'params': self.studentNet.feature_extractor.parameters(), 'lr': self.lr} 157 | ] 158 | self.optimizer = torch.optim.Adam(paras, weight_decay=self.config.weight_decay) 159 | print('Best live test SRCC %f, PLCC %f, KRCC %f\n' % (best_srcc_LIVE, best_plcc_LIVE, best_krcc_LIVE)) 160 | print('Best csiq test SRCC %f, PLCC %f, KRCC %f\n' % (best_srcc_CSIQ, best_plcc_CSIQ, best_krcc_CSIQ)) 161 | print('Best tid2013 test SRCC %f, PLCC %f, KRCC %f\n' % (best_srcc_TID, best_plcc_TID, best_krcc_TID)) 162 | print('Best koniq-10k test SRCC %f, PLCC %f, KRCC %f\n' % (best_srcc_Koniq, best_plcc_Koniq, best_krcc_Koniq)) 163 | 164 | 165 | def test(self, test_data): 166 | self.studentNet.train(False) 167 | test_pred_scores, test_gt_scores = [], [] 168 | for LQ_patches, _, ref_patches, label in test_data: 169 | LQ_patches, ref_patches, label = LQ_patches.to(self.device), ref_patches.to(self.device), label.to(self.device) 170 | with torch.no_grad(): 171 | _, _, pred = self.studentNet(LQ_patches, ref_patches) 172 | test_pred_scores.append(float(pred.item())) 173 | test_gt_scores = test_gt_scores + label.cpu().tolist() 174 | if self.config.use_fitting_prcc_srcc: 175 | fitting_pred_scores = convert_obj_score(test_pred_scores, test_gt_scores) 176 | test_pred_scores = np.mean(np.reshape(np.array(test_pred_scores), (-1, self.config.test_patch_num)), axis=1) 177 | test_gt_scores = np.mean(np.reshape(np.array(test_gt_scores), (-1, self.config.test_patch_num)), axis=1) 178 | test_srcc, _ = stats.spearmanr(test_pred_scores, test_gt_scores) 179 | if self.config.use_fitting_prcc_srcc: 180 | test_plcc, _ = stats.pearsonr(fitting_pred_scores, test_gt_scores) 181 | else: 182 | test_plcc, _ = stats.pearsonr(test_pred_scores, test_gt_scores) 183 | test_krcc, _ = stats.stats.kendalltau(test_pred_scores, test_gt_scores) 184 | test_srcc, test_plcc, test_krcc = abs(test_srcc), abs(test_plcc), abs(test_krcc) 185 | self.studentNet.train(True) 186 | return test_srcc, test_plcc, test_krcc 187 | 188 | if __name__ == "__main__": 189 | config = set_args() 190 | config = check_args(config) 191 | solver = DistillationIQASolver(config=config) 192 | solver.train() 193 | 194 | 195 | 196 | 197 | 198 | -------------------------------------------------------------------------------- /models/TRIQ.py: -------------------------------------------------------------------------------- 1 | import torch as torch 2 | from torch._C import device 3 | import torch.nn as nn 4 | import math 5 | from torch.nn.modules.conv import Conv2d 6 | import torch.utils.model_zoo as model_zoo 7 | from torch.nn import Dropout, Softmax, Linear, LayerNorm 8 | # from SemanticResNet50 import ResNetBackbone, Bottleneck 9 | # from models.ResNet50_MLP import ResNetBackbone 10 | # from models.MLP_return_inner_feature import MLPMixer 11 | 12 | 13 | #ResNet 14 | model_urls = { 15 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 16 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 17 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 18 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 19 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 20 | } 21 | 22 | class Bottleneck(nn.Module): 23 | expansion = 4 24 | 25 | def __init__(self, inplanes, planes, stride=1, downsample=None): 26 | super(Bottleneck, self).__init__() 27 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 28 | self.bn1 = nn.BatchNorm2d(planes) 29 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 30 | padding=1, bias=False) 31 | self.bn2 = nn.BatchNorm2d(planes) 32 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 33 | self.bn3 = nn.BatchNorm2d(planes * 4) 34 | self.relu = nn.ReLU(inplace=True) 35 | self.downsample = downsample 36 | self.stride = stride 37 | 38 | def forward(self, x): 39 | residual = x 40 | 41 | out = self.conv1(x) 42 | out = self.bn1(out) 43 | out = self.relu(out) 44 | 45 | out = self.conv2(out) 46 | out = self.bn2(out) 47 | out = self.relu(out) 48 | 49 | out = self.conv3(out) 50 | out = self.bn3(out) 51 | 52 | if self.downsample is not None: 53 | residual = self.downsample(x) 54 | 55 | out += residual 56 | out = self.relu(out) 57 | 58 | return out 59 | 60 | class ResNetBackbone(nn.Module): 61 | 62 | # def __init__(self, outc=176, block=Bottleneck, layers=[3, 4, 6, 3], num_classes=1000, pretrained=True): 63 | def __init__(self, outc=2048, block=Bottleneck, layers=[3, 4, 6, 3], pretrained=True): 64 | super(ResNetBackbone, self).__init__() 65 | self.pretrained = pretrained 66 | self.inplanes = 64 67 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 68 | self.bn1 = nn.BatchNorm2d(64) 69 | self.relu = nn.ReLU(inplace=True) 70 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 71 | self.layer1 = self._make_layer(block, 64, layers[0]) 72 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 73 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 74 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 75 | 76 | self.lda_out_channels = int(outc // 4) 77 | 78 | # local distortion aware module 79 | self.lda1_pool = nn.Sequential( 80 | nn.Conv2d(256, 16, kernel_size=1, stride=1, padding=0, bias=False), 81 | nn.AvgPool2d(7, stride=7), 82 | ) 83 | self.lda1_fc = nn.Linear(16 * 64, self.lda_out_channels) 84 | 85 | self.lda2_pool = nn.Sequential( 86 | nn.Conv2d(512, 32, kernel_size=1, stride=1, padding=0, bias=False), 87 | nn.AvgPool2d(7, stride=7), 88 | ) 89 | self.lda2_fc = nn.Linear(32 * 16, self.lda_out_channels) 90 | 91 | self.lda3_pool = nn.Sequential( 92 | nn.Conv2d(1024, 64, kernel_size=1, stride=1, padding=0, bias=False), 93 | nn.AvgPool2d(7, stride=7), 94 | ) 95 | self.lda3_fc = nn.Linear(64 * 4, self.lda_out_channels) 96 | 97 | self.lda4_pool = nn.AvgPool2d(7, stride=7) 98 | self.lda4_fc = nn.Linear(2048, self.lda_out_channels) 99 | 100 | for m in self.modules(): 101 | if isinstance(m, nn.Conv2d): 102 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 103 | m.weight.data.normal_(0, math.sqrt(2. / n)) 104 | elif isinstance(m, nn.BatchNorm2d): 105 | m.weight.data.fill_(1) 106 | m.bias.data.zero_() 107 | if self.pretrained: 108 | self.load_resnet50_backbone() 109 | 110 | def _make_layer(self, block, planes, blocks, stride=1): 111 | downsample = None 112 | if stride != 1 or self.inplanes != planes * block.expansion: 113 | downsample = nn.Sequential( 114 | nn.Conv2d(self.inplanes, planes * block.expansion, 115 | kernel_size=1, stride=stride, bias=False), 116 | nn.BatchNorm2d(planes * block.expansion), 117 | ) 118 | 119 | layers = [] 120 | layers.append(block(self.inplanes, planes, stride, downsample)) 121 | self.inplanes = planes * block.expansion 122 | for i in range(1, blocks): 123 | layers.append(block(self.inplanes, planes)) 124 | 125 | return nn.Sequential(*layers) 126 | 127 | 128 | def forward(self, x): 129 | 130 | x = self.conv1(x) 131 | x = self.bn1(x) 132 | x = self.relu(x) 133 | x = self.maxpool(x) 134 | x = self.layer1(x) 135 | x = self.layer2(x) 136 | x = self.layer3(x) 137 | x = self.layer4(x) 138 | return x 139 | 140 | def load_resnet50_backbone(self): 141 | """Constructs a ResNet-50 model_hyper. 142 | Args: 143 | pretrained (bool): If True, returns a model_hyper pre-trained on ImageNet 144 | """ 145 | save_model = model_zoo.load_url(model_urls['resnet50']) 146 | model_dict = self.state_dict() 147 | state_dict = {k: v for k, v in save_model.items() if k in model_dict.keys()} 148 | model_dict.update(state_dict) 149 | self.load_state_dict(model_dict) 150 | 151 | class Mlp(nn.Module): 152 | def __init__(self): 153 | super(Mlp, self).__init__() 154 | self.hidden_size = 32 155 | self.mlp_size = 64 156 | self.dropout_rate = 0.1 157 | self.fc1 = Linear(self.hidden_size, self.mlp_size) 158 | self.fc2 = Linear(self.mlp_size, self.hidden_size) 159 | self.act_fn = nn.ReLU() 160 | self.dropout = Dropout(self.dropout_rate) 161 | 162 | self._init_weights() 163 | 164 | def _init_weights(self): 165 | nn.init.xavier_uniform_(self.fc1.weight) 166 | nn.init.xavier_uniform_(self.fc2.weight) 167 | nn.init.normal_(self.fc1.bias, std=1e-6) 168 | nn.init.normal_(self.fc2.bias, std=1e-6) 169 | 170 | def forward(self, x): 171 | x = self.fc1(x) 172 | x = self.act_fn(x) 173 | x = self.dropout(x) 174 | x = self.fc2(x) 175 | x = self.dropout(x) 176 | return x 177 | 178 | class MHA(nn.Module): 179 | def __init__(self): 180 | super(MHA, self).__init__() 181 | self.vis = True 182 | self.num_attention_heads = 8 183 | self.hidden_size = 32 184 | self.attention_head_size = int(self.hidden_size / self.num_attention_heads) 185 | self.all_head_size = self.num_attention_heads * self.attention_head_size 186 | 187 | self.query = Linear(self.hidden_size, self.all_head_size) 188 | self.key = Linear(self.hidden_size, self.all_head_size) 189 | self.value = Linear(self.hidden_size, self.all_head_size) 190 | 191 | self.out = Linear(self.hidden_size, self.hidden_size) 192 | self.attn_dropout = Dropout(0.0) 193 | self.proj_dropout = Dropout(0.0) 194 | 195 | self.softmax = Softmax(dim=-1) 196 | 197 | def transpose_for_scores(self, x): 198 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 199 | x = x.view(*new_x_shape) 200 | return x.permute(0, 2, 1, 3) 201 | 202 | def forward(self, x): 203 | mixed_key_layer = self.key(x) 204 | mixed_value_layer = self.value(x) 205 | mixed_query_layer = self.query(x) 206 | 207 | query_layer = self.transpose_for_scores(mixed_query_layer) 208 | key_layer = self.transpose_for_scores(mixed_key_layer) 209 | value_layer = self.transpose_for_scores(mixed_value_layer) 210 | 211 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 212 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 213 | attention_probs = self.softmax(attention_scores) 214 | weights = attention_probs 215 | attention_probs = self.attn_dropout(attention_probs) 216 | 217 | context_layer = torch.matmul(attention_probs, value_layer) 218 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 219 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 220 | context_layer = context_layer.view(*new_context_layer_shape) 221 | attention_output = self.out(context_layer) 222 | attention_output = self.proj_dropout(attention_output) 223 | return attention_output 224 | 225 | class TransformerBlock(nn.Module): 226 | def __init__(self): 227 | super(TransformerBlock, self).__init__() 228 | self.hidden_size = 32 229 | self.ffn_norm = LayerNorm(self.hidden_size, eps=1e-6) 230 | self.ffn = Mlp() 231 | self.attention_norm = LayerNorm(self.hidden_size, eps=1e-6) 232 | self.attn = MHA() 233 | 234 | self.dropout_rate = 0.1 235 | self.dropout1 = Dropout(self.dropout_rate) 236 | self.dropout2 = Dropout(self.dropout_rate) 237 | 238 | def forward(self, x): 239 | x1 = self.attn(x) 240 | x1 = self.dropout1(x1) 241 | x1 += x 242 | x1 = self.attention_norm(x1) 243 | 244 | x2 = self.ffn(x1) 245 | x2 = self.dropout2(x2) 246 | x2 += x1 247 | x2 = self.attention_norm(x2) 248 | 249 | return x2 250 | 251 | class RegressionFCNet(nn.Module): 252 | def __init__(self): 253 | super(RegressionFCNet, self).__init__() 254 | self.target_in_size=32 255 | self.target_fc1_size=64 256 | 257 | self.l1 = nn.Linear(self.target_in_size, self.target_fc1_size) 258 | self.relu = nn.ReLU() 259 | self.l2 = nn.Linear(self.target_fc1_size, 1) 260 | 261 | 262 | def forward(self, x): 263 | q = self.l1(x) 264 | q = self.relu(q) 265 | q = self.l2(q).squeeze() 266 | return q 267 | 268 | class TRIQ(nn.Module): 269 | def __init__(self): 270 | super(TRIQ, self).__init__() 271 | self.feature_extractor = ResNetBackbone() 272 | # for param in self.feature_extractor.parameters(): 273 | # param.requires_grad = False 274 | 275 | self.conv = Conv2d(2048,32,kernel_size=1, stride=1, padding=0) 276 | self.position_embeddings = nn.Parameter(torch.zeros(1, 49+1, 32)) 277 | self.quality_token = nn.Parameter(torch.zeros(1, 1, 32)) 278 | 279 | self.transformer_block1 = TransformerBlock() 280 | self.transformer_block2 = TransformerBlock() 281 | 282 | self.regressor = RegressionFCNet() 283 | 284 | def cal_params(self): 285 | params = list(self.parameters()) 286 | k = 0 287 | for i in params: 288 | l = 1 289 | for j in i.size(): 290 | l *= j 291 | k = k + l 292 | print("Total parameters is :" + str(k)) 293 | 294 | def forward(self, LQ): 295 | B = LQ.shape[0] 296 | feature_LQ = self.feature_extractor(LQ) 297 | feature_LQ = self.conv(feature_LQ) 298 | 299 | quality_tokens = self.quality_token.expand(B,-1,-1) 300 | 301 | flat_feature_LQ = feature_LQ.flatten(2).transpose(-1,-2) 302 | flat_feature_LQ = torch.cat((quality_tokens,flat_feature_LQ), dim=1) + self.position_embeddings 303 | 304 | f = self.transformer_block1(flat_feature_LQ) 305 | f = self.transformer_block2(f) 306 | 307 | y = self.regressor(f) 308 | 309 | return y 310 | 311 | ''' 312 | TEST 313 | Run this code with: 314 | ``` 315 | cd $HOME/pretrained-models.pytorch 316 | python -m pretrainedmodels.inceptionresnetv2 317 | ``` 318 | ''' 319 | if __name__ == '__main__': 320 | import time 321 | device = torch.device('cuda') 322 | LQ = torch.rand((1,3,224, 224)).to(device) 323 | net = TRIQ().to(device) 324 | net.cal_params() 325 | 326 | torch.cuda.synchronize() 327 | start = time.time() 328 | a = net(LQ) 329 | torch.cuda.synchronize() 330 | end = time.time() 331 | print("run time is :" + str(end-start)) 332 | -------------------------------------------------------------------------------- /models/HyperIQA.py: -------------------------------------------------------------------------------- 1 | import torch as torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | from torch.nn import init 5 | import math 6 | import torch.utils.model_zoo as model_zoo 7 | 8 | model_urls = { 9 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 10 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 11 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 12 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 13 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 14 | } 15 | 16 | 17 | class HyperNet(nn.Module): 18 | """ 19 | Hyper network for learning perceptual rules. 20 | 21 | Args: 22 | lda_out_channels: local distortion aware module output size. 23 | hyper_in_channels: input feature channels for hyper network. 24 | target_in_size: input vector size for target network. 25 | target_fc(i)_size: fully connection layer size of target network. 26 | feature_size: input feature map width/height for hyper network. 27 | 28 | Note: 29 | For size match, input args must satisfy: 'target_fc(i)_size * target_fc(i+1)_size' is divisible by 'feature_size ^ 2'. 30 | 31 | """ 32 | def __init__(self, lda_out_channels=16, hyper_in_channels=112, target_in_size=224, target_fc1_size=112, target_fc2_size=56, target_fc3_size=28, target_fc4_size=14, feature_size=7): 33 | super(HyperNet, self).__init__() 34 | 35 | self.hyperInChn = hyper_in_channels 36 | self.target_in_size = target_in_size 37 | self.f1 = target_fc1_size 38 | self.f2 = target_fc2_size 39 | self.f3 = target_fc3_size 40 | self.f4 = target_fc4_size 41 | self.feature_size = feature_size 42 | 43 | self.res = resnet50_backbone(lda_out_channels, target_in_size, pretrained=True) 44 | 45 | self.pool = nn.AdaptiveAvgPool2d((1, 1)) 46 | 47 | # Conv layers for resnet output features 48 | self.conv1 = nn.Sequential( 49 | nn.Conv2d(2048, 1024, 1, padding=(0, 0)), 50 | nn.ReLU(inplace=True), 51 | nn.Conv2d(1024, 512, 1, padding=(0, 0)), 52 | nn.ReLU(inplace=True), 53 | nn.Conv2d(512, self.hyperInChn, 1, padding=(0, 0)), 54 | nn.ReLU(inplace=True) 55 | ) 56 | 57 | # Hyper network part, conv for generating target fc weights, fc for generating target fc biases 58 | self.fc1w_conv = nn.Conv2d(self.hyperInChn, int(self.target_in_size * self.f1 / feature_size ** 2), 3, padding=(1, 1)) 59 | self.fc1b_fc = nn.Linear(self.hyperInChn, self.f1) 60 | 61 | self.fc2w_conv = nn.Conv2d(self.hyperInChn, int(self.f1 * self.f2 / feature_size ** 2), 3, padding=(1, 1)) 62 | self.fc2b_fc = nn.Linear(self.hyperInChn, self.f2) 63 | 64 | self.fc3w_conv = nn.Conv2d(self.hyperInChn, int(self.f2 * self.f3 / feature_size ** 2), 3, padding=(1, 1)) 65 | self.fc3b_fc = nn.Linear(self.hyperInChn, self.f3) 66 | 67 | self.fc4w_conv = nn.Conv2d(self.hyperInChn, int(self.f3 * self.f4 / feature_size ** 2), 3, padding=(1, 1)) 68 | self.fc4b_fc = nn.Linear(self.hyperInChn, self.f4) 69 | 70 | self.fc5w_fc = nn.Linear(self.hyperInChn, self.f4) 71 | self.fc5b_fc = nn.Linear(self.hyperInChn, 1) 72 | 73 | # initialize 74 | for i, m_name in enumerate(self._modules): 75 | if i > 2: 76 | nn.init.kaiming_normal_(self._modules[m_name].weight.data) 77 | 78 | def forward(self, img): 79 | feature_size = self.feature_size 80 | 81 | res_out = self.res(img) 82 | 83 | # input vector for target net 84 | target_in_vec = res_out['target_in_vec'].view(-1, self.target_in_size, 1, 1) 85 | 86 | # input features for hyper net 87 | hyper_in_feat = self.conv1(res_out['hyper_in_feat']).view(-1, self.hyperInChn, feature_size, feature_size) 88 | 89 | # generating target net weights & biases 90 | target_fc1w = self.fc1w_conv(hyper_in_feat).view(-1, self.f1, self.target_in_size, 1, 1) 91 | target_fc1b = self.fc1b_fc(self.pool(hyper_in_feat).squeeze()).view(-1, self.f1) 92 | 93 | target_fc2w = self.fc2w_conv(hyper_in_feat).view(-1, self.f2, self.f1, 1, 1) 94 | target_fc2b = self.fc2b_fc(self.pool(hyper_in_feat).squeeze()).view(-1, self.f2) 95 | 96 | target_fc3w = self.fc3w_conv(hyper_in_feat).view(-1, self.f3, self.f2, 1, 1) 97 | target_fc3b = self.fc3b_fc(self.pool(hyper_in_feat).squeeze()).view(-1, self.f3) 98 | 99 | target_fc4w = self.fc4w_conv(hyper_in_feat).view(-1, self.f4, self.f3, 1, 1) 100 | target_fc4b = self.fc4b_fc(self.pool(hyper_in_feat).squeeze()).view(-1, self.f4) 101 | 102 | target_fc5w = self.fc5w_fc(self.pool(hyper_in_feat).squeeze()).view(-1, 1, self.f4, 1, 1) 103 | target_fc5b = self.fc5b_fc(self.pool(hyper_in_feat).squeeze()).view(-1, 1) 104 | 105 | out = {} 106 | out['target_in_vec'] = target_in_vec 107 | out['target_fc1w'] = target_fc1w 108 | out['target_fc1b'] = target_fc1b 109 | out['target_fc2w'] = target_fc2w 110 | out['target_fc2b'] = target_fc2b 111 | out['target_fc3w'] = target_fc3w 112 | out['target_fc3b'] = target_fc3b 113 | out['target_fc4w'] = target_fc4w 114 | out['target_fc4b'] = target_fc4b 115 | out['target_fc5w'] = target_fc5w 116 | out['target_fc5b'] = target_fc5b 117 | 118 | return out 119 | 120 | 121 | class TargetNet(nn.Module): 122 | """ 123 | Target network for quality prediction. 124 | """ 125 | def __init__(self, paras): 126 | super(TargetNet, self).__init__() 127 | self.l1 = nn.Sequential( 128 | TargetFC(paras['target_fc1w'], paras['target_fc1b']), 129 | nn.Sigmoid(), 130 | ) 131 | self.l2 = nn.Sequential( 132 | TargetFC(paras['target_fc2w'], paras['target_fc2b']), 133 | nn.Sigmoid(), 134 | ) 135 | 136 | self.l3 = nn.Sequential( 137 | TargetFC(paras['target_fc3w'], paras['target_fc3b']), 138 | nn.Sigmoid(), 139 | ) 140 | 141 | self.l4 = nn.Sequential( 142 | TargetFC(paras['target_fc4w'], paras['target_fc4b']), 143 | nn.Sigmoid(), 144 | TargetFC(paras['target_fc5w'], paras['target_fc5b']), 145 | ) 146 | 147 | def forward(self, x): 148 | q = self.l1(x) 149 | # q = F.dropout(q) 150 | q = self.l2(q) 151 | q = self.l3(q) 152 | q = self.l4(q).squeeze() 153 | return q 154 | 155 | 156 | class TargetFC(nn.Module): 157 | """ 158 | Fully connection operations for target net 159 | 160 | Note: 161 | Weights & biases are different for different images in a batch, 162 | thus here we use group convolution for calculating images in a batch with individual weights & biases. 163 | """ 164 | def __init__(self, weight, bias): 165 | super(TargetFC, self).__init__() 166 | self.weight = weight 167 | self.bias = bias 168 | 169 | def forward(self, input_): 170 | 171 | input_re = input_.view(-1, input_.shape[0] * input_.shape[1], input_.shape[2], input_.shape[3]) 172 | weight_re = self.weight.view(self.weight.shape[0] * self.weight.shape[1], self.weight.shape[2], self.weight.shape[3], self.weight.shape[4]) 173 | bias_re = self.bias.view(self.bias.shape[0] * self.bias.shape[1]) 174 | out = F.conv2d(input=input_re, weight=weight_re, bias=bias_re, groups=self.weight.shape[0]) 175 | 176 | return out.view(input_.shape[0], self.weight.shape[1], input_.shape[2], input_.shape[3]) 177 | 178 | 179 | class Bottleneck(nn.Module): 180 | expansion = 4 181 | 182 | def __init__(self, inplanes, planes, stride=1, downsample=None): 183 | super(Bottleneck, self).__init__() 184 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 185 | self.bn1 = nn.BatchNorm2d(planes) 186 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 187 | padding=1, bias=False) 188 | self.bn2 = nn.BatchNorm2d(planes) 189 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 190 | self.bn3 = nn.BatchNorm2d(planes * 4) 191 | self.relu = nn.ReLU(inplace=True) 192 | self.downsample = downsample 193 | self.stride = stride 194 | 195 | def forward(self, x): 196 | residual = x 197 | 198 | out = self.conv1(x) 199 | out = self.bn1(out) 200 | out = self.relu(out) 201 | 202 | out = self.conv2(out) 203 | out = self.bn2(out) 204 | out = self.relu(out) 205 | 206 | out = self.conv3(out) 207 | out = self.bn3(out) 208 | 209 | if self.downsample is not None: 210 | residual = self.downsample(x) 211 | 212 | out += residual 213 | out = self.relu(out) 214 | 215 | return out 216 | 217 | 218 | class ResNetBackbone(nn.Module): 219 | 220 | def __init__(self, lda_out_channels, in_chn, block, layers, num_classes=1000): 221 | super(ResNetBackbone, self).__init__() 222 | self.inplanes = 64 223 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 224 | self.bn1 = nn.BatchNorm2d(64) 225 | self.relu = nn.ReLU(inplace=True) 226 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 227 | self.layer1 = self._make_layer(block, 64, layers[0]) 228 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 229 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 230 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 231 | 232 | # local distortion aware module 233 | self.lda1_pool = nn.Sequential( 234 | nn.Conv2d(256, 16, kernel_size=1, stride=1, padding=0, bias=False), 235 | nn.AvgPool2d(7, stride=7), 236 | ) 237 | self.lda1_fc = nn.Linear(16 * 64, lda_out_channels) 238 | 239 | self.lda2_pool = nn.Sequential( 240 | nn.Conv2d(512, 32, kernel_size=1, stride=1, padding=0, bias=False), 241 | nn.AvgPool2d(7, stride=7), 242 | ) 243 | self.lda2_fc = nn.Linear(32 * 16, lda_out_channels) 244 | 245 | self.lda3_pool = nn.Sequential( 246 | nn.Conv2d(1024, 64, kernel_size=1, stride=1, padding=0, bias=False), 247 | nn.AvgPool2d(7, stride=7), 248 | ) 249 | self.lda3_fc = nn.Linear(64 * 4, lda_out_channels) 250 | 251 | self.lda4_pool = nn.AvgPool2d(7, stride=7) 252 | self.lda4_fc = nn.Linear(2048, in_chn - lda_out_channels * 3) 253 | 254 | for m in self.modules(): 255 | if isinstance(m, nn.Conv2d): 256 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 257 | m.weight.data.normal_(0, math.sqrt(2. / n)) 258 | elif isinstance(m, nn.BatchNorm2d): 259 | m.weight.data.fill_(1) 260 | m.bias.data.zero_() 261 | 262 | # initialize 263 | nn.init.kaiming_normal_(self.lda1_pool._modules['0'].weight.data) 264 | nn.init.kaiming_normal_(self.lda2_pool._modules['0'].weight.data) 265 | nn.init.kaiming_normal_(self.lda3_pool._modules['0'].weight.data) 266 | nn.init.kaiming_normal_(self.lda1_fc.weight.data) 267 | nn.init.kaiming_normal_(self.lda2_fc.weight.data) 268 | nn.init.kaiming_normal_(self.lda3_fc.weight.data) 269 | nn.init.kaiming_normal_(self.lda4_fc.weight.data) 270 | 271 | def _make_layer(self, block, planes, blocks, stride=1): 272 | downsample = None 273 | if stride != 1 or self.inplanes != planes * block.expansion: 274 | downsample = nn.Sequential( 275 | nn.Conv2d(self.inplanes, planes * block.expansion, 276 | kernel_size=1, stride=stride, bias=False), 277 | nn.BatchNorm2d(planes * block.expansion), 278 | ) 279 | 280 | layers = [] 281 | layers.append(block(self.inplanes, planes, stride, downsample)) 282 | self.inplanes = planes * block.expansion 283 | for i in range(1, blocks): 284 | layers.append(block(self.inplanes, planes)) 285 | 286 | return nn.Sequential(*layers) 287 | 288 | def forward(self, x): 289 | x = self.conv1(x) 290 | x = self.bn1(x) 291 | x = self.relu(x) 292 | x = self.maxpool(x) 293 | x = self.layer1(x) 294 | 295 | # the same effect as lda operation in the paper, but save much more memory 296 | lda_1 = self.lda1_fc(self.lda1_pool(x).view(x.size(0), -1)) 297 | x = self.layer2(x) 298 | lda_2 = self.lda2_fc(self.lda2_pool(x).view(x.size(0), -1)) 299 | x = self.layer3(x) 300 | lda_3 = self.lda3_fc(self.lda3_pool(x).view(x.size(0), -1)) 301 | x = self.layer4(x) 302 | lda_4 = self.lda4_fc(self.lda4_pool(x).view(x.size(0), -1)) 303 | 304 | vec = torch.cat((lda_1, lda_2, lda_3, lda_4), 1) 305 | 306 | out = {} 307 | out['hyper_in_feat'] = x 308 | out['target_in_vec'] = vec 309 | 310 | return out 311 | 312 | 313 | def resnet50_backbone(lda_out_channels, in_chn, pretrained=False, **kwargs): 314 | """Constructs a ResNet-50 model_hyper. 315 | 316 | Args: 317 | pretrained (bool): If True, returns a model_hyper pre-trained on ImageNet 318 | """ 319 | model = ResNetBackbone(lda_out_channels, in_chn, Bottleneck, [3, 4, 6, 3], **kwargs) 320 | if pretrained: 321 | save_model = model_zoo.load_url(model_urls['resnet50']) 322 | model_dict = model.state_dict() 323 | state_dict = {k: v for k, v in save_model.items() if k in model_dict.keys()} 324 | model_dict.update(state_dict) 325 | model.load_state_dict(model_dict) 326 | else: 327 | model.apply(weights_init_xavier) 328 | return model 329 | 330 | 331 | def weights_init_xavier(m): 332 | classname = m.__class__.__name__ 333 | # print(classname) 334 | # if isinstance(m, nn.Conv2d): 335 | if classname.find('Conv') != -1: 336 | init.kaiming_normal_(m.weight.data) 337 | elif classname.find('Linear') != -1: 338 | init.kaiming_normal_(m.weight.data) 339 | elif classname.find('BatchNorm2d') != -1: 340 | init.uniform_(m.weight.data, 1.0, 0.02) 341 | init.constant_(m.bias.data, 0.0) 342 | -------------------------------------------------------------------------------- /folders/folders_LQ_HQ.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as data 3 | import torchvision 4 | from PIL import Image 5 | import os 6 | import os.path 7 | import scipy.io 8 | import numpy as np 9 | import csv 10 | import random 11 | from openpyxl import load_workbook 12 | import cv2 13 | from torchvision import transforms 14 | 15 | class Kadid10kFolder(data.Dataset): 16 | 17 | def __init__(self, root, index, patch_num, patch_size=224, self_patch_num=1): 18 | self.patch_size = patch_size 19 | self.self_patch_num = self_patch_num 20 | 21 | imgname = [] 22 | refimgname = [] 23 | mos_all = [] 24 | csv_file = os.path.join(root, 'dmos.csv') 25 | with open(csv_file) as f: 26 | reader = csv.DictReader(f) 27 | for row in reader: 28 | imgname.append(row['dist_img']) 29 | refimgname.append(row['ref_img']) 30 | mos = np.array(float(row['dmos'])).astype(np.float32) 31 | mos_all.append(mos) 32 | 33 | sample = [] 34 | for i, item in enumerate(index): 35 | for aug in range(patch_num): 36 | sample.append((os.path.join(root, 'images', imgname[item]),os.path.join(root, 'images', refimgname[item]), mos_all[item])) 37 | 38 | self.samples = sample 39 | self.transform = torchvision.transforms.Compose([ 40 | torchvision.transforms.ToTensor(), 41 | torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406), 42 | std=(0.229, 0.224, 0.225)) 43 | ]) 44 | 45 | def __getitem__(self, index): 46 | """ 47 | Args: 48 | index (int): Index 49 | Returns: 50 | tuple: (LQ, HQ, target) where target is IQA values of the target LQ. 51 | """ 52 | LQ_path, HQ_path, target = self.samples[index] 53 | LQ = pil_loader(LQ_path) 54 | HQ = pil_loader(HQ_path) 55 | LQ_patches, HQ_patches = [], [] 56 | for _ in range(self.self_patch_num): 57 | LQ_patch, HQ_patch = getPairRandomPatch(LQ,HQ, crop_size=self.patch_size) 58 | LQ_patch, HQ_patch = getPairAugment(LQ_patch, HQ_patch) 59 | 60 | LQ_patch = self.transform(LQ_patch) 61 | HQ_patch = self.transform(HQ_patch) 62 | 63 | LQ_patches.append(LQ_patch.unsqueeze(0)) 64 | HQ_patches.append(HQ_patch.unsqueeze(0)) 65 | #[self_patch_num, 3, patch_size, patch_size] 66 | LQ_patches = torch.cat(LQ_patches, 0) 67 | HQ_patches = torch.cat(HQ_patches, 0) 68 | 69 | return LQ_patches, HQ_patches, target 70 | 71 | def __len__(self): 72 | length = len(self.samples) 73 | return length 74 | 75 | class LIVEFolder(data.Dataset): 76 | 77 | def __init__(self, root, index, patch_num, patch_size=224, self_patch_num=1): 78 | self.patch_size =patch_size 79 | self.self_patch_num = self_patch_num 80 | 81 | refpath = os.path.join(root, 'refimgs') 82 | refname = getFileName(refpath, '.bmp') 83 | 84 | jp2kroot = os.path.join(root, 'jp2k') 85 | jp2kname = self.getDistortionTypeFileName(jp2kroot, 227) 86 | 87 | jpegroot = os.path.join(root, 'jpeg') 88 | jpegname = self.getDistortionTypeFileName(jpegroot, 233) 89 | 90 | wnroot = os.path.join(root, 'wn') 91 | wnname = self.getDistortionTypeFileName(wnroot, 174) 92 | 93 | gblurroot = os.path.join(root, 'gblur') 94 | gblurname = self.getDistortionTypeFileName(gblurroot, 174) 95 | 96 | fastfadingroot = os.path.join(root, 'fastfading') 97 | fastfadingname = self.getDistortionTypeFileName(fastfadingroot, 174) 98 | 99 | imgpath = jp2kname + jpegname + wnname + gblurname + fastfadingname 100 | 101 | dmos = scipy.io.loadmat(os.path.join(root, 'dmos_realigned.mat')) 102 | labels = dmos['dmos_new'].astype(np.float32) 103 | 104 | orgs = dmos['orgs'] 105 | refnames_all = scipy.io.loadmat(os.path.join(root, 'refnames_all.mat')) 106 | refnames_all = refnames_all['refnames_all'] 107 | 108 | sample = [] 109 | for i in range(0, len(index)): 110 | train_sel = (refname[index[i]] == refnames_all) 111 | train_sel = train_sel * ~orgs.astype(np.bool_) 112 | train_sel = np.where(train_sel == True) 113 | train_sel = train_sel[1].tolist() 114 | for j, item in enumerate(train_sel): 115 | for aug in range(patch_num): 116 | LQ_path = imgpath[item] 117 | HQ_path = os.path.join(root, 'refimgs', refnames_all[0][item][0]) 118 | label = labels[0][item] 119 | sample.append((LQ_path, HQ_path, label)) 120 | 121 | self.samples = sample 122 | self.transform = torchvision.transforms.Compose([ 123 | torchvision.transforms.ToTensor(), 124 | torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406), 125 | std=(0.229, 0.224, 0.225)) 126 | ]) 127 | 128 | def __getitem__(self, index): 129 | """ 130 | Args: 131 | index (int): Index 132 | Returns: 133 | tuple: (LQ, HQ, target) where target is IQA values of the target LQ. 134 | """ 135 | LQ_path, HQ_path, target = self.samples[index] 136 | LQ = pil_loader(LQ_path) 137 | HQ = pil_loader(HQ_path) 138 | LQ_patches, HQ_patches = [], [] 139 | for _ in range(self.self_patch_num): 140 | LQ_patch, HQ_patch = getPairRandomPatch(LQ,HQ, crop_size=self.patch_size) 141 | LQ_patch, HQ_patch = getPairAugment(LQ_patch, HQ_patch) 142 | 143 | LQ_patch = self.transform(LQ_patch) 144 | HQ_patch = self.transform(HQ_patch) 145 | 146 | LQ_patches.append(LQ_patch.unsqueeze(0)) 147 | HQ_patches.append(HQ_patch.unsqueeze(0)) 148 | LQ_patches = torch.cat(LQ_patches, 0) 149 | HQ_patches = torch.cat(HQ_patches, 0) 150 | 151 | return LQ_patches, HQ_patches, target 152 | 153 | def __len__(self): 154 | length = len(self.samples) 155 | return length 156 | 157 | def getDistortionTypeFileName(self, path, num): 158 | filename = [] 159 | index = 1 160 | for i in range(0, num): 161 | name = '{:0>3d}{}'.format(index, '.bmp') 162 | filename.append(os.path.join(path, name)) 163 | index = index + 1 164 | return filename 165 | 166 | class CSIQFolder(data.Dataset): 167 | 168 | def __init__(self, root, index, patch_num, patch_size =224, self_patch_num=1): 169 | self.patch_size =patch_size 170 | self.self_patch_num = self_patch_num 171 | 172 | refpath = os.path.join(root, 'src_imgs') 173 | refname = getFileName(refpath,'.png') 174 | txtpath = os.path.join(root, 'csiq_label.txt') 175 | fh = open(txtpath, 'r') 176 | imgnames = [] 177 | target = [] 178 | refnames_all = [] 179 | for line in fh: 180 | line = line.split('\n') 181 | words = line[0].split() 182 | imgnames.append((words[0])) 183 | target.append(words[1]) 184 | ref_temp = words[0].split(".") 185 | refnames_all.append(ref_temp[0] + '.' + ref_temp[-1]) 186 | 187 | labels = np.array(target).astype(np.float32) 188 | refnames_all = np.array(refnames_all) 189 | 190 | sample = [] 191 | 192 | for i, item in enumerate(index): 193 | train_sel = (refname[index[i]] == refnames_all) 194 | train_sel = np.where(train_sel == True) 195 | train_sel = train_sel[0].tolist() 196 | for j, item in enumerate(train_sel): 197 | for aug in range(patch_num): 198 | LQ_path = os.path.join(root, 'dst_imgs_all', imgnames[item]) 199 | HQ_path = os.path.join(root, 'src_imgs', refnames_all[item]) 200 | label = labels[item] 201 | sample.append((LQ_path, HQ_path, label)) 202 | self.samples = sample 203 | self.transform = torchvision.transforms.Compose([ 204 | torchvision.transforms.ToTensor(), 205 | torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406), 206 | std=(0.229, 0.224, 0.225)) 207 | ]) 208 | 209 | def __getitem__(self, index): 210 | """ 211 | Args: 212 | index (int): Index 213 | Returns: 214 | tuple: (LQ, HQ, target) where target is IQA values of the target LQ. 215 | """ 216 | LQ_path, HQ_path, target = self.samples[index] 217 | LQ = pil_loader(LQ_path) 218 | HQ = pil_loader(HQ_path) 219 | LQ_patches, HQ_patches = [], [] 220 | for _ in range(self.self_patch_num): 221 | LQ_patch, HQ_patch = getPairRandomPatch(LQ,HQ, crop_size=self.patch_size) 222 | LQ_patch, HQ_patch = getPairAugment(LQ_patch, HQ_patch) 223 | 224 | LQ_patch = self.transform(LQ_patch) 225 | HQ_patch = self.transform(HQ_patch) 226 | 227 | LQ_patches.append(LQ_patch.unsqueeze(0)) 228 | HQ_patches.append(HQ_patch.unsqueeze(0)) 229 | LQ_patches = torch.cat(LQ_patches, 0) 230 | HQ_patches = torch.cat(HQ_patches, 0) 231 | 232 | return LQ_patches, HQ_patches, target 233 | 234 | def __len__(self): 235 | length = len(self.samples) 236 | return length 237 | 238 | class TID2013Folder(data.Dataset): 239 | 240 | def __init__(self, root, index, patch_num, patch_size=224, self_patch_num=1): 241 | self.patch_size =patch_size 242 | self.self_patch_num = self_patch_num 243 | 244 | refpath = os.path.join(root, 'reference_images') 245 | refname = self._getTIDFileName(refpath,'.bmp.BMP') 246 | txtpath = os.path.join(root, 'mos_with_names.txt') 247 | fh = open(txtpath, 'r') 248 | imgnames = [] 249 | target = [] 250 | refnames_all = [] 251 | for line in fh: 252 | line = line.split('\n') 253 | words = line[0].split() 254 | imgnames.append((words[1])) 255 | target.append(words[0]) 256 | ref_temp = words[1].split("_") 257 | refnames_all.append(ref_temp[0][1:]) 258 | labels = np.array(target).astype(np.float32) 259 | refnames_all = np.array(refnames_all) 260 | 261 | sample = [] 262 | for i, item in enumerate(index): 263 | train_sel = (refname[index[i]] == refnames_all) 264 | train_sel = np.where(train_sel == True) 265 | train_sel = train_sel[0].tolist() 266 | for j, item in enumerate(train_sel): 267 | for aug in range(patch_num): 268 | LQ_path = os.path.join(root, 'distorted_images', imgnames[item]) 269 | HQ_name = 'I' + imgnames[item].split("_")[0][1:] + '.BMP' 270 | HQ_path = os.path.join(refpath, HQ_name) 271 | label = labels[item] 272 | sample.append((LQ_path, HQ_path, label)) 273 | self.samples = sample 274 | self.transform = self.transform = torchvision.transforms.Compose([ 275 | torchvision.transforms.ToTensor(), 276 | torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406), 277 | std=(0.229, 0.224, 0.225)) 278 | ]) 279 | 280 | def _getTIDFileName(self, path, suffix): 281 | filename = [] 282 | f_list = os.listdir(path) 283 | for i in f_list: 284 | if suffix.find(os.path.splitext(i)[1]) != -1: 285 | filename.append(i[1:3]) 286 | return filename 287 | 288 | def __getitem__(self, index): 289 | """ 290 | Args: 291 | index (int): Index 292 | Returns: 293 | tuple: (LQ, HQ, target) where target is IQA values of the target LQ. 294 | """ 295 | LQ_path, HQ_path, target = self.samples[index] 296 | LQ = pil_loader(LQ_path) 297 | HQ = pil_loader(HQ_path) 298 | LQ_patches, HQ_patches = [], [] 299 | for _ in range(self.self_patch_num): 300 | LQ_patch, HQ_patch = getPairRandomPatch(LQ,HQ, crop_size=self.patch_size) 301 | LQ_patch, HQ_patch = getPairAugment(LQ_patch, HQ_patch) 302 | 303 | LQ_patch = self.transform(LQ_patch) 304 | HQ_patch = self.transform(HQ_patch) 305 | 306 | LQ_patches.append(LQ_patch.unsqueeze(0)) 307 | HQ_patches.append(HQ_patch.unsqueeze(0)) 308 | LQ_patches = torch.cat(LQ_patches, 0) 309 | HQ_patches = torch.cat(HQ_patches, 0) 310 | 311 | return LQ_patches, HQ_patches, target 312 | 313 | def __len__(self): 314 | length = len(self.samples) 315 | return length 316 | 317 | 318 | def getFileName(path, suffix): 319 | filename = [] 320 | f_list = os.listdir(path) 321 | for i in f_list: 322 | if os.path.splitext(i)[1] == suffix: 323 | filename.append(i) 324 | return filename 325 | 326 | 327 | def getPairRandomPatch(img1, img2, crop_size=512): 328 | (iw,ih) = img1.size 329 | # print(ih,iw) 330 | 331 | ip = int(crop_size) 332 | 333 | ix = random.randrange(0, iw - ip + 1) 334 | iy = random.randrange(0, ih - ip + 1) 335 | 336 | 337 | img1_patch = img1.crop((ix, iy, ix+ip, iy+ip))#左上右下 338 | img2_patch = img2.crop((ix, iy, ix+ip, iy+ip))#左上右下 339 | 340 | return img1_patch, img2_patch 341 | 342 | def getPairAugment(img1, img2, hflip=True, vflip=True, rot=True): 343 | hflip = hflip and random.random() < 0.5 344 | vflip = vflip and random.random() < 0.5 345 | rot180 = rot and random.random() < 0.5 346 | 347 | if hflip: 348 | img1 = img1.transpose(Image.FLIP_TOP_BOTTOM) 349 | img2 = img2.transpose(Image.FLIP_TOP_BOTTOM) 350 | if vflip: 351 | img1 = img1.transpose(Image.FLIP_LEFT_RIGHT) 352 | img2 = img2.transpose(Image.FLIP_LEFT_RIGHT) 353 | if rot180: 354 | img1 = img1.transpose(Image.ROTATE_180) 355 | img2 = img2.transpose(Image.ROTATE_180) 356 | 357 | return img1, img2 358 | 359 | 360 | def getSelfPatch(img, patch_size, patch_num, is_random=True): 361 | (iw,ih) = img.size 362 | patches = [] 363 | for i in range(patch_num): 364 | if is_random: 365 | ix = random.randrange(0, iw - patch_size + 1) 366 | iy = random.randrange(0, ih - patch_size + 1) 367 | else:ix,iy=(iw - patch_size + 1)//2,(ih - patch_size + 1)//2 368 | 369 | # patch = img[iy:iy + lr_size, ix:ix + lr_size, :]#上下左右 370 | patch = img.crop((ix, iy, ix+patch_size, iy+patch_size))#左上右下 371 | patches.append(patch) 372 | 373 | return patches 374 | 375 | 376 | def pil_loader(path): 377 | with open(path, 'rb') as f: 378 | img = Image.open(f) 379 | return img.convert('RGB') 380 | -------------------------------------------------------------------------------- /models/DistillationIQA.py: -------------------------------------------------------------------------------- 1 | import torch as torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | from torch.nn import init 5 | import math 6 | import torch.utils.model_zoo as model_zoo 7 | # from SemanticResNet50 import ResNetBackbone, Bottleneck 8 | # from models.ResNet50_MLP import ResNetBackbone 9 | # from models.MLP_return_inner_feature import MLPMixer 10 | 11 | from functools import partial 12 | from einops.layers.torch import Rearrange, Reduce 13 | 14 | #ResNet 15 | model_urls = { 16 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 17 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 18 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 19 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 20 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 21 | } 22 | 23 | class PatchEmbed(nn.Module): 24 | """ Feature to Patch Embedding 25 | input : N C H W 26 | output: N num_patch P^2*C 27 | """ 28 | 29 | def __init__(self, patch_size=7, in_channels=2048): 30 | super().__init__() 31 | self.patch_size = patch_size 32 | self.dim = self.patch_size ** 2 * in_channels 33 | 34 | def forward(self, x): 35 | N, C, H, W = ori_shape = x.shape 36 | 37 | p = self.patch_size 38 | num_patches = (H // p) * (W // p) 39 | 40 | fold_out = torch.nn.functional.unfold(x, (p, p), stride=p) # B, num_dim, num_patch 41 | out = fold_out.permute(0, 2, 1 )# B, num_patch, num_dim 42 | 43 | return out 44 | 45 | class Bottleneck(nn.Module): 46 | expansion = 4 47 | 48 | def __init__(self, inplanes, planes, stride=1, downsample=None): 49 | super(Bottleneck, self).__init__() 50 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 51 | self.bn1 = nn.BatchNorm2d(planes) 52 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 53 | padding=1, bias=False) 54 | self.bn2 = nn.BatchNorm2d(planes) 55 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 56 | self.bn3 = nn.BatchNorm2d(planes * 4) 57 | self.relu = nn.ReLU(inplace=True) 58 | self.downsample = downsample 59 | self.stride = stride 60 | 61 | def forward(self, x): 62 | residual = x 63 | 64 | out = self.conv1(x) 65 | out = self.bn1(out) 66 | out = self.relu(out) 67 | 68 | out = self.conv2(out) 69 | out = self.bn2(out) 70 | out = self.relu(out) 71 | 72 | out = self.conv3(out) 73 | out = self.bn3(out) 74 | 75 | if self.downsample is not None: 76 | residual = self.downsample(x) 77 | 78 | out += residual 79 | out = self.relu(out) 80 | 81 | return out 82 | 83 | class ResNetBackbone(nn.Module): 84 | 85 | # def __init__(self, outc=176, block=Bottleneck, layers=[3, 4, 6, 3], num_classes=1000, pretrained=True): 86 | def __init__(self, outc=2048, block=Bottleneck, layers=[3, 4, 6, 3], pretrained=True): 87 | super(ResNetBackbone, self).__init__() 88 | self.pretrained = pretrained 89 | self.inplanes = 64 90 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 91 | self.bn1 = nn.BatchNorm2d(64) 92 | self.relu = nn.ReLU(inplace=True) 93 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 94 | self.layer1 = self._make_layer(block, 64, layers[0]) 95 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 96 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 97 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 98 | 99 | self.lda_out_channels = int(outc // 4) 100 | 101 | # local distortion aware module 102 | self.lda1_pool = nn.Sequential( 103 | nn.Conv2d(256, 16, kernel_size=1, stride=1, padding=0, bias=False), 104 | nn.AvgPool2d(7, stride=7), 105 | ) 106 | self.lda1_fc = nn.Linear(16 * 64, self.lda_out_channels) 107 | 108 | self.lda2_pool = nn.Sequential( 109 | nn.Conv2d(512, 32, kernel_size=1, stride=1, padding=0, bias=False), 110 | nn.AvgPool2d(7, stride=7), 111 | ) 112 | self.lda2_fc = nn.Linear(32 * 16, self.lda_out_channels) 113 | 114 | self.lda3_pool = nn.Sequential( 115 | nn.Conv2d(1024, 64, kernel_size=1, stride=1, padding=0, bias=False), 116 | nn.AvgPool2d(7, stride=7), 117 | ) 118 | self.lda3_fc = nn.Linear(64 * 4, self.lda_out_channels) 119 | 120 | self.lda4_pool = nn.AvgPool2d(7, stride=7) 121 | self.lda4_fc = nn.Linear(2048, self.lda_out_channels) 122 | 123 | for m in self.modules(): 124 | if isinstance(m, nn.Conv2d): 125 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 126 | m.weight.data.normal_(0, math.sqrt(2. / n)) 127 | elif isinstance(m, nn.BatchNorm2d): 128 | m.weight.data.fill_(1) 129 | m.bias.data.zero_() 130 | if self.pretrained: 131 | self.load_resnet50_backbone() 132 | 133 | def _make_layer(self, block, planes, blocks, stride=1): 134 | downsample = None 135 | if stride != 1 or self.inplanes != planes * block.expansion: 136 | downsample = nn.Sequential( 137 | nn.Conv2d(self.inplanes, planes * block.expansion, 138 | kernel_size=1, stride=stride, bias=False), 139 | nn.BatchNorm2d(planes * block.expansion), 140 | ) 141 | 142 | layers = [] 143 | layers.append(block(self.inplanes, planes, stride, downsample)) 144 | self.inplanes = planes * block.expansion 145 | for i in range(1, blocks): 146 | layers.append(block(self.inplanes, planes)) 147 | 148 | return nn.Sequential(*layers) 149 | 150 | def cal_params(self): 151 | params = list(self.parameters()) 152 | k = 0 153 | for i in params: 154 | l = 1 155 | for j in i.size(): 156 | l *= j 157 | k = k + l 158 | print("Total parameters is :" + str(k)) 159 | 160 | def forward(self, x): 161 | 162 | x = self.conv1(x) 163 | x = self.bn1(x) 164 | x = self.relu(x) 165 | x = self.maxpool(x) 166 | x = self.layer1(x) 167 | # lda_1 = self.lda1_fc(self.lda1_pool(x).view(x.size(0), -1)) 168 | lda_1 = x #[b, 256, 56, 56] 169 | x = self.layer2(x) 170 | # lda_2 = self.lda2_fc(self.lda2_pool(x).view(x.size(0), -1)) 171 | lda_2 = x #[b, 512, 28, 28] 172 | x = self.layer3(x) 173 | # lda_3 = self.lda3_fc(self.lda3_pool(x).view(x.size(0), -1)) 174 | lda_3 = x #[b, 1024, 14, 14] 175 | x = self.layer4(x) 176 | # lda_4 = self.lda4_fc(self.lda4_pool(x).view(x.size(0), -1)) 177 | lda_4 = x #[b, 2048, 7, 7] 178 | # return x #[b, 2048, 7, 7] 179 | return [lda_1, lda_2, lda_3, lda_4] 180 | 181 | def load_resnet50_backbone(self): 182 | """Constructs a ResNet-50 model_hyper. 183 | Args: 184 | pretrained (bool): If True, returns a model_hyper pre-trained on ImageNet 185 | """ 186 | save_model = model_zoo.load_url(model_urls['resnet50']) 187 | model_dict = self.state_dict() 188 | state_dict = {k: v for k, v in save_model.items() if k in model_dict.keys()} 189 | model_dict.update(state_dict) 190 | self.load_state_dict(model_dict) 191 | 192 | def weights_init_xavier(m): 193 | classname = m.__class__.__name__ 194 | # print(classname) 195 | # if isinstance(m, nn.Conv2d): 196 | if classname.find('Conv') != -1: 197 | init.kaiming_normal_(m.weight.data) 198 | elif classname.find('Linear') != -1: 199 | init.kaiming_normal_(m.weight.data) 200 | elif classname.find('BatchNorm2d') != -1: 201 | init.uniform_(m.weight.data, 1.0, 0.02) 202 | init.constant_(m.bias.data, 0.0) 203 | 204 | #MLP 205 | class PreNormResidual(nn.Module): 206 | def __init__(self, dim, fn): 207 | super().__init__() 208 | self.fn = fn 209 | self.norm = nn.LayerNorm(dim) 210 | 211 | def forward(self, x): 212 | return self.fn(self.norm(x)) + x 213 | 214 | class MLPMixer(nn.Module): 215 | def __init__(self, image_size, channels, patch_size, dim, depth, expansion_factor = 4, dropout = 0.): 216 | super().__init__() 217 | assert (image_size % patch_size) == 0, 'image must be divisible by patch size' 218 | self.num_patches = (image_size // patch_size) ** 2 219 | self.chan_first, self.chan_last = partial(nn.Conv1d, kernel_size = 1), nn.Linear 220 | 221 | self.mlp = nn.Sequential( 222 | Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size), 223 | nn.Linear((patch_size ** 2) * channels, dim), 224 | *[nn.Sequential( 225 | PreNormResidual(dim, self.FeedForward(self.num_patches, expansion_factor, dropout, self.chan_first)), 226 | PreNormResidual(dim, self.FeedForward(dim, expansion_factor, dropout, self.chan_last)) 227 | ) for _ in range(depth)], 228 | nn.LayerNorm(dim), 229 | Reduce('b n c -> b c', 'mean'), 230 | # nn.Linear(dim, num_classes) 231 | ) 232 | # print(self.mlp) 233 | 234 | def FeedForward(self, dim, expansion_factor = 4, dropout = 0., dense = nn.Linear): 235 | return nn.Sequential( 236 | dense(dim, dim * expansion_factor), 237 | nn.GELU(), 238 | nn.Dropout(dropout), 239 | dense(dim * expansion_factor, dim), 240 | nn.Dropout(dropout) 241 | ) 242 | 243 | def forward(self, x, distillation_layer_num=None): 244 | # [3, 256*self_patch_num, 7, 7] 245 | mlp_inner_feature = [] 246 | layer_idx = 0 247 | for mlp_single in self.mlp: 248 | x = mlp_single(x) 249 | mlp_inner_feature.append(x) 250 | if distillation_layer_num: 251 | return x, mlp_inner_feature[-distillation_layer_num-2:-2] 252 | else: 253 | return x, mlp_inner_feature 254 | 255 | 256 | def initialize_weights(net_l, scale=1): 257 | if not isinstance(net_l, list): 258 | net_l = [net_l] 259 | for net in net_l: 260 | for m in net.modules(): 261 | if isinstance(m, nn.Conv2d): 262 | init.kaiming_normal_(m.weight, a=0, mode='fan_in') 263 | m.weight.data *= scale # for residual block 264 | if m.bias is not None: 265 | m.bias.data.zero_() 266 | elif isinstance(m, nn.Linear): 267 | init.kaiming_normal_(m.weight, a=0, mode='fan_in') 268 | m.weight.data *= scale 269 | if m.bias is not None: 270 | m.bias.data.zero_() 271 | elif isinstance(m, nn.BatchNorm2d): 272 | init.constant_(m.weight, 1) 273 | init.constant_(m.bias.data, 0.0) 274 | else: 275 | pass 276 | 277 | class RegressionFCNet(nn.Module): 278 | """ 279 | Target network for quality prediction. 280 | """ 281 | def __init__(self): 282 | super(RegressionFCNet, self).__init__() 283 | self.target_in_size=512 284 | self.target_fc1_size=256 285 | 286 | self.sigmoid = nn.Sigmoid() 287 | self.l1 = nn.Linear(self.target_in_size, self.target_fc1_size) 288 | self.relu1 = nn.PReLU() 289 | self.drop1 = nn.Dropout(0.5) 290 | self.bn1 = nn.BatchNorm1d(256) 291 | 292 | self.l2 = nn.Linear(self.target_fc1_size, 1) 293 | 294 | 295 | def forward(self, x): 296 | q = self.l1(x) 297 | q = self.l2(q).squeeze() 298 | return q 299 | 300 | class DistillationIQANet(nn.Module): 301 | def __init__(self, self_patch_num=10, lda_channel=64, encode_decode_channel=64, MLP_depth=9, distillation_layer=9): 302 | super(DistillationIQANet, self).__init__() 303 | 304 | self.self_patch_num = self_patch_num 305 | self.lda_channel = lda_channel 306 | self.encode_decode_channel = encode_decode_channel 307 | self.MLP_depth = MLP_depth 308 | self.distillation_layer_num = distillation_layer 309 | 310 | self.feature_extractor = ResNetBackbone() 311 | for param in self.feature_extractor.parameters(): 312 | param.requires_grad = False 313 | 314 | self.lda1_process = nn.Sequential(nn.Conv2d(256, self.lda_channel, kernel_size=1, stride=1, padding=0), nn.AdaptiveAvgPool2d((7, 7))) 315 | self.lda2_process = nn.Sequential(nn.Conv2d(512, self.lda_channel, kernel_size=1, stride=1, padding=0), nn.AdaptiveAvgPool2d((7, 7))) 316 | self.lda3_process = nn.Sequential(nn.Conv2d(1024, self.lda_channel, kernel_size=1, stride=1, padding=0), nn.AdaptiveAvgPool2d((7, 7))) 317 | self.lda4_process = nn.Sequential(nn.Conv2d(2048, self.lda_channel, kernel_size=1, stride=1, padding=0), nn.AdaptiveAvgPool2d((7, 7))) 318 | 319 | self.lda_process = [self.lda1_process, self.lda2_process, self.lda3_process, self.lda4_process] 320 | 321 | self.MLP_encoder_diff = MLPMixer(image_size = 7, channels = self.self_patch_num*self.lda_channel*4, patch_size = 1, dim = self.encode_decode_channel*4, depth = self.MLP_depth*2) 322 | self.MLP_encoder_lq = MLPMixer(image_size = 7, channels = self.self_patch_num*self.lda_channel*4, patch_size = 1, dim = self.encode_decode_channel*4, depth = self.MLP_depth) 323 | 324 | self.regressor = RegressionFCNet() 325 | 326 | initialize_weights(self.MLP_encoder_diff,0.1) 327 | initialize_weights(self.MLP_encoder_lq,0.1) 328 | initialize_weights(self.regressor,0.1) 329 | 330 | initialize_weights(self.lda1_process,0.1) 331 | initialize_weights(self.lda2_process,0.1) 332 | initialize_weights(self.lda3_process,0.1) 333 | initialize_weights(self.lda4_process,0.1) 334 | 335 | def forward(self, LQ_patches, refHQ_patches): 336 | device = LQ_patches.device 337 | b, p, c, h, w = LQ_patches.shape 338 | LQ_patches_reshape = LQ_patches.view(b*p, c, h, w) 339 | refHQ_patches_reshape = refHQ_patches.view(b*p, c, h, w) 340 | 341 | # [b*p, 256, 56, 56], [b*p, 512, 28, 28], [b*p, 1024, 14, 14], [b*p, 2048, 7, 7] 342 | lq_lda_features = self.feature_extractor(LQ_patches_reshape) 343 | refHQ_lda_features = self.feature_extractor(refHQ_patches_reshape) 344 | 345 | # encode_diff_feature, encode_lq_feature, feature = [], [], [] 346 | multi_scale_diff_feature, multi_scale_lq_feature, feature = [], [], [] 347 | for lq_lda_feature, refHQ_lda_feature, lda_process in zip(lq_lda_features, refHQ_lda_features, self.lda_process): 348 | # [b, p, 64, 7, 7] 349 | lq_lda_feature = lda_process(lq_lda_feature).view(b, -1, 7, 7) 350 | refHQ_lda_feature = lda_process(refHQ_lda_feature).view(b, -1, 7, 7) 351 | diff_lda_feature = refHQ_lda_feature - lq_lda_feature 352 | 353 | 354 | multi_scale_diff_feature.append(diff_lda_feature) 355 | multi_scale_lq_feature.append(lq_lda_feature) 356 | 357 | multi_scale_lq_feature = torch.cat(multi_scale_lq_feature, 1).to(device) 358 | multi_scale_diff_feature = torch.cat(multi_scale_diff_feature, 1).to(device) 359 | encode_lq_feature, encode_lq_inner_feature = self.MLP_encoder_lq(multi_scale_lq_feature, self.distillation_layer_num) 360 | encode_diff_feature, encode_diff_inner_feature = self.MLP_encoder_diff(multi_scale_diff_feature, self.distillation_layer_num) 361 | feature = torch.cat((encode_lq_feature, encode_diff_feature), 1) 362 | 363 | pred = self.regressor(feature) 364 | return encode_diff_inner_feature, encode_lq_inner_feature, pred 365 | 366 | def _load_state_dict(self, state_dict, strict=True): 367 | own_state = self.state_dict() 368 | for name, param in state_dict.items(): 369 | if name in own_state: 370 | if isinstance(param, nn.Parameter): 371 | param = param.data 372 | try: 373 | own_state[name].copy_(param) 374 | except Exception: 375 | if name.find('tail') >= 0: 376 | print('Replace pre-trained upsampler to new one...') 377 | else: 378 | raise RuntimeError('While copying the parameter named {}, ' 379 | 'whose dimensions in the model are {} and ' 380 | 'whose dimensions in the checkpoint are {}.' 381 | .format(name, own_state[name].size(), param.size())) 382 | elif strict: 383 | if name.find('tail') == -1: 384 | raise KeyError('unexpected key "{}" in state_dict' 385 | .format(name)) 386 | 387 | if strict: 388 | missing = set(own_state.keys()) - set(state_dict.keys()) 389 | if len(missing) > 0: 390 | raise KeyError('missing keys in state_dict: "{}"'.format(missing)) 391 | 392 | 393 | 394 | if __name__ == "__main__": 395 | net = ResNetBackbone() 396 | x = torch.rand((1,3,224,224)) 397 | y = net(x) 398 | print(y.shape) 399 | 400 | model = MLPMixer(image_size = 7, channels = 1280, patch_size = 1, dim = 512, depth = 12) 401 | img = torch.randn(96, 256, 7, 7) 402 | pred = model(img) # (1, 1000) 403 | print(pred.shape) 404 | 405 | m = DistillationIQANet() 406 | lq = torch.rand((3,10,3,224,224)) 407 | hq = torch.rand((3,10,3,224,224)) 408 | encode_diff_feature, encode_lq_feature, pred = m(lq, hq) 409 | print(pred.shape) 410 | -------------------------------------------------------------------------------- /folders/folders_LQ.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as data 3 | import torchvision 4 | from PIL import Image 5 | import os 6 | import os.path 7 | import scipy.io 8 | import numpy as np 9 | import csv 10 | import random 11 | from openpyxl import load_workbook 12 | import cv2 13 | from torchvision import transforms 14 | 15 | class Kadid10kFolder(data.Dataset): 16 | 17 | def __init__(self, root, index, transform, patch_num, patch_size=224, self_patch_num=1, use_L=False): 18 | self.patch_size = patch_size 19 | self.self_patch_num = self_patch_num 20 | self.use_L = use_L 21 | self.transform = transform 22 | 23 | imgname = [] 24 | refimgname = [] 25 | mos_all = [] 26 | csv_file = os.path.join(root, 'dmos.csv') 27 | with open(csv_file) as f: 28 | reader = csv.DictReader(f) 29 | for row in reader: 30 | imgname.append(row['dist_img']) 31 | refimgname.append(row['ref_img']) 32 | mos = np.array(float(row['dmos'])).astype(np.float32) 33 | mos_all.append(mos) 34 | 35 | sample = [] 36 | for i, item in enumerate(index): 37 | for aug in range(patch_num): 38 | sample.append((os.path.join(root, 'images', imgname[item]),os.path.join(root, 'images', refimgname[item]), mos_all[item])) 39 | 40 | self.samples = sample 41 | 42 | 43 | def __getitem__(self, index): 44 | """ 45 | Args: 46 | index (int): Index 47 | Returns: 48 | tuple: (LQ, HQ, target) where target is IQA values of the target LQ. 49 | """ 50 | LQ_path, HQ_path, target = self.samples[index] 51 | 52 | LQ = pil_loader(LQ_path, self.use_L) 53 | LQ_patches = [] 54 | for _ in range(self.self_patch_num): 55 | LQ_patch = self.transform(LQ) 56 | LQ_patches.append(LQ_patch.unsqueeze(0)) 57 | #[self_patch_num, 3, patch_size, patch_size] 58 | LQ_patches = torch.cat(LQ_patches, 0) 59 | 60 | return LQ_patches, target 61 | 62 | def __len__(self): 63 | length = len(self.samples) 64 | return length 65 | 66 | class LIVEFolder(data.Dataset): 67 | 68 | def __init__(self, root, index, transform, patch_num, patch_size=224, self_patch_num=1, use_L=False): 69 | self.patch_size =patch_size 70 | self.self_patch_num = self_patch_num 71 | self.transform = transform 72 | self.use_L = use_L 73 | 74 | refpath = os.path.join(root, 'refimgs') 75 | refname = getFileName(refpath, '.bmp') 76 | 77 | jp2kroot = os.path.join(root, 'jp2k') 78 | jp2kname = self.getDistortionTypeFileName(jp2kroot, 227) 79 | 80 | jpegroot = os.path.join(root, 'jpeg') 81 | jpegname = self.getDistortionTypeFileName(jpegroot, 233) 82 | 83 | wnroot = os.path.join(root, 'wn') 84 | wnname = self.getDistortionTypeFileName(wnroot, 174) 85 | 86 | gblurroot = os.path.join(root, 'gblur') 87 | gblurname = self.getDistortionTypeFileName(gblurroot, 174) 88 | 89 | fastfadingroot = os.path.join(root, 'fastfading') 90 | fastfadingname = self.getDistortionTypeFileName(fastfadingroot, 174) 91 | 92 | imgpath = jp2kname + jpegname + wnname + gblurname + fastfadingname 93 | 94 | dmos = scipy.io.loadmat(os.path.join(root, 'dmos_realigned.mat')) 95 | labels = dmos['dmos_new'].astype(np.float32) 96 | 97 | orgs = dmos['orgs'] 98 | refnames_all = scipy.io.loadmat(os.path.join(root, 'refnames_all.mat')) 99 | refnames_all = refnames_all['refnames_all'] 100 | 101 | sample = [] 102 | for i in range(0, len(index)): 103 | train_sel = (refname[index[i]] == refnames_all) 104 | train_sel = train_sel * ~orgs.astype(np.bool_) 105 | train_sel = np.where(train_sel == True) 106 | train_sel = train_sel[1].tolist() 107 | for j, item in enumerate(train_sel): 108 | for aug in range(patch_num): 109 | LQ_path = imgpath[item] 110 | HQ_path = os.path.join(root, 'refimgs', refnames_all[0][item][0]) 111 | label = labels[0][item] 112 | sample.append((LQ_path, HQ_path, label)) 113 | 114 | self.samples = sample 115 | 116 | 117 | def __getitem__(self, index): 118 | """ 119 | Args: 120 | index (int): Index 121 | Returns: 122 | tuple: (LQ, HQ, target) where target is IQA values of the target LQ. 123 | """ 124 | LQ_path, HQ_path, target = self.samples[index] 125 | LQ = pil_loader(LQ_path, self.use_L) 126 | LQ_patches = [] 127 | for _ in range(self.self_patch_num): 128 | LQ_patch = self.transform(LQ) 129 | LQ_patches.append(LQ_patch.unsqueeze(0)) 130 | LQ_patches = torch.cat(LQ_patches, 0) 131 | 132 | return LQ_patches, target 133 | 134 | def __len__(self): 135 | length = len(self.samples) 136 | return length 137 | 138 | def getDistortionTypeFileName(self, path, num): 139 | filename = [] 140 | index = 1 141 | for i in range(0, num): 142 | name = '{:0>3d}{}'.format(index, '.bmp') 143 | filename.append(os.path.join(path, name)) 144 | index = index + 1 145 | return filename 146 | 147 | class CSIQFolder(data.Dataset): 148 | 149 | def __init__(self, root, index, transform, patch_num, patch_size =224, self_patch_num=1, use_L=False): 150 | self.patch_size =patch_size 151 | self.self_patch_num = self_patch_num 152 | self.transform = transform 153 | self.use_L = use_L 154 | 155 | refpath = os.path.join(root, 'src_imgs') 156 | refname = getFileName(refpath,'.png') 157 | txtpath = os.path.join(root, 'csiq_label.txt') 158 | fh = open(txtpath, 'r') 159 | imgnames = [] 160 | target = [] 161 | refnames_all = [] 162 | for line in fh: 163 | line = line.split('\n') 164 | words = line[0].split() 165 | imgnames.append((words[0])) 166 | target.append(words[1]) 167 | ref_temp = words[0].split(".") 168 | refnames_all.append(ref_temp[0] + '.' + ref_temp[-1]) 169 | 170 | labels = np.array(target).astype(np.float32) 171 | refnames_all = np.array(refnames_all) 172 | 173 | sample = [] 174 | 175 | for i, item in enumerate(index): 176 | train_sel = (refname[index[i]] == refnames_all) 177 | train_sel = np.where(train_sel == True) 178 | train_sel = train_sel[0].tolist() 179 | for j, item in enumerate(train_sel): 180 | for aug in range(patch_num): 181 | LQ_path = os.path.join(root, 'dst_imgs_all', imgnames[item]) 182 | HQ_path = os.path.join(root, 'src_imgs', refnames_all[item]) 183 | label = labels[item] 184 | sample.append((LQ_path, HQ_path, label)) 185 | self.samples = sample 186 | 187 | 188 | def __getitem__(self, index): 189 | """ 190 | Args: 191 | index (int): Index 192 | Returns: 193 | tuple: (LQ, HQ, target) where target is IQA values of the target LQ. 194 | """ 195 | LQ_path, HQ_path, target = self.samples[index] 196 | LQ = pil_loader(LQ_path, self.use_L) 197 | LQ_patches = [] 198 | for _ in range(self.self_patch_num): 199 | LQ_patch = self.transform(LQ) 200 | LQ_patches.append(LQ_patch.unsqueeze(0)) 201 | LQ_patches = torch.cat(LQ_patches, 0) 202 | 203 | return LQ_patches, target 204 | 205 | def __len__(self): 206 | length = len(self.samples) 207 | return length 208 | 209 | class TID2013Folder(data.Dataset): 210 | 211 | def __init__(self, root, index, transform, patch_num, patch_size=224, self_patch_num=1, use_L=False): 212 | self.patch_size =patch_size 213 | self.self_patch_num = self_patch_num 214 | self.transform = transform 215 | self.use_L = use_L 216 | 217 | refpath = os.path.join(root, 'reference_images') 218 | refname = self._getTIDFileName(refpath,'.bmp.BMP') 219 | txtpath = os.path.join(root, 'mos_with_names.txt') 220 | fh = open(txtpath, 'r') 221 | imgnames = [] 222 | target = [] 223 | refnames_all = [] 224 | for line in fh: 225 | line = line.split('\n') 226 | words = line[0].split() 227 | imgnames.append((words[1])) 228 | target.append(words[0]) 229 | ref_temp = words[1].split("_") 230 | refnames_all.append(ref_temp[0][1:]) 231 | labels = np.array(target).astype(np.float32) 232 | refnames_all = np.array(refnames_all) 233 | 234 | sample = [] 235 | for i, item in enumerate(index): 236 | train_sel = (refname[index[i]] == refnames_all) 237 | train_sel = np.where(train_sel == True) 238 | train_sel = train_sel[0].tolist() 239 | for j, item in enumerate(train_sel): 240 | for aug in range(patch_num): 241 | LQ_path = os.path.join(root, 'distorted_images', imgnames[item]) 242 | HQ_name = 'I' + imgnames[item].split("_")[0][1:] + '.BMP' 243 | HQ_path = os.path.join(refpath, HQ_name) 244 | label = labels[item] 245 | sample.append((LQ_path, HQ_path, label)) 246 | self.samples = sample 247 | 248 | 249 | def _getTIDFileName(self, path, suffix): 250 | filename = [] 251 | f_list = os.listdir(path) 252 | for i in f_list: 253 | if suffix.find(os.path.splitext(i)[1]) != -1: 254 | filename.append(i[1:3]) 255 | return filename 256 | 257 | def __getitem__(self, index): 258 | """ 259 | Args: 260 | index (int): Index 261 | Returns: 262 | tuple: (LQ, HQ, target) where target is IQA values of the target LQ. 263 | """ 264 | LQ_path, HQ_path, target = self.samples[index] 265 | LQ = pil_loader(LQ_path, self.use_L) 266 | LQ_patches = [] 267 | for _ in range(self.self_patch_num): 268 | LQ_patch = self.transform(LQ) 269 | LQ_patches.append(LQ_patch.unsqueeze(0)) 270 | LQ_patches = torch.cat(LQ_patches, 0) 271 | 272 | return LQ_patches, target 273 | 274 | def __len__(self): 275 | length = len(self.samples) 276 | return length 277 | 278 | class LIVEChallengeFolder(data.Dataset): 279 | def __init__(self, root, index, transform, patch_num, patch_size=224, self_patch_num=1, use_L=False): 280 | self.patch_size =patch_size 281 | self.self_patch_num = self_patch_num 282 | self.transform = transform 283 | self.use_L = use_L 284 | 285 | LQ_pathes = scipy.io.loadmat(os.path.join(root, 'Data', 'AllImages_release.mat')) 286 | LQ_pathes = LQ_pathes['AllImages_release'] 287 | LQ_pathes = LQ_pathes[7:1169] 288 | mos = scipy.io.loadmat(os.path.join(root, 'Data', 'AllMOS_release.mat')) 289 | labels = mos['AllMOS_release'].astype(np.float32) 290 | labels = labels[0][7:1169] 291 | 292 | sample = [] 293 | for _, item in enumerate(index): 294 | for _ in range(patch_num): 295 | sample.append((os.path.join(root, 'Images', LQ_pathes[item][0][0]), labels[item])) 296 | self.samples = sample 297 | 298 | def __getitem__(self, index): 299 | """ 300 | Args: 301 | index (int): Index 302 | Returns: 303 | tuple: (LQ, target) where target is IQA values of the target LQ. 304 | """ 305 | LQ_path, target = self.samples[index] 306 | LQ = pil_loader(LQ_path, self.use_L) 307 | LQ_patches = [] 308 | for _ in range(self.self_patch_num): 309 | LQ_patch = self.transform(LQ) 310 | 311 | LQ_patches.append(LQ_patch.unsqueeze(0)) 312 | #[self_patch_num, 3, patch_size, patch_size] 313 | LQ_patches = torch.cat(LQ_patches, 0) 314 | 315 | return LQ_patches, target 316 | 317 | def __len__(self): 318 | length = len(self.samples) 319 | return length 320 | 321 | class BIDChallengeFolder(data.Dataset): 322 | def __init__(self, root, index, transform, patch_num, patch_size=224, self_patch_num=1, use_L=False): 323 | self.patch_size =patch_size 324 | self.self_patch_num = self_patch_num 325 | self.transform = transform 326 | self.use_L = use_L 327 | 328 | LQ_pathes = [] 329 | labels = [] 330 | 331 | xls_file = os.path.join(root, 'DatabaseGrades.xlsx') 332 | workbook = load_workbook(xls_file) 333 | booksheet = workbook.active 334 | rows = booksheet.rows 335 | count = 1 336 | for _ in rows: 337 | count += 1 338 | img_num = (booksheet.cell(row=count, column=1).value) 339 | img_name = "DatabaseImage%04d.JPG" % (img_num) 340 | LQ_pathes.append(img_name) 341 | mos = (booksheet.cell(row=count, column=2).value) 342 | mos = np.array(mos) 343 | mos = mos.astype(np.float32) 344 | labels.append(mos) 345 | if count == 587: 346 | break 347 | 348 | sample = [] 349 | for _, item in enumerate(index): 350 | for _ in range(patch_num): 351 | sample.append((os.path.join(root, LQ_pathes[item]), labels[item])) 352 | self.samples = sample 353 | 354 | def __getitem__(self, index): 355 | """ 356 | Args: 357 | index (int): Index 358 | Returns: 359 | tuple: (LQ, target) where target is IQA values of the target LQ. 360 | """ 361 | LQ_path, target = self.samples[index] 362 | LQ = pil_loader(LQ_path, self.use_L) 363 | LQ_patches = [] 364 | for _ in range(self.self_patch_num): 365 | LQ_patch = self.transform(LQ) 366 | 367 | LQ_patches.append(LQ_patch.unsqueeze(0)) 368 | #[self_patch_num, 3, patch_size, patch_size] 369 | LQ_patches = torch.cat(LQ_patches, 0) 370 | 371 | return LQ_patches, target 372 | 373 | def __len__(self): 374 | length = len(self.samples) 375 | return length 376 | 377 | class Koniq_10kFolder(data.Dataset): 378 | def __init__(self, root, index, transform, patch_num, patch_size=224, self_patch_num=1, use_L=False): 379 | self.patch_size =patch_size 380 | self.self_patch_num = self_patch_num 381 | self.transform = transform 382 | self.use_L = use_L 383 | 384 | imgname = [] 385 | mos_all = [] 386 | csv_file = os.path.join(root, 'koniq10k_scores_and_distributions.csv') 387 | with open(csv_file) as f: 388 | reader = csv.DictReader(f) 389 | for row in reader: 390 | imgname.append(row['image_name']) 391 | mos = np.array(float(row['MOS_zscore'])).astype(np.float32) 392 | mos_all.append(mos) 393 | 394 | sample = [] 395 | for _, item in enumerate(index): 396 | for _ in range(patch_num): 397 | sample.append((os.path.join(root, '1024x768', imgname[item]), mos_all[item])) 398 | 399 | self.samples = sample 400 | 401 | def __getitem__(self, index): 402 | """ 403 | Args: 404 | index (int): Index 405 | Returns: 406 | tuple: (LQ, target) where target is IQA values of the target LQ. 407 | """ 408 | LQ_path, target = self.samples[index] 409 | LQ = pil_loader(LQ_path, self.use_L) 410 | LQ_patches = [] 411 | for _ in range(self.self_patch_num): 412 | LQ_patch = self.transform(LQ) 413 | 414 | LQ_patches.append(LQ_patch.unsqueeze(0)) 415 | #[self_patch_num, 3, patch_size, patch_size] 416 | LQ_patches = torch.cat(LQ_patches, 0) 417 | 418 | return LQ_patches, target 419 | 420 | def __len__(self): 421 | length = len(self.samples) 422 | return length 423 | 424 | 425 | def getFileName(path, suffix): 426 | filename = [] 427 | f_list = os.listdir(path) 428 | for i in f_list: 429 | if os.path.splitext(i)[1] == suffix: 430 | filename.append(i) 431 | return filename 432 | 433 | 434 | def getPairRandomPatch(img1, img2, crop_size=512): 435 | (iw,ih) = img1.size 436 | # print(ih,iw) 437 | 438 | ip = int(crop_size) 439 | 440 | ix = random.randrange(0, iw - ip + 1) 441 | iy = random.randrange(0, ih - ip + 1) 442 | 443 | 444 | img1_patch = img1.crop((ix, iy, ix+ip, iy+ip))#左上右下 445 | img2_patch = img2.crop((ix, iy, ix+ip, iy+ip))#左上右下 446 | 447 | return img1_patch, img2_patch 448 | 449 | def getPairAugment(img1, img2, hflip=True, vflip=True, rot=True): 450 | hflip = hflip and random.random() < 0.5 451 | vflip = vflip and random.random() < 0.5 452 | rot180 = rot and random.random() < 0.5 453 | 454 | if hflip: 455 | img1 = img1.transpose(Image.FLIP_TOP_BOTTOM) 456 | img2 = img2.transpose(Image.FLIP_TOP_BOTTOM) 457 | if vflip: 458 | img1 = img1.transpose(Image.FLIP_LEFT_RIGHT) 459 | img2 = img2.transpose(Image.FLIP_LEFT_RIGHT) 460 | if rot180: 461 | img1 = img1.transpose(Image.ROTATE_180) 462 | img2 = img2.transpose(Image.ROTATE_180) 463 | 464 | return img1, img2 465 | 466 | 467 | def getSelfPatch(img, patch_size, patch_num, is_random=True): 468 | (iw,ih) = img.size 469 | patches = [] 470 | for i in range(patch_num): 471 | if is_random: 472 | ix = random.randrange(0, iw - patch_size + 1) 473 | iy = random.randrange(0, ih - patch_size + 1) 474 | else:ix,iy=(iw - patch_size + 1)//2,(ih - patch_size + 1)//2 475 | 476 | # patch = img[iy:iy + lr_size, ix:ix + lr_size, :]#上下左右 477 | patch = img.crop((ix, iy, ix+patch_size, iy+patch_size))#左上右下 478 | patches.append(patch) 479 | 480 | return patches 481 | 482 | 483 | def pil_loader(path, use_L=False): 484 | if use_L: 485 | return Image.open(path).convert('L') 486 | else: 487 | return Image.open(path).convert('RGB') 488 | -------------------------------------------------------------------------------- /models/IQT.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch._C import device 3 | import torch.nn as nn 4 | from torch.nn import Dropout, Softmax, Linear, LayerNorm 5 | import torch.utils.model_zoo as model_zoo 6 | import math 7 | import copy 8 | 9 | __all__ = ['InceptionResNetV2', 'inceptionresnetv2'] 10 | ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu} 11 | hidden_size = 256 12 | 13 | pretrained_settings = { 14 | 'inceptionresnetv2': { 15 | 'imagenet': { 16 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/inceptionresnetv2-520b38e4.pth', 17 | 'input_space': 'RGB', 18 | 'input_size': [3, 299, 299], 19 | 'input_range': [0, 1], 20 | 'mean': [0.5, 0.5, 0.5], 21 | 'std': [0.5, 0.5, 0.5], 22 | 'num_classes': 1000 23 | }, 24 | 'imagenet+background': { 25 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/inceptionresnetv2-520b38e4.pth', 26 | 'input_space': 'RGB', 27 | 'input_size': [3, 299, 299], 28 | 'input_range': [0, 1], 29 | 'mean': [0.5, 0.5, 0.5], 30 | 'std': [0.5, 0.5, 0.5], 31 | 'num_classes': 1001 32 | } 33 | } 34 | } 35 | 36 | 37 | class BasicConv2d(nn.Module): 38 | 39 | def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0): 40 | super(BasicConv2d, self).__init__() 41 | self.conv = nn.Conv2d(in_planes, out_planes, 42 | kernel_size=kernel_size, stride=stride, 43 | padding=padding, bias=False) # verify bias false 44 | self.bn = nn.BatchNorm2d(out_planes, 45 | eps=0.001, # value found in tensorflow 46 | momentum=0.1, # default pytorch value 47 | affine=True) 48 | self.relu = nn.ReLU(inplace=False) 49 | 50 | def forward(self, x): 51 | x = self.conv(x) 52 | x = self.bn(x) 53 | x = self.relu(x) 54 | return x 55 | 56 | 57 | class Mixed_5b(nn.Module): 58 | 59 | def __init__(self): 60 | super(Mixed_5b, self).__init__() 61 | 62 | self.branch0 = BasicConv2d(192, 96, kernel_size=1, stride=1) 63 | 64 | self.branch1 = nn.Sequential( 65 | BasicConv2d(192, 48, kernel_size=1, stride=1), 66 | BasicConv2d(48, 64, kernel_size=5, stride=1, padding=2) 67 | ) 68 | 69 | self.branch2 = nn.Sequential( 70 | BasicConv2d(192, 64, kernel_size=1, stride=1), 71 | BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1), 72 | BasicConv2d(96, 96, kernel_size=3, stride=1, padding=1) 73 | ) 74 | 75 | self.branch3 = nn.Sequential( 76 | nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False), 77 | BasicConv2d(192, 64, kernel_size=1, stride=1) 78 | ) 79 | 80 | def forward(self, x): 81 | x0 = self.branch0(x) 82 | x1 = self.branch1(x) 83 | x2 = self.branch2(x) 84 | x3 = self.branch3(x) 85 | out = torch.cat((x0, x1, x2, x3), 1) 86 | return out 87 | 88 | 89 | class Block35(nn.Module): 90 | 91 | def __init__(self, scale=1.0): 92 | super(Block35, self).__init__() 93 | 94 | self.scale = scale 95 | 96 | self.branch0 = BasicConv2d(320, 32, kernel_size=1, stride=1) 97 | 98 | self.branch1 = nn.Sequential( 99 | BasicConv2d(320, 32, kernel_size=1, stride=1), 100 | BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1) 101 | ) 102 | 103 | self.branch2 = nn.Sequential( 104 | BasicConv2d(320, 32, kernel_size=1, stride=1), 105 | BasicConv2d(32, 48, kernel_size=3, stride=1, padding=1), 106 | BasicConv2d(48, 64, kernel_size=3, stride=1, padding=1) 107 | ) 108 | 109 | self.conv2d = nn.Conv2d(128, 320, kernel_size=1, stride=1) 110 | self.relu = nn.ReLU(inplace=False) 111 | 112 | def forward(self, x): 113 | x0 = self.branch0(x) 114 | x1 = self.branch1(x) 115 | x2 = self.branch2(x) 116 | out = torch.cat((x0, x1, x2), 1) 117 | out = self.conv2d(out) 118 | out = out * self.scale + x 119 | out = self.relu(out) 120 | return out 121 | 122 | 123 | class Mixed_6a(nn.Module): 124 | 125 | def __init__(self): 126 | super(Mixed_6a, self).__init__() 127 | 128 | self.branch0 = BasicConv2d(320, 384, kernel_size=3, stride=2) 129 | 130 | self.branch1 = nn.Sequential( 131 | BasicConv2d(320, 256, kernel_size=1, stride=1), 132 | BasicConv2d(256, 256, kernel_size=3, stride=1, padding=1), 133 | BasicConv2d(256, 384, kernel_size=3, stride=2) 134 | ) 135 | 136 | self.branch2 = nn.MaxPool2d(3, stride=2) 137 | 138 | def forward(self, x): 139 | x0 = self.branch0(x) 140 | x1 = self.branch1(x) 141 | x2 = self.branch2(x) 142 | out = torch.cat((x0, x1, x2), 1) 143 | return out 144 | 145 | 146 | class Block17(nn.Module): 147 | 148 | def __init__(self, scale=1.0): 149 | super(Block17, self).__init__() 150 | 151 | self.scale = scale 152 | 153 | self.branch0 = BasicConv2d(1088, 192, kernel_size=1, stride=1) 154 | 155 | self.branch1 = nn.Sequential( 156 | BasicConv2d(1088, 128, kernel_size=1, stride=1), 157 | BasicConv2d(128, 160, kernel_size=(1,7), stride=1, padding=(0,3)), 158 | BasicConv2d(160, 192, kernel_size=(7,1), stride=1, padding=(3,0)) 159 | ) 160 | 161 | self.conv2d = nn.Conv2d(384, 1088, kernel_size=1, stride=1) 162 | self.relu = nn.ReLU(inplace=False) 163 | 164 | def forward(self, x): 165 | x0 = self.branch0(x) 166 | x1 = self.branch1(x) 167 | out = torch.cat((x0, x1), 1) 168 | out = self.conv2d(out) 169 | out = out * self.scale + x 170 | out = self.relu(out) 171 | return out 172 | 173 | 174 | class Mixed_7a(nn.Module): 175 | 176 | def __init__(self): 177 | super(Mixed_7a, self).__init__() 178 | 179 | self.branch0 = nn.Sequential( 180 | BasicConv2d(1088, 256, kernel_size=1, stride=1), 181 | BasicConv2d(256, 384, kernel_size=3, stride=2) 182 | ) 183 | 184 | self.branch1 = nn.Sequential( 185 | BasicConv2d(1088, 256, kernel_size=1, stride=1), 186 | BasicConv2d(256, 288, kernel_size=3, stride=2) 187 | ) 188 | 189 | self.branch2 = nn.Sequential( 190 | BasicConv2d(1088, 256, kernel_size=1, stride=1), 191 | BasicConv2d(256, 288, kernel_size=3, stride=1, padding=1), 192 | BasicConv2d(288, 320, kernel_size=3, stride=2) 193 | ) 194 | 195 | self.branch3 = nn.MaxPool2d(3, stride=2) 196 | 197 | def forward(self, x): 198 | x0 = self.branch0(x) 199 | x1 = self.branch1(x) 200 | x2 = self.branch2(x) 201 | x3 = self.branch3(x) 202 | out = torch.cat((x0, x1, x2, x3), 1) 203 | return out 204 | 205 | 206 | class Block8(nn.Module): 207 | 208 | def __init__(self, scale=1.0, noReLU=False): 209 | super(Block8, self).__init__() 210 | 211 | self.scale = scale 212 | self.noReLU = noReLU 213 | 214 | self.branch0 = BasicConv2d(2080, 192, kernel_size=1, stride=1) 215 | 216 | self.branch1 = nn.Sequential( 217 | BasicConv2d(2080, 192, kernel_size=1, stride=1), 218 | BasicConv2d(192, 224, kernel_size=(1,3), stride=1, padding=(0,1)), 219 | BasicConv2d(224, 256, kernel_size=(3,1), stride=1, padding=(1,0)) 220 | ) 221 | 222 | self.conv2d = nn.Conv2d(448, 2080, kernel_size=1, stride=1) 223 | if not self.noReLU: 224 | self.relu = nn.ReLU(inplace=False) 225 | 226 | def forward(self, x): 227 | x0 = self.branch0(x) 228 | x1 = self.branch1(x) 229 | out = torch.cat((x0, x1), 1) 230 | out = self.conv2d(out) 231 | out = out * self.scale + x 232 | if not self.noReLU: 233 | out = self.relu(out) 234 | return out 235 | 236 | 237 | class InceptionResNetV2(nn.Module): 238 | 239 | def __init__(self, num_classes=1001): 240 | super(InceptionResNetV2, self).__init__() 241 | # Special attributs 242 | self.input_space = None 243 | self.input_size = (299, 299, 3) 244 | self.mean = None 245 | self.std = None 246 | # Modules 247 | self.conv2d_1a = BasicConv2d(3, 32, kernel_size=3, stride=2) 248 | self.conv2d_2a = BasicConv2d(32, 32, kernel_size=3, stride=1) 249 | self.conv2d_2b = BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1) 250 | self.maxpool_3a = nn.MaxPool2d(3, stride=2) 251 | self.conv2d_3b = BasicConv2d(64, 80, kernel_size=1, stride=1) 252 | self.conv2d_4a = BasicConv2d(80, 192, kernel_size=3, stride=1) 253 | self.maxpool_5a = nn.MaxPool2d(3, stride=2) 254 | self.mixed_5b = Mixed_5b() 255 | self.repeat = nn.Sequential( 256 | Block35(scale=0.17), 257 | Block35(scale=0.17), 258 | Block35(scale=0.17), 259 | Block35(scale=0.17), 260 | Block35(scale=0.17), 261 | Block35(scale=0.17), 262 | Block35(scale=0.17), 263 | Block35(scale=0.17), 264 | Block35(scale=0.17), 265 | Block35(scale=0.17) 266 | ) 267 | self.mixed_6a = Mixed_6a() 268 | self.repeat_1 = nn.Sequential( 269 | Block17(scale=0.10), 270 | Block17(scale=0.10), 271 | Block17(scale=0.10), 272 | Block17(scale=0.10), 273 | Block17(scale=0.10), 274 | Block17(scale=0.10), 275 | Block17(scale=0.10), 276 | Block17(scale=0.10), 277 | Block17(scale=0.10), 278 | Block17(scale=0.10), 279 | Block17(scale=0.10), 280 | Block17(scale=0.10), 281 | Block17(scale=0.10), 282 | Block17(scale=0.10), 283 | Block17(scale=0.10), 284 | Block17(scale=0.10), 285 | Block17(scale=0.10), 286 | Block17(scale=0.10), 287 | Block17(scale=0.10), 288 | Block17(scale=0.10) 289 | ) 290 | self.mixed_7a = Mixed_7a() 291 | self.repeat_2 = nn.Sequential( 292 | Block8(scale=0.20), 293 | Block8(scale=0.20), 294 | Block8(scale=0.20), 295 | Block8(scale=0.20), 296 | Block8(scale=0.20), 297 | Block8(scale=0.20), 298 | Block8(scale=0.20), 299 | Block8(scale=0.20), 300 | Block8(scale=0.20) 301 | ) 302 | self.block8 = Block8(noReLU=True) 303 | self.conv2d_7b = BasicConv2d(2080, 1536, kernel_size=1, stride=1) 304 | self.avgpool_1a = nn.AvgPool2d(8, count_include_pad=False) 305 | self.last_linear = nn.Linear(1536, num_classes) 306 | 307 | def features(self, input): 308 | device = input.device 309 | concat_x = [] 310 | x = self.conv2d_1a(input) 311 | x = self.conv2d_2a(x) 312 | x = self.conv2d_2b(x) 313 | x = self.maxpool_3a(x) 314 | x = self.conv2d_3b(x) 315 | x = self.conv2d_4a(x) 316 | x = self.maxpool_5a(x) 317 | x = self.mixed_5b(x) 318 | concat_x.append(x) 319 | for i, block in enumerate(self.repeat): 320 | x = block(x) 321 | if (i+1)%2 == 0: 322 | concat_x.append(x) 323 | concat_x = torch.cat(concat_x, 1).to(device) 324 | return concat_x 325 | 326 | def logits(self, features): 327 | x = self.avgpool_1a(features) 328 | x = x.view(x.size(0), -1) 329 | x = self.last_linear(x) 330 | return x 331 | 332 | def forward(self, input): 333 | x = self.features(input) 334 | return x 335 | 336 | def inceptionresnetv2_feature_extractor(num_classes=1000, pretrained='imagenet'): 337 | r"""InceptionResNetV2 model architecture from the 338 | `"InceptionV4, Inception-ResNet..." `_ paper. 339 | """ 340 | if pretrained: 341 | settings = pretrained_settings['inceptionresnetv2'][pretrained] 342 | assert num_classes == settings['num_classes'], \ 343 | "num_classes should be {}, but is {}".format(settings['num_classes'], num_classes) 344 | 345 | # both 'imagenet'&'imagenet+background' are loaded from same parameters 346 | model = InceptionResNetV2(num_classes=1001) 347 | model.load_state_dict(model_zoo.load_url(settings['url'])) 348 | 349 | if pretrained == 'imagenet': 350 | new_last_linear = nn.Linear(1536, 1000) 351 | new_last_linear.weight.data = model.last_linear.weight.data[1:] 352 | new_last_linear.bias.data = model.last_linear.bias.data[1:] 353 | model.last_linear = new_last_linear 354 | 355 | model.input_space = settings['input_space'] 356 | model.input_size = settings['input_size'] 357 | model.input_range = settings['input_range'] 358 | 359 | model.mean = settings['mean'] 360 | model.std = settings['std'] 361 | else: 362 | model = InceptionResNetV2(num_classes=num_classes) 363 | return model 364 | 365 | class Mlp(nn.Module): 366 | def __init__(self): 367 | super(Mlp, self).__init__() 368 | self.hidden_size = 256 369 | self.mlp_size = 512 370 | self.dropout_rate = 0.1 371 | self.fc1 = Linear(self.hidden_size, self.mlp_size) 372 | self.fc2 = Linear(self.mlp_size, self.hidden_size) 373 | self.act_fn = ACT2FN["gelu"] 374 | self.dropout = Dropout(self.dropout_rate) 375 | 376 | self._init_weights() 377 | 378 | def _init_weights(self): 379 | nn.init.xavier_uniform_(self.fc1.weight) 380 | nn.init.xavier_uniform_(self.fc2.weight) 381 | nn.init.normal_(self.fc1.bias, std=1e-6) 382 | nn.init.normal_(self.fc2.bias, std=1e-6) 383 | 384 | def forward(self, x): 385 | x = self.fc1(x) 386 | x = self.act_fn(x) 387 | x = self.dropout(x) 388 | x = self.fc2(x) 389 | x = self.dropout(x) 390 | return x 391 | 392 | class MHA(nn.Module): 393 | def __init__(self): 394 | super(MHA, self).__init__() 395 | self.vis = True 396 | self.num_attention_heads = 4 397 | self.hidden_size = 256 398 | self.attention_head_size = int(self.hidden_size / self.num_attention_heads) 399 | self.all_head_size = self.num_attention_heads * self.attention_head_size 400 | 401 | self.query = Linear(self.hidden_size, self.all_head_size) 402 | self.key = Linear(self.hidden_size, self.all_head_size) 403 | self.value = Linear(self.hidden_size, self.all_head_size) 404 | 405 | self.out = Linear(self.hidden_size, self.hidden_size) 406 | self.attn_dropout = Dropout(0.0) 407 | self.proj_dropout = Dropout(0.0) 408 | 409 | self.softmax = Softmax(dim=-1) 410 | 411 | def transpose_for_scores(self, x): 412 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 413 | x = x.view(*new_x_shape) 414 | return x.permute(0, 2, 1, 3) 415 | 416 | def forward(self, k, v, q): 417 | mixed_key_layer = self.key(k) 418 | mixed_value_layer = self.value(v) 419 | mixed_query_layer = self.query(q) 420 | 421 | query_layer = self.transpose_for_scores(mixed_query_layer) 422 | key_layer = self.transpose_for_scores(mixed_key_layer) 423 | value_layer = self.transpose_for_scores(mixed_value_layer) 424 | 425 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 426 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 427 | attention_probs = self.softmax(attention_scores) 428 | weights = attention_probs 429 | attention_probs = self.attn_dropout(attention_probs) 430 | 431 | context_layer = torch.matmul(attention_probs, value_layer) 432 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 433 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 434 | context_layer = context_layer.view(*new_context_layer_shape) 435 | attention_output = self.out(context_layer) 436 | attention_output = self.proj_dropout(attention_output) 437 | return attention_output 438 | 439 | class IQT_Encoder_Block(nn.Module): 440 | def __init__(self): 441 | super(IQT_Encoder_Block, self).__init__() 442 | self.hidden_size = 256 443 | self.ffn_norm = LayerNorm(self.hidden_size, eps=1e-6) 444 | self.ffn = Mlp() 445 | self.attention_norm = LayerNorm(self.hidden_size, eps=1e-6) 446 | self.attn = MHA() 447 | 448 | def forward(self, k, v, q): 449 | x = self.attn(k, v, q) 450 | x += q 451 | x1 = self.attention_norm(x) 452 | x = self.ffn(x1) 453 | x += x1 454 | y = self.ffn_norm(x) 455 | return y 456 | 457 | class IQT_Encoder(nn.Module): 458 | def __init__(self): 459 | super(IQT_Encoder, self).__init__() 460 | self.layer = nn.ModuleList() 461 | for _ in range(2): 462 | layer = IQT_Encoder_Block() 463 | self.layer.append(copy.deepcopy(layer)) 464 | 465 | def forward(self, x_diff): 466 | for encoder_block in self.layer: 467 | x_diff = encoder_block(x_diff, x_diff, x_diff) 468 | return x_diff 469 | 470 | class IQT_Decoder_Block(nn.Module): 471 | def __init__(self): 472 | super(IQT_Decoder_Block, self).__init__() 473 | self.hidden_size = 256 474 | self.ffn_norm = LayerNorm(self.hidden_size, eps=1e-6) 475 | self.ffn = Mlp() 476 | self.attention_norm = LayerNorm(self.hidden_size, eps=1e-6) 477 | self.attn = MHA() 478 | self.self_attention_norm = LayerNorm(self.hidden_size, eps=1e-6) 479 | self.self_attn = MHA() 480 | 481 | def forward(self, k, v, q, k1, v1): 482 | x = self.self_attn(k, v, q) 483 | x += q 484 | x1 = self.self_attention_norm(x) 485 | x = self.attn(k1, v1, x1) 486 | x += x1 487 | x2 = self.attention_norm(x) 488 | x = self.ffn(x2) 489 | x += x2 490 | y = self.ffn_norm(x) 491 | return y 492 | 493 | class IQT_Decoder(nn.Module): 494 | def __init__(self): 495 | super(IQT_Decoder, self).__init__() 496 | self.layer = nn.ModuleList() 497 | for _ in range(2): 498 | layer = IQT_Decoder_Block() 499 | self.layer.append(copy.deepcopy(layer)) 500 | 501 | def forward(self, x_HQ, x_diff): 502 | for decoder_block in self.layer: 503 | x_HQ = decoder_block(x_HQ, x_HQ, x_HQ, x_diff, x_diff) 504 | return x_HQ 505 | 506 | class RegressionFCNet(nn.Module): 507 | def __init__(self): 508 | super(RegressionFCNet, self).__init__() 509 | self.target_in_size=256 510 | self.target_fc1_size=512 511 | 512 | self.l1 = nn.Linear(self.target_in_size, self.target_fc1_size) 513 | self.relu = nn.ReLU() 514 | self.l2 = nn.Linear(self.target_fc1_size, 1) 515 | 516 | 517 | def forward(self, x): 518 | q = self.l1(x) 519 | q = self.relu(q) 520 | q = self.l2(q).squeeze() 521 | return q 522 | 523 | class IQT(nn.Module): 524 | 525 | def __init__(self): 526 | super(IQT, self).__init__() 527 | self.feature_extractor = inceptionresnetv2_feature_extractor(num_classes=1000, pretrained='imagenet') 528 | for param in self.feature_extractor.parameters(): 529 | param.requires_grad = False 530 | 531 | self.conv = nn.Conv2d(1920, 256, kernel_size=1, stride=1, padding=0) 532 | 533 | self.position_embeddings = nn.Parameter(torch.zeros(1, 625+1, 256)) 534 | self.quality_token = nn.Parameter(torch.zeros(1, 1, 256)) 535 | 536 | self.encoder = IQT_Encoder() 537 | self.decoder = IQT_Decoder() 538 | 539 | self.regressor = RegressionFCNet() 540 | 541 | def cal_params(self): 542 | params = list(self.parameters()) 543 | k = 0 544 | for i in params: 545 | l = 1 546 | for j in i.size(): 547 | l *= j 548 | k = k + l 549 | print("Total parameters is :" + str(k)) 550 | 551 | def forward(self, LQ_patch, HQ_patch): 552 | B = LQ.shape[0] 553 | feature_LQ = self.feature_extractor(LQ_patch) 554 | feature_HQ = self.feature_extractor(HQ_patch) 555 | 556 | feature_diff = self.conv(feature_HQ-feature_LQ) 557 | feature_HQ = self.conv(feature_HQ) 558 | 559 | quality_tokens = self.quality_token.expand(B,-1,-1) 560 | 561 | flat_feature_HQ = feature_HQ.flatten(2).transpose(-1,-2) 562 | flat_feature_HQ = torch.cat((quality_tokens,flat_feature_HQ), dim=1) + self.position_embeddings 563 | 564 | flat_feature_diff = feature_diff.flatten(2).transpose(-1,-2) 565 | flat_feature_diff = torch.cat((quality_tokens,flat_feature_diff), dim=1) + self.position_embeddings 566 | 567 | flat_feature_diff = self.encoder(flat_feature_diff) 568 | f = self.decoder(flat_feature_diff, flat_feature_HQ) 569 | y = self.regressor(f[:,0]) 570 | 571 | return y 572 | 573 | ''' 574 | TEST 575 | Run this code with: 576 | ``` 577 | cd $HOME/pretrained-models.pytorch 578 | python -m pretrainedmodels.inceptionresnetv2 579 | ``` 580 | ''' 581 | if __name__ == '__main__': 582 | import time 583 | device = torch.device('cuda') 584 | LQ = torch.rand((1,3,224, 224)).to(device) 585 | HQ = torch.rand((1,3,224, 224)).to(device) 586 | net = IQT().to(device) 587 | net.cal_params() 588 | 589 | torch.cuda.synchronize() 590 | start = time.time() 591 | a = net(LQ, HQ) 592 | torch.cuda.synchronize() 593 | end = time.time() 594 | print("run time is :" + str(end-start)) 595 | -------------------------------------------------------------------------------- /folders/folders_LQ_HQ_diff_content_HQ.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as data 3 | import torchvision 4 | from PIL import Image 5 | import os 6 | import os.path 7 | import scipy.io 8 | import numpy as np 9 | import csv 10 | import random 11 | from openpyxl import load_workbook 12 | import cv2 13 | from torchvision import transforms 14 | 15 | class Kadid10kFolder(data.Dataset): 16 | def __init__(self, root, HQ_diff_content_root, index, transform, HQ_diff_content_transform, patch_num, patch_size=224, self_patch_num=10): 17 | self.patch_size = patch_size 18 | self.self_patch_num = self_patch_num 19 | self.HQ_diff_content_root = HQ_diff_content_root 20 | 21 | LQ_paths = [] 22 | HQ_paths = [] 23 | mos_all = [] 24 | csv_file = os.path.join(root, 'dmos.csv') 25 | with open(csv_file) as f: 26 | reader = csv.DictReader(f) 27 | for row in reader: 28 | LQ_paths.append(row['dist_img']) 29 | HQ_paths.append(row['ref_img']) 30 | mos = np.array(float(row['dmos'])).astype(np.float32) 31 | mos_all.append(mos) 32 | 33 | sample = [] 34 | for _, item in enumerate(index): 35 | for _ in range(patch_num): 36 | sample.append((os.path.join(root, 'images', LQ_paths[item]),os.path.join(root, 'images', HQ_paths[item]), mos_all[item])) 37 | 38 | self.HQ_diff_content_paths = [] 39 | for HQ_diff_content_img_path in os.listdir(HQ_diff_content_root): 40 | if HQ_diff_content_img_path[-3:] == 'png' or HQ_diff_content_img_path[-3:] == 'jpg' or HQ_diff_content_img_path[-3:] == 'bmp': 41 | self.HQ_diff_content_paths.append(os.path.join(HQ_diff_content_root, HQ_diff_content_img_path)) 42 | 43 | self.samples = sample 44 | self.transform = transform 45 | self.HQ_diff_content_transform = HQ_diff_content_transform 46 | 47 | def __getitem__(self, index): 48 | """ 49 | Args: 50 | index (int): Index 51 | Returns: 52 | tuple: (LQ, HQ, HQ_diff_content, target) where target is IQA values of the target LQ. 53 | """ 54 | LQ_path, HQ_path, target = self.samples[index] 55 | HQ_diff_content_path = self.HQ_diff_content_paths[random.randint(0, len(self.HQ_diff_content_paths)-1)] 56 | LQ = pil_loader(LQ_path) 57 | HQ_diff_content = pil_loader(HQ_diff_content_path) 58 | HQ = pil_loader(HQ_path) 59 | LQ_patches, HQ_patches, HQ_diff_content_patches = [], [], [] 60 | for _ in range(self.self_patch_num): 61 | LQ_patch, HQ_patch = getPairRandomPatch(LQ,HQ, crop_size=self.patch_size) 62 | 63 | LQ_patch = self.transform(LQ_patch) 64 | HQ_patch = self.transform(HQ_patch) 65 | HQ_diff_content_patch = self.HQ_diff_content_transform(HQ_diff_content) 66 | 67 | LQ_patches.append(LQ_patch.unsqueeze(0)) 68 | HQ_patches.append(HQ_patch.unsqueeze(0)) 69 | HQ_diff_content_patches.append(HQ_diff_content_patch.unsqueeze(0)) 70 | #[self_patch_num, 3, patch_size, patch_size] 71 | LQ_patches = torch.cat(LQ_patches, 0) 72 | HQ_patches = torch.cat(HQ_patches, 0) 73 | HQ_diff_content_patches = torch.cat(HQ_diff_content_patches, 0) 74 | 75 | return LQ_patches, HQ_patches, HQ_diff_content_patches, target 76 | 77 | def __len__(self): 78 | length = len(self.samples) 79 | return length 80 | 81 | class LIVEFolder(data.Dataset): 82 | 83 | def __init__(self, root, HQ_diff_content_root, index, transform, HQ_diff_content_transform, patch_num, patch_size=224, self_patch_num=10): 84 | self.patch_size =patch_size 85 | self.self_patch_num = self_patch_num 86 | self.root = root 87 | self.HQ_diff_content_root = HQ_diff_content_root 88 | 89 | refpath = os.path.join(root, 'refimgs') 90 | refname = getFileName(refpath, '.bmp') 91 | 92 | jp2kroot = os.path.join(root, 'jp2k') 93 | jp2kname = self.getDistortionTypeFileName(jp2kroot, 227) 94 | 95 | jpegroot = os.path.join(root, 'jpeg') 96 | jpegname = self.getDistortionTypeFileName(jpegroot, 233) 97 | 98 | wnroot = os.path.join(root, 'wn') 99 | wnname = self.getDistortionTypeFileName(wnroot, 174) 100 | 101 | gblurroot = os.path.join(root, 'gblur') 102 | gblurname = self.getDistortionTypeFileName(gblurroot, 174) 103 | 104 | fastfadingroot = os.path.join(root, 'fastfading') 105 | fastfadingname = self.getDistortionTypeFileName(fastfadingroot, 174) 106 | 107 | imgpath = jp2kname + jpegname + wnname + gblurname + fastfadingname 108 | 109 | dmos = scipy.io.loadmat(os.path.join(root, 'dmos_realigned.mat')) 110 | labels = dmos['dmos_new'].astype(np.float32) 111 | 112 | orgs = dmos['orgs'] 113 | refpaths_all = scipy.io.loadmat(os.path.join(root, 'refnames_all.mat')) 114 | refpaths_all = refpaths_all['refnames_all'] 115 | 116 | sample = [] 117 | for i in range(0, len(index)): 118 | train_sel = (refname[index[i]] == refpaths_all) 119 | train_sel = train_sel * ~orgs.astype(np.bool_) 120 | train_sel = np.where(train_sel == True) 121 | train_sel = train_sel[1].tolist() 122 | for j, item in enumerate(train_sel): 123 | for aug in range(patch_num): 124 | LQ_path = imgpath[item] 125 | HQ_path = os.path.join(root, 'refimgs', refpaths_all[0][item][0]) 126 | label = labels[0][item] 127 | sample.append((LQ_path, HQ_path, label)) 128 | # print(self.imgpath[item]) 129 | 130 | self.HQ_diff_content_path = [] 131 | for HQ_diff_content_img_name in os.listdir(HQ_diff_content_root): 132 | if HQ_diff_content_img_name[-3:] == 'png' or HQ_diff_content_img_name[-3:] == 'jpg' or HQ_diff_content_img_name[-3:] == 'bmp': 133 | self.HQ_diff_content_path.append(os.path.join(HQ_diff_content_root, HQ_diff_content_img_name)) 134 | 135 | self.samples = sample 136 | self.transform = transform 137 | self.HQ_diff_content_transform = HQ_diff_content_transform 138 | 139 | def __getitem__(self, index): 140 | """ 141 | Args: 142 | index (int): Index 143 | Returns: 144 | tuple: (LQ, HQ, HQ_diff_content, target) where target is IQA values of the target LQ. 145 | """ 146 | LQ_path, HQ_path, target = self.samples[index] 147 | HQ_diff_content_path = self.HQ_diff_content_path[random.randint(0, len(self.HQ_diff_content_path)-1)] 148 | LQ = pil_loader(LQ_path) 149 | HQ = pil_loader(HQ_path) 150 | HQ_diff_content = pil_loader(HQ_diff_content_path) 151 | LQ_patches, HQ_patches, HQ_diff_content_patches = [], [], [] 152 | for _ in range(self.self_patch_num): 153 | LQ_patch, HQ_patch = getPairRandomPatch(LQ, HQ, crop_size=self.patch_size) 154 | 155 | LQ_patch = self.transform(LQ_patch) 156 | HQ_patch = self.transform(HQ_patch) 157 | HQ_diff_content_patch = self.HQ_diff_content_transform(HQ_diff_content) 158 | 159 | LQ_patches.append(LQ_patch.unsqueeze(0)) 160 | HQ_patches.append(HQ_patch.unsqueeze(0)) 161 | HQ_diff_content_patches.append(HQ_diff_content_patch.unsqueeze(0)) 162 | #[self_patch_num, 3, patch_size, patch_size] 163 | LQ_patches = torch.cat(LQ_patches, 0) 164 | HQ_patches = torch.cat(HQ_patches, 0) 165 | HQ_diff_content_patches = torch.cat(HQ_diff_content_patches, 0) 166 | 167 | return LQ_patches, HQ_patches, HQ_diff_content_patches, target 168 | 169 | def __len__(self): 170 | length = len(self.samples) 171 | return length 172 | 173 | def getDistortionTypeFileName(self, path, num): 174 | filename = [] 175 | index = 1 176 | for i in range(0, num): 177 | name = '{:0>3d}{}'.format(index, '.bmp') 178 | filename.append(os.path.join(path, name)) 179 | index = index + 1 180 | return filename 181 | 182 | class CSIQFolder(data.Dataset): 183 | 184 | def __init__(self, root, HQ_diff_content_root, index, transform, HQ_diff_content_transform, patch_num, patch_size =224, self_patch_num=10): 185 | self.patch_size =patch_size 186 | self.self_patch_num = self_patch_num 187 | 188 | refpath = os.path.join(root, 'src_imgs') 189 | refname = getFileName(refpath,'.png') 190 | txtpath = os.path.join(root, 'csiq_label.txt') 191 | fh = open(txtpath, 'r') 192 | LQ_pathes = [] 193 | target = [] 194 | refpaths_all = [] 195 | for line in fh: 196 | line = line.split('\n') 197 | words = line[0].split() 198 | LQ_pathes.append((words[0])) 199 | target.append(words[1]) 200 | ref_temp = words[0].split(".") 201 | refpaths_all.append(ref_temp[0] + '.' + ref_temp[-1]) 202 | 203 | labels = np.array(target).astype(np.float32) 204 | refpaths_all = np.array(refpaths_all) 205 | 206 | sample = [] 207 | 208 | for i, item in enumerate(index): 209 | train_sel = (refname[index[i]] == refpaths_all) 210 | train_sel = np.where(train_sel == True) 211 | train_sel = train_sel[0].tolist() 212 | for j, item in enumerate(train_sel): 213 | for aug in range(patch_num): 214 | LQ_path = os.path.join(root, 'dst_imgs_all', LQ_pathes[item]) 215 | HQ_path = os.path.join(root, 'src_imgs', refpaths_all[item]) 216 | label = labels[item] 217 | sample.append((LQ_path, HQ_path, label)) 218 | 219 | self.HQ_diff_content = [] 220 | for HQ_diff_content_img_name in os.listdir(HQ_diff_content_root): 221 | if HQ_diff_content_img_name[-3:] == 'png' or HQ_diff_content_img_name[-3:] == 'jpg' or HQ_diff_content_img_name[-3:] == 'bmp': 222 | self.HQ_diff_content.append(os.path.join(HQ_diff_content_root, HQ_diff_content_img_name)) 223 | 224 | self.samples = sample 225 | self.transform = transform 226 | self.HQ_diff_content_transform = HQ_diff_content_transform 227 | 228 | def __getitem__(self, index): 229 | """ 230 | Args: 231 | index (int): Index 232 | Returns: 233 | tuple: (LQ, HQ, HQ_diff_content, target) where target is IQA values of the target LQ. 234 | """ 235 | LQ_path, HQ_path, target = self.samples[index] 236 | HQ_diff_content_path = self.HQ_diff_content[random.randint(0, len(self.HQ_diff_content)-1)] 237 | LQ = pil_loader(LQ_path) 238 | HQ = pil_loader(HQ_path) 239 | HQ_diff_content = pil_loader(HQ_diff_content_path) 240 | LQ_patches, HQ_patches, HQ_diff_content_patches = [], [], [] 241 | for _ in range(self.self_patch_num): 242 | LQ_patch, HQ_patch = getPairRandomPatch(LQ, HQ, crop_size=self.patch_size) 243 | 244 | LQ_patch = self.transform(LQ_patch) 245 | HQ_patch = self.transform(HQ_patch) 246 | HQ_diff_content_patch = self.HQ_diff_content_transform(HQ_diff_content) 247 | 248 | LQ_patches.append(LQ_patch.unsqueeze(0)) 249 | HQ_patches.append(HQ_patch.unsqueeze(0)) 250 | HQ_diff_content_patches.append(HQ_diff_content_patch.unsqueeze(0)) 251 | #[self_patch_num, 3, patch_size, patch_size] 252 | LQ_patches = torch.cat(LQ_patches, 0) 253 | HQ_patches = torch.cat(HQ_patches, 0) 254 | HQ_diff_content_patches = torch.cat(HQ_diff_content_patches, 0) 255 | 256 | return LQ_patches, HQ_patches, HQ_diff_content_patches, target 257 | 258 | def __len__(self): 259 | length = len(self.samples) 260 | return length 261 | 262 | class TID2013Folder(data.Dataset): 263 | 264 | def __init__(self, root, HQ_diff_content_root, index, transform, HQ_diff_content_transform, patch_num, patch_size =224, self_patch_num=10): 265 | self.patch_size =patch_size 266 | self.self_patch_num = self_patch_num 267 | 268 | refpath = os.path.join(root, 'reference_images') 269 | refname = self._getTIDFileName(refpath,'.bmp.BMP') 270 | txtpath = os.path.join(root, 'mos_with_names.txt') 271 | fh = open(txtpath, 'r') 272 | LQ_pathes = [] 273 | target = [] 274 | refpaths_all = [] 275 | for line in fh: 276 | line = line.split('\n') 277 | words = line[0].split() 278 | LQ_pathes.append((words[1])) 279 | target.append(words[0]) 280 | ref_temp = words[1].split("_") 281 | refpaths_all.append(ref_temp[0][1:]) 282 | labels = np.array(target).astype(np.float32) 283 | refpaths_all = np.array(refpaths_all) 284 | 285 | sample = [] 286 | for i, item in enumerate(index): 287 | train_sel = (refname[index[i]] == refpaths_all) 288 | train_sel = np.where(train_sel == True) 289 | train_sel = train_sel[0].tolist() 290 | for j, item in enumerate(train_sel): 291 | for aug in range(patch_num): 292 | LQ_path = os.path.join(root, 'distorted_images', LQ_pathes[item]) 293 | refHQ_name = 'I' + LQ_pathes[item].split("_")[0][1:] + '.BMP' 294 | HQ_path = os.path.join(refpath, refHQ_name) 295 | label = labels[item] 296 | sample.append((LQ_path, HQ_path, label)) 297 | self.HQ_diff_content = [] 298 | for HQ_diff_content_img_name in os.listdir(HQ_diff_content_root): 299 | if HQ_diff_content_img_name[-3:] == 'png' or HQ_diff_content_img_name[-3:] == 'jpg' or HQ_diff_content_img_name[-3:] == 'bmp': 300 | self.HQ_diff_content.append(os.path.join(HQ_diff_content_root, HQ_diff_content_img_name)) 301 | 302 | self.samples = sample 303 | self.transform = transform 304 | self.HQ_diff_content_transform = HQ_diff_content_transform 305 | 306 | def _getTIDFileName(self, path, suffix): 307 | filename = [] 308 | f_list = os.listdir(path) 309 | for i in f_list: 310 | if suffix.find(os.path.splitext(i)[1]) != -1: 311 | filename.append(i[1:3]) 312 | return filename 313 | 314 | def __getitem__(self, index): 315 | """ 316 | Args: 317 | index (int): Index 318 | Returns: 319 | tuple: (LQ, HQ, HQ_diff_content, target) where target is IQA values of the target LQ. 320 | """ 321 | LQ_path, HQ_path, target = self.samples[index] 322 | HQ_diff_content_path = self.HQ_diff_content[random.randint(0, len(self.HQ_diff_content)-1)] 323 | LQ = pil_loader(LQ_path) 324 | HQ = pil_loader(HQ_path) 325 | HQ_diff_content = pil_loader(HQ_diff_content_path) 326 | LQ_patches, HQ_patches, HQ_diff_content_patches = [], [], [] 327 | for _ in range(self.self_patch_num): 328 | LQ_patch, HQ_patch = getPairRandomPatch(LQ, HQ, crop_size=self.patch_size) 329 | 330 | LQ_patch = self.transform(LQ_patch) 331 | HQ_patch = self.transform(HQ_patch) 332 | HQ_diff_content_patch = self.HQ_diff_content_transform(HQ_diff_content) 333 | 334 | LQ_patches.append(LQ_patch.unsqueeze(0)) 335 | HQ_patches.append(HQ_patch.unsqueeze(0)) 336 | HQ_diff_content_patches.append(HQ_diff_content_patch.unsqueeze(0)) 337 | #[self_patch_num, 3, patch_size, patch_size] 338 | LQ_patches = torch.cat(LQ_patches, 0) 339 | HQ_patches = torch.cat(HQ_patches, 0) 340 | HQ_diff_content_patches = torch.cat(HQ_diff_content_patches, 0) 341 | 342 | return LQ_patches, HQ_patches, HQ_diff_content_patches, target 343 | 344 | def __len__(self): 345 | length = len(self.samples) 346 | return length 347 | 348 | class LIVEChallengeFolder(data.Dataset): 349 | def __init__(self, root, HQ_diff_content_root, index, transform, HQ_diff_content_transform, patch_num, patch_size =224, self_patch_num=10): 350 | self.patch_size =patch_size 351 | self.self_patch_num = self_patch_num 352 | 353 | LQ_pathes = scipy.io.loadmat(os.path.join(root, 'Data', 'AllImages_release.mat')) 354 | LQ_pathes = LQ_pathes['AllImages_release'] 355 | LQ_pathes = LQ_pathes[7:1169] 356 | mos = scipy.io.loadmat(os.path.join(root, 'Data', 'AllMOS_release.mat')) 357 | labels = mos['AllMOS_release'].astype(np.float32) 358 | labels = labels[0][7:1169] 359 | 360 | sample = [] 361 | for _, item in enumerate(index): 362 | for _ in range(patch_num): 363 | sample.append((os.path.join(root, 'Images', LQ_pathes[item][0][0]), labels[item])) 364 | 365 | self.HQ_diff_content_paths = [] 366 | for HQ_diff_content_img_path in os.listdir(HQ_diff_content_root): 367 | if HQ_diff_content_img_path[-3:] == 'png' or HQ_diff_content_img_path[-3:] == 'jpg' or HQ_diff_content_img_path[-3:] == 'bmp': 368 | self.HQ_diff_content_paths.append(os.path.join(HQ_diff_content_root, HQ_diff_content_img_path)) 369 | 370 | self.samples = sample 371 | self.transform = transform 372 | self.HQ_diff_content_transform = HQ_diff_content_transform 373 | 374 | def __getitem__(self, index): 375 | """ 376 | Args: 377 | index (int): Index 378 | Returns: 379 | tuple: (LQ, _, HQ_diff_content, target) where target is IQA values of the target LQ. 380 | """ 381 | LQ_path, target = self.samples[index] 382 | HQ_diff_content_path = self.HQ_diff_content_paths[random.randint(0, len(self.HQ_diff_content_paths)-1)] 383 | LQ = pil_loader(LQ_path) 384 | HQ_diff_content = pil_loader(HQ_diff_content_path) 385 | LQ_patches, HQ_diff_content_patches = [], [] 386 | for _ in range(self.self_patch_num): 387 | LQ_patch = self.HQ_diff_content_transform(LQ) 388 | HQ_diff_content_patch = self.HQ_diff_content_transform(HQ_diff_content) 389 | 390 | LQ_patches.append(LQ_patch.unsqueeze(0)) 391 | HQ_diff_content_patches.append(HQ_diff_content_patch.unsqueeze(0)) 392 | #[self_patch_num, 3, patch_size, patch_size] 393 | LQ_patches = torch.cat(LQ_patches, 0) 394 | HQ_diff_content_patches = torch.cat(HQ_diff_content_patches, 0) 395 | 396 | return LQ_patches, _, HQ_diff_content_patches, target 397 | 398 | def __len__(self): 399 | length = len(self.samples) 400 | return length 401 | 402 | class BIDChallengeFolder(data.Dataset): 403 | def __init__(self, root, HQ_diff_content_root, index, transform, HQ_diff_content_transform, patch_num, patch_size =224, self_patch_num=10): 404 | self.patch_size =patch_size 405 | self.self_patch_num = self_patch_num 406 | 407 | LQ_pathes = [] 408 | labels = [] 409 | 410 | xls_file = os.path.join(root, 'DatabaseGrades.xlsx') 411 | workbook = load_workbook(xls_file) 412 | booksheet = workbook.active 413 | rows = booksheet.rows 414 | count = 1 415 | for _ in rows: 416 | count += 1 417 | img_num = (booksheet.cell(row=count, column=1).value) 418 | img_name = "DatabaseImage%04d.JPG" % (img_num) 419 | LQ_pathes.append(img_name) 420 | mos = (booksheet.cell(row=count, column=2).value) 421 | mos = np.array(mos) 422 | mos = mos.astype(np.float32) 423 | labels.append(mos) 424 | if count == 587: 425 | break 426 | 427 | sample = [] 428 | for _, item in enumerate(index): 429 | for _ in range(patch_num): 430 | sample.append((os.path.join(root, LQ_pathes[item]), labels[item])) 431 | 432 | self.HQ_diff_content_paths = [] 433 | for HQ_diff_content_img_path in os.listdir(HQ_diff_content_root): 434 | if HQ_diff_content_img_path[-3:] == 'png' or HQ_diff_content_img_path[-3:] == 'jpg' or HQ_diff_content_img_path[-3:] == 'bmp': 435 | self.HQ_diff_content_paths.append(os.path.join(HQ_diff_content_root, HQ_diff_content_img_path)) 436 | 437 | self.samples = sample 438 | self.transform = transform 439 | self.HQ_diff_content_transform = HQ_diff_content_transform 440 | 441 | def __getitem__(self, index): 442 | """ 443 | Args: 444 | index (int): Index 445 | Returns: 446 | tuple: (LQ, _, HQ_diff_content, target) where target is IQA values of the target LQ. 447 | """ 448 | LQ_path, target = self.samples[index] 449 | HQ_diff_content_path = self.HQ_diff_content_paths[random.randint(0, len(self.HQ_diff_content_paths)-1)] 450 | LQ = pil_loader(LQ_path) 451 | HQ_diff_content = pil_loader(HQ_diff_content_path) 452 | LQ_patches, HQ_diff_content_patches = [], [] 453 | for _ in range(self.self_patch_num): 454 | LQ_patch = self.HQ_diff_content_transform(LQ) 455 | HQ_diff_content_patch = self.HQ_diff_content_transform(HQ_diff_content) 456 | 457 | LQ_patches.append(LQ_patch.unsqueeze(0)) 458 | HQ_diff_content_patches.append(HQ_diff_content_patch.unsqueeze(0)) 459 | #[self_patch_num, 3, patch_size, patch_size] 460 | LQ_patches = torch.cat(LQ_patches, 0) 461 | HQ_diff_content_patches = torch.cat(HQ_diff_content_patches, 0) 462 | 463 | return LQ_patches, _, HQ_diff_content_patches, target 464 | 465 | def __len__(self): 466 | length = len(self.samples) 467 | return length 468 | 469 | class Koniq_10kFolder(data.Dataset): 470 | def __init__(self, root, HQ_diff_content_root, index, transform, HQ_diff_content_transform, patch_num, patch_size =224, self_patch_num=10): 471 | self.patch_size =patch_size 472 | self.self_patch_num = self_patch_num 473 | 474 | imgname = [] 475 | mos_all = [] 476 | csv_file = os.path.join(root, 'koniq10k_scores_and_distributions.csv') 477 | with open(csv_file) as f: 478 | reader = csv.DictReader(f) 479 | for row in reader: 480 | imgname.append(row['image_name']) 481 | mos = np.array(float(row['MOS_zscore'])).astype(np.float32) 482 | mos_all.append(mos) 483 | 484 | sample = [] 485 | for _, item in enumerate(index): 486 | for _ in range(patch_num): 487 | sample.append((os.path.join(root, '1024x768', imgname[item]), mos_all[item])) 488 | 489 | self.HQ_diff_content_paths = [] 490 | for HQ_diff_content_img_path in os.listdir(HQ_diff_content_root): 491 | if HQ_diff_content_img_path[-3:] == 'png' or HQ_diff_content_img_path[-3:] == 'jpg' or HQ_diff_content_img_path[-3:] == 'bmp': 492 | self.HQ_diff_content_paths.append(os.path.join(HQ_diff_content_root, HQ_diff_content_img_path)) 493 | 494 | 495 | self.samples = sample 496 | self.transform = transform 497 | self.HQ_diff_content_transform = HQ_diff_content_transform 498 | 499 | def __getitem__(self, index): 500 | """ 501 | Args: 502 | index (int): Index 503 | Returns: 504 | tuple: (LQ, _, HQ_diff_content, target) where target is IQA values of the target LQ. 505 | """ 506 | LQ_path, target = self.samples[index] 507 | HQ_diff_content_path = self.HQ_diff_content_paths[random.randint(0, len(self.HQ_diff_content_paths)-1)] 508 | LQ = pil_loader(LQ_path) 509 | HQ_diff_content = pil_loader(HQ_diff_content_path) 510 | LQ_patches, HQ_diff_content_patches = [], [] 511 | for _ in range(self.self_patch_num): 512 | LQ_patch = self.HQ_diff_content_transform(LQ) 513 | HQ_diff_content_patch = self.HQ_diff_content_transform(HQ_diff_content) 514 | 515 | LQ_patches.append(LQ_patch.unsqueeze(0)) 516 | HQ_diff_content_patches.append(HQ_diff_content_patch.unsqueeze(0)) 517 | #[self_patch_num, 3, patch_size, patch_size] 518 | LQ_patches = torch.cat(LQ_patches, 0) 519 | HQ_diff_content_patches = torch.cat(HQ_diff_content_patches, 0) 520 | 521 | return LQ_patches, _, HQ_diff_content_patches, target 522 | 523 | def __len__(self): 524 | length = len(self.samples) 525 | return length 526 | 527 | def getFileName(path, suffix): 528 | filename = [] 529 | f_list = os.listdir(path) 530 | for i in f_list: 531 | if os.path.splitext(i)[1] == suffix: 532 | filename.append(i) 533 | return filename 534 | 535 | 536 | def getPairRandomPatch(img1, img2, crop_size=512): 537 | (iw,ih) = img1.size 538 | # print(ih,iw) 539 | 540 | ip = int(crop_size) 541 | 542 | ix = random.randrange(0, iw - ip + 1) 543 | iy = random.randrange(0, ih - ip + 1) 544 | 545 | 546 | img1_patch = img1.crop((ix, iy, ix+ip, iy+ip))#左上右下 547 | img2_patch = img2.crop((ix, iy, ix+ip, iy+ip))#左上右下 548 | 549 | return img1_patch, img2_patch 550 | 551 | def getPairAugment(img1, img2, hflip=True, vflip=False, rot=False): 552 | hflip = hflip and random.random() < 0.5 553 | vflip = vflip and random.random() < 0.5 554 | rot180 = rot and random.random() < 0.5 555 | 556 | if hflip: 557 | img1 = img1.transpose(Image.FLIP_TOP_BOTTOM) 558 | img2 = img2.transpose(Image.FLIP_TOP_BOTTOM) 559 | if vflip: 560 | img1 = img1.transpose(Image.FLIP_LEFT_RIGHT) 561 | img2 = img2.transpose(Image.FLIP_LEFT_RIGHT) 562 | if rot180: 563 | img1 = img1.transpose(Image.ROTATE_180) 564 | img2 = img2.transpose(Image.ROTATE_180) 565 | 566 | return img1, img2 567 | 568 | 569 | def getSelfPatch(img, patch_size, patch_num, is_random=True): 570 | (iw,ih) = img.size 571 | patches = [] 572 | for i in range(patch_num): 573 | if is_random: 574 | ix = random.randrange(0, iw - patch_size + 1) 575 | iy = random.randrange(0, ih - patch_size + 1) 576 | else:ix,iy=(iw - patch_size + 1)//2,(ih - patch_size + 1)//2 577 | 578 | # patch = img[iy:iy + lr_size, ix:ix + lr_size, :]#上下左右 579 | patch = img.crop((ix, iy, ix+patch_size, iy+patch_size))#左上右下 580 | patches.append(patch) 581 | 582 | return patches 583 | 584 | 585 | def pil_loader(path): 586 | with open(path, 'rb') as f: 587 | img = Image.open(f) 588 | return img.convert('RGB') 589 | 590 | --------------------------------------------------------------------------------