├── .gitignore ├── README.md ├── checkpoint └── README.md ├── demo.py ├── demoImg ├── I_01.jpg ├── I_02.jpg ├── I_03.jpg └── I_04.jpg ├── model ├── studentNetwork.py └── teacherNetwork.py ├── pic └── framework.jpg └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | *.pkl 2 | *.pyc -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TeacherIQA 2 | 3 | ![visitors](https://visitor-badge.laobi.icu/badge?page_id=chencn2020/TeacherIQA) 4 | [![Open issue](https://img.shields.io/github/issues/chencn2020/TeacherIQA)](https://github.com/chencn2020/TeacherIQA/issues) 5 | [![Closed issue](https://img.shields.io/github/issues-closed/chencn2020/TeacherIQA)](https://github.com/chencn2020/TeacherIQA/issues) 6 | [![GitHub Stars](https://img.shields.io/github/stars/chencn2020/TeacherIQA?style=social)](https://github.com/chencn2020/TeacherIQA) 7 | 8 | This repository is the source code for the paper "[Teacher-Guided Learning for Blind Image Quality Assessment](https://openaccess.thecvf.com/content/ACCV2022/html/Chen_Teacher-Guided_Learning_for_Blind_Image_Quality_Assessment_ACCV_2022_paper.html)". 9 | 10 | ![Framework](./pic/framework.jpg) 11 | 12 | ## Dependencies 13 | 14 | - matplotlib==3.2.2 15 | - numpy==1.22.3 16 | - Pillow==9.2.0 17 | - torch==1.11.0 18 | - torchvision==0.11.2+cu113 19 | 20 | ## Usages For Testing 21 | 22 | 23 | You can predict the image quality score for any images with our model which is trained on KonIq-10k dataset. 24 | 25 | The pre-trained model can be downloaded from [Google drive](https://drive.google.com/file/d/1iNhJQpUWSAkwSfDbfXzu834gm7NoT3m0/view?usp=sharing) or [Baidu Netdisk](https://pan.baidu.com/s/1aE8_stfHexjzPECyk1YlwA) (password: b86d). 26 | 27 | Please put the pre-trained model into **'./checkpoint'** folder. Then run: 28 | 29 | ``` 30 | python3 demo.py --input_image ./demoImg/I_02.jpg --pre_train_model ./checkpoint/koniq_teacher_iqa.pkl --crop_times 25 31 | ``` 32 | 33 | The input image will be randomly crop into 25 patches in size 224 × 224 and the IQA model will predict 25 scores for each patches. 34 | 35 | Finally, you will get an average quality score ranging from 0-100. But there exists some cases whose the value may be out of the range. The higher value is, the better image quality is. 36 | 37 | 38 | ## Citation 39 | If our work is useful to your research, we will be grateful for you to cite our paper: 40 | ``` 41 | @InProceedings{Chen_2022_ACCV, 42 | author = {Chen, Zewen and Wang, Juan and Li, Bing and Yuan, Chunfeng and Xiong, Weihua and Cheng, Rui and Hu, Weiming}, 43 | title = {Teacher-Guided Learning for Blind Image Quality Assessment}, 44 | booktitle = {Proceedings of the Asian Conference on Computer Vision (ACCV)}, 45 | month = {December}, 46 | year = {2022}, 47 | pages = {2457-2474} 48 | } 49 | 50 | -------------------------------------------------------------------------------- /checkpoint/README.md: -------------------------------------------------------------------------------- 1 | Please download the checkpoint trained on KonIQ-10k dataset from [Google drive](https://drive.google.com/file/d/1iNhJQpUWSAkwSfDbfXzu834gm7NoT3m0/view?usp=sharing) or [Baidu Netdisk](https://pan.baidu.com/s/1aE8_stfHexjzPECyk1YlwA) (password: b86d). -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | from PIL import Image 4 | import numpy as np 5 | from model import studentNetwork as IQAModel 6 | import argparse 7 | import warnings 8 | import os 9 | 10 | os.environ['CUDA_VISIBLE_DEVICES'] = '4' 11 | warnings.filterwarnings("ignore") 12 | 13 | def pil_loader(path): 14 | with open(path, 'rb') as f: 15 | img = Image.open(f) 16 | return img.convert('RGB') 17 | 18 | def predict_IQA_Score(config): 19 | im_path = config.input_image 20 | 21 | # load the model 22 | model_hyper = IQAModel.StudentNetwork().cuda() 23 | model_hyper.load_state_dict((torch.load(config.pre_train_model))) 24 | model_hyper.eval() 25 | 26 | # define the way of transforming. 27 | transforms = torchvision.transforms.Compose([ 28 | torchvision.transforms.RandomCrop(size=224), 29 | torchvision.transforms.ToTensor()]) 30 | 31 | img_or = pil_loader(im_path) 32 | pred_scores = [] 33 | 34 | # crop the image 25 times 35 | for time in range(config.crop_times): 36 | img = transforms(img_or) 37 | img = torch.tensor(img.cuda()).unsqueeze(0) 38 | pred = model_hyper(img) 39 | pred_scores.append(float(pred.item())) 40 | 41 | # calculate the average score. 42 | score = np.mean(pred_scores) 43 | 44 | return score 45 | 46 | if __name__ == '__main__': 47 | parser = argparse.ArgumentParser() 48 | parser.add_argument('--input_image', dest='input_image', type=str, required=True) 49 | parser.add_argument('--pre_train_model', dest='pre_train_model', type=str, required=True) 50 | parser.add_argument('--crop_times', dest='crop_times', type=int, default=25) 51 | config = parser.parse_args() 52 | 53 | score = predict_IQA_Score(config) 54 | print('Final Average Predicted Quality Score: {}'.format(round(score, 2))) 55 | -------------------------------------------------------------------------------- /demoImg/I_01.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chencn2020/TeacherIQA/d9c5cf22fdd6cc70ed5a15cff4e048a55607ce07/demoImg/I_01.jpg -------------------------------------------------------------------------------- /demoImg/I_02.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chencn2020/TeacherIQA/d9c5cf22fdd6cc70ed5a15cff4e048a55607ce07/demoImg/I_02.jpg -------------------------------------------------------------------------------- /demoImg/I_03.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chencn2020/TeacherIQA/d9c5cf22fdd6cc70ed5a15cff4e048a55607ce07/demoImg/I_03.jpg -------------------------------------------------------------------------------- /demoImg/I_04.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chencn2020/TeacherIQA/d9c5cf22fdd6cc70ed5a15cff4e048a55607ce07/demoImg/I_04.jpg -------------------------------------------------------------------------------- /model/studentNetwork.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from model import teacherNetwork as TN 4 | 5 | 6 | class selfAttention(nn.Module): 7 | """ Self attention Layer""" 8 | 9 | def __init__(self, in_dim): 10 | super(selfAttention, self).__init__() 11 | 12 | self.qConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1) 13 | self.kConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1) 14 | self.vConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1) 15 | self.gamma = nn.Parameter(torch.zeros(1)) 16 | 17 | self.softmax = nn.Softmax(dim=-1) 18 | 19 | def forward(self, inFeature): 20 | bs, C, w, h = inFeature.size() 21 | 22 | proj_query = self.qConv(inFeature).view(bs, -1, w * h).permute(0, 2, 1) 23 | proj_key = self.kConv(inFeature).view(bs, -1, w * h) 24 | energy = torch.bmm(proj_query, proj_key) 25 | attention = self.softmax(energy) 26 | proj_value = self.vConv(inFeature).view(bs, -1, w * h) 27 | 28 | out = torch.bmm(proj_value, attention.permute(0, 2, 1)) 29 | out = out.view(bs, C, w, h) 30 | 31 | out = self.gamma * out + inFeature 32 | 33 | return out 34 | 35 | class enc(nn.Module): 36 | def __init__(self, ch): 37 | super(enc, self).__init__() 38 | self.KI = selfAttention(ch) 39 | 40 | def forward(self, KL, distortionKL): 41 | fusionKL = torch.cat((KL, distortionKL), dim=1) 42 | KIRes = self.KI(fusionKL) 43 | return KIRes 44 | 45 | class StudentNetwork(nn.Module): 46 | def __init__(self): 47 | super(StudentNetwork, self).__init__() 48 | 49 | self.tn = TN.TeacherNetwork(3).cuda() 50 | 51 | self.bottomConv = nn.Sequential( 52 | nn.Conv2d(2048, 1024, kernel_size=1), 53 | nn.ReLU(inplace=True), 54 | nn.Conv2d(1024, 512, kernel_size=1), 55 | nn.ReLU(inplace=True), 56 | nn.Conv2d(512, 256, kernel_size=1), 57 | nn.ReLU(inplace=True), 58 | ) 59 | 60 | self.upMC1Conv = nn.Sequential( 61 | nn.Conv2d(1024, 256, kernel_size=1), 62 | nn.ReLU(inplace=True), 63 | nn.AvgPool2d(2, 2), 64 | ) 65 | 66 | self.upMC2Conv = nn.Sequential( 67 | nn.Conv2d(512, 512, kernel_size=1), 68 | nn.ReLU(inplace=True), 69 | nn.AvgPool2d(4, 4), 70 | ) 71 | 72 | self.upMC3Conv = nn.Sequential( 73 | nn.Conv2d(256, 1024, kernel_size=1), 74 | nn.ReLU(inplace=True), 75 | nn.AvgPool2d(8, 8), 76 | ) 77 | 78 | self.enc1 = enc(512) 79 | self.enc2 = enc(1024) 80 | self.enc3 = enc(2048) 81 | 82 | self.iqaScore = nn.Sequential( 83 | nn.Conv2d(2048, 1024, kernel_size=1, stride=1), # 24 24 84 | nn.ReLU(inplace=True), 85 | nn.Conv2d(1024, 512, kernel_size=1, stride=1), # 1 8 8 86 | nn.ReLU(inplace=True), 87 | nn.Conv2d(512, 256, kernel_size=1, stride=1), # 1 8 8 88 | nn.ReLU(inplace=True), 89 | nn.Conv2d(256, 32, kernel_size=1, stride=1), # 1 8 8 90 | nn.ReLU(inplace=True), 91 | nn.Conv2d(32, 1, kernel_size=7, stride=1), # 1 10 10 92 | ) 93 | 94 | def forward(self, img): 95 | _, semanticKL, [distortionKL1, distortionKL2, distortionKL3] = self.tn(img) 96 | resNetBottomFeature = self.bottomConv(semanticKL) # n, 32, 28, 28 97 | 98 | distortionKL1 = self.upMC1Conv(distortionKL1) # n, 32, 28, 28 99 | distortionKL2 = self.upMC2Conv(distortionKL2) # n, 32, 28, 28 100 | distortionKL3 = self.upMC3Conv(distortionKL3) # n, 32, 28, 28 101 | 102 | attention1 = self.enc1(resNetBottomFeature, distortionKL1) 103 | attention2 = self.enc2(attention1, distortionKL2) 104 | attention3 = self.enc3(attention2, distortionKL3) 105 | 106 | return self.iqaScore(attention3).view(img.shape[0]) 107 | 108 | if __name__ == '__main__': 109 | net = StudentNetwork().cuda() 110 | inputs = torch.zeros((1, 3, 224, 224), dtype=torch.float32).cuda() 111 | output = net(inputs) 112 | print(output.size()) 113 | -------------------------------------------------------------------------------- /model/teacherNetwork.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | import torch.utils.model_zoo as model_zoo 5 | 6 | model_urls = { 7 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 8 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 9 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 10 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 11 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 12 | } 13 | 14 | 15 | class Bottleneck(nn.Module): 16 | expansion = 4 17 | 18 | def __init__(self, inplanes, planes, stride=1, downsample=None): 19 | super(Bottleneck, self).__init__() 20 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 21 | self.bn1 = nn.BatchNorm2d(planes) 22 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 23 | self.bn2 = nn.BatchNorm2d(planes) 24 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 25 | self.bn3 = nn.BatchNorm2d(planes * 4) 26 | self.relu = nn.ReLU(inplace=True) 27 | self.downsample = downsample 28 | self.stride = stride 29 | 30 | def forward(self, x): 31 | residual = x 32 | 33 | out = self.conv1(x) 34 | out = self.bn1(out) 35 | out = self.relu(out) 36 | 37 | out = self.conv2(out) 38 | out = self.bn2(out) 39 | out = self.relu(out) 40 | 41 | out = self.conv3(out) 42 | out = self.bn3(out) 43 | 44 | if self.downsample is not None: 45 | residual = self.downsample(x) 46 | 47 | out += residual 48 | out = self.relu(out) 49 | 50 | return out 51 | 52 | 53 | def resnet50_backbone(pretrained=False, **kwargs): 54 | """Constructs a ResNet-50 model_hyper. 55 | 56 | Args: 57 | pretrained (bool): If True, returns a model_hyper pre-trained on ImageNet 58 | """ 59 | model = ResNetBackbone(Bottleneck, [3, 4, 6, 3], **kwargs) 60 | if pretrained: 61 | save_model = model_zoo.load_url(model_urls['resnet50']) 62 | model_dict = model.state_dict() 63 | state_dict = {k: v for k, v in save_model.items() if k in model_dict.keys()} 64 | model_dict.update(state_dict) 65 | model.load_state_dict(model_dict) 66 | return model 67 | 68 | 69 | class ResNetBackbone(nn.Module): 70 | 71 | def __init__(self, block, layers, num_classes=1000): 72 | super(ResNetBackbone, self).__init__() 73 | self.inplanes = 64 74 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 75 | self.bn1 = nn.BatchNorm2d(64) 76 | self.relu = nn.ReLU(inplace=True) 77 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 78 | self.layer1 = self._make_layer(block, 64, layers[0]) 79 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 80 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 81 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 82 | 83 | def _make_layer(self, block, planes, blocks, stride=1): 84 | downsample = None 85 | if stride != 1 or self.inplanes != planes * block.expansion: 86 | downsample = nn.Sequential( 87 | nn.Conv2d(self.inplanes, planes * block.expansion, 88 | kernel_size=1, stride=stride, bias=False), 89 | nn.BatchNorm2d(planes * block.expansion), 90 | ) 91 | 92 | layers = [] 93 | layers.append(block(self.inplanes, planes, stride, downsample)) 94 | self.inplanes = planes * block.expansion 95 | for i in range(1, blocks): 96 | layers.append(block(self.inplanes, planes)) 97 | 98 | return nn.Sequential(*layers) 99 | 100 | def forward(self, x): 101 | 102 | x = self.conv1(x) 103 | x = self.bn1(x) 104 | x = self.relu(x) 105 | x0 = self.maxpool(x) 106 | x1 = self.layer1(x0) 107 | x2 = self.layer2(x1) 108 | x3 = self.layer3(x2) 109 | bottom = self.layer4(x3) 110 | 111 | return x, x1, x2, x3, bottom 112 | 113 | 114 | class TripleConv(nn.Module): 115 | def __init__(self, in_ch, out_ch): 116 | super(TripleConv, self).__init__() 117 | hide_ch = out_ch // 2 118 | self.TripleConv = nn.Sequential( 119 | nn.Conv2d(in_ch, hide_ch, 3, padding=1, groups=hide_ch), 120 | nn.BatchNorm2d(hide_ch), 121 | nn.ReLU(inplace=True), 122 | nn.Conv2d(hide_ch, hide_ch, 3, padding=1, groups=hide_ch), 123 | nn.BatchNorm2d(hide_ch), 124 | nn.ReLU(inplace=True), 125 | nn.Conv2d(hide_ch, out_ch, 3, padding=1, groups=hide_ch), 126 | ) 127 | 128 | def forward(self, x): 129 | return self.TripleConv(x) 130 | 131 | 132 | class Inception(nn.Module): 133 | def __init__(self, in_ch, out_ch): 134 | super(Inception, self).__init__() 135 | 136 | out = out_ch // 4 137 | hide_ch = out // 2 138 | 139 | self.p1 = nn.Sequential( 140 | nn.Conv2d(in_ch, out, kernel_size=1), 141 | ) 142 | 143 | self.p2 = nn.Sequential( 144 | nn.Conv2d(in_ch, hide_ch, 1), 145 | nn.BatchNorm2d(hide_ch), 146 | nn.ReLU(inplace=True), 147 | nn.Conv2d(hide_ch, out, 5, padding=2, groups=hide_ch), 148 | ) 149 | 150 | self.p3 = nn.Sequential( 151 | nn.Conv2d(in_ch, hide_ch, 1), 152 | nn.BatchNorm2d(hide_ch), 153 | nn.ReLU(inplace=True), 154 | nn.Conv2d(hide_ch, hide_ch, 3, padding=1, groups=hide_ch), 155 | nn.BatchNorm2d(hide_ch), 156 | nn.ReLU(inplace=True), 157 | nn.Conv2d(hide_ch, out, 1) 158 | ) 159 | self.p4 = nn.Sequential( 160 | nn.MaxPool2d(kernel_size=3, stride=1, padding=1), 161 | nn.Conv2d(in_ch, out, kernel_size=1), 162 | ) 163 | 164 | def forward(self, x): 165 | p1 = self.p1(x) 166 | p2 = self.p2(x) 167 | p3 = self.p3(x) 168 | p4 = self.p4(x) 169 | 170 | return torch.cat((p1, p2, p3, p4), dim=1) 171 | 172 | 173 | class InceptionConv(nn.Module): 174 | def __init__(self, in_ch, out_ch): 175 | super(InceptionConv, self).__init__() 176 | self.inception = nn.Sequential( 177 | Inception(in_ch, out_ch), 178 | nn.BatchNorm2d(out_ch), 179 | nn.ReLU(inplace=True), 180 | ) 181 | 182 | self.tripleConv = nn.Sequential( 183 | TripleConv(out_ch, out_ch), 184 | nn.BatchNorm2d(out_ch), 185 | nn.ReLU(inplace=True) 186 | ) 187 | 188 | def forward(self, x): 189 | inceptionRes = self.inception(x) 190 | return self.tripleConv(inceptionRes) 191 | 192 | 193 | class MC(nn.Module): 194 | def __init__(self, in_ch, out_ch): 195 | super(MC, self).__init__() 196 | self.MCConv = InceptionConv(in_ch, out_ch) 197 | 198 | def pad(self, x1, x2): 199 | diffY = x2.size()[2] - x1.size()[2] 200 | diffX = x2.size()[3] - x1.size()[3] 201 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, 202 | diffY // 2, diffY - diffY // 2]) 203 | return x1 204 | 205 | def forward(self, encoderFeature, decoderFeature): 206 | encoderFeature = self.pad(encoderFeature, decoderFeature) 207 | merge = torch.cat([encoderFeature, decoderFeature], dim=1) 208 | return self.MCConv(merge) 209 | 210 | 211 | class TeacherNetwork(nn.Module): 212 | def __init__(self, out_ch, pretrainedResnet=True): 213 | super(TeacherNetwork, self).__init__() 214 | 215 | self.resNet = resnet50_backbone(pretrained=pretrainedResnet) 216 | 217 | self.bottom = InceptionConv(2048, 2048) 218 | 219 | # up 220 | self.up1 = nn.ConvTranspose2d(2048, 1024, 2, 2) 221 | self.upMC1 = MC(1024 * 2, 1024) # 14 * 14 222 | 223 | self.up2 = nn.ConvTranspose2d(1024, 512, 2, 2) 224 | self.upMC2 = MC(512 * 2, 512) # 28 * 28 225 | 226 | self.up3 = nn.ConvTranspose2d(512, 256, 2, 2) 227 | self.upMC3 = MC(256 * 2, 256) # 56 * 56 228 | 229 | self.up4 = nn.ConvTranspose2d(256, 64, 2, 2) # 112 * 112 230 | self.upMC4 = MC(64 * 2, 64) # 56 * 56 231 | 232 | self.up5 = nn.ConvTranspose2d(64, 64, 2, 2) # 224 * 224 233 | self.out = nn.Conv2d(64, out_ch, 1) 234 | 235 | def forward(self, x): 236 | # down 237 | encoderFeature0, encoderFeature1, encoderFeature2, encoderFeature3, bottom = self.resNet(x) 238 | bottom = self.bottom(bottom) 239 | 240 | # up 241 | decoderFeature1 = self.up1(bottom) # 14 * 14 242 | upMC1 = self.upMC1(decoderFeature1, encoderFeature3) 243 | 244 | decoderFeature2 = self.up2(upMC1) 245 | upMC2 = self.upMC2(decoderFeature2, encoderFeature2) # 28 * 28 246 | 247 | decoderFeature3 = self.up3(upMC2) 248 | upMC3 = self.upMC3(decoderFeature3, encoderFeature1) # 56 * 56 249 | 250 | decoderFeature4 = self.up4(upMC3) 251 | upMC4 = self.upMC4(decoderFeature4, encoderFeature0) # 56 * 56 252 | 253 | up5 = self.up5(upMC4) 254 | out = self.out(up5) 255 | 256 | return nn.Sigmoid()(out), bottom, [upMC1, upMC2, upMC3] 257 | 258 | 259 | if __name__ == '__main__': 260 | net = TeacherNetwork(3) 261 | inputs = torch.zeros((1, 3, 224, 224), dtype=torch.float32) 262 | output = net(inputs) 263 | print(output[0].size()) 264 | -------------------------------------------------------------------------------- /pic/framework.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chencn2020/TeacherIQA/d9c5cf22fdd6cc70ed5a15cff4e048a55607ce07/pic/framework.jpg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==3.2.2 2 | numpy==1.22.3 3 | Pillow==9.2.0 4 | torch==1.11.0 5 | torchvision==0.11.2+cu113 6 | --------------------------------------------------------------------------------