├── models ├── __pycache__ │ ├── ResNet_demo.cpython-38.pyc │ └── stairIQA_resnet.cpython-38.pyc ├── stairIQA_mobilenet.py ├── stairIQA_resnet.py └── ResNet_staircase.py ├── train.sh ├── config.yaml ├── IQADataset.py ├── README.md ├── utils.py ├── test_staircase.py ├── test_staircase_ensemble.py ├── train_single_database.py └── train_imdt.py /models/__pycache__/ResNet_demo.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwei925/StairIQA/HEAD/models/__pycache__/ResNet_demo.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/stairIQA_resnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwei925/StairIQA/HEAD/models/__pycache__/stairIQA_resnet.cpython-38.pyc -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python -u train_single_database.py \ 2 | --num_epochs 100 \ 3 | --batch_size 30 \ 4 | --resize 384 \ 5 | --crop_size 320 \ 6 | --lr 0.00005 \ 7 | --decay_ratio 0.9 \ 8 | --decay_interval 10 \ 9 | --snapshot /data/sunwei_data/ModelFolder/StairIQA/ \ 10 | --database_dir /data/sunwei_data/BID/ImageDatabase/ImageDatabase/ \ 11 | --model stairIQA_resnet \ 12 | --multi_gpu False \ 13 | --print_samples 20 \ 14 | --database BID \ 15 | --test_method five \ 16 | >> logfiles/train_BID_stairIQA_resnet.log 17 | 18 | 19 | 20 | 21 | CUDA_VISIBLE_DEVICES=0 python -u train_imdt.py \ 22 | --num_epochs 3 \ 23 | --batch_size 30 \ 24 | --lr 0.00001 \ 25 | --decay_ratio 0.9 \ 26 | --decay_interval 1 \ 27 | --snapshot /data/sunwei_data/ModelFolder/StairIQA/ \ 28 | --model stairIQA_resnet \ 29 | --multi_gpu False \ 30 | --print_samples 100 \ 31 | --test_method five \ 32 | --results_path results \ 33 | --exp_id 0 \ 34 | >> logfiles/train_stairIQA_resnet_imdt_exp_id_0.log 35 | 36 | -------------------------------------------------------------------------------- /config.yaml: -------------------------------------------------------------------------------- 1 | BID: 2 | database_name: BID 3 | database_dir: /data/sunwei_data/BID/ImageDatabase/ImageDatabase/ 4 | filename_dir: csvfiles/ 5 | train_filename: BID_train 6 | test_filename: BID_test 7 | resize: 384 8 | crop_size: 320 9 | n_epochs: 8 10 | 11 | LIVE_challenge: 12 | database_name: LIVE_challenge 13 | database_dir: /data/sunwei_data/ChallengeDB_release/Images/ 14 | filename_dir: csvfiles/ 15 | train_filename: LIVE_challenge_train 16 | test_filename: LIVE_challenge_test 17 | resize: 384 18 | crop_size: 320 19 | n_epochs: 8 20 | 21 | 22 | Koniq10k: 23 | database_name: Koniq10k 24 | database_dir: /data/sunwei_data/koniq10k_1024x768/1024x768/ 25 | filename_dir: csvfiles/ 26 | train_filename: Koniq10k_train 27 | test_filename: Koniq10k_test 28 | resize: 384 29 | crop_size: 320 30 | n_epochs: 12 31 | 32 | SPAQ: 33 | database_name: SPAQ 34 | database_dir: /data/sunwei_data/SPAQ_resize/ 35 | filename_dir: csvfiles/ 36 | train_filename: SPQA_train 37 | test_filename: SPQA_test 38 | resize: 384 39 | crop_size: 320 40 | n_epochs: 12 41 | 42 | FLIVE: 43 | database_name: FLIVE 44 | database_dir: /data/sunwei_data/FLIVEDatabase_WP/ 45 | filename_dir: csvfiles/ 46 | train_filename: FLIVE_train 47 | test_filename: FLIVE_test 48 | resize: 340 49 | crop_size: 320 50 | n_epochs: 3 51 | 52 | FLIVE_patch: 53 | database_name: FLIVE_patch 54 | database_dir: /data/sunwei_data/FLIVEDatabase_WPPatch/ 55 | filename_dir: csvfiles/ 56 | train_filename: FLIVE_patch_train 57 | test_filename: FLIVE_patch_test 58 | resize: 256 59 | crop_size: 224 60 | n_epochs: 1 -------------------------------------------------------------------------------- /IQADataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | 5 | import torch 6 | from torch.utils.data.dataset import Dataset 7 | 8 | from PIL import Image 9 | 10 | 11 | class IQA_dataloader(Dataset): 12 | def __init__(self, data_dir, csv_path, transform, database): 13 | self.database = database 14 | if self.database == 'Koniq10k': 15 | column_names = ['image_name','c1','c2','c3','c4','c5','c_total','MOS','SD','MOS_zscore'] 16 | tmp_df = pd.read_csv(csv_path,header= 0, sep=',', names=column_names, index_col=False, encoding="utf-8-sig") 17 | self.X_train = tmp_df[['image_name']] 18 | self.Y_train = tmp_df['MOS_zscore'] 19 | 20 | elif self.database == 'FLIVE' or self.database == 'FLIVE_patch': 21 | column_names = ['name','mos'] 22 | tmp_df = pd.read_csv(csv_path,header= 0, sep=',', names=column_names, index_col=False, encoding="utf-8-sig") 23 | self.X_train = tmp_df[['name']] 24 | self.Y_train = tmp_df['mos'] 25 | 26 | elif self.database == 'LIVE_challenge': 27 | column_names = ['image','mos','std'] 28 | tmp_df = pd.read_csv(csv_path,header= 0, sep=',', names=column_names, index_col=False, encoding="utf-8-sig") 29 | self.X_train = tmp_df[['image']] 30 | self.Y_train = tmp_df['mos'] 31 | 32 | elif self.database == 'SPAQ': 33 | column_names = ['name','mos','brightness','colorfulness','contrast','noisiness','sharpness'] 34 | tmp_df = pd.read_csv(csv_path,header= 0, sep=',', names=column_names, index_col=False, encoding="utf-8-sig") 35 | self.X_train = tmp_df[['name']] 36 | self.Y_train = tmp_df['mos'] 37 | 38 | elif self.database == 'BID': 39 | column_names = ['name','mos'] 40 | tmp_df = pd.read_csv(csv_path,header= 0, sep=',', names=column_names, index_col=False, encoding="utf-8-sig") 41 | self.X_train = tmp_df[['name']] 42 | self.Y_train = tmp_df['mos'] 43 | 44 | self.data_dir = data_dir 45 | self.transform = transform 46 | self.length = len(self.X_train) 47 | 48 | def __getitem__(self, index): 49 | path = os.path.join(self.data_dir,self.X_train.iloc[index,0]) 50 | 51 | img = Image.open(path) 52 | img = img.convert('RGB') 53 | 54 | if self.transform is not None: 55 | img = self.transform(img) 56 | 57 | y_mos = self.Y_train.iloc[index] 58 | if self.database == 'BID': 59 | y_label = torch.FloatTensor(np.array(float(y_mos*20))) 60 | elif self.database == 'FLIVE' or self.database == 'FLIVE_patch': 61 | y_label = torch.FloatTensor(np.array(float(y_mos-50)*2)) 62 | else: 63 | y_label = torch.FloatTensor(np.array(float(y_mos))) 64 | 65 | return img, y_label 66 | 67 | 68 | def __len__(self): 69 | return self.length -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # StairIQA 2 | This is a repository for the models proposed in the paper "Blind Quality Assessment for in-the-Wild Images via Hierarchical Feature Fusion and Iterative Mixed Database Training" [JSTSP version](https://ieeexplore.ieee.org/abstract/document/10109108) [Arxiv version](https://arxiv.org/abs/2105.14550). 3 | 4 | ## Usage 5 | ### Download csv files 6 | The train and test split files can be download from [Google drive](https://drive.google.com/file/d/121evqfjcsUwb014sOhl0mq7gMmaPuzpu/view?usp=sharing) or [Baidu Yun](https://pan.baidu.com/s/17zOm49cxZzhSqQCcsDv4jQ) (提取码:y4be) 7 | 8 | ### Train 9 | 10 | Train on a single database (e.g. BID) 11 | ``` 12 | CUDA_VISIBLE_DEVICES=0 python -u train_single_database.py \ 13 | --num_epochs 100 \ 14 | --batch_size 30 \ 15 | --resize 384 \ 16 | --crop_size 320 \ 17 | --lr 0.00005 \ 18 | --decay_ratio 0.9 \ 19 | --decay_interval 10 \ 20 | --snapshot /data/sunwei_data/ModelFolder/StairIQA/ \ 21 | --database_dir /data/sunwei_data/BID/ImageDatabase/ImageDatabase/ \ 22 | --model stairIQA_resnet \ 23 | --multi_gpu False \ 24 | --print_samples 20 \ 25 | --database BID \ 26 | --test_method five \ 27 | >> logfiles/train_BID_stairIQA_resnet.log 28 | ``` 29 | 30 | Train on multiple databases 31 | ``` 32 | CUDA_VISIBLE_DEVICES=0 python -u train_imdt.py \ 33 | --num_epochs 3 \ 34 | --batch_size 30 \ 35 | --lr 0.00001 \ 36 | --decay_ratio 0.9 \ 37 | --decay_interval 1 \ 38 | --snapshot /data/sunwei_data/ModelFolder/StairIQA/ \ 39 | --model stairIQA_resnet \ 40 | --multi_gpu False \ 41 | --print_samples 100 \ 42 | --test_method five \ 43 | --results_path results \ 44 | --exp_id 0 \ 45 | >> logfiles/train_stairIQA_resnet_imdt_exp_id_0.log 46 | ``` 47 | 48 | The information of databases used in the train_imdt.py file can be edited in the config.yaml file. 49 | 50 | 51 | 52 | 53 | 54 | ### Test 55 | Download the trained model: 56 | 57 | Koniq10k: [Google drive](https://drive.google.com/file/d/1fKENSrGXao8po7R4yzK_YvxDxvzV4wTs/view?usp=sharing) 58 | 59 | SPAQ: [Google drive](https://drive.google.com/file/d/1Px-PJE-08BPCfhP_gB7B78Z4Sp7dAi-4/view?usp=sharing) 60 | 61 | BID: [Google drive](https://drive.google.com/file/d/1u6SfaXg1TaMDx7TmkKGfjOC163jmNg3J/view?usp=sharing) 62 | 63 | LIVE_challenge: [Google drive](https://drive.google.com/file/d/1da4Aoe-zGvVljuB1PKOi4bnW2kBsnrP3/view?usp=sharing) 64 | 65 | FLIVE: [Google drive](https://drive.google.com/file/d/14nV0R4AONnRD9EfgTD3uPVt_b1RgNZoN/view?usp=sharing) 66 | 67 | FLIVE_patch: [Google drive](https://drive.google.com/file/d/1s5UyerDfjvGxE34OvMPOO7zBm_UO3BpZ/view?usp=sharing) 68 | 69 | Test a image on the model where the regressor is trained on one dataset (i.e. Koniq10k): 70 | ``` 71 | CUDA_VISIBLE_DEVICES=0 python -u test_staircase.py \ 72 | --test_image_name image_name \ 73 | --model_path model_file \ 74 | --trained_database Koniq10k \ 75 | --test_method five \ 76 | --output_name output.txt 77 | ``` 78 | 79 | 80 | Test a image on the ensemble model: 81 | ``` 82 | CUDA_VISIBLE_DEVICES=1 python -u test_staircase_ensemble.py \ 83 | --test_image_name image_name \ 84 | --test_method five \ 85 | --output_name output.txt 86 | ``` 87 | 88 | ## Citation 89 | If you find this code is useful for your research, please cite: 90 | ``` 91 | @article{sun2023blind, 92 | title={Blind quality assessment for in-the-wild images via hierarchical feature fusion and iterative mixed database training}, 93 | author={Sun, Wei and Min, Xiongkuo and Tu, Danyang and Ma, Siwei and Zhai, Guangtao}, 94 | journal={IEEE Journal of Selected Topics in Signal Processing}, 95 | year={2023}, 96 | publisher={IEEE} 97 | } 98 | ``` 99 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.optimize import curve_fit 3 | from scipy import stats 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | def logistic_func(X, bayta1, bayta2, bayta3, bayta4): 8 | logisticPart = 1 + np.exp(np.negative(np.divide(X - bayta3, np.abs(bayta4)))) 9 | yhat = bayta2 + np.divide(bayta1 - bayta2, logisticPart) 10 | return yhat 11 | 12 | def fit_function(y_label, y_output): 13 | beta = [np.max(y_label), np.min(y_label), np.mean(y_output), 0.5] 14 | popt, _ = curve_fit(logistic_func, y_output, \ 15 | y_label, p0=beta, maxfev=100000000) 16 | y_output_logistic = logistic_func(y_output, *popt) 17 | 18 | return y_output_logistic 19 | 20 | 21 | def performance_fit(y_label, y_output): 22 | y_output_logistic = fit_function(y_label, y_output) 23 | PLCC = stats.pearsonr(y_output_logistic, y_label)[0] 24 | SRCC = stats.spearmanr(y_output, y_label)[0] 25 | KRCC = stats.stats.kendalltau(y_output, y_label)[0] 26 | RMSE = np.sqrt(((y_output_logistic-y_label) ** 2).mean()) 27 | 28 | return PLCC, SRCC, KRCC, RMSE 29 | 30 | 31 | def performance_no_fit(y_label, y_output): 32 | PLCC = stats.pearsonr(y_output, y_label)[0] 33 | SRCC = stats.spearmanr(y_output, y_label)[0] 34 | KRCC = stats.stats.kendalltau(y_output, y_label)[0] 35 | RMSE = np.sqrt(((y_label-y_label) ** 2).mean()) 36 | 37 | return PLCC, SRCC, KRCC, RMSE 38 | 39 | 40 | 41 | class L1RankLoss(torch.nn.Module): 42 | """ 43 | L1 loss + Rank loss 44 | """ 45 | 46 | def __init__(self, **kwargs): 47 | super(L1RankLoss, self).__init__() 48 | self.l1_w = kwargs.get("l1_w", 1) 49 | self.rank_w = kwargs.get("rank_w", 1) 50 | self.hard_thred = kwargs.get("hard_thred", 1) 51 | self.use_margin = kwargs.get("use_margin", False) 52 | 53 | def forward(self, preds, gts): 54 | preds = preds.view(-1) 55 | gts = gts.view(-1) 56 | # l1 loss 57 | l1_loss = F.l1_loss(preds, gts) * self.l1_w 58 | 59 | # simple rank 60 | n = len(preds) 61 | preds = preds.unsqueeze(0).repeat(n, 1) 62 | preds_t = preds.t() 63 | img_label = gts.unsqueeze(0).repeat(n, 1) 64 | img_label_t = img_label.t() 65 | masks = torch.sign(img_label - img_label_t) 66 | masks_hard = (torch.abs(img_label - img_label_t) < self.hard_thred) & (torch.abs(img_label - img_label_t) > 0) 67 | if self.use_margin: 68 | rank_loss = masks_hard * torch.relu(torch.abs(img_label - img_label_t) - masks * (preds - preds_t)) 69 | else: 70 | rank_loss = masks_hard * torch.relu(- masks * (preds - preds_t)) 71 | rank_loss = rank_loss.sum() / (masks_hard.sum() + 1e-08) 72 | loss_total = l1_loss + rank_loss * self.rank_w 73 | return loss_total 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | def plcc_loss(y_pred, y): 82 | sigma_hat, m_hat = torch.std_mean(y_pred, unbiased=False) 83 | y_pred = (y_pred - m_hat) / (sigma_hat + 1e-8) 84 | sigma, m = torch.std_mean(y, unbiased=False) 85 | y = (y - m) / (sigma + 1e-8) 86 | loss0 = torch.nn.functional.mse_loss(y_pred, y) / 4 87 | rho = torch.mean(y_pred * y) 88 | loss1 = torch.nn.functional.mse_loss(rho * y_pred, y) / 4 89 | return ((loss0 + loss1) / 2).float() 90 | 91 | def rank_loss(y_pred, y): 92 | ranking_loss = torch.nn.functional.relu( 93 | (y_pred - y_pred.t()) * torch.sign((y.t() - y)) 94 | ) 95 | scale = 1 + torch.max(ranking_loss) 96 | return ( 97 | torch.sum(ranking_loss) / y_pred.shape[0] / (y_pred.shape[0] - 1) / scale 98 | ).float() 99 | 100 | 101 | 102 | def plcc_rank_loss(y_output, y_label): 103 | plcc = plcc_loss(y_output, y_label) 104 | rank = rank_loss(y_output, y_label) 105 | return plcc + rank*0.3 -------------------------------------------------------------------------------- /models/stairIQA_mobilenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torchvision.models as models 4 | 5 | 6 | class mobilenet_v2(torch.nn.Module): 7 | def __init__(self): 8 | super(mobilenet_v2, self).__init__() 9 | mobilenet_features = nn.Sequential(*list(models.mobilenet_v2(weights='DEFAULT').children())[0]) 10 | 11 | self.feature_extraction_stem = torch.nn.Sequential() 12 | self.feature_extraction1 = torch.nn.Sequential() 13 | self.feature_extraction2 = torch.nn.Sequential() 14 | self.feature_extraction3 = torch.nn.Sequential() 15 | self.feature_extraction4 = torch.nn.Sequential() 16 | 17 | self.avg_pool = torch.nn.Sequential() 18 | 19 | for x in range(0,4): 20 | self.feature_extraction_stem.add_module(str(x), mobilenet_features[x]) 21 | 22 | for x in range(4,7): 23 | self.feature_extraction1.add_module(str(x), mobilenet_features[x]) 24 | 25 | for x in range(7,11): 26 | self.feature_extraction2.add_module(str(x), mobilenet_features[x]) 27 | 28 | for x in range(11,17): 29 | self.feature_extraction3.add_module(str(x), mobilenet_features[x]) 30 | 31 | 32 | for x in range(17,19): 33 | self.feature_extraction4.add_module(str(x), mobilenet_features[x]) 34 | 35 | 36 | self.hyper1_1 = self.hyper_structure1(24,32) 37 | self.hyper2_1 = self.hyper_structure1(32,64) 38 | self.hyper3_1 = self.hyper_structure1(64,160) 39 | self.hyper4_1 = self.hyper_structure2(160,1280) 40 | 41 | self.hyper2_2 = self.hyper_structure1(32,64) 42 | self.hyper3_2 = self.hyper_structure1(64,160) 43 | self.hyper4_2 = self.hyper_structure2(160,1280) 44 | 45 | self.hyper3_3 = self.hyper_structure1(64,160) 46 | self.hyper4_3 = self.hyper_structure2(160,1280) 47 | 48 | self.hyper4_4 = self.hyper_structure2(160,1280) 49 | 50 | self.quality = self.quality_regression(1280, 128, 1) 51 | 52 | def hyper_structure1(self,in_channels,out_channels): 53 | hyper_block = nn.Sequential( 54 | nn.Conv2d(in_channels,in_channels//4,kernel_size=1,stride=1, padding=0,bias=False), 55 | nn.Conv2d(in_channels//4,in_channels//4,kernel_size=3,stride=2, padding=1,bias=False), 56 | nn.Conv2d(in_channels//4,out_channels,kernel_size=1,stride=1, padding=0,bias=False), 57 | ) 58 | 59 | return hyper_block 60 | 61 | def hyper_structure2(self,in_channels,out_channels): 62 | hyper_block = nn.Sequential( 63 | nn.Conv2d(in_channels,in_channels//4,kernel_size=1,stride=1, padding=0,bias=False), 64 | nn.Conv2d(in_channels//4,in_channels//4,kernel_size=3,stride=1, padding=1,bias=False), 65 | nn.Conv2d(in_channels//4,out_channels,kernel_size=1,stride=1, padding=0,bias=False), 66 | ) 67 | 68 | return hyper_block 69 | 70 | def quality_regression(self,in_channels, middle_channels, out_channels): 71 | regression_block = nn.Sequential( 72 | nn.Linear(in_channels, middle_channels), 73 | nn.Linear(middle_channels, out_channels), 74 | ) 75 | 76 | return regression_block 77 | 78 | 79 | def forward(self, x): 80 | 81 | x = self.feature_extraction_stem(x) 82 | 83 | x_hyper1 = self.hyper1_1(x) 84 | x = self.feature_extraction1(x) 85 | 86 | 87 | x_hyper1 = self.hyper2_1(x_hyper1+x) 88 | x_hyper2 = self.hyper2_2(x) 89 | x = self.feature_extraction2(x) 90 | 91 | 92 | x_hyper1 = self.hyper3_1(x_hyper1+x) 93 | x_hyper2 = self.hyper3_2(x_hyper2+x) 94 | x_hyper3 = self.hyper3_3(x) 95 | x = self.feature_extraction3(x) 96 | 97 | 98 | x_hyper1 = self.hyper4_1(x_hyper1+x) 99 | x_hyper2 = self.hyper4_2(x_hyper2+x) 100 | x_hyper3 = self.hyper4_3(x_hyper3+x) 101 | x_hyper4 = self.hyper4_4(x) 102 | x = self.feature_extraction4(x) 103 | 104 | x = x+x_hyper1+x_hyper2+x_hyper3+x_hyper4 105 | 106 | x = nn.functional.adaptive_avg_pool2d(x, (1, 1)) 107 | 108 | x = torch.flatten(x, 1) 109 | 110 | x = self.quality(x) 111 | 112 | 113 | return x 114 | 115 | if __name__ == '__main__': 116 | 117 | model = mobilenet_v2() 118 | 119 | print(model) -------------------------------------------------------------------------------- /test_staircase.py: -------------------------------------------------------------------------------- 1 | import os, argparse 2 | 3 | import numpy as np 4 | 5 | import torch 6 | from torchvision import transforms 7 | 8 | import models.ResNet_staircase as ResNet_staircase 9 | from PIL import Image 10 | 11 | 12 | 13 | def parse_args(): 14 | """Parse input arguments. """ 15 | parser = argparse.ArgumentParser(description="Authentic Image Quality Assessment") 16 | parser.add_argument('--model_path', help='Path of model snapshot.', default='', type=str) 17 | parser.add_argument('--test_image_name', type=str) 18 | parser.add_argument('--trained_database', default='Koniq10k', type=str) 19 | parser.add_argument('--test_method', default='five', type=str, 20 | help='use the center crop or five crop to test the image (default: one)') 21 | parser.add_argument('--output_name', type=str) 22 | 23 | 24 | args = parser.parse_args() 25 | 26 | return args 27 | 28 | 29 | if __name__ == '__main__': 30 | args = parse_args() 31 | 32 | test_image_name = args.test_image_name 33 | model_path = args.model_path 34 | trained_database = args.trained_database 35 | 36 | output_name = args.output_name 37 | 38 | 39 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 40 | 41 | # load the network 42 | model = ResNet_staircase.resnet50(pretrained = False) 43 | model = torch.nn.DataParallel(model) 44 | model = model.to(device) 45 | model.load_state_dict(torch.load(model_path)) 46 | model.eval() 47 | 48 | 49 | if trained_database == 'FLIVE': 50 | if args.test_method == 'one': 51 | transformations_test = transforms.Compose([transforms.Resize(340),transforms.CenterCrop(320), \ 52 | transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 53 | elif args.test_method == 'five': 54 | transformations_test = transforms.Compose([transforms.Resize(340),transforms.FiveCrop(320), \ 55 | (lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])), \ 56 | (lambda crops: torch.stack([transforms.Normalize(mean=[0.485, 0.456, 0.406], \ 57 | std=[0.229, 0.224, 0.225])(crop) for crop in crops]))]) 58 | elif trained_database == 'FLIVE_patch': 59 | if args.test_method == 'one': 60 | transformations_test = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224), \ 61 | transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 62 | elif args.test_method == 'five': 63 | transformations_test = transforms.Compose([transforms.Resize(256),transforms.FiveCrop(224), \ 64 | (lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])), \ 65 | (lambda crops: torch.stack([transforms.Normalize(mean=[0.485, 0.456, 0.406], \ 66 | std=[0.229, 0.224, 0.225])(crop) for crop in crops]))]) 67 | else: 68 | if args.test_method == 'one': 69 | transformations_test = transforms.Compose([transforms.Resize(384),transforms.CenterCrop(320), \ 70 | transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 71 | elif args.test_method == 'five': 72 | transformations_test = transforms.Compose([transforms.Resize(384),transforms.FiveCrop(320), \ 73 | (lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])), \ 74 | (lambda crops: torch.stack([transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(crop) for crop in crops]))]) 75 | 76 | test_image = Image.open(test_image_name) 77 | test_image = test_image.convert('RGB') 78 | test_image = transformations_test(test_image) 79 | test_image = test_image.unsqueeze(0) 80 | 81 | 82 | 83 | with torch.no_grad(): 84 | if args.test_method == 'one': 85 | test_image = test_image.to(device) 86 | if trained_database == 'FLIVE': 87 | outputs,_,_,_,_,_ = model(test_image) 88 | elif trained_database == 'FLIVE_patch': 89 | _,outputs,_,_,_,_ = model(test_image) 90 | elif trained_database == 'LIVE_challenge': 91 | _,_,outputs,_,_,_ = model(test_image) 92 | elif trained_database == 'Koniq10k': 93 | _,_,_,outputs,_,_ = model(test_image) 94 | elif trained_database == 'SPAQ': 95 | _,_,_,_,outputs,_ = model(test_image) 96 | elif trained_database == 'BID': 97 | _,_,_,_,_,outputs = model(test_image) 98 | test_scores = outputs.item() 99 | print(test_image_name) 100 | print(test_scores) 101 | 102 | 103 | elif args.test_method == 'five': 104 | bs, ncrops, c, h, w = test_image.size() 105 | itest_imagemage = test_image.to(device) 106 | if trained_database == 'FLIVE': 107 | outputs,_,_,_,_,_ = model(test_image.view(-1, c, h, w)) 108 | elif trained_database == 'FLIVE_patch': 109 | _,outputs,_,_,_,_ = model(test_image.view(-1, c, h, w)) 110 | elif trained_database == 'LIVE_challenge': 111 | _,_,outputs,_,_,_ = model(test_image.view(-1, c, h, w)) 112 | elif trained_database == 'Koniq10k': 113 | _,_,_,outputs,_,_ = model(test_image.view(-1, c, h, w)) 114 | elif trained_database == 'SPAQ': 115 | _,_,_,_,outputs,_ = model(test_image.view(-1, c, h, w)) 116 | elif trained_database == 'BID': 117 | _,_,_,_,_,outputs = model(test_image.view(-1, c, h, w)) 118 | test_scores = outputs.view(bs, ncrops, -1).mean(1).item() 119 | print(test_image_name) 120 | print(test_scores) 121 | 122 | 123 | 124 | 125 | 126 | if not os.path.exists(output_name): 127 | os.system(r"touch {}".format(output_name)) 128 | 129 | f = open(output_name,'w') 130 | f.write(test_image_name) 131 | f.write(',') 132 | f.write(str(test_scores)) 133 | f.write('\n') 134 | 135 | f.close() -------------------------------------------------------------------------------- /test_staircase_ensemble.py: -------------------------------------------------------------------------------- 1 | import os, argparse 2 | 3 | import numpy as np 4 | 5 | import torch 6 | from torchvision import transforms 7 | 8 | import models.ResNet_staircase as ResNet_staircase 9 | from PIL import Image 10 | 11 | def logistic_func(X, bayta1, bayta2, bayta3, bayta4): 12 | logisticPart = 1 + np.exp(np.negative(np.divide(X - bayta3, np.abs(bayta4)))) 13 | yhat = bayta2 + np.divide(bayta1 - bayta2, logisticPart) 14 | return yhat 15 | 16 | 17 | 18 | 19 | def fit_function(y_output, popt): 20 | y_output_logistic = logistic_func(y_output, *popt) 21 | 22 | return y_output_logistic 23 | 24 | 25 | 26 | def parse_args(): 27 | """Parse input arguments. """ 28 | parser = argparse.ArgumentParser(description="Authentic Image Quality Assessment") 29 | parser.add_argument('--model_path', help='Path of model snapshot.', default='', type=str) 30 | parser.add_argument('--test_image_name', type=str) 31 | parser.add_argument('--test_method', default='five', type=str, 32 | help='use the center crop or five crop to test the image (default: one)') 33 | parser.add_argument('--output_name', type=str) 34 | 35 | 36 | args = parser.parse_args() 37 | 38 | return args 39 | 40 | 41 | if __name__ == '__main__': 42 | args = parse_args() 43 | 44 | test_image_name = args.test_image_name 45 | database = ['Koniq10k', 'SPAQ', 'LIVE_challenge', 'BID', 'FLIVE', 'FLIVE_patch'] 46 | 47 | popt_all = [[120.41629963, -28.56005564, 46.48938183, 34.6190837],\ 48 | [85.18902596, 11.6431685, 53.20173936, 17.09686183],\ 49 | [87.71193443, 13.06699313, 52.12460518, 20.14566219],\ 50 | [93.65937742, 0.51803345, 49.84010415, 28.07474279],\ 51 | [94.71514948, 21.6468321, 39.27372875, 15.37350998],\ 52 | [81.43059766, 23.51273452, 47.11992594, 9.2093784 ]] 53 | # model file 54 | model_path_all = ['ResNet_staircase_50-EXP1-Koniq10k.pkl',\ 55 | 'ResNet_staircase_50-EXP1-SPAQ.pkl',\ 56 | 'ResNet_staircase_50-EXP1-LIVE_challenge.pkl',\ 57 | 'ResNet_staircase_50-EXP1-BID.pkl',\ 58 | 'ResNet_staircase_50-EXP1-FLIVE.pkl',\ 59 | 'ResNet_staircase_50-EXP1-FLIVE_patch.pkl'] 60 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 61 | 62 | test_scores_all = np.zeros([6]) 63 | 64 | for i in range(6): 65 | model_path = model_path_all[i] 66 | popt = popt_all[i] 67 | 68 | output_name = args.output_name 69 | 70 | trained_database = database[i] 71 | 72 | # load the network 73 | model = ResNet_staircase.resnet50(pretrained = False) 74 | model = torch.nn.DataParallel(model) 75 | model = model.to(device) 76 | model.load_state_dict(torch.load(model_path)) 77 | model.eval() 78 | 79 | 80 | if trained_database == 'FLIVE': 81 | if args.test_method == 'one': 82 | transformations_test = transforms.Compose([transforms.Resize(340),transforms.CenterCrop(320), \ 83 | transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 84 | elif args.test_method == 'five': 85 | transformations_test = transforms.Compose([transforms.Resize(340),transforms.FiveCrop(320), \ 86 | (lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])), \ 87 | (lambda crops: torch.stack([transforms.Normalize(mean=[0.485, 0.456, 0.406], \ 88 | std=[0.229, 0.224, 0.225])(crop) for crop in crops]))]) 89 | elif trained_database == 'FLIVE_patch': 90 | if args.test_method == 'one': 91 | transformations_test = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224), \ 92 | transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 93 | elif args.test_method == 'five': 94 | transformations_test = transforms.Compose([transforms.Resize(256),transforms.FiveCrop(224), \ 95 | (lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])), \ 96 | (lambda crops: torch.stack([transforms.Normalize(mean=[0.485, 0.456, 0.406], \ 97 | std=[0.229, 0.224, 0.225])(crop) for crop in crops]))]) 98 | else: 99 | if args.test_method == 'one': 100 | transformations_test = transforms.Compose([transforms.Resize(384),transforms.CenterCrop(320), \ 101 | transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 102 | elif args.test_method == 'five': 103 | transformations_test = transforms.Compose([transforms.Resize(384),transforms.FiveCrop(320), \ 104 | (lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])), \ 105 | (lambda crops: torch.stack([transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(crop) for crop in crops]))]) 106 | 107 | test_image = Image.open(test_image_name) 108 | test_image = test_image.convert('RGB') 109 | test_image = transformations_test(test_image) 110 | test_image = test_image.unsqueeze(0) 111 | 112 | 113 | 114 | with torch.no_grad(): 115 | if args.test_method == 'one': 116 | test_image = test_image.to(device) 117 | if trained_database == 'FLIVE': 118 | outputs,_,_,_,_,_ = model(test_image) 119 | elif trained_database == 'FLIVE_patch': 120 | _,outputs,_,_,_,_ = model(test_image) 121 | elif trained_database == 'LIVE_challenge': 122 | _,_,outputs,_,_,_ = model(test_image) 123 | elif trained_database == 'Koniq10k': 124 | _,_,_,outputs,_,_ = model(test_image) 125 | elif trained_database == 'SPAQ': 126 | _,_,_,_,outputs,_ = model(test_image) 127 | elif trained_database == 'BID': 128 | _,_,_,_,_,outputs = model(test_image) 129 | test_scores = outputs.item() 130 | test_scores = fit_function(test_scores, popt) 131 | test_scores_all[i] = test_scores 132 | 133 | 134 | elif args.test_method == 'five': 135 | bs, ncrops, c, h, w = test_image.size() 136 | itest_imagemage = test_image.to(device) 137 | if trained_database == 'FLIVE': 138 | outputs,_,_,_,_,_ = model(test_image.view(-1, c, h, w)) 139 | elif trained_database == 'FLIVE_patch': 140 | _,outputs,_,_,_,_ = model(test_image.view(-1, c, h, w)) 141 | elif trained_database == 'LIVE_challenge': 142 | _,_,outputs,_,_,_ = model(test_image.view(-1, c, h, w)) 143 | elif trained_database == 'Koniq10k': 144 | _,_,_,outputs,_,_ = model(test_image.view(-1, c, h, w)) 145 | elif trained_database == 'SPAQ': 146 | _,_,_,_,outputs,_ = model(test_image.view(-1, c, h, w)) 147 | elif trained_database == 'BID': 148 | _,_,_,_,_,outputs = model(test_image.view(-1, c, h, w)) 149 | test_scores = outputs.view(bs, ncrops, -1).mean(1).item() 150 | test_scores = fit_function(test_scores, popt) 151 | test_scores_all[i] = test_scores 152 | 153 | 154 | 155 | test_scores = np.mean(test_scores_all) 156 | print(test_image_name) 157 | print(test_scores) 158 | 159 | if not os.path.exists(output_name): 160 | os.system(r"touch {}".format(output_name)) 161 | 162 | f = open(output_name,'w') 163 | f.write(test_image_name) 164 | f.write(',') 165 | f.write(str(test_scores)) 166 | f.write('\n') 167 | 168 | f.close() 169 | -------------------------------------------------------------------------------- /models/stairIQA_resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torchvision.models as models 4 | 5 | 6 | class resnet50(torch.nn.Module): 7 | def __init__(self, pretrained = True): 8 | super(resnet50, self).__init__() 9 | if pretrained == True: 10 | resnet50_features = nn.Sequential(*list(models.resnet50(weights='DEFAULT').children())) 11 | else: 12 | resnet50_features = nn.Sequential(*list(models.resnet50().children())) 13 | 14 | self.feature_extraction_stem = torch.nn.Sequential() 15 | self.feature_extraction1 = torch.nn.Sequential() 16 | self.feature_extraction2 = torch.nn.Sequential() 17 | self.feature_extraction3 = torch.nn.Sequential() 18 | self.feature_extraction4 = torch.nn.Sequential() 19 | 20 | self.avg_pool = torch.nn.Sequential() 21 | 22 | for x in range(0,4): 23 | self.feature_extraction_stem.add_module(str(x), resnet50_features[x]) 24 | 25 | for x in range(4,5): 26 | self.feature_extraction1.add_module(str(x), resnet50_features[x]) 27 | 28 | for x in range(5,6): 29 | self.feature_extraction2.add_module(str(x), resnet50_features[x]) 30 | 31 | for x in range(6,7): 32 | self.feature_extraction3.add_module(str(x), resnet50_features[x]) 33 | 34 | for x in range(7,8): 35 | self.feature_extraction4.add_module(str(x), resnet50_features[x]) 36 | 37 | 38 | self.hyper1_1 = self.hyper_structure1(64,256) 39 | self.hyper2_1 = self.hyper_structure2(256,512) 40 | self.hyper3_1 = self.hyper_structure2(512,1024) 41 | self.hyper4_1 = self.hyper_structure2(1024,2048) 42 | 43 | self.hyper2_2 = self.hyper_structure2(256,512) 44 | self.hyper3_2 = self.hyper_structure2(512,1024) 45 | self.hyper4_2 = self.hyper_structure2(1024,2048) 46 | 47 | self.hyper3_3 = self.hyper_structure2(512,1024) 48 | self.hyper4_3 = self.hyper_structure2(1024,2048) 49 | 50 | self.hyper4_4 = self.hyper_structure2(1024,2048) 51 | 52 | 53 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 54 | 55 | self.quality = self.quality_regression(2048, 128, 1) 56 | 57 | def hyper_structure1(self,in_channels,out_channels): 58 | 59 | hyper_block = nn.Sequential( 60 | nn.Conv2d(in_channels,in_channels//4,kernel_size=1,stride=1, padding=0,bias=False), 61 | nn.Conv2d(in_channels//4,in_channels//4,kernel_size=3,stride=1, padding=1,bias=False), 62 | nn.Conv2d(in_channels//4,out_channels,kernel_size=1,stride=1, padding=0,bias=False), 63 | ) 64 | 65 | return hyper_block 66 | 67 | def hyper_structure2(self,in_channels,out_channels): 68 | hyper_block = nn.Sequential( 69 | nn.Conv2d(in_channels,in_channels//4,kernel_size=1,stride=1, padding=0,bias=False), 70 | nn.Conv2d(in_channels//4,in_channels//4,kernel_size=3,stride=2, padding=1,bias=False), 71 | nn.Conv2d(in_channels//4,out_channels,kernel_size=1,stride=1, padding=0,bias=False), 72 | ) 73 | 74 | return hyper_block 75 | 76 | def quality_regression(self,in_channels, middle_channels, out_channels): 77 | regression_block = nn.Sequential( 78 | nn.Linear(in_channels, middle_channels), 79 | nn.Linear(middle_channels, out_channels), 80 | ) 81 | 82 | return regression_block 83 | 84 | 85 | def forward(self, x): 86 | 87 | x = self.feature_extraction_stem(x) 88 | 89 | x_hyper1 = self.hyper1_1(x) 90 | x = self.feature_extraction1(x) 91 | 92 | 93 | x_hyper1 = self.hyper2_1(x_hyper1+x) 94 | x_hyper2 = self.hyper2_2(x) 95 | x = self.feature_extraction2(x) 96 | 97 | 98 | x_hyper1 = self.hyper3_1(x_hyper1+x) 99 | x_hyper2 = self.hyper3_2(x_hyper2+x) 100 | x_hyper3 = self.hyper3_3(x) 101 | x = self.feature_extraction3(x) 102 | 103 | 104 | x_hyper1 = self.hyper4_1(x_hyper1+x) 105 | x_hyper2 = self.hyper4_2(x_hyper2+x) 106 | x_hyper3 = self.hyper4_3(x_hyper3+x) 107 | x_hyper4 = self.hyper4_4(x) 108 | x = self.feature_extraction4(x) 109 | 110 | x = x+x_hyper1+x_hyper2+x_hyper3+x_hyper4 111 | 112 | x = self.avgpool(x) 113 | 114 | x = torch.flatten(x, 1) 115 | 116 | x = self.quality(x) 117 | 118 | 119 | return x 120 | 121 | 122 | 123 | class resnet50_imdt(torch.nn.Module): 124 | def __init__(self, pretrained = True): 125 | super(resnet50_imdt, self).__init__() 126 | if pretrained == True: 127 | resnet50_features = nn.Sequential(*list(models.resnet50(weights='DEFAULT').children())) 128 | else: 129 | resnet50_features = nn.Sequential(*list(models.resnet50().children())) 130 | 131 | self.feature_extraction_stem = torch.nn.Sequential() 132 | self.feature_extraction1 = torch.nn.Sequential() 133 | self.feature_extraction2 = torch.nn.Sequential() 134 | self.feature_extraction3 = torch.nn.Sequential() 135 | self.feature_extraction4 = torch.nn.Sequential() 136 | 137 | self.avg_pool = torch.nn.Sequential() 138 | 139 | for x in range(0,4): 140 | self.feature_extraction_stem.add_module(str(x), resnet50_features[x]) 141 | 142 | for x in range(4,5): 143 | self.feature_extraction1.add_module(str(x), resnet50_features[x]) 144 | 145 | for x in range(5,6): 146 | self.feature_extraction2.add_module(str(x), resnet50_features[x]) 147 | 148 | for x in range(6,7): 149 | self.feature_extraction3.add_module(str(x), resnet50_features[x]) 150 | 151 | for x in range(7,8): 152 | self.feature_extraction4.add_module(str(x), resnet50_features[x]) 153 | 154 | 155 | self.hyper1_1 = self.hyper_structure1(64,256) 156 | self.hyper2_1 = self.hyper_structure2(256,512) 157 | self.hyper3_1 = self.hyper_structure2(512,1024) 158 | self.hyper4_1 = self.hyper_structure2(1024,2048) 159 | 160 | self.hyper2_2 = self.hyper_structure2(256,512) 161 | self.hyper3_2 = self.hyper_structure2(512,1024) 162 | self.hyper4_2 = self.hyper_structure2(1024,2048) 163 | 164 | self.hyper3_3 = self.hyper_structure2(512,1024) 165 | self.hyper4_3 = self.hyper_structure2(1024,2048) 166 | 167 | self.hyper4_4 = self.hyper_structure2(1024,2048) 168 | 169 | 170 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 171 | 172 | self.quality1 = self.quality_regression(2048, 128, 1) 173 | self.quality2 = self.quality_regression(2048, 128, 1) 174 | self.quality3 = self.quality_regression(2048, 128, 1) 175 | self.quality4 = self.quality_regression(2048, 128, 1) 176 | self.quality5 = self.quality_regression(2048, 128, 1) 177 | self.quality6 = self.quality_regression(2048, 128, 1) 178 | 179 | def hyper_structure1(self,in_channels,out_channels): 180 | 181 | hyper_block = nn.Sequential( 182 | nn.Conv2d(in_channels,in_channels//4,kernel_size=1,stride=1, padding=0,bias=False), 183 | nn.Conv2d(in_channels//4,in_channels//4,kernel_size=3,stride=1, padding=1,bias=False), 184 | nn.Conv2d(in_channels//4,out_channels,kernel_size=1,stride=1, padding=0,bias=False), 185 | ) 186 | 187 | return hyper_block 188 | 189 | def hyper_structure2(self,in_channels,out_channels): 190 | hyper_block = nn.Sequential( 191 | nn.Conv2d(in_channels,in_channels//4,kernel_size=1,stride=1, padding=0,bias=False), 192 | nn.Conv2d(in_channels//4,in_channels//4,kernel_size=3,stride=2, padding=1,bias=False), 193 | nn.Conv2d(in_channels//4,out_channels,kernel_size=1,stride=1, padding=0,bias=False), 194 | ) 195 | 196 | return hyper_block 197 | 198 | def quality_regression(self,in_channels, middle_channels, out_channels): 199 | regression_block = nn.Sequential( 200 | nn.Linear(in_channels, middle_channels), 201 | nn.Linear(middle_channels, out_channels), 202 | ) 203 | 204 | return regression_block 205 | 206 | 207 | def forward(self, x): 208 | 209 | x = self.feature_extraction_stem(x) 210 | 211 | x_hyper1 = self.hyper1_1(x) 212 | x = self.feature_extraction1(x) 213 | 214 | 215 | x_hyper1 = self.hyper2_1(x_hyper1+x) 216 | x_hyper2 = self.hyper2_2(x) 217 | x = self.feature_extraction2(x) 218 | 219 | 220 | x_hyper1 = self.hyper3_1(x_hyper1+x) 221 | x_hyper2 = self.hyper3_2(x_hyper2+x) 222 | x_hyper3 = self.hyper3_3(x) 223 | x = self.feature_extraction3(x) 224 | 225 | 226 | x_hyper1 = self.hyper4_1(x_hyper1+x) 227 | x_hyper2 = self.hyper4_2(x_hyper2+x) 228 | x_hyper3 = self.hyper4_3(x_hyper3+x) 229 | x_hyper4 = self.hyper4_4(x) 230 | x = self.feature_extraction4(x) 231 | 232 | x = x+x_hyper1+x_hyper2+x_hyper3+x_hyper4 233 | 234 | x = self.avgpool(x) 235 | 236 | x = torch.flatten(x, 1) 237 | 238 | x1 = self.quality1(x) 239 | x2 = self.quality2(x) 240 | x3 = self.quality3(x) 241 | x4 = self.quality4(x) 242 | x5 = self.quality5(x) 243 | x6 = self.quality6(x) 244 | 245 | 246 | return x1, x2, x3, x4, x5, x6 247 | 248 | if __name__ == '__main__': 249 | 250 | model = resnet50() 251 | 252 | print(model) -------------------------------------------------------------------------------- /train_single_database.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import time 4 | import numpy as np 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torchvision import transforms 9 | 10 | import torch.backends.cudnn as cudnn 11 | 12 | import IQADataset 13 | import models.stairIQA_resnet as stairIQA_resnet 14 | from utils import performance_fit 15 | 16 | 17 | 18 | 19 | 20 | def parse_args(): 21 | """Parse input arguments. """ 22 | parser = argparse.ArgumentParser(description="In the wild Image Quality Assessment") 23 | parser.add_argument('--gpu', help="GPU device id to use [0]", default=0, type=int) 24 | parser.add_argument('--num_epochs', help='Maximum number of training epochs.', default=30, type=int) 25 | parser.add_argument('--batch_size', help='Batch size.', default=40, type=int) 26 | parser.add_argument('--resize', help='resize.', type=int) 27 | parser.add_argument('--crop_size', help='crop_size.',type=int) 28 | parser.add_argument('--lr', type=float, default=0.00001) 29 | parser.add_argument('--decay_ratio', type=float, default=0.9) 30 | parser.add_argument('--decay_interval', type=float, default=10) 31 | parser.add_argument('--snapshot', help='Path of model snapshot.', default='', type=str) 32 | parser.add_argument('--results_path', type=str) 33 | parser.add_argument('--database_dir', type=str) 34 | parser.add_argument('--model', default='ResNet', type=str) 35 | parser.add_argument('--multi_gpu', type=bool, default=False) 36 | parser.add_argument('--print_samples', type=int, default = 50) 37 | parser.add_argument('--database', default='FLIVE', type=str) 38 | parser.add_argument('--test_method', default='five', type=str, 39 | help='use the center crop or five crop to test the image') 40 | 41 | 42 | args = parser.parse_args() 43 | 44 | return args 45 | 46 | 47 | if __name__ == '__main__': 48 | args = parse_args() 49 | 50 | 51 | gpu = args.gpu 52 | cudnn.enabled = True 53 | num_epochs = args.num_epochs 54 | batch_size = args.batch_size 55 | lr = args.lr 56 | decay_interval = args.decay_interval 57 | decay_ratio = args.decay_ratio 58 | snapshot = args.snapshot 59 | database = args.database 60 | print_samples = args.print_samples 61 | results_path = args.results_path 62 | database_dir = args.database_dir 63 | resize = args.resize 64 | crop_size = args.crop_size 65 | 66 | 67 | best_all = np.zeros([10, 4]) 68 | for exp_id in range(10): 69 | 70 | print('The current exp_id is ' + str(exp_id)) 71 | if not os.path.exists(snapshot): 72 | os.makedirs(snapshot) 73 | trained_model_file = os.path.join(snapshot, 'train-ind-{}-{}-exp_id-{}.pkl'.format(database, args.model, exp_id)) 74 | 75 | print('The save model name is ' + trained_model_file) 76 | 77 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 78 | if torch.cuda.is_available(): 79 | os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu) 80 | 81 | if database == 'Koniq10k': 82 | train_filename_list = 'csvfiles/Koniq10k_train_'+str(exp_id)+'.csv' 83 | test_filename_list = 'csvfiles/Koniq10k_test_'+str(exp_id)+'.csv' 84 | elif database == 'FLIVE': 85 | train_filename_list = 'csvfiles/FLIVE_train_'+str(exp_id)+'.csv' 86 | test_filename_list = 'csvfiles/FLIVE_test_'+str(exp_id)+'.csv' 87 | elif database == 'FLIVE_patch': 88 | train_filename_list = 'csvfiles/FLIVE_patch_train_'+str(exp_id)+'.csv' 89 | test_filename_list = 'csvfiles/FLIVE_patch_test_'+str(exp_id)+'.csv' 90 | elif database == 'LIVE_challenge': 91 | train_filename_list = 'csvfiles/LIVE_challenge_train_'+str(exp_id)+'.csv' 92 | test_filename_list = 'csvfiles/LIVE_challenge_test_'+str(exp_id)+'.csv' 93 | elif database == 'SPAQ': 94 | train_filename_list = 'csvfiles/SPQA_train_'+str(exp_id)+'.csv' 95 | test_filename_list = 'csvfiles/SPQA_test_'+str(exp_id)+'.csv' 96 | elif database == 'BID': 97 | train_filename_list = 'csvfiles/BID_train_'+str(exp_id)+'.csv' 98 | test_filename_list = 'csvfiles/BID_test_'+str(exp_id)+'.csv' 99 | 100 | 101 | print(train_filename_list) 102 | print(test_filename_list) 103 | 104 | # load the network 105 | if args.model == 'stairIQA_resnet': 106 | model = stairIQA_resnet.resnet50(pretrained = True) 107 | 108 | 109 | 110 | transformations_train = transforms.Compose([transforms.Resize(resize),transforms.RandomCrop(crop_size), \ 111 | transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 112 | if args.test_method == 'one': 113 | transformations_test = transforms.Compose([transforms.Resize(resize),transforms.CenterCrop(crop_size), \ 114 | transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 115 | elif args.test_method == 'five': 116 | transformations_test = transforms.Compose([transforms.Resize(resize),transforms.FiveCrop(crop_size), \ 117 | (lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])), \ 118 | (lambda crops: torch.stack([transforms.Normalize(mean=[0.485, 0.456, 0.406], \ 119 | std=[0.229, 0.224, 0.225])(crop) for crop in crops]))]) 120 | 121 | 122 | 123 | train_dataset = IQADataset.IQA_dataloader(database_dir, train_filename_list, transformations_train, database) 124 | test_dataset = IQADataset.IQA_dataloader(database_dir, test_filename_list, transformations_test, database) 125 | 126 | 127 | 128 | train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=8) 129 | test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=1, shuffle=False, num_workers=8) 130 | 131 | 132 | if args.multi_gpu: 133 | model = torch.nn.DataParallel(model) 134 | model = model.to(device) 135 | else: 136 | model = model.to(device) 137 | 138 | criterion = nn.MSELoss().to(device) 139 | 140 | 141 | param_num = 0 142 | for param in model.parameters(): 143 | param_num += int(np.prod(param.shape)) 144 | print('Trainable params: %.2f million' % (param_num / 1e6)) 145 | 146 | 147 | optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=0.0000001) 148 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=decay_interval, gamma=decay_ratio) 149 | 150 | 151 | print("Ready to train network") 152 | 153 | best_test_criterion = -1 # SROCC min 154 | best = np.zeros(4) 155 | 156 | n_train = len(train_dataset) 157 | n_test = len(test_dataset) 158 | 159 | 160 | for epoch in range(num_epochs): 161 | # train 162 | model.train() 163 | 164 | batch_losses = [] 165 | batch_losses_each_disp = [] 166 | session_start_time = time.time() 167 | for i, (image, mos) in enumerate(train_loader): 168 | image = image.to(device) 169 | mos = mos[:,np.newaxis] 170 | mos = mos.to(device) 171 | 172 | mos_output = model(image) 173 | 174 | loss = criterion(mos_output, mos) 175 | batch_losses.append(loss.item()) 176 | batch_losses_each_disp.append(loss.item()) 177 | 178 | optimizer.zero_grad() # clear gradients for next train 179 | torch.autograd.backward(loss) 180 | optimizer.step() 181 | 182 | if (i+1) % print_samples == 0: 183 | session_end_time = time.time() 184 | avg_loss_epoch = sum(batch_losses_each_disp) / print_samples 185 | print('Epoch: %d/%d | Step: %d/%d | Training loss: %.4f' % \ 186 | (epoch + 1, num_epochs, i + 1, len(train_dataset) // batch_size, \ 187 | avg_loss_epoch)) 188 | batch_losses_each_disp = [] 189 | print('CostTime: {:.4f}'.format(session_end_time - session_start_time)) 190 | session_start_time = time.time() 191 | 192 | avg_loss = sum(batch_losses) / (len(train_dataset) // batch_size) 193 | print('Epoch %d averaged training loss: %.4f' % (epoch + 1, avg_loss)) 194 | 195 | scheduler.step() 196 | lr_current = scheduler.get_last_lr() 197 | print('The current learning rate is {:.06f}'.format(lr_current[0])) 198 | 199 | # Test 200 | model.eval() 201 | y_output = np.zeros(n_test) 202 | y_test = np.zeros(n_test) 203 | 204 | with torch.no_grad(): 205 | for i, (image, mos) in enumerate(test_loader): 206 | if args.test_method == 'one': 207 | image = image.to(device) 208 | y_test[i] = mos.item() 209 | mos = mos.to(device) 210 | outputs = model(image) 211 | y_output[i] = outputs.item() 212 | 213 | 214 | elif args.test_method == 'five': 215 | bs, ncrops, c, h, w = image.size() 216 | y_test[i] = mos.item() 217 | image = image.to(device) 218 | mos = mos.to(device) 219 | 220 | outputs = model(image.view(-1, c, h, w)) 221 | outputs_avg = outputs.view(bs, ncrops, -1).mean(1) 222 | y_output[i] = outputs_avg.item() 223 | 224 | test_PLCC, test_SRCC, test_KRCC, test_RMSE = performance_fit(y_test, y_output) 225 | print("Test results: SROCC={:.4f}, KROCC={:.4f}, PLCC={:.4f}, RMSE={:.4f}".format(test_SRCC, test_KRCC, test_PLCC, test_RMSE)) 226 | 227 | if test_SRCC > best_test_criterion: 228 | print("Update best model using best_val_criterion ") 229 | torch.save(model.state_dict(), trained_model_file) 230 | best[0:4] = [test_SRCC, test_KRCC, test_PLCC, test_RMSE] 231 | best_test_criterion = test_SRCC # update best val SROCC 232 | 233 | print("The best Test results: SROCC={:.4f}, KROCC={:.4f}, PLCC={:.4f}, RMSE={:.4f}".format(test_SRCC, test_KRCC, test_PLCC, test_RMSE)) 234 | 235 | print(database) 236 | best_all[exp_id, :] = best 237 | print("The best Val results: SROCC={:.4f}, KROCC={:.4f}, PLCC={:.4f}, RMSE={:.4f}".format(best[0], best[1], best[2], best[3])) 238 | print('*************************************************************************************************************************') 239 | 240 | best_median = np.median(best_all, 0) 241 | best_mean = np.mean(best_all, 0) 242 | best_std = np.std(best_all, 0) 243 | print('*************************************************************************************************************************') 244 | print(best_all) 245 | print("The median val results: SROCC={:.4f}, KROCC={:.4f}, PLCC={:.4f}, RMSE={:.4f}".format(best_median[0], best_median[1], best_median[2], best_median[3])) 246 | print("The mean val results: SROCC={:.4f}, KROCC={:.4f}, PLCC={:.4f}, RMSE={:.4f}".format(best_mean[0], best_mean[1], best_mean[2], best_mean[3])) 247 | print("The std val results: SROCC={:.4f}, KROCC={:.4f}, PLCC={:.4f}, RMSE={:.4f}".format(best_std[0], best_std[1], best_std[2], best_std[3])) 248 | print('*************************************************************************************************************************') -------------------------------------------------------------------------------- /train_imdt.py: -------------------------------------------------------------------------------- 1 | import os, argparse, time 2 | 3 | import numpy as np 4 | import yaml 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torchvision import transforms 9 | from utils import performance_fit 10 | 11 | import IQADataset 12 | # import scipy.io as scio 13 | import models.stairIQA_resnet as stairIQA_resnet 14 | 15 | 16 | def train_and_test(model, optimizer, criterion, trained_model_file, args, data_name): 17 | 18 | n_epoch = args.config[data_name]['n_epochs'] 19 | train_loader = args.train_loader[data_name] 20 | test_loader = args.test_loader[data_name] 21 | print_samples = args.print_samples 22 | num_epoch = args.num_epochs 23 | n_train = args.n_train_sample[data_name] 24 | batch_size = args.batch_size 25 | n_test = args.n_test_sample[data_name] 26 | 27 | for i_epoch in range(n_epoch): 28 | print(data_name + ':') 29 | print("eval mode") 30 | # eval 31 | model.eval() 32 | batch_losses = [] 33 | batch_losses_each_disp = [] 34 | session_start_time = time.time() 35 | for i, (image, mos) in enumerate(train_loader): 36 | image = image.to(device) 37 | mos = mos[:,np.newaxis] 38 | mos = mos.to(device) 39 | 40 | if data_name == 'FLIVE': 41 | mos_output,_,_,_,_,_ = model(image) 42 | elif data_name == 'FLIVE_patch': 43 | _,mos_output,_,_,_,_ = model(image) 44 | elif data_name == 'LIVE_challenge': 45 | _,_,mos_output,_,_,_ = model(image) 46 | elif data_name == 'Koniq10k': 47 | _,_,_,mos_output,_,_ = model(image) 48 | elif data_name == 'SPAQ': 49 | _,_,_,_,mos_output,_ = model(image) 50 | elif data_name == 'BID': 51 | _,_,_,_,_,mos_output = model(image) 52 | 53 | # MSE loss 54 | loss = criterion(mos_output,mos) 55 | batch_losses.append(loss.item()) 56 | batch_losses_each_disp.append(loss.item()) 57 | 58 | optimizer.zero_grad() # clear gradients for next train 59 | torch.autograd.backward(loss) 60 | optimizer.step() 61 | 62 | if (i+1) % print_samples == 0: 63 | session_end_time = time.time() 64 | avg_loss_epoch = sum(batch_losses_each_disp) / print_samples 65 | print ('Epoch [%d/%d], Iter [%d/%d] Losses: %.4f CostTime: %.4f' % \ 66 | (epoch*n_epoch+i_epoch+1, num_epoch*n_epoch, i+1, n_train//batch_size, \ 67 | avg_loss_epoch, session_end_time-session_start_time)) 68 | batch_losses_each_disp = [] 69 | session_start_time = time.time() 70 | 71 | avg_loss = sum(batch_losses)/(i+1) 72 | print('Epoch [%d/%d], training loss is: %.4f' %(epoch*n_epoch+i_epoch+1, num_epoch*n_epoch, avg_loss)) 73 | 74 | # Test 75 | model.eval() 76 | y_output = np.zeros(n_test) 77 | y_test = np.zeros(n_test) 78 | 79 | with torch.no_grad(): 80 | for i, (image, mos) in enumerate(test_loader): 81 | if args.test_method == 'one': 82 | image = image.to(device) 83 | y_test[i] = mos.item() 84 | mos = mos.to(device) 85 | if data_name == 'FLIVE': 86 | outputs,_,_,_,_,_ = model(image) 87 | elif data_name == 'FLIVE_patch': 88 | _,outputs,_,_,_,_ = model(image) 89 | elif data_name == 'LIVE_challenge': 90 | _,_,outputs,_,_,_ = model(image) 91 | elif data_name == 'Koniq10k': 92 | _,_,_,outputs,_,_ = model(image) 93 | elif data_name == 'SPAQ': 94 | _,_,_,_,outputs,_ = model(image) 95 | elif data_name == 'BID': 96 | _,_,_,_,_,outputs = model(image) 97 | y_output[i] = outputs.item() 98 | 99 | elif args.test_method == 'five': 100 | bs, ncrops, c, h, w = image.size() 101 | y_test[i] = mos.item() 102 | image = image.to(device) 103 | mos = mos.to(device) 104 | 105 | if data_name == 'FLIVE': 106 | outputs,_,_,_,_,_ = model(image.view(-1, c, h, w)) 107 | elif data_name == 'FLIVE_patch': 108 | _,outputs,_,_,_,_ = model(image.view(-1, c, h, w)) 109 | elif data_name == 'LIVE_challenge': 110 | _,_,outputs,_,_,_ = model(image.view(-1, c, h, w)) 111 | elif data_name == 'Koniq10k': 112 | _,_,_,outputs,_,_ = model(image.view(-1, c, h, w)) 113 | elif data_name == 'SPAQ': 114 | _,_,_,_,outputs,_ = model(image.view(-1, c, h, w)) 115 | elif data_name == 'BID': 116 | _,_,_,_,_,outputs = model(image.view(-1, c, h, w)) 117 | 118 | outputs_avg = outputs.view(bs, ncrops, -1).mean(1) 119 | y_output[i] = outputs_avg.item() 120 | 121 | test_PLCC, test_SRCC, test_KRCC, test_RMSE = performance_fit(y_test, y_output) 122 | print("Test results: SROCC={:.4f}, KROCC={:.4f}, PLCC={:.4f}, RMSE={:.4f}".format(test_SRCC, test_KRCC, test_PLCC, test_RMSE)) 123 | 124 | if test_SRCC > args.best_test_criterion[data_name]: 125 | print("Update best model using best_val_criterion ") 126 | 127 | torch.save(model.state_dict(), trained_model_file + data_name + '.pkl') 128 | 129 | args.best_performance[data_name][0:4] = [test_SRCC, test_KRCC, test_PLCC, test_RMSE] 130 | args.best_test_criterion[data_name] = test_SRCC # update best val SROCC 131 | 132 | print("The best Test results: SROCC={:.4f}, KROCC={:.4f}, PLCC={:.4f}, RMSE={:.4f}".format(test_SRCC, test_KRCC, test_PLCC, test_RMSE)) 133 | 134 | scheduler.step() 135 | lr_current = scheduler.get_last_lr() 136 | print('The current learning rate is {:.06f}'.format(lr_current[0])) 137 | 138 | return 139 | 140 | 141 | def parse_args(): 142 | """Parse input arguments. """ 143 | parser = argparse.ArgumentParser(description="In the wild Image Quality Assessment") 144 | parser.add_argument('--gpu', dest='gpu_id', help="GPU device id to use [0]", default=0, type=int) 145 | parser.add_argument('--num_epochs', dest='num_epochs', help='Maximum number of training epochs.', 146 | default=30, type=int) 147 | parser.add_argument('--batch_size', dest='batch_size', help='Batch size.', 148 | default=40, type=int) 149 | parser.add_argument('--lr', type=float, default=0.00001) 150 | parser.add_argument('--decay_ratio', type=float, default=0.9) 151 | parser.add_argument('--decay_interval', type=float, default=10) 152 | parser.add_argument('--snapshot', dest='snapshot', help='Path of model snapshot.', 153 | default='', type=str) 154 | parser.add_argument('--results_path', type=str) 155 | parser.add_argument('--model', default='stairIQA_resnet', type=str, 156 | help='model name (default: stairIQA_resnet)') 157 | parser.add_argument('--multi_gpu', type=bool, default=False) 158 | parser.add_argument('--print_samples', type=int, default = 50) 159 | parser.add_argument('--test_method', default='five', type=str, 160 | help='use the center crop or five crop to test the image (default: one)') 161 | parser.add_argument('--exp_id', default=0, type=int, 162 | help='exp id for train-test splits (default: 0)') 163 | 164 | 165 | args = parser.parse_args() 166 | 167 | return args 168 | 169 | if __name__ == '__main__': 170 | args = parse_args() 171 | 172 | with open('config.yaml') as f: 173 | config = yaml.load(f, Loader=yaml.FullLoader) 174 | 175 | args.config = config 176 | 177 | print('The current exp_id is ' + str(args.exp_id)) 178 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 179 | 180 | 181 | args.train_loader = {} 182 | args.test_loader = {} 183 | args.best_test_criterion = {} 184 | args.best_performance = {} 185 | args.n_train_sample = {} 186 | args.n_test_sample = {} 187 | 188 | for i_database in args.config: 189 | 190 | train_filename_list = os.path.join(args.config[i_database]['filename_dir'], \ 191 | args.config[i_database]['train_filename'] + '_' + str(args.exp_id)+'.csv') 192 | test_filename_list = os.path.join(args.config[i_database]['filename_dir'], \ 193 | args.config[i_database]['test_filename'] + '_' + str(args.exp_id)+'.csv') 194 | 195 | 196 | transformations_train = transforms.Compose([transforms.Resize(args.config[i_database]['resize']),\ 197 | transforms.RandomCrop(args.config[i_database]['crop_size']), \ 198 | transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], \ 199 | std=[0.229, 0.224, 0.225])]) 200 | 201 | if args.test_method == 'one': 202 | transformations_test = transforms.Compose([transforms.Resize(args.config[i_database]['resize']),\ 203 | transforms.CenterCrop(args.config[i_database]['crop_size']), \ 204 | transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], \ 205 | std=[0.229, 0.224, 0.225])]) 206 | elif args.test_method == 'five': 207 | transformations_test = transforms.Compose([transforms.Resize(args.config[i_database]['resize']),\ 208 | transforms.FiveCrop(args.config[i_database]['crop_size']), (lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])), (lambda crops: torch.stack([transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(crop) for crop in crops]))]) 209 | 210 | 211 | train_dataset = IQADataset.IQA_dataloader(args.config[i_database]['database_dir'], train_filename_list, transformations_train, i_database) 212 | test_dataset = IQADataset.IQA_dataloader(args.config[i_database]['database_dir'], test_filename_list, transformations_test, i_database) 213 | 214 | args.train_loader[i_database] = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=8) 215 | args.test_loader[i_database] = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=1, shuffle=False, num_workers=8) 216 | 217 | args.best_test_criterion[i_database] = -1 218 | args.best_performance[i_database] = np.zeros(4) 219 | 220 | args.n_train_sample[i_database] = len(train_dataset) 221 | args.n_test_sample[i_database] = len(test_dataset) 222 | 223 | 224 | trained_model_file = os.path.join(args.snapshot, '{}-EXP{}-'.format(args.model, args.exp_id)) 225 | 226 | # load the network 227 | if args.model == 'stairIQA_resnet': 228 | model = stairIQA_resnet.resnet50_imdt(pretrained = True) 229 | 230 | 231 | if args.multi_gpu: 232 | model = torch.nn.DataParallel(model) 233 | model = model.to(device) 234 | else: 235 | model = model.to(device) 236 | 237 | criterion = nn.MSELoss().to(device) 238 | 239 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=0.0000001) 240 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.decay_interval, gamma=args.decay_ratio) 241 | 242 | 243 | print("Ready to train network") 244 | 245 | 246 | for epoch in range(args.num_epochs): 247 | model.eval() 248 | 249 | # train and test FLIVE patch 250 | train_and_test(model, optimizer, criterion, trained_model_file, args, 'FLIVE_patch') 251 | train_and_test(model, optimizer, criterion, trained_model_file, args, 'LIVE_challenge') 252 | train_and_test(model, optimizer, criterion, trained_model_file, args, 'BID') 253 | 254 | train_and_test(model, optimizer, criterion, trained_model_file, args, 'FLIVE') 255 | train_and_test(model, optimizer, criterion, trained_model_file, args, 'LIVE_challenge') 256 | train_and_test(model, optimizer, criterion, trained_model_file, args, 'BID') 257 | 258 | train_and_test(model, optimizer, criterion, trained_model_file, args, 'Koniq10k') 259 | train_and_test(model, optimizer, criterion, trained_model_file, args, 'LIVE_challenge') 260 | train_and_test(model, optimizer, criterion, trained_model_file, args, 'BID') 261 | 262 | train_and_test(model, optimizer, criterion, trained_model_file, args, 'SPAQ') 263 | train_and_test(model, optimizer, criterion, trained_model_file, args, 'LIVE_challenge') 264 | train_and_test(model, optimizer, criterion, trained_model_file, args, 'BID') 265 | 266 | 267 | 268 | 269 | 270 | 271 | 272 | 273 | 274 | 275 | 276 | 277 | 278 | for i_database in args.config: 279 | print(i_database) 280 | print("The best Val results: SROCC={:.4f}, KROCC={:.4f}, PLCC={:.4f}, RMSE={:.4f}".format(args.best_performance[i_database][0], \ 281 | args.best_performance[i_database][1], args.best_performance[i_database][2], args.best_performance[i_database][3])) 282 | np.save(os.path.join(args.results_path, args.model + '_' + i_database + '_' + str(args.exp_id)), args.best_performance[i_database]) -------------------------------------------------------------------------------- /models/ResNet_staircase.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.utils.model_zoo as model_zoo 4 | 5 | 6 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 7 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 8 | 'wide_resnet50_2', 'wide_resnet101_2'] 9 | 10 | 11 | model_urls = { 12 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 13 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 14 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 15 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 16 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 17 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 18 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 19 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 20 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 21 | } 22 | 23 | 24 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 25 | """3x3 convolution with padding""" 26 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 27 | padding=dilation, groups=groups, bias=False, dilation=dilation) 28 | 29 | 30 | def conv1x1(in_planes, out_planes, stride=1): 31 | """1x1 convolution""" 32 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 33 | 34 | 35 | class BasicBlock(nn.Module): 36 | expansion = 1 37 | __constants__ = ['downsample'] 38 | 39 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 40 | base_width=64, dilation=1, norm_layer=None): 41 | super(BasicBlock, self).__init__() 42 | if norm_layer is None: 43 | norm_layer = nn.BatchNorm2d 44 | if groups != 1 or base_width != 64: 45 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 46 | if dilation > 1: 47 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 48 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 49 | self.conv1 = conv3x3(inplanes, planes, stride) 50 | self.bn1 = norm_layer(planes) 51 | self.relu = nn.ReLU(inplace=True) 52 | self.conv2 = conv3x3(planes, planes) 53 | self.bn2 = norm_layer(planes) 54 | self.downsample = downsample 55 | self.stride = stride 56 | 57 | def forward(self, x): 58 | identity = x 59 | 60 | out = self.conv1(x) 61 | out = self.bn1(out) 62 | out = self.relu(out) 63 | 64 | out = self.conv2(out) 65 | out = self.bn2(out) 66 | 67 | if self.downsample is not None: 68 | identity = self.downsample(x) 69 | 70 | out += identity 71 | out = self.relu(out) 72 | 73 | return out 74 | 75 | 76 | class Bottleneck(nn.Module): 77 | expansion = 4 78 | __constants__ = ['downsample'] 79 | 80 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 81 | base_width=64, dilation=1, norm_layer=None): 82 | super(Bottleneck, self).__init__() 83 | if norm_layer is None: 84 | norm_layer = nn.BatchNorm2d 85 | width = int(planes * (base_width / 64.)) * groups 86 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 87 | self.conv1 = conv1x1(inplanes, width) 88 | self.bn1 = norm_layer(width) 89 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 90 | self.bn2 = norm_layer(width) 91 | self.conv3 = conv1x1(width, planes * self.expansion) 92 | self.bn3 = norm_layer(planes * self.expansion) 93 | self.relu = nn.ReLU(inplace=True) 94 | self.downsample = downsample 95 | self.stride = stride 96 | 97 | def forward(self, x): 98 | identity = x 99 | 100 | out = self.conv1(x) 101 | out = self.bn1(out) 102 | out = self.relu(out) 103 | 104 | out = self.conv2(out) 105 | out = self.bn2(out) 106 | out = self.relu(out) 107 | 108 | out = self.conv3(out) 109 | out = self.bn3(out) 110 | 111 | if self.downsample is not None: 112 | identity = self.downsample(x) 113 | 114 | out += identity 115 | out = self.relu(out) 116 | 117 | return out 118 | 119 | 120 | class ResNet(nn.Module): 121 | 122 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 123 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 124 | norm_layer=None): 125 | super(ResNet, self).__init__() 126 | if norm_layer is None: 127 | norm_layer = nn.BatchNorm2d 128 | self._norm_layer = norm_layer 129 | 130 | self.inplanes = 64 131 | self.dilation = 1 132 | if replace_stride_with_dilation is None: 133 | # each element in the tuple indicates if we should replace 134 | # the 2x2 stride with a dilated convolution instead 135 | replace_stride_with_dilation = [False, False, False] 136 | if len(replace_stride_with_dilation) != 3: 137 | raise ValueError("replace_stride_with_dilation should be None " 138 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 139 | self.groups = groups 140 | self.base_width = width_per_group 141 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 142 | bias=False) 143 | self.bn1 = norm_layer(self.inplanes) 144 | self.relu = nn.ReLU(inplace=True) 145 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 146 | self.layer1 = self._make_layer(block, 64, layers[0]) 147 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 148 | dilate=replace_stride_with_dilation[0]) 149 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 150 | dilate=replace_stride_with_dilation[1]) 151 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 152 | dilate=replace_stride_with_dilation[2]) 153 | 154 | self.hyper1_1 = self.hyper_structure1(64,256) 155 | self.hyper2_1 = self.hyper_structure2(256,512) 156 | self.hyper3_1 = self.hyper_structure2(512,1024) 157 | self.hyper4_1 = self.hyper_structure2(1024,2048) 158 | 159 | self.hyper2_2 = self.hyper_structure2(256,512) 160 | self.hyper3_2 = self.hyper_structure2(512,1024) 161 | self.hyper4_2 = self.hyper_structure2(1024,2048) 162 | 163 | self.hyper3_3 = self.hyper_structure2(512,1024) 164 | self.hyper4_3 = self.hyper_structure2(1024,2048) 165 | 166 | self.hyper4_4 = self.hyper_structure2(1024,2048) 167 | 168 | 169 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 170 | 171 | self.quality1 = self.quality_regression(2048,1) 172 | self.quality2 = self.quality_regression(2048,1) 173 | self.quality3 = self.quality_regression(2048,1) 174 | self.quality4 = self.quality_regression(2048,1) 175 | self.quality5 = self.quality_regression(2048,1) 176 | self.quality6 = self.quality_regression(2048,1) 177 | 178 | for m in self.modules(): 179 | if isinstance(m, nn.Conv2d): 180 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 181 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 182 | nn.init.constant_(m.weight, 1) 183 | nn.init.constant_(m.bias, 0) 184 | 185 | # Zero-initialize the last BN in each residual branch, 186 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 187 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 188 | if zero_init_residual: 189 | for m in self.modules(): 190 | if isinstance(m, Bottleneck): 191 | nn.init.constant_(m.bn3.weight, 0) 192 | elif isinstance(m, BasicBlock): 193 | nn.init.constant_(m.bn2.weight, 0) 194 | 195 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 196 | norm_layer = self._norm_layer 197 | downsample = None 198 | previous_dilation = self.dilation 199 | if dilate: 200 | self.dilation *= stride 201 | stride = 1 202 | if stride != 1 or self.inplanes != planes * block.expansion: 203 | downsample = nn.Sequential( 204 | conv1x1(self.inplanes, planes * block.expansion, stride), 205 | norm_layer(planes * block.expansion), 206 | ) 207 | 208 | layers = [] 209 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 210 | self.base_width, previous_dilation, norm_layer)) 211 | self.inplanes = planes * block.expansion 212 | for _ in range(1, blocks): 213 | layers.append(block(self.inplanes, planes, groups=self.groups, 214 | base_width=self.base_width, dilation=self.dilation, 215 | norm_layer=norm_layer)) 216 | 217 | return nn.Sequential(*layers) 218 | 219 | def quality_regression(self,in_channels,out_channels): 220 | regression_block = nn.Sequential( 221 | nn.Linear(2048, 128), 222 | nn.Linear(128, 1), 223 | ) 224 | 225 | return regression_block 226 | 227 | def hyper_structure1(self,in_channels,out_channels): 228 | 229 | hyper_block = nn.Sequential( 230 | nn.Conv2d(in_channels,in_channels//4,kernel_size=1,stride=1, padding=0,bias=False), 231 | nn.Conv2d(in_channels//4,in_channels//4,kernel_size=3,stride=1, padding=1,bias=False), 232 | nn.Conv2d(in_channels//4,out_channels,kernel_size=1,stride=1, padding=0,bias=False), 233 | ) 234 | 235 | return hyper_block 236 | 237 | def hyper_structure2(self,in_channels,out_channels): 238 | hyper_block = nn.Sequential( 239 | nn.Conv2d(in_channels,in_channels//4,kernel_size=1,stride=1, padding=0,bias=False), 240 | nn.Conv2d(in_channels//4,in_channels//4,kernel_size=3,stride=2, padding=1,bias=False), 241 | nn.Conv2d(in_channels//4,out_channels,kernel_size=1,stride=1, padding=0,bias=False), 242 | ) 243 | 244 | return hyper_block 245 | 246 | 247 | def _forward_impl(self, x): 248 | # See note [TorchScript super()] 249 | x = self.conv1(x) 250 | x = self.bn1(x) 251 | x = self.relu(x) 252 | x = self.maxpool(x) 253 | 254 | x_hyper1 = self.hyper1_1(x) 255 | x = self.layer1(x) 256 | x_hyper1 = self.hyper2_1(x_hyper1+x) 257 | x_hyper2 = self.hyper2_2(x) 258 | x = self.layer2(x) 259 | x_hyper1 = self.hyper3_1(x_hyper1+x) 260 | x_hyper2 = self.hyper3_2(x_hyper2+x) 261 | x_hyper3 = self.hyper3_3(x) 262 | x = self.layer3(x) 263 | x_hyper1 = self.hyper4_1(x_hyper1+x) 264 | x_hyper2 = self.hyper4_2(x_hyper2+x) 265 | x_hyper3 = self.hyper4_3(x_hyper3+x) 266 | x_hyper4 = self.hyper4_4(x) 267 | 268 | x = self.layer4(x) 269 | 270 | x = x+x_hyper1+x_hyper2+x_hyper3+x_hyper4 271 | 272 | x = self.avgpool(x) 273 | x = torch.flatten(x, 1) 274 | 275 | x1 = self.quality1(x) 276 | x2 = self.quality2(x) 277 | x3 = self.quality3(x) 278 | x4 = self.quality4(x) 279 | x5 = self.quality5(x) 280 | x6 = self.quality6(x) 281 | 282 | return x1,x2,x3,x4,x5,x6 283 | 284 | def forward(self, x): 285 | return self._forward_impl(x) 286 | 287 | 288 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 289 | model = ResNet(block, layers, **kwargs) 290 | if pretrained: 291 | state_dict = load_state_dict_from_url(model_urls[arch], 292 | progress=progress) 293 | model.load_state_dict(state_dict) 294 | return model 295 | 296 | 297 | def resnet18(pretrained=False, progress=True, **kwargs): 298 | r"""ResNet-18 model from 299 | `"Deep Residual Learning for Image Recognition" `_ 300 | Args: 301 | pretrained (bool): If True, returns a model pre-trained on ImageNet 302 | progress (bool): If True, displays a progress bar of the download to stderr 303 | """ 304 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 305 | **kwargs) 306 | 307 | 308 | def resnet34(pretrained=False, progress=True, **kwargs): 309 | r"""ResNet-34 model from 310 | `"Deep Residual Learning for Image Recognition" `_ 311 | Args: 312 | pretrained (bool): If True, returns a model pre-trained on ImageNet 313 | progress (bool): If True, displays a progress bar of the download to stderr 314 | """ 315 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 316 | if pretrained: 317 | # model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 318 | model_dict = model.state_dict() 319 | pre_train_model = model_zoo.load_url(model_urls['resnet34']) 320 | pre_train_model = {k:v for k,v in pre_train_model.items() if k in model_dict} 321 | model_dict.update(pre_train_model) 322 | model.load_state_dict(model_dict) 323 | return model 324 | 325 | 326 | def resnet50(pretrained=False, progress=True, **kwargs): 327 | r"""ResNet-50 model from 328 | `"Deep Residual Learning for Image Recognition" `_ 329 | Args: 330 | pretrained (bool): If True, returns a model pre-trained on ImageNet 331 | progress (bool): If True, displays a progress bar of the download to stderr 332 | """ 333 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 334 | if pretrained: 335 | # model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 336 | model_dict = model.state_dict() 337 | pre_train_model = model_zoo.load_url(model_urls['resnet50']) 338 | pre_train_model = {k:v for k,v in pre_train_model.items() if k in model_dict} 339 | model_dict.update(pre_train_model) 340 | model.load_state_dict(model_dict) 341 | return model 342 | 343 | 344 | def resnet101(pretrained=False, progress=True, **kwargs): 345 | r"""ResNet-101 model from 346 | `"Deep Residual Learning for Image Recognition" `_ 347 | Args: 348 | pretrained (bool): If True, returns a model pre-trained on ImageNet 349 | progress (bool): If True, displays a progress bar of the download to stderr 350 | """ 351 | # return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 352 | # **kwargs) 353 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 354 | if pretrained: 355 | model_dict = model.state_dict() 356 | pre_train_model = model_zoo.load_url(model_urls['resnet101']) 357 | pre_train_model = {k:v for k,v in pre_train_model.items() if k in model_dict} 358 | model_dict.update(pre_train_model) 359 | model.load_state_dict(model_dict) 360 | return model 361 | 362 | 363 | def resnet152(pretrained=False, progress=True, **kwargs): 364 | r"""ResNet-152 model from 365 | `"Deep Residual Learning for Image Recognition" `_ 366 | Args: 367 | pretrained (bool): If True, returns a model pre-trained on ImageNet 368 | progress (bool): If True, displays a progress bar of the download to stderr 369 | """ 370 | # return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 371 | # **kwargs) 372 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 373 | if pretrained: 374 | model_dict = model.state_dict() 375 | pre_train_model = model_zoo.load_url(model_urls['resnet152']) 376 | pre_train_model = {k:v for k,v in pre_train_model.items() if k in model_dict} 377 | model_dict.update(pre_train_model) 378 | model.load_state_dict(model_dict) 379 | return model 380 | 381 | 382 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 383 | r"""ResNeXt-50 32x4d model from 384 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 385 | Args: 386 | pretrained (bool): If True, returns a model pre-trained on ImageNet 387 | progress (bool): If True, displays a progress bar of the download to stderr 388 | """ 389 | kwargs['groups'] = 32 390 | kwargs['width_per_group'] = 4 391 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 392 | pretrained, progress, **kwargs) 393 | 394 | 395 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 396 | r"""ResNeXt-101 32x8d model from 397 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 398 | Args: 399 | pretrained (bool): If True, returns a model pre-trained on ImageNet 400 | progress (bool): If True, displays a progress bar of the download to stderr 401 | """ 402 | kwargs['groups'] = 32 403 | kwargs['width_per_group'] = 8 404 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 405 | pretrained, progress, **kwargs) 406 | 407 | 408 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs): 409 | r"""Wide ResNet-50-2 model from 410 | `"Wide Residual Networks" `_ 411 | The model is the same as ResNet except for the bottleneck number of channels 412 | which is twice larger in every block. The number of channels in outer 1x1 413 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 414 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 415 | Args: 416 | pretrained (bool): If True, returns a model pre-trained on ImageNet 417 | progress (bool): If True, displays a progress bar of the download to stderr 418 | """ 419 | kwargs['width_per_group'] = 64 * 2 420 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], 421 | pretrained, progress, **kwargs) 422 | 423 | 424 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs): 425 | r"""Wide ResNet-101-2 model from 426 | `"Wide Residual Networks" `_ 427 | The model is the same as ResNet except for the bottleneck number of channels 428 | which is twice larger in every block. The number of channels in outer 1x1 429 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 430 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 431 | Args: 432 | pretrained (bool): If True, returns a model pre-trained on ImageNet 433 | progress (bool): If True, displays a progress bar of the download to stderr 434 | """ 435 | kwargs['width_per_group'] = 64 * 2 436 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], 437 | pretrained, progress, **kwargs) --------------------------------------------------------------------------------