├── demo └── 8.jpg ├── figures ├── macs.PNG ├── framework_UHD_IQA.PNG └── UHD_Image_Preprecessing.PNG ├── utils.py ├── models └── UIQA.py ├── README.md ├── test_single_image.py ├── train_AVA.py ├── train.py └── IQADataset.py /demo/8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwei925/UIQA/HEAD/demo/8.jpg -------------------------------------------------------------------------------- /figures/macs.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwei925/UIQA/HEAD/figures/macs.PNG -------------------------------------------------------------------------------- /figures/framework_UHD_IQA.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwei925/UIQA/HEAD/figures/framework_UHD_IQA.PNG -------------------------------------------------------------------------------- /figures/UHD_Image_Preprecessing.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwei925/UIQA/HEAD/figures/UHD_Image_Preprecessing.PNG -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.optimize import curve_fit 3 | from scipy import stats 4 | import torch 5 | 6 | def logistic_func(X, bayta1, bayta2, bayta3, bayta4): 7 | logisticPart = 1 + np.exp(np.negative(np.divide(X - bayta3, np.abs(bayta4)))) 8 | yhat = bayta2 + np.divide(bayta1 - bayta2, logisticPart) 9 | return yhat 10 | 11 | def fit_function(y_label, y_output): 12 | beta = [np.max(y_label), np.min(y_label), np.mean(y_output), 0.5] 13 | popt, _ = curve_fit(logistic_func, y_output, \ 14 | y_label, p0=beta, maxfev=100000000) 15 | y_output_logistic = logistic_func(y_output, *popt) 16 | 17 | return y_output_logistic 18 | 19 | def fit_function_regression_values(y_label, y_output): 20 | beta = [np.max(y_label), np.min(y_label), np.mean(y_output), 0.5] 21 | popt, _ = curve_fit(logistic_func, y_output, \ 22 | y_label, p0=beta, maxfev=100000000) 23 | y_output_logistic = logistic_func(y_output, *popt) 24 | 25 | return y_output_logistic, popt 26 | 27 | 28 | def performance_fit(y_label, y_output): 29 | y_output_logistic, popt = fit_function_regression_values(y_label, y_output) 30 | PLCC = stats.pearsonr(y_output_logistic, y_label)[0] 31 | SRCC = stats.spearmanr(y_output, y_label)[0] 32 | KRCC = stats.stats.kendalltau(y_output, y_label)[0] 33 | RMSE = np.sqrt(((y_output_logistic-y_label) ** 2).mean()) 34 | MAE = np.absolute((y_output_logistic-y_label)).mean() 35 | 36 | return PLCC, SRCC, KRCC, RMSE, MAE, popt 37 | 38 | EPS = 1e-2 39 | esp = 1e-8 40 | 41 | class Fidelity_Loss(torch.nn.Module): 42 | 43 | def __init__(self): 44 | super(Fidelity_Loss, self).__init__() 45 | 46 | def forward(self, p, g): 47 | g = g.view(-1, 1) 48 | p = p.view(-1, 1) 49 | loss = 1 - (torch.sqrt(p * g + esp) + torch.sqrt((1 - p) * (1 - g) + esp)) 50 | 51 | return torch.mean(loss) -------------------------------------------------------------------------------- /models/UIQA.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.models as models 4 | 5 | 6 | 7 | 8 | 9 | 10 | class Identity(nn.Module): 11 | def __init__(self): 12 | super(Identity, self).__init__() 13 | 14 | def forward(self, x): 15 | return x 16 | 17 | class Model_SwinT(nn.Module): 18 | def __init__(self): 19 | super(Model_SwinT, self).__init__() 20 | 21 | model = models.swin_t(weights='Swin_T_Weights.DEFAULT') 22 | model.head = Identity() 23 | 24 | # spatial quality analyzer 25 | self.feature_extraction = model 26 | 27 | # quality regressor 28 | self.quality = self.quality_regression(768, 128, 1) 29 | 30 | def quality_regression(self,in_channels, middle_channels, out_channels): 31 | regression_block = nn.Sequential( 32 | nn.Linear(in_channels, middle_channels), 33 | nn.Linear(middle_channels, out_channels), 34 | ) 35 | 36 | return regression_block 37 | 38 | def forward(self, x): 39 | 40 | x = self.feature_extraction(x) 41 | x = self.quality(x) 42 | 43 | return x 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | class UIQA_Model(torch.nn.Module): 69 | def __init__(self, pretrained_path=None): 70 | 71 | super(UIQA_Model, self).__init__() 72 | # Aesthetics feature extractor 73 | swin_t_aesthetics = Model_SwinT() 74 | if pretrained_path!=None: 75 | print('load aesthetics model') 76 | swin_t_aesthetics.load_state_dict(torch.load(pretrained_path)) 77 | 78 | # Distortion feature extractor 79 | swin_t_distortion = Model_SwinT() 80 | if pretrained_path!=None: 81 | print('load distortion model') 82 | swin_t_distortion.load_state_dict(torch.load(pretrained_path)) 83 | 84 | # Salient image feature extractor 85 | swin_t_salient = Model_SwinT() 86 | if pretrained_path!=None: 87 | print('load saliency model') 88 | swin_t_salient.load_state_dict(torch.load(pretrained_path)) 89 | 90 | self.aesthetics_feature_extraction = swin_t_aesthetics.feature_extraction 91 | self.distortion_feature_extraction = swin_t_distortion.feature_extraction 92 | self.saliency_feature_extraction = swin_t_salient.feature_extraction 93 | self.quality = self.quality_regression(768+768+768, 128, 1) 94 | 95 | def quality_regression(self,in_channels, middle_channels, out_channels): 96 | regression_block = nn.Sequential( 97 | nn.Linear(in_channels, middle_channels), 98 | nn.Linear(middle_channels, out_channels), 99 | ) 100 | 101 | return regression_block 102 | 103 | def forward(self, x_aesthetics, x_distortion, x_saliency): 104 | 105 | x_aesthetics = self.aesthetics_feature_extraction(x_aesthetics) 106 | x_distortion = self.distortion_feature_extraction(x_distortion) 107 | x_saliency = self.saliency_feature_extraction(x_saliency) 108 | # fuse the features 109 | x = torch.cat((x_aesthetics, x_distortion, x_saliency), dim = 1) 110 | 111 | x = self.quality(x) 112 | 113 | return x -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # UIQA 2 | 3 | ![visitors](https://visitor-badge.laobi.icu/badge?page_id=sunwei925/UIQA) [![](https://img.shields.io/github/stars/sunwei925/UIQA)](https://github.com/sunwei925/UIQA) 4 | [![Pytorch](https://img.shields.io/badge/PyTorch-1.13%2B-brightgree?logo=PyTorch)](https://pytorch.org/) 5 | [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/sunwei925/UIQA) 6 | [![arXiv](https://img.shields.io/badge/build-paper-red?logo=arXiv&label=arXiv)](https://arxiv.org/abs/2409.00749) 7 | 8 | 🏆 🥇 **Winner solution for [ECCV AIM 2024 UHD-IQA Challenge: Pushing the Boundaries of Blind Photo Quality Assessment](https://codalab.lisn.upsaclay.fr/competitions/19335) at the [AIM 2024](https://cvlai.net/aim/2024/) workshop @ ECCV 2024** 9 | 10 | Official Code for **[Assessing UHD Image Quality from Aesthetics, Distortions, and Saliency](https://arxiv.org/abs/2409.00749)** 11 | 12 | ## Introduction 13 | > **UHD images**, typically with resolutions equal to or higher than 4K, pose a significant challenge for efficient image quality assessment (IQA) algorithms, as adopting full-resolution images as inputs leads to overwhelming computational complexity and commonly used pre-processing methods like resizing or cropping may cause substantial loss of detail. To address this problem, we design a multi-branch deep neural network (DNN) to assess the quality of UHD images from three perspectives: **global aesthetic characteristics, local technical distortions, and salient content perception**. Specifically, *aesthetic features are extracted from low-resolution images downsampled from the UHD ones*, which lose high-frequency texture information but still preserve the global aesthetics characteristics. *Technical distortions are measured using a fragment image composed of mini-patches cropped from UHD images based on the grid mini-patch sampling strategy*. *The salient content of UHD images is detected and cropped to extract quality-aware features from the salient regions*. We adopt the Swin Transformer Tiny as the backbone networks to extract features from these three perspectives. The extracted features are concatenated and regressed into quality scores by a two-layer multi-layer perceptron (MLP) network. We employ the mean square error (MSE) loss to optimize prediction accuracy and the fidelity loss to optimize prediction monotonicity. Experimental results show that the proposed model achieves the best performance on the UHD-IQA dataset while maintaining the lowest computational complexity, demonstrating its effectiveness and efficiency. Moreover, the proposed model won **first prize in ECCV AIM 2024 UHD-IQA Challenge**. 14 | 15 | ## Image Pre-processing 16 | ![Image Pre-processing Figure](./figures/UHD_Image_Preprecessing.PNG) 17 | 18 | > The different image pre-processing methods for UHD images. (a) is the proposed method, which utilizes the resized image, the fragment image, and the salient patch to extract features of aesthetic, distortion, and salient content. (b) samples all non-overlapped image patches for feature extraction. (c) selects three representative patches with the highest texture complexity for feature extraction. 19 | 20 | ## Model 21 | ![Model Figure](./figures/framework_UHD_IQA.PNG) 22 | 23 | > The diagram of the proposed model. It consists of three modules: the image pre-processing module, the feature extraction module, and the quality regression module. We assess the quality of UHD images from three perspectives: global aesthetic characteristics, local technical distortions, and salient content perception, which are evaluated by the aesthetic assessment branch, distortion measurement branch, and salient content perception branch, respectively. 24 | 25 | 27 | 28 | ## Performance 29 | #### Compared with state-of-the-art IQA methods 30 | - Performance on the validation set of the UHD-IQA dataset 31 | 32 | | Methods | SRCC | PLCC | KRCC | RMSE | MAE | 33 | | :---: | :---:| :---:|:---: |:---: |:---: | 34 | |HyperIQA|0.524|0.182| 0.359|0.087| 0.055| 35 | |Effnet-2C-MLSP|0.615| 0.627|0.445|0.060|0.050| 36 | |CONTRIQUE|0.716| 0.712|0.521|0.049|0.038| 37 | |ARNIQA|0.718|0.717|0.523| 0.050|0.039| 38 | |CLIP-IQA+|0.743|0.732| 0.546| 0.108|0.087| 39 | |QualiCLIP|0.757|0.752|0.557|0.079|0.064| 40 | |**UIQA**|**0.817**| **0.823**| **0.625**|**0.040**| **0.032**| 41 | 42 | - Performance on the test set of the UHD-IQA dataset 43 | 44 | | Methods | SRCC | PLCC | KRCC | RMSE | MAE | 45 | | :---: | :---:| :---:|:---: |:---: |:---: | 46 | |HyperIQA|0.553| 0.103| 0.389|0.118|0.070 | 47 | |Effnet-2C-MLSP|0.675|0.641 | 0.491|0.074|0.059| 48 | |CONTRIQUE|0.732| 0.678|0.532| 0.073|0.052| 49 | |ARNIQA|0.739|0.694|0.544| 0.052|0.739| 50 | |CLIP-IQA+|0.747| 0.709| 0.551| 0.111| 0.089| 51 | |QualiCLIP|0.770|0.725|0.570|0.083|0.066| 52 | |**UIQA**|**0.846**| **0.798**|**0.657**|**0.061**| **0.042**| 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | #### Performance on ECCV AIM 2024 UHD-IQA Challenge 64 | | Team | SRCC | PLCC | KRCC | RMSE | MAE | 65 | | :---: | :---:| :---:|:---: |:---: |:---: | 66 | | **SJTU MMLab (ours)** | **0.846** | 0.798 | **0.657** | **0.061** | **0.042** | 67 | | CIPLAB | 0.835 | **0.800** | 0.642 | 0.064 | 0.044 | 68 | | ZX AIE Vector | 0.795 | 0.768 | 0.605 | 0.062 | 0.044 | 69 | | I2Group | 0.788 | 0.756 | 0.598 | 0.066 | 0.046 | 70 | | Dominator | 0.731 | 0.712 | 0.539 | 0.072 | 0.052 | 71 | |ICL|0.517| 0.521|0.361| 0.136| 0.115| 72 | 73 | - for more results on the ECCV AIM UHD IQA challenge, please refer to the [challenge report](https://arxiv.org/abs/2409.16271). 74 | 75 | ## Usage 76 | ### Environments 77 | - Requirements: 78 | ``` 79 | torch(>=1.13), torchvision, pandas, ptflops, numpy, Pillow 80 | ``` 81 | - Create a new environment 82 | ``` 83 | conda create -n UIQA python=3.8 84 | conda activate UIQA 85 | conda install pytorch torchvision torchaudio pytorch-cuda=12.4 -c pytorch -c nvidia # this command install pytorch version of 2.40, you can install pytorch >=1.13 86 | pip install pandas ptflops numpy 87 | ``` 88 | 89 | ### Dataset 90 | Download the [UHD-IQA dataset](https://database.mmsp-kn.de/uhd-iqa-benchmark-database.html). 91 | 92 | ### Train UIQA 93 | 94 | Download the [pre-trained model](https://www.dropbox.com/scl/fi/dk6co7hqquxpuq1nh04gf/Model_SwinT_AVA_epoch_10.pth?rlkey=tp13fdewe7hdosc3dja6al2dx&st=rg7tsy3t&dl=0) on AVA. 95 | 96 | ``` 97 | CUDA_VISIBLE_DEVICES=0,1 python -u train.py \ 98 | --num_epochs 100 \ 99 | --batch_size 12 \ 100 | --n_fragment 15 \ 101 | --resize 512 \ 102 | --crop_size 480 \ 103 | --salient_patch_dimension 480 \ 104 | --lr 0.00001 \ 105 | --lr_weight_L2 0.1 \ 106 | --lr_weight_pair 1 \ 107 | --decay_ratio 0.9 \ 108 | --decay_interval 10 \ 109 | --random_seed 1000 \ 110 | --snapshot ckpts \ 111 | --pretrained_path ckpts/Model_SwinT_AVA_size_480_epoch_10.pth \ 112 | --database_dir UHDIQA/challenge/training/ \ 113 | --model UIQA \ 114 | --multi_gpu True \ 115 | --print_samples 20 \ 116 | --database UHD_IQA \ 117 | >> logfiles/train_UIQA.log 118 | ``` 119 | 120 | ### Test UIQA 121 | Put your trained model in the ckpts folder, or download the provided trained model ([model weights](https://www.dropbox.com/scl/fi/mgvvt902zehhmo6drnxve/UIQA.pth?rlkey=413edq08c8qnxbrlgclnq0va6&st=yzcskkus&dl=0), [quality alignment profile file](https://www.dropbox.com/scl/fi/1st2jjga6ssirvsex5oo6/UIQA.npy?rlkey=6mbf2utiz1t3dlm5nl635cvmz&st=n0tvbqv9&dl=0)) on the UHD-IQA dataset into the ckpts folder. 122 | 123 | ``` 124 | CUDA_VISIBLE_DEVICES=0 python -u test_single_image.py \ 125 | --model_path ckpts/ \ 126 | --trained_model_file UIQA.pth \ 127 | --popt_file UIQA.npy \ 128 | --image_path demo/8.jpg \ 129 | --resize 512 \ 130 | --crop_size 480 \ 131 | --n_fragment 15 \ 132 | --salient_patch_dimension 480 \ 133 | --model UIQA 134 | ``` 135 | 136 | ## Citation 137 | **If you find this code is useful for your research, please cite**: 138 | 139 | ```latex 140 | @inproceedings{sun2025assessing, 141 | title={Assessing uhd image quality from aesthetics, distortions, and saliency}, 142 | author={Sun, Wei and Zhang, Weixia and Cao, Yuqin and Cao, Linhan and Jia, Jun and Chen, Zijian and Zhang, Zicheng and Min, Xiongkuo and Zhai, Guangtao}, 143 | booktitle={European Conference on Computer Vision}, 144 | pages={109--126}, 145 | year={2025} 146 | } 147 | ``` 148 | 149 | ## Acknowledgement 150 | 151 | 1. 152 | 2. 153 | 3. 154 | -------------------------------------------------------------------------------- /test_single_image.py: -------------------------------------------------------------------------------- 1 | import os, argparse 2 | import numpy as np 3 | import torch 4 | from torchvision import transforms 5 | 6 | import models.UIQA as UIQA 7 | from PIL import Image 8 | 9 | from scipy.optimize import curve_fit 10 | from scipy import stats 11 | import pandas as pd 12 | import random 13 | 14 | def logistic_func(X, bayta1, bayta2, bayta3, bayta4): 15 | logisticPart = 1 + np.exp(np.negative(np.divide(X - bayta3, np.abs(bayta4)))) 16 | yhat = bayta2 + np.divide(bayta1 - bayta2, logisticPart) 17 | return yhat 18 | 19 | def fit_function(y_label, y_output): 20 | beta = [np.max(y_label), np.min(y_label), np.mean(y_output), 0.5] 21 | popt, _ = curve_fit(logistic_func, y_output, \ 22 | y_label, p0=beta, maxfev=100000000) 23 | y_output_logistic = logistic_func(y_output, *popt) 24 | 25 | return y_output_logistic 26 | 27 | 28 | def performance_fit(y_label, y_output): 29 | y_output_logistic = fit_function(y_label, y_output) 30 | PLCC = stats.pearsonr(y_output_logistic, y_label)[0] 31 | SRCC = stats.spearmanr(y_output, y_label)[0] 32 | KRCC = stats.stats.kendalltau(y_output, y_label)[0] 33 | RMSE = np.sqrt(((y_output_logistic-y_label) ** 2).mean()) 34 | 35 | return PLCC, SRCC, KRCC, RMSE 36 | 37 | def get_spatial_fragments( 38 | video, 39 | fragments_h=7, 40 | fragments_w=7, 41 | fsize_h=32, 42 | fsize_w=32, 43 | aligned=32, 44 | nfrags=1, 45 | random=False, 46 | random_upsample=False, 47 | fallback_type="upsample", 48 | **kwargs, 49 | ): 50 | size_h = fragments_h * fsize_h 51 | size_w = fragments_w * fsize_w 52 | ## video: [C,T,H,W] 53 | ## situation for images 54 | if video.shape[1] == 1: 55 | aligned = 1 56 | 57 | dur_t, res_h, res_w = video.shape[-3:] 58 | ratio = min(res_h / size_h, res_w / size_w) 59 | if fallback_type == "upsample" and ratio < 1: 60 | 61 | ovideo = video 62 | video = torch.nn.functional.interpolate( 63 | video / 255.0, scale_factor=1 / ratio, mode="bilinear" 64 | ) 65 | video = (video * 255.0).type_as(ovideo) 66 | 67 | if random_upsample: 68 | 69 | randratio = random.random() * 0.5 + 1 70 | video = torch.nn.functional.interpolate( 71 | video / 255.0, scale_factor=randratio, mode="bilinear" 72 | ) 73 | video = (video * 255.0).type_as(ovideo) 74 | 75 | 76 | 77 | assert dur_t % aligned == 0, "Please provide match vclip and align index" 78 | size = size_h, size_w 79 | 80 | ## make sure that sampling will not run out of the picture 81 | hgrids = torch.LongTensor( 82 | [min(res_h // fragments_h * i, res_h - fsize_h) for i in range(fragments_h)] 83 | ) 84 | wgrids = torch.LongTensor( 85 | [min(res_w // fragments_w * i, res_w - fsize_w) for i in range(fragments_w)] 86 | ) 87 | hlength, wlength = res_h // fragments_h, res_w // fragments_w 88 | 89 | if random: 90 | print("This part is deprecated. Please remind that.") 91 | if res_h > fsize_h: 92 | rnd_h = torch.randint( 93 | res_h - fsize_h, (len(hgrids), len(wgrids), dur_t // aligned) 94 | ) 95 | else: 96 | rnd_h = torch.zeros((len(hgrids), len(wgrids), dur_t // aligned)).int() 97 | if res_w > fsize_w: 98 | rnd_w = torch.randint( 99 | res_w - fsize_w, (len(hgrids), len(wgrids), dur_t // aligned) 100 | ) 101 | else: 102 | rnd_w = torch.zeros((len(hgrids), len(wgrids), dur_t // aligned)).int() 103 | else: 104 | if hlength > fsize_h: 105 | rnd_h = torch.randint( 106 | hlength - fsize_h, (len(hgrids), len(wgrids), dur_t // aligned) 107 | ) 108 | else: 109 | rnd_h = torch.zeros((len(hgrids), len(wgrids), dur_t // aligned)).int() 110 | if wlength > fsize_w: 111 | rnd_w = torch.randint( 112 | wlength - fsize_w, (len(hgrids), len(wgrids), dur_t // aligned) 113 | ) 114 | else: 115 | rnd_w = torch.zeros((len(hgrids), len(wgrids), dur_t // aligned)).int() 116 | 117 | target_video = torch.zeros(video.shape[:-2] + size).to(video.device) 118 | # target_videos = [] 119 | 120 | for i, hs in enumerate(hgrids): 121 | for j, ws in enumerate(wgrids): 122 | for t in range(dur_t // aligned): 123 | t_s, t_e = t * aligned, (t + 1) * aligned 124 | h_s, h_e = i * fsize_h, (i + 1) * fsize_h 125 | w_s, w_e = j * fsize_w, (j + 1) * fsize_w 126 | if random: 127 | h_so, h_eo = rnd_h[i][j][t], rnd_h[i][j][t] + fsize_h 128 | w_so, w_eo = rnd_w[i][j][t], rnd_w[i][j][t] + fsize_w 129 | else: 130 | h_so, h_eo = hs + rnd_h[i][j][t], hs + rnd_h[i][j][t] + fsize_h 131 | w_so, w_eo = ws + rnd_w[i][j][t], ws + rnd_w[i][j][t] + fsize_w 132 | target_video[:, t_s:t_e, h_s:h_e, w_s:w_e] = video[ 133 | :, t_s:t_e, h_so:h_eo, w_so:w_eo 134 | ] 135 | # target_videos.append(video[:,t_s:t_e,h_so:h_eo,w_so:w_eo]) 136 | # target_video = torch.stack(target_videos, 0).reshape((dur_t // aligned, fragments, fragments,) + target_videos[0].shape).permute(3,0,4,1,5,2,6) 137 | # target_video = target_video.reshape((-1, dur_t,) + size) ## Splicing Fragments 138 | return target_video 139 | 140 | def parse_args(): 141 | """Parse input arguments. """ 142 | parser = argparse.ArgumentParser(description="Authentic Image Quality Assessment") 143 | parser.add_argument('--model_path', help='Path of model snapshot.', type=str) 144 | parser.add_argument('--trained_model_file', type=str) 145 | parser.add_argument('--popt_file', type=str) 146 | parser.add_argument('--model', type=str) 147 | parser.add_argument('--n_fragment', type=int, default=12) 148 | parser.add_argument('--image_path', type=str) 149 | parser.add_argument('--resize', type=int) 150 | parser.add_argument('--salient_patch_dimension', type=int, default=384) 151 | parser.add_argument('--crop_size', help='crop_size.',type=int) 152 | parser.add_argument('--gpu_ids', type=list, default=None) 153 | 154 | 155 | args = parser.parse_args() 156 | 157 | return args 158 | 159 | 160 | if __name__ == '__main__': 161 | 162 | random_seed = 2 163 | torch.manual_seed(random_seed) # 164 | torch.backends.cudnn.deterministic = True 165 | torch.backends.cudnn.benchmark = False 166 | np.random.seed(random_seed) 167 | random.seed(random_seed) 168 | 169 | args = parse_args() 170 | 171 | model_path = args.model_path 172 | popt_file = args.popt_file 173 | 174 | 175 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 176 | 177 | # load the network 178 | model_file = args.trained_model_file 179 | 180 | if args.model == 'UIQA': 181 | model = UIQA.UIQA_Model() 182 | 183 | # model = torch.nn.DataParallel(model) 184 | model = model.to(device) 185 | model.load_state_dict(torch.load(os.path.join(model_path, model_file))) 186 | popt = np.load(os.path.join(model_path, popt_file)) 187 | model.eval() 188 | 189 | 190 | transform_asethetics = transforms.Compose([transforms.Resize(args.resize), 191 | transforms.CenterCrop(args.crop_size), 192 | transforms.ToTensor(), 193 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 194 | 195 | 196 | 197 | 198 | transform_distortion = transforms.Compose([transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 199 | transform_distortion_preprocessing = transforms.Compose([transforms.ToTensor()]) 200 | 201 | 202 | transform_saliency = transforms.Compose([ 203 | transforms.CenterCrop(args.salient_patch_dimension), 204 | transforms.ToTensor(), 205 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 206 | ]) 207 | 208 | 209 | test_image = Image.open(os.path.join(args.image_path)) 210 | 211 | test_image = test_image.convert('RGB') 212 | test_image_aesthetics = transform_asethetics(test_image) 213 | test_image_saliency = transform_saliency(test_image) 214 | 215 | test_image_distortion = transform_distortion_preprocessing(test_image) 216 | test_image_distortion = test_image_distortion.unsqueeze(1) 217 | test_image_distortion = get_spatial_fragments( 218 | test_image_distortion, 219 | fragments_h=args.n_fragment, 220 | fragments_w=args.n_fragment, 221 | fsize_h=32, 222 | fsize_w=32, 223 | aligned=32, 224 | nfrags=1, 225 | random=False, 226 | random_upsample=False, 227 | fallback_type="upsample" 228 | ) 229 | test_image_distortion = test_image_distortion.squeeze(1) 230 | test_image_distortion = transform_distortion(test_image_distortion) 231 | 232 | test_image_aesthetics = test_image_aesthetics.unsqueeze(0) 233 | test_image_distortion = test_image_distortion.unsqueeze(0) 234 | test_image_saliency = test_image_saliency.unsqueeze(0) 235 | 236 | with torch.no_grad(): 237 | test_image_aesthetics = test_image_aesthetics.to(device) 238 | test_image_saliency = test_image_saliency.to(device) 239 | test_image_distortion = test_image_distortion.to(device) 240 | outputs = model(test_image_aesthetics, test_image_distortion, test_image_saliency) 241 | score = outputs.item() 242 | print('The quality of the test image is {:.4f}'.format(score)) -------------------------------------------------------------------------------- /train_AVA.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import time 4 | import numpy as np 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torchvision import transforms 9 | 10 | import torch.backends.cudnn as cudnn 11 | 12 | import IQADataset 13 | import models.UIQA as UIQA 14 | from utils import performance_fit 15 | from utils import Fidelity_Loss 16 | 17 | import random 18 | 19 | 20 | 21 | def parse_args(): 22 | """Parse input arguments. """ 23 | parser = argparse.ArgumentParser(description="Image Aesthetics Assessment") 24 | parser.add_argument('--gpu', help="GPU device id to use [0]", default=0, type=int) 25 | parser.add_argument('--num_epochs', help='Maximum number of training epochs.', default=30, type=int) 26 | parser.add_argument('--batch_size', help='Batch size.', default=40, type=int) 27 | parser.add_argument('--resize', help='resize.', type=int) 28 | parser.add_argument('--crop_size', help='crop_size.',type=int) 29 | parser.add_argument('--lr', type=float, default=0.00001) 30 | parser.add_argument('--lr_weight_L2', type=float, default=1) 31 | parser.add_argument('--lr_weight_pair', type=float, default=1) 32 | parser.add_argument('--decay_ratio', type=float, default=0.9) 33 | parser.add_argument('--decay_interval', type=float, default=10) 34 | parser.add_argument('--random_seed', type=int, default=0) 35 | parser.add_argument('--snapshot', default='', type=str) 36 | parser.add_argument('--results_path', type=str) 37 | parser.add_argument('--database_dir', type=str) 38 | parser.add_argument('--model', type=str) 39 | parser.add_argument('--multi_gpu', type=bool, default=False) 40 | parser.add_argument('--print_samples', type=int, default = 50) 41 | parser.add_argument('--database', type=str) 42 | 43 | 44 | args = parser.parse_args() 45 | 46 | return args 47 | 48 | 49 | if __name__ == '__main__': 50 | args = parse_args() 51 | 52 | 53 | 54 | 55 | 56 | 57 | torch.manual_seed(args.random_seed) # 58 | torch.backends.cudnn.deterministic = True 59 | torch.backends.cudnn.benchmark = False 60 | np.random.seed(args.random_seed) 61 | random.seed(args.random_seed) 62 | 63 | gpu = args.gpu 64 | cudnn.enabled = True 65 | num_epochs = args.num_epochs 66 | batch_size = args.batch_size 67 | lr = args.lr 68 | decay_interval = args.decay_interval 69 | decay_ratio = args.decay_ratio 70 | snapshot = args.snapshot 71 | database = args.database 72 | print_samples = args.print_samples 73 | results_path = args.results_path 74 | database_dir = args.database_dir 75 | resize = args.resize 76 | crop_size = args.crop_size 77 | 78 | 79 | 80 | 81 | 82 | seed = args.random_seed 83 | if not os.path.exists(snapshot): 84 | os.makedirs(snapshot) 85 | 86 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 87 | 88 | if database == 'AVA': 89 | filename_list = 'csvfiles/ground_truth_dataset.csv' 90 | 91 | # load the network 92 | if args.model == 'Model_SwinT': 93 | model = UIQA.Model_SwinT() 94 | 95 | 96 | transforms_train = transforms.Compose([transforms.Resize(resize), 97 | transforms.RandomCrop(crop_size), 98 | transforms.ToTensor(), 99 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 100 | transforms_test = transforms.Compose([transforms.Resize(resize), 101 | transforms.CenterCrop(crop_size), 102 | transforms.ToTensor(), 103 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 104 | 105 | 106 | 107 | 108 | if database == 'AVA': 109 | train_dataset = IQADataset.AVA_dataloader_pair(database_dir, 110 | filename_list, 111 | transforms_train, 112 | database+'_train', seed) 113 | test_dataset = IQADataset.AVA_dataloader(database_dir, 114 | filename_list, 115 | transforms_test, 116 | database+'_test', seed) 117 | 118 | 119 | 120 | train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 121 | batch_size=batch_size, 122 | shuffle=True, 123 | num_workers=8) 124 | test_loader = torch.utils.data.DataLoader(dataset=test_dataset, 125 | batch_size=1, 126 | shuffle=False, 127 | num_workers=8) 128 | 129 | 130 | if args.multi_gpu: 131 | model = torch.nn.DataParallel(model) 132 | model = model.to(device) 133 | else: 134 | model = model.to(device) 135 | 136 | criterion = Fidelity_Loss() 137 | criterion2 = nn.MSELoss().to(device) 138 | 139 | 140 | param_num = 0 141 | for param in model.parameters(): 142 | param_num += int(np.prod(param.shape)) 143 | print('Trainable params: %.2f million' % (param_num / 1e6)) 144 | 145 | 146 | optimizer = torch.optim.Adam(model.parameters(), 147 | lr=lr, 148 | weight_decay=0.0000001) 149 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 150 | step_size=decay_interval, 151 | gamma=decay_ratio) 152 | 153 | 154 | print("Ready to train network") 155 | 156 | best_test_criterion = -1 # SROCC min 157 | best = np.zeros(5) 158 | 159 | n_train = len(train_dataset) 160 | n_test = len(test_dataset) 161 | 162 | 163 | for epoch in range(num_epochs): 164 | # train 165 | model.train() 166 | 167 | batch_losses = [] 168 | batch_losses_each_disp = [] 169 | session_start_time = time.time() 170 | for i, (image, mos, image_second, mos_second) in enumerate(train_loader): 171 | image = image.to(device) 172 | mos = mos[:,np.newaxis] 173 | mos = mos.to(device) 174 | 175 | image_second = image_second.to(device) 176 | mos_second = mos_second[:,np.newaxis] 177 | mos_second = mos_second.to(device) 178 | 179 | mos_output = model(image) 180 | mos_output_second = model(image_second) 181 | mos_output_diff = mos_output- mos_output_second 182 | constant =torch.sqrt(torch.Tensor([2])).to(device) 183 | p_output = 0.5 * (1 + torch.erf(mos_output_diff / constant)) 184 | mos_diff = mos - mos_second 185 | p = 0.5 * (1 + torch.erf(mos_diff / constant)) 186 | optimizer.zero_grad() 187 | loss = args.lr_weight_pair*criterion(p_output, p.detach()) + \ 188 | args.lr_weight_L2*criterion2(mos_output, mos) + \ 189 | args.lr_weight_L2*criterion2(mos_output_second, mos_second) 190 | 191 | batch_losses.append(loss.item()) 192 | batch_losses_each_disp.append(loss.item()) 193 | 194 | loss.backward() 195 | optimizer.step() 196 | 197 | if (i+1) % print_samples == 0: 198 | session_end_time = time.time() 199 | avg_loss_epoch = sum(batch_losses_each_disp) / print_samples 200 | print('Epoch: {:d}/{:d} | Step: {:d}/{:d} | Training loss: {:.4f}'.format(epoch + 1, 201 | num_epochs, 202 | i + 1, 203 | len(train_dataset)//batch_size, 204 | avg_loss_epoch)) 205 | batch_losses_each_disp = [] 206 | print('CostTime: {:.4f}'.format(session_end_time - session_start_time)) 207 | session_start_time = time.time() 208 | 209 | avg_loss = sum(batch_losses) / (len(train_dataset) // batch_size) 210 | print('Epoch {:d} averaged training loss: {:.4f}'.format(epoch + 1, avg_loss)) 211 | 212 | scheduler.step() 213 | lr_current = scheduler.get_last_lr() 214 | print('The current learning rate is {:.06f}'.format(lr_current[0])) 215 | 216 | # Test 217 | model.eval() 218 | y_output = np.zeros(n_test) 219 | y_test = np.zeros(n_test) 220 | 221 | with torch.no_grad(): 222 | for i, (image, mos) in enumerate(test_loader): 223 | image = image.to(device) 224 | y_test[i] = mos.item() 225 | mos = mos.to(device) 226 | outputs = model(image) 227 | y_output[i] = outputs.item() 228 | 229 | test_PLCC, test_SRCC, test_KRCC, test_RMSE, test_MAE, popt = performance_fit(y_test, y_output) 230 | print("Test results: SROCC={:.4f}, KROCC={:.4f}, PLCC={:.4f}, RMSE={:.4f}, MAE={:.4f}".format(test_SRCC, test_KRCC, test_PLCC, test_RMSE, test_MAE)) 231 | 232 | if test_SRCC > best_test_criterion: 233 | if epoch > 0: 234 | if os.path.exists(old_save_name): 235 | os.remove(old_save_name) 236 | if os.path.exists(old_save_name_popt): 237 | os.remove(old_save_name_popt) 238 | 239 | save_model_name = os.path.join(args.snapshot, 240 | args.model + '_' + args.database + '_' + '_NR_' + 'epoch_%d_SRCC_%f.pth' % (epoch + 1, test_SRCC)) 241 | save_popt_name = os.path.join(args.snapshot, 242 | args.model + '_' + args.database + '_' + '_NR_' + 'epoch_%d_SRCC_%f.npy' % (epoch + 1, test_SRCC)) 243 | print("Update best model using best_val_criterion ") 244 | torch.save(model.module.state_dict(), save_model_name) 245 | np.save(save_popt_name, popt) 246 | old_save_name = save_model_name 247 | old_save_name_popt = save_popt_name 248 | best[0:5] = [test_SRCC, test_KRCC, test_PLCC, test_RMSE, test_MAE] 249 | best_popt = popt 250 | best_test_criterion = test_SRCC # update best val SROCC 251 | 252 | print("The best Test results: SROCC={:.4f}, KROCC={:.4f}, PLCC={:.4f}, RMSE={:.4f}, MAE={:.4f}".format(test_SRCC, test_KRCC, test_PLCC, test_RMSE, test_MAE)) 253 | 254 | print(database) 255 | print("The best Val results: SROCC={:.4f}, KROCC={:.4f}, PLCC={:.4f}, RMSE={:.4f}, MAE={:.4f}".format(best[0], best[1], best[2], best[3], best[4])) 256 | print('*************************************************************************************************************************') 257 | 258 | 259 | 260 | 261 | 262 | 263 | 264 | 265 | 266 | 267 | 268 | 269 | 270 | 271 | 272 | 273 | 274 | 275 | 276 | 277 | 278 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import time 4 | import numpy as np 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torchvision import transforms 9 | 10 | import torch.backends.cudnn as cudnn 11 | 12 | import IQADataset 13 | import models.UIQA as UIQA 14 | from utils import performance_fit 15 | from utils import Fidelity_Loss 16 | 17 | import random 18 | 19 | 20 | 21 | def parse_args(): 22 | """Parse input arguments. """ 23 | parser = argparse.ArgumentParser(description="UHD Image Quality Assessment") 24 | parser.add_argument('--gpu', help="GPU device id to use [0]", default=0, type=int) 25 | parser.add_argument('--n_fragment', type=int, default=12) 26 | parser.add_argument('--fragments_h', type=int, default=2) 27 | parser.add_argument('--fragments_w', type=int, default=4) 28 | parser.add_argument('--num_epochs', help='Maximum number of training epochs.', default=30, type=int) 29 | parser.add_argument('--batch_size', help='Batch size.', default=40, type=int) 30 | parser.add_argument('--resize', type=int) 31 | parser.add_argument('--salient_patch_dimension', type=int, default=480) 32 | parser.add_argument('--crop_size', type=int) 33 | parser.add_argument('--lr', type=float, default=0.00001) 34 | parser.add_argument('--lr_weight_L2', type=float, default=1) 35 | parser.add_argument('--lr_weight_pair', type=float, default=1) 36 | parser.add_argument('--decay_ratio', type=float, default=0.9) 37 | parser.add_argument('--decay_interval', type=float, default=10) 38 | parser.add_argument('--random_seed', type=int, default=0) 39 | parser.add_argument('--snapshot', default='', type=str) 40 | parser.add_argument('--pretrained_path', type=str, default=None) 41 | parser.add_argument('--results_path', type=str) 42 | parser.add_argument('--database_dir', type=str) 43 | parser.add_argument('--model', default='UIQA', type=str) 44 | parser.add_argument('--multi_gpu', type=bool, default=False) 45 | parser.add_argument('--print_samples', type=int, default = 50) 46 | parser.add_argument('--database', default='UHD_IQA', type=str) 47 | 48 | 49 | args = parser.parse_args() 50 | 51 | return args 52 | 53 | 54 | if __name__ == '__main__': 55 | args = parse_args() 56 | 57 | 58 | 59 | 60 | 61 | 62 | torch.manual_seed(args.random_seed) # 63 | torch.backends.cudnn.deterministic = True 64 | torch.backends.cudnn.benchmark = False 65 | np.random.seed(args.random_seed) 66 | random.seed(args.random_seed) 67 | 68 | gpu = args.gpu 69 | cudnn.enabled = True 70 | num_epochs = args.num_epochs 71 | batch_size = args.batch_size 72 | lr = args.lr 73 | decay_interval = args.decay_interval 74 | decay_ratio = args.decay_ratio 75 | snapshot = args.snapshot 76 | database = args.database 77 | print_samples = args.print_samples 78 | results_path = args.results_path 79 | database_dir = args.database_dir 80 | resize = args.resize 81 | crop_size = args.crop_size 82 | n_fragment = args.n_fragment 83 | fragments_h = args.fragments_h 84 | fragments_w = args.fragments_w 85 | salient_patch_dimension = args.salient_patch_dimension 86 | 87 | seed = args.random_seed 88 | if not os.path.exists(snapshot): 89 | os.makedirs(snapshot) 90 | 91 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 92 | 93 | if database == 'UHD_IQA': 94 | filename_list = 'csvfiles/uhd-iqa-training-metadata.csv' 95 | 96 | print(filename_list) 97 | 98 | # load the network 99 | if args.model == 'UIQA': 100 | model = UIQA.UIQA_Model(pretrained_path = args.pretrained_path) 101 | 102 | 103 | transforms_train = transforms.Compose([transforms.Resize(resize), 104 | transforms.RandomCrop(crop_size), 105 | transforms.ToTensor(), 106 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 107 | transforms_test = transforms.Compose([transforms.Resize(resize), 108 | transforms.CenterCrop(crop_size), 109 | transforms.ToTensor(), 110 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 111 | 112 | 113 | 114 | 115 | 116 | train_dataset = IQADataset.UIQA_dataloader_pair(database_dir, 117 | filename_list, 118 | transforms_train, 119 | database+'_train', 120 | n_fragment, 121 | salient_patch_dimension, 122 | seed) 123 | test_dataset = IQADataset.UIQA_dataloader(database_dir, 124 | filename_list, 125 | transforms_test, 126 | database+'_test', 127 | n_fragment, 128 | salient_patch_dimension, 129 | seed) 130 | 131 | 132 | 133 | 134 | train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 135 | batch_size=batch_size, 136 | shuffle=True, 137 | num_workers=8) 138 | test_loader = torch.utils.data.DataLoader(dataset=test_dataset, 139 | batch_size=1, 140 | shuffle=False, 141 | num_workers=8) 142 | 143 | 144 | if args.multi_gpu: 145 | model = torch.nn.DataParallel(model) 146 | model = model.to(device) 147 | else: 148 | model = model.to(device) 149 | 150 | 151 | criterion = Fidelity_Loss() 152 | criterion2 = nn.MSELoss().to(device) 153 | 154 | 155 | param_num = 0 156 | for param in model.parameters(): 157 | param_num += int(np.prod(param.shape)) 158 | print('Trainable params: %.2f million' % (param_num / 1e6)) 159 | 160 | 161 | optimizer = torch.optim.Adam(model.parameters(), 162 | lr=lr, 163 | weight_decay=0.0000001) 164 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 165 | step_size=decay_interval, 166 | gamma=decay_ratio) 167 | 168 | 169 | print("Ready to train network") 170 | 171 | best_test_criterion = -1 # SROCC min 172 | best = np.zeros(5) 173 | 174 | n_train = len(train_dataset) 175 | n_test = len(test_dataset) 176 | 177 | 178 | for epoch in range(num_epochs): 179 | # train 180 | model.train() 181 | 182 | batch_losses = [] 183 | batch_losses_each_disp = [] 184 | session_start_time = time.time() 185 | for i, data_train in enumerate(train_loader): 186 | img_aesthetics = data_train['img_aesthetics'].to(device) 187 | img_distortion = data_train['img_distortion'].to(device) 188 | img_saliency = data_train['img_saliency'].to(device) 189 | mos = data_train['y_label'][:,np.newaxis] 190 | mos = mos.to(device) 191 | 192 | img_second_aesthetics = data_train['img_second_aesthetics'].to(device) 193 | img_second_distortion = data_train['img_second_distortion'].to(device) 194 | img_second_saliency = data_train['img_second_saliency'].to(device) 195 | mos_second = data_train['y_label_second'][:,np.newaxis] 196 | mos_second = mos_second.to(device) 197 | 198 | 199 | mos_output = model(img_aesthetics, img_distortion, img_saliency) 200 | mos_output_second = model(img_second_aesthetics, img_second_distortion, img_second_saliency) 201 | 202 | mos_output_diff = mos_output- mos_output_second 203 | constant =torch.sqrt(torch.Tensor([2])).to(device) 204 | p_output = 0.5 * (1 + torch.erf(mos_output_diff / constant)) 205 | mos_diff = mos - mos_second 206 | p = 0.5 * (1 + torch.erf(mos_diff / constant)) 207 | 208 | optimizer.zero_grad() 209 | loss = args.lr_weight_pair*criterion(p_output, p.detach()) + \ 210 | args.lr_weight_L2*criterion2(mos_output, mos) + \ 211 | args.lr_weight_L2*criterion2(mos_output_second, mos_second) 212 | 213 | 214 | batch_losses.append(loss.item()) 215 | batch_losses_each_disp.append(loss.item()) 216 | 217 | loss.backward() 218 | optimizer.step() 219 | 220 | if (i+1) % print_samples == 0: 221 | session_end_time = time.time() 222 | avg_loss_epoch = sum(batch_losses_each_disp) / print_samples 223 | # print('Epoch: %d/%d | Step: %d/%d | Training loss: %.4f' % (epoch + 1, num_epochs, i + 1, len(train_dataset) // batch_size, avg_loss_epoch)) 224 | print('Epoch: {:d}/{:d} | Step: {:d}/{:d} | Training loss: {:.4f}'.format(epoch + 1, 225 | num_epochs, 226 | i + 1, 227 | len(train_dataset)//batch_size, 228 | avg_loss_epoch)) 229 | batch_losses_each_disp = [] 230 | print('CostTime: {:.4f}'.format(session_end_time - session_start_time)) 231 | session_start_time = time.time() 232 | 233 | avg_loss = sum(batch_losses) / (len(train_dataset) // batch_size) 234 | # print('Epoch %d averaged training loss: %.4f' % (epoch + 1, avg_loss)) 235 | print('Epoch {:d} averaged training loss: {:.4f}'.format(epoch + 1, avg_loss)) 236 | 237 | scheduler.step() 238 | lr_current = scheduler.get_last_lr() 239 | print('The current learning rate is {:.06f}'.format(lr_current[0])) 240 | 241 | # Test 242 | model.eval() 243 | y_output = np.zeros(n_test) 244 | y_test = np.zeros(n_test) 245 | 246 | with torch.no_grad(): 247 | for i, data_test in enumerate(test_loader): 248 | 249 | img_aesthetics = data_test['img_aesthetics'].to(device) 250 | img_distortion = data_test['img_distortion'].to(device) 251 | img_saliency = data_test['img_saliency'].to(device) 252 | y_test[i] = data_test['y_label'].item() 253 | mos = mos.to(device) 254 | outputs = model(img_aesthetics, img_distortion, img_saliency) 255 | y_output[i] = outputs.item() 256 | 257 | test_PLCC, test_SRCC, test_KRCC, test_RMSE, test_MAE, popt = performance_fit(y_test, y_output) 258 | print("Test results: SROCC={:.4f}, KROCC={:.4f}, PLCC={:.4f}, RMSE={:.4f}, MAE={:.4f}".format(test_SRCC, test_KRCC, test_PLCC, test_RMSE, test_MAE)) 259 | 260 | if test_SRCC + test_PLCC + test_KRCC - test_RMSE - test_MAE > best_test_criterion: 261 | if epoch > 0: 262 | if os.path.exists(old_save_name): 263 | os.remove(old_save_name) 264 | if os.path.exists(old_save_name_popt): 265 | os.remove(old_save_name_popt) 266 | 267 | save_model_name = os.path.join(args.snapshot, 268 | args.model + '_' + args.database + '_' + '_NR_' + 'epoch_%d_SRCC_%f.pth' % (epoch + 1, test_SRCC)) 269 | save_popt_name = os.path.join(args.snapshot, 270 | args.model + '_' + args.database + '_' + '_NR_' + 'epoch_%d_SRCC_%f.npy' % (epoch + 1, test_SRCC)) 271 | print("Update best model using best_val_criterion ") 272 | torch.save(model.module.state_dict(), save_model_name) 273 | np.save(save_popt_name, popt) 274 | old_save_name = save_model_name 275 | old_save_name_popt = save_popt_name 276 | best[0:5] = [test_SRCC, test_KRCC, test_PLCC, test_RMSE, test_MAE] 277 | best_popt = popt 278 | best_test_criterion = test_SRCC + test_PLCC + test_KRCC - test_RMSE - test_MAE # update best val SROCC 279 | 280 | print("The best Test results: SROCC={:.4f}, KROCC={:.4f}, PLCC={:.4f}, RMSE={:.4f}, MAE={:.4f}".format(test_SRCC, test_KRCC, test_PLCC, test_RMSE, test_MAE)) 281 | 282 | print(database) 283 | print("The best Val results: SROCC={:.4f}, KROCC={:.4f}, PLCC={:.4f}, RMSE={:.4f}, MAE={:.4f}".format(best[0], best[1], best[2], best[3], best[4])) 284 | print('*************************************************************************************************************************') 285 | 286 | -------------------------------------------------------------------------------- /IQADataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | 5 | import torch 6 | from torch.utils.data.dataset import Dataset 7 | 8 | from PIL import Image 9 | from PIL import ImageFile 10 | ImageFile.LOAD_TRUNCATED_IMAGES = True 11 | from torchvision import transforms 12 | import random 13 | import cv2 14 | 15 | 16 | 17 | 18 | def get_spatial_fragments( 19 | video, 20 | fragments_h=7, 21 | fragments_w=7, 22 | fsize_h=32, 23 | fsize_w=32, 24 | aligned=32, 25 | nfrags=1, 26 | random=False, 27 | random_upsample=False, 28 | fallback_type="upsample", 29 | **kwargs, 30 | ): 31 | size_h = fragments_h * fsize_h 32 | size_w = fragments_w * fsize_w 33 | ## video: [C,T,H,W] 34 | ## situation for images 35 | if video.shape[1] == 1: 36 | aligned = 1 37 | 38 | dur_t, res_h, res_w = video.shape[-3:] 39 | ratio = min(res_h / size_h, res_w / size_w) 40 | if fallback_type == "upsample" and ratio < 1: 41 | 42 | ovideo = video 43 | video = torch.nn.functional.interpolate( 44 | video / 255.0, scale_factor=1 / ratio, mode="bilinear" 45 | ) 46 | video = (video * 255.0).type_as(ovideo) 47 | 48 | if random_upsample: 49 | 50 | randratio = random.random() * 0.5 + 1 51 | video = torch.nn.functional.interpolate( 52 | video / 255.0, scale_factor=randratio, mode="bilinear" 53 | ) 54 | video = (video * 255.0).type_as(ovideo) 55 | 56 | 57 | 58 | assert dur_t % aligned == 0, "Please provide match vclip and align index" 59 | size = size_h, size_w 60 | 61 | ## make sure that sampling will not run out of the picture 62 | hgrids = torch.LongTensor( 63 | [min(res_h // fragments_h * i, res_h - fsize_h) for i in range(fragments_h)] 64 | ) 65 | wgrids = torch.LongTensor( 66 | [min(res_w // fragments_w * i, res_w - fsize_w) for i in range(fragments_w)] 67 | ) 68 | hlength, wlength = res_h // fragments_h, res_w // fragments_w 69 | 70 | if random: 71 | print("This part is deprecated. Please remind that.") 72 | if res_h > fsize_h: 73 | rnd_h = torch.randint( 74 | res_h - fsize_h, (len(hgrids), len(wgrids), dur_t // aligned) 75 | ) 76 | else: 77 | rnd_h = torch.zeros((len(hgrids), len(wgrids), dur_t // aligned)).int() 78 | if res_w > fsize_w: 79 | rnd_w = torch.randint( 80 | res_w - fsize_w, (len(hgrids), len(wgrids), dur_t // aligned) 81 | ) 82 | else: 83 | rnd_w = torch.zeros((len(hgrids), len(wgrids), dur_t // aligned)).int() 84 | else: 85 | if hlength > fsize_h: 86 | rnd_h = torch.randint( 87 | hlength - fsize_h, (len(hgrids), len(wgrids), dur_t // aligned) 88 | ) 89 | else: 90 | rnd_h = torch.zeros((len(hgrids), len(wgrids), dur_t // aligned)).int() 91 | if wlength > fsize_w: 92 | rnd_w = torch.randint( 93 | wlength - fsize_w, (len(hgrids), len(wgrids), dur_t // aligned) 94 | ) 95 | else: 96 | rnd_w = torch.zeros((len(hgrids), len(wgrids), dur_t // aligned)).int() 97 | 98 | target_video = torch.zeros(video.shape[:-2] + size).to(video.device) 99 | # target_videos = [] 100 | 101 | for i, hs in enumerate(hgrids): 102 | for j, ws in enumerate(wgrids): 103 | for t in range(dur_t // aligned): 104 | t_s, t_e = t * aligned, (t + 1) * aligned 105 | h_s, h_e = i * fsize_h, (i + 1) * fsize_h 106 | w_s, w_e = j * fsize_w, (j + 1) * fsize_w 107 | if random: 108 | h_so, h_eo = rnd_h[i][j][t], rnd_h[i][j][t] + fsize_h 109 | w_so, w_eo = rnd_w[i][j][t], rnd_w[i][j][t] + fsize_w 110 | else: 111 | h_so, h_eo = hs + rnd_h[i][j][t], hs + rnd_h[i][j][t] + fsize_h 112 | w_so, w_eo = ws + rnd_w[i][j][t], ws + rnd_w[i][j][t] + fsize_w 113 | target_video[:, t_s:t_e, h_s:h_e, w_s:w_e] = video[ 114 | :, t_s:t_e, h_so:h_eo, w_so:w_eo 115 | ] 116 | # target_videos.append(video[:,t_s:t_e,h_so:h_eo,w_so:w_eo]) 117 | # target_video = torch.stack(target_videos, 0).reshape((dur_t // aligned, fragments, fragments,) + target_videos[0].shape).permute(3,0,4,1,5,2,6) 118 | # target_video = target_video.reshape((-1, dur_t,) + size) ## Splicing Fragments 119 | return target_video 120 | 121 | 122 | 123 | class AVA_dataloader_pair(Dataset): 124 | def __init__(self, data_dir, csv_path, transform, database, seed): 125 | self.database = database 126 | 127 | 128 | tmp_df = pd.read_csv(csv_path) 129 | image_name = tmp_df['image_num'].to_list() 130 | image_score = np.zeros([len(image_name)]) 131 | for i_vote in range(1,11): 132 | image_score += i_vote * tmp_df['vote_'+str(i_vote)].to_numpy() 133 | 134 | n_images = len(image_name) 135 | random.seed(seed) 136 | np.random.seed(seed) 137 | index_rd = np.random.permutation(n_images) 138 | 139 | if 'train' in database: 140 | index_subset = index_rd[ : int(n_images * 0.8)] 141 | self.X_train = [str(image_name[i])+'.jpg' for i in index_subset] 142 | self.Y_train = [image_score[i] for i in index_subset] 143 | elif 'test' in database: 144 | index_subset = index_rd[int(n_images * 0.8) : ] 145 | self.X_train = [str(image_name[i])+'.jpg' for i in index_subset] 146 | self.Y_train = [image_score[i] for i in index_subset] 147 | else: 148 | raise ValueError(f"Unsupported subset database name: {database}") 149 | print(self.X_train) 150 | 151 | 152 | 153 | self.data_dir = data_dir 154 | self.transform = transform 155 | self.length = len(self.X_train) 156 | 157 | def __getitem__(self, index): 158 | 159 | index_second = random.randint(0, self.length - 1) 160 | if index == index_second: 161 | index_second = (index_second + 1) % self.length 162 | while self.Y_train[index] == self.Y_train[index_second]: 163 | index_second = random.randint(0, self.length - 1) 164 | if index == index_second: 165 | index_second = (index_second + 1) % self.length 166 | 167 | path = os.path.join(self.data_dir,self.X_train[index]) 168 | path_second = os.path.join(self.data_dir,self.X_train[index_second]) 169 | 170 | img = Image.open(path) 171 | img = img.convert('RGB') 172 | 173 | 174 | img_second = Image.open(path_second) 175 | img_second = img_second.convert('RGB') 176 | 177 | img_overall = self.transform(img) 178 | img_second_overall = self.transform(img_second) 179 | 180 | y_mos = self.Y_train[index] 181 | y_label = torch.FloatTensor(np.array(float(y_mos))) 182 | 183 | 184 | y_mos_second = self.Y_train[index_second] 185 | y_label_second = torch.FloatTensor(np.array(float(y_mos_second))) 186 | 187 | return img_overall, y_label, img_second_overall, y_label_second 188 | 189 | 190 | 191 | 192 | class AVA_dataloader(Dataset): 193 | def __init__(self, data_dir, csv_path, transform, database, seed): 194 | self.database = database 195 | 196 | 197 | tmp_df = pd.read_csv(csv_path) 198 | image_name = tmp_df['image_num'].to_list() 199 | image_score = np.zeros([len(image_name)]) 200 | for i_vote in range(1,11): 201 | image_score += i_vote * tmp_df['vote_'+str(i_vote)].to_numpy() 202 | 203 | n_images = len(image_name) 204 | random.seed(seed) 205 | np.random.seed(seed) 206 | index_rd = np.random.permutation(n_images) 207 | 208 | if 'train' in database: 209 | index_subset = index_rd[ : int(n_images * 0.8)] 210 | self.X_train = [str(image_name[i])+'.jpg' for i in index_subset] 211 | self.Y_train = [image_score[i] for i in index_subset] 212 | elif 'test' in database: 213 | index_subset = index_rd[int(n_images * 0.8) : ] 214 | self.X_train = [str(image_name[i])+'.jpg' for i in index_subset] 215 | self.Y_train = [image_score[i] for i in index_subset] 216 | else: 217 | raise ValueError(f"Unsupported subset database name: {database}") 218 | print(self.X_train) 219 | 220 | 221 | 222 | self.data_dir = data_dir 223 | self.transform = transform 224 | self.length = len(self.X_train) 225 | 226 | def __getitem__(self, index): 227 | 228 | path = os.path.join(self.data_dir,self.X_train[index]) 229 | 230 | img = Image.open(path) 231 | img = img.convert('RGB') 232 | 233 | img_overall = self.transform(img) 234 | 235 | y_mos = self.Y_train[index] 236 | y_label = torch.FloatTensor(np.array(float(y_mos))) 237 | 238 | return img_overall, y_label 239 | 240 | 241 | def __len__(self): 242 | return self.length 243 | 244 | class UIQA_dataloader_pair(Dataset): 245 | def __init__(self, data_dir, csv_path, transform, database, n_fragment=12, salient_patch_dimension=448, seed=0): 246 | self.database = database 247 | self.salient_patch_dimension = salient_patch_dimension 248 | self.n_fragment = n_fragment 249 | 250 | tmp_df = pd.read_csv(csv_path) 251 | image_name = tmp_df['image_name'].to_list() 252 | mos = tmp_df['quality_mos'].to_list() 253 | 254 | n_images = len(image_name) 255 | random.seed(seed) 256 | np.random.seed(seed) 257 | index_rd = np.random.permutation(n_images) 258 | 259 | if 'train' in database: 260 | index_subset = index_rd[ : int(n_images * 0.8)] 261 | self.X_train = [image_name[i] for i in index_subset] 262 | self.Y_train = [mos[i] for i in index_subset] 263 | elif 'test' in database: 264 | index_subset = index_rd[int(n_images * 0.8) : ] 265 | self.X_train = [image_name[i] for i in index_subset] 266 | self.Y_train = [mos[i] for i in index_subset] 267 | elif 'all' in database: 268 | index_subset = index_rd 269 | self.X_train = [image_name[i] for i in index_subset] 270 | self.Y_train = [mos[i] for i in index_subset] 271 | else: 272 | raise ValueError(f"Unsupported subset database name: {database}") 273 | print(self.X_train) 274 | 275 | 276 | 277 | 278 | 279 | 280 | 281 | 282 | 283 | self.data_dir = data_dir 284 | self.transform_aesthetics = transform 285 | self.transform_distortion = transforms.Compose([transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 286 | self.transform_distortion_preprocessing = transforms.Compose([transforms.ToTensor()]) 287 | self.transform_saliency = transforms.Compose([ 288 | transforms.CenterCrop(self.salient_patch_dimension), 289 | transforms.ToTensor(), 290 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 291 | ]) 292 | self.length = len(self.X_train) 293 | 294 | def __getitem__(self, index): 295 | 296 | index_second = random.randint(0, self.length - 1) 297 | if index == index_second: 298 | index_second = (index_second + 1) % self.length 299 | while self.Y_train[index] == self.Y_train[index_second]: 300 | index_second = random.randint(0, self.length - 1) 301 | if index == index_second: 302 | index_second = (index_second + 1) % self.length 303 | 304 | path = os.path.join(self.data_dir, self.X_train[index]) 305 | path_second = os.path.join(self.data_dir, self.X_train[index_second]) 306 | 307 | img = Image.open(path) 308 | img = img.convert('RGB') 309 | 310 | 311 | img_second = Image.open(path_second) 312 | img_second = img_second.convert('RGB') 313 | 314 | img_aesthetics = self.transform_aesthetics(img) 315 | img_second_aesthetics = self.transform_aesthetics(img_second) 316 | 317 | img_saliency = self.transform_saliency(img) 318 | img_second_saliency = self.transform_saliency(img_second) 319 | 320 | 321 | img_distortion = self.transform_distortion_preprocessing(img) 322 | img_second_distortion = self.transform_distortion_preprocessing(img_second) 323 | 324 | img_distortion = img_distortion.unsqueeze(1) 325 | img_second_distortion = img_second_distortion.unsqueeze(1) 326 | 327 | img_distortion = get_spatial_fragments( 328 | img_distortion, 329 | fragments_h=self.n_fragment, 330 | fragments_w=self.n_fragment, 331 | fsize_h=32, 332 | fsize_w=32, 333 | aligned=32, 334 | nfrags=1, 335 | random=False, 336 | random_upsample=False, 337 | fallback_type="upsample" 338 | ) 339 | img_second_distortion = get_spatial_fragments( 340 | img_second_distortion, 341 | fragments_h=self.n_fragment, 342 | fragments_w=self.n_fragment, 343 | fsize_h=32, 344 | fsize_w=32, 345 | aligned=32, 346 | nfrags=1, 347 | random=False, 348 | random_upsample=False, 349 | fallback_type="upsample" 350 | ) 351 | 352 | img_distortion = img_distortion.squeeze(1) 353 | img_second_distortion = img_second_distortion.squeeze(1) 354 | 355 | img_distortion = self.transform_distortion(img_distortion) 356 | img_second_distortion = self.transform_distortion(img_second_distortion) 357 | 358 | y_mos = self.Y_train[index] 359 | 360 | y_label = torch.FloatTensor(np.array(float(y_mos))) 361 | 362 | 363 | y_mos_second = self.Y_train[index_second] 364 | 365 | y_label_second = torch.FloatTensor(np.array(float(y_mos_second))) 366 | 367 | 368 | data = {'img_aesthetics': img_aesthetics, 369 | 'img_distortion': img_distortion, 370 | 'img_saliency': img_saliency, 371 | 'y_label': y_label, 372 | 'img_second_aesthetics': img_second_aesthetics, 373 | 'img_second_distortion': img_second_distortion, 374 | 'img_second_saliency': img_second_saliency, 375 | 'y_label_second': y_label_second} 376 | 377 | return data 378 | 379 | 380 | def __len__(self): 381 | return self.length 382 | 383 | 384 | class UIQA_dataloader(Dataset): 385 | def __init__(self, data_dir, csv_path, transform, database, n_fragment=12, salient_patch_dimension=448, seed=0): 386 | self.database = database 387 | self.salient_patch_dimension = salient_patch_dimension 388 | self.n_fragment = n_fragment 389 | 390 | 391 | tmp_df = pd.read_csv(csv_path) 392 | image_name = tmp_df['image_name'].to_list() 393 | mos = tmp_df['quality_mos'].to_list() 394 | 395 | n_images = len(image_name) 396 | random.seed(seed) 397 | np.random.seed(seed) 398 | index_rd = np.random.permutation(n_images) 399 | 400 | if 'train' in database: 401 | index_subset = index_rd[ : int(n_images * 0.8)] 402 | self.X_train = [image_name[i] for i in index_subset] 403 | self.Y_train = [mos[i] for i in index_subset] 404 | elif 'test' in database: 405 | index_subset = index_rd[int(n_images * 0.8) : ] 406 | self.X_train = [image_name[i] for i in index_subset] 407 | self.Y_train = [mos[i] for i in index_subset] 408 | elif 'all' in database: 409 | index_subset = index_rd 410 | self.X_train = [image_name[i] for i in index_subset] 411 | self.Y_train = [mos[i] for i in index_subset] 412 | else: 413 | raise ValueError(f"Unsupported subset database name: {database}") 414 | print(self.X_train) 415 | 416 | 417 | 418 | self.data_dir = data_dir 419 | self.transform_aesthetics = transform 420 | self.transform_distortion = transforms.Compose([transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 421 | self.transform_distortion_preprocessing = transforms.Compose([transforms.ToTensor()]) 422 | self.transform_saliency = transforms.Compose([ 423 | transforms.CenterCrop(self.salient_patch_dimension), 424 | transforms.ToTensor(), 425 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 426 | ]) 427 | self.length = len(self.X_train) 428 | 429 | def __getitem__(self, index): 430 | path = os.path.join(self.data_dir,self.X_train[index]) 431 | 432 | img = Image.open(path) 433 | img = img.convert('RGB') 434 | 435 | img_aesthetics = self.transform_aesthetics(img) 436 | img_saliency = self.transform_saliency(img) 437 | 438 | 439 | img_distortion = self.transform_distortion_preprocessing(img) 440 | img_distortion = img_distortion.unsqueeze(1) 441 | img_distortion = get_spatial_fragments( 442 | img_distortion, 443 | fragments_h=self.n_fragment, 444 | fragments_w=self.n_fragment, 445 | fsize_h=32, 446 | fsize_w=32, 447 | aligned=32, 448 | nfrags=1, 449 | random=False, 450 | random_upsample=False, 451 | fallback_type="upsample" 452 | ) 453 | img_distortion = img_distortion.squeeze(1) 454 | img_distortion = self.transform_distortion(img_distortion) 455 | 456 | y_mos = self.Y_train[index] 457 | 458 | y_label = torch.FloatTensor(np.array(float(y_mos))) 459 | 460 | data = {'img_aesthetics': img_aesthetics, 461 | 'img_distortion': img_distortion, 462 | 'img_saliency': img_saliency, 463 | 'y_label': y_label} 464 | 465 | return data 466 | 467 | 468 | def __len__(self): 469 | return self.length 470 | --------------------------------------------------------------------------------