├── .gitattributes ├── CNNfeatures.py ├── Framework.png ├── LICENSE ├── README.md ├── VQAdataset.py ├── VQAloss.py ├── VQAmodel.py ├── VQAperformance.py ├── cross_dataset_evaluation.py ├── cross_job.sh ├── data ├── CVD2014info.mat ├── KoNViD-1kinfo.mat ├── LIVE-Qualcomminfo.mat ├── LIVE-VQCinfo.mat ├── data_info_maker.m └── test.mp4 ├── job.sh ├── main.py ├── models └── MDTVSFA.pt ├── requirements.txt └── test_demo.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /CNNfeatures.py: -------------------------------------------------------------------------------- 1 | """Extracting Content-Aware Perceptual Features using Pre-Trained Image Classification Models (e.g., ResNet-50)""" 2 | # Author: Dingquan Li 3 | # Email: dingquanli AT pku DOT edu DOT cn 4 | # Date: 2019/11/8 5 | 6 | from argparse import ArgumentParser 7 | import torch 8 | from torchvision import transforms, models 9 | import torch.nn as nn 10 | from torch.utils.data import Dataset 11 | import skvideo.io 12 | from PIL import Image 13 | import os 14 | import h5py 15 | import numpy as np 16 | import random 17 | import time 18 | 19 | 20 | class VideoDataset(Dataset): 21 | """Read data from the original dataset for feature extraction""" 22 | def __init__(self, videos_dir, video_names, score, video_format='RGB', width=None, height=None): 23 | 24 | super(VideoDataset, self).__init__() 25 | self.videos_dir = videos_dir 26 | self.video_names = video_names 27 | self.score = score 28 | self.format = video_format 29 | self.width = width 30 | self.height = height 31 | 32 | def __len__(self): 33 | return len(self.video_names) 34 | 35 | def __getitem__(self, idx): 36 | video_name = self.video_names[idx] 37 | assert self.format == 'YUV420' or self.format == 'RGB' 38 | if self.format == 'YUV420': 39 | video_data = skvideo.io.vread(os.path.join(self.videos_dir, video_name), self.height, self.width, inputdict={'-pix_fmt':'yuvj420p'}) 40 | else: 41 | video_data = skvideo.io.vread(os.path.join(self.videos_dir, video_name)) 42 | video_score = self.score[idx] 43 | 44 | transform = transforms.Compose([ 45 | transforms.ToTensor(), 46 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 47 | ]) 48 | 49 | video_length = video_data.shape[0] 50 | video_channel = video_data.shape[3] 51 | video_height = video_data.shape[1] 52 | video_width = video_data.shape[2] 53 | print('video_width: {} video_height: {}'.format(video_width, video_height)) 54 | transformed_video = torch.zeros([video_length, video_channel, video_height, video_width]) 55 | for frame_idx in range(video_length): 56 | frame = video_data[frame_idx] 57 | frame = Image.fromarray(frame) 58 | # frame.show() 59 | frame = transform(frame) 60 | transformed_video[frame_idx] = frame 61 | 62 | sample = {'video': transformed_video, 63 | 'score': video_score} 64 | 65 | return sample 66 | 67 | 68 | class CNNModel(torch.nn.Module): 69 | """Modified CNN models for feature extraction""" 70 | def __init__(self, model='ResNet-50'): 71 | super(CNNModel, self).__init__() 72 | if model == 'AlexNet': 73 | print("use AlexNet") 74 | self.features = nn.Sequential(*list(models.alexnet(pretrained=True).children())[:-2]) 75 | elif model == 'ResNet-152': 76 | print("use ResNet-152") 77 | self.features = nn.Sequential(*list(models.resnet152(pretrained=True).children())[:-2]) 78 | elif model == 'ResNeXt-101-32x8d': 79 | print("use ResNetXt-101-32x8d") 80 | self.features = nn.Sequential(*list(models.resnext101_32x8d(pretrained=True).children())[:-2]) 81 | elif model == 'Wide ResNet-101-2': 82 | print("use Wide ResNet-101-2") 83 | self.features = nn.Sequential(*list(models.wide_resnet101_2(pretrained=True).children())[:-2]) 84 | else: 85 | print("use default ResNet-50") 86 | self.features = nn.Sequential(*list(models.resnet50(pretrained=True).children())[:-2]) 87 | 88 | def forward(self, x): 89 | x = self.features(x) 90 | features_mean = nn.functional.adaptive_avg_pool2d(x, 1) 91 | features_std = global_std_pool2d(x) 92 | return features_mean, features_std 93 | # # features@: 7->res5c 94 | # for ii, model in enumerate(self.features): 95 | # x = model(x) 96 | # if ii == 7: 97 | # features_mean = nn.functional.adaptive_avg_pool2d(x, 1) 98 | # features_std = global_std_pool2d(x) 99 | # return features_mean, features_std 100 | 101 | 102 | def global_std_pool2d(x): 103 | """2D global standard variation pooling""" 104 | return torch.std(x.view(x.size()[0], x.size()[1], -1, 1), 105 | dim=2, keepdim=True) 106 | 107 | 108 | def get_features(video_data, frame_batch_size=64, model='ResNet-50', device='cuda'): 109 | """feature extraction""" 110 | extractor = CNNModel(model=model).to(device) 111 | video_length = video_data.shape[0] 112 | frame_start = 0 113 | frame_end = frame_start + frame_batch_size 114 | output1 = torch.Tensor().to(device) 115 | output2 = torch.Tensor().to(device) 116 | extractor.eval() 117 | with torch.no_grad(): 118 | while frame_end < video_length: 119 | batch = video_data[frame_start:frame_end].to(device) 120 | features_mean, features_std = extractor(batch) 121 | output1 = torch.cat((output1, features_mean), 0) 122 | output2 = torch.cat((output2, features_std), 0) 123 | frame_end += frame_batch_size 124 | frame_start += frame_batch_size 125 | 126 | last_batch = video_data[frame_start:video_length].to(device) 127 | features_mean, features_std = extractor(last_batch) 128 | output1 = torch.cat((output1, features_mean), 0) 129 | output2 = torch.cat((output2, features_std), 0) 130 | output = torch.cat((output1, output2), 1).squeeze() 131 | 132 | return output 133 | 134 | 135 | if __name__ == "__main__": 136 | parser = ArgumentParser(description='"Extracting Content-Aware Perceptual Features using pre-trained models') 137 | parser.add_argument("--seed", type=int, default=19920517) 138 | parser.add_argument('--database', default='CVD2014', type=str, 139 | help='database name (default: CVD2014)') 140 | parser.add_argument('--model', default='ResNet-50', type=str, 141 | help='which pre-trained model used (default: ResNet-50)') 142 | parser.add_argument('--frame_batch_size', type=int, default=8, 143 | help='frame batch size for feature extraction (default: 8)') 144 | 145 | parser.add_argument('--disable_gpu', action='store_true', help='flag whether to disable GPU') 146 | 147 | parser.add_argument("--ith", type=int, default=0, help='start frame id') 148 | args = parser.parse_args() 149 | 150 | torch.manual_seed(args.seed) # 151 | torch.backends.cudnn.deterministic = True 152 | torch.backends.cudnn.benchmark = False 153 | np.random.seed(args.seed) 154 | random.seed(args.seed) 155 | 156 | torch.utils.backcompat.broadcast_warning.enabled = True 157 | 158 | if args.database == 'KoNViD-1k': 159 | videos_dir = 'KoNViD-1k/' # videos dir, e.g., ln -s /home/ldq/Downloads/KoNViD-1k/ KoNViD-1k 160 | features_dir = 'CNN_features_KoNViD-1k/' # features dir 161 | datainfo = 'data/KoNViD-1kinfo.mat' # database info: video_names, scores; video format, width, height, index, ref_ids, max_len, etc. 162 | if args.database == 'CVD2014': 163 | videos_dir = 'CVD2014/' # ln -s /media/ldq/Research/Data/CVD2014/ CVD2014 164 | features_dir = 'CNN_features_CVD2014/' 165 | datainfo = 'data/CVD2014info.mat' 166 | if args.database == 'LIVE-Qualcomm': 167 | videos_dir = 'LIVE-Qualcomm/' # ln -s /media/ldq/Others/Data/12.LIVE-Qualcomm\ Mobile\ In-Capture\ Video\ Quality\ Database/ LIVE-Qualcomm 168 | features_dir = 'CNN_features_LIVE-Qualcomm/' 169 | datainfo = 'data/LIVE-Qualcomminfo.mat' 170 | if args.database == 'LIVE-VQC': 171 | videos_dir = 'LIVE-VQC/' # /media/ldq/Others/Data/LIVE\ Video\ Quality\ Challenge\ \(VQC\)\ Database/Video LIVE-VQC 172 | features_dir = 'CNN_features_LIVE-VQC/' 173 | datainfo = 'data/LIVE-VQCinfo.mat' 174 | 175 | if not os.path.exists(features_dir): 176 | os.makedirs(features_dir) 177 | 178 | device = torch.device("cuda" if not args.disable_gpu and torch.cuda.is_available() else "cpu") 179 | 180 | Info = h5py.File(datainfo, 'r') 181 | video_names = [Info[Info['video_names'][0, :][i]][()].tobytes()[::2].decode() for i in range(len(Info['video_names'][0, :]))] 182 | scores = Info['scores'][0, :] 183 | video_format = Info['video_format'][()].tobytes()[::2].decode() 184 | width = int(Info['width'][0]) 185 | height = int(Info['height'][0]) 186 | dataset = VideoDataset(videos_dir, video_names, scores, video_format, width, height) 187 | 188 | max_len = 0 189 | # extract feature on LIVE-Qualcomm using AlexNet will cause the error of "cannot allocate memory" 190 | # One way to solve the problem is to move the for loop to bash. 191 | """ 192 | for ((i=0; i<208; i++)); do 193 | CUDA_VISIBLE_DEVICES=0 python CNNfeatures.py --ith=$i --model=AlexNet --database=LIVE-Qualcomm 194 | done 195 | """ 196 | for i in range(args.ith, len(dataset)): # range(args.ith, args.ith+1): # 197 | start = time.time() 198 | current_data = dataset[i] 199 | print('Video {}: length {}'.format(i, current_data['video'].shape[0])) 200 | if max_len < current_data['video'].shape[0]: 201 | max_len = current_data['video'].shape[0] 202 | features = get_features(current_data['video'], args.frame_batch_size, args.model, device) 203 | np.save(features_dir + str(i) + '_' + args.model +'_last_conv', features.to('cpu').numpy()) 204 | np.save(features_dir + str(i) + '_score', current_data['score']) 205 | end = time.time() 206 | print('{} seconds'.format(end-start)) 207 | print(max_len) 208 | -------------------------------------------------------------------------------- /Framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lidq92/MDTVSFA/1460fb21d8e8cf1493331edee5a0082dd6bbf2ff/Framework.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Dingquan Li 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Unified Quality Assessment of In-the-Wild Videos with Mixed Datasets Training 2 | [![License](https://img.shields.io/github/license/mashape/apistatus.svg?maxAge=2592000)](License) 3 | 4 | ## Description 5 | MDTVSFA code for the following paper: 6 | 7 | - Dingquan Li, Tingting Jiang, and Ming Jiang. [Unified Quality Assessment of In-the-Wild Videos with Mixed Datasets Training](https://link.springer.com/article/10.1007%2Fs11263-020-01408-w). International Journal of Computer Vision (IJCV) Special Issue on Computer Vision in the Wild, 2021. [[arxiv version]](https://arxiv.org/abs/2011.04263) 8 | ![Framework](Framework.png) 9 | 10 | ## How to? 11 | ### Install Requirements 12 | ```bash 13 | conda create -n reproducibleresearch pip python=3.6 14 | source activate reproducibleresearch 15 | pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple 16 | # source deactive 17 | ``` 18 | Note: Make sure that the CUDA version is consistent. If you have any installation problems, please find the details of error information in `*.log` file. 19 | 20 | ### Download Datasets 21 | Download the [KoNViD-1k](http://database.mmsp-kn.de/konvid-1k-database.html), [CVD2014](https://www.mv.helsinki.fi/home/msjnuuti/CVD2014/) ([alternative link](https://zenodo.org/record/2646315#.X6OmVC-1H3Q)), [LIVE-Qualcomm](http://live.ece.utexas.edu/research/incaptureDatabase/index.html), and [LIVE-VQC](http://live.ece.utexas.edu/research/LIVEVQC/index.html) datasets. Then, run the following `ln` commands in the root of this project. 22 | 23 | ```bash 24 | ln -s KoNViD-1k_path KoNViD-1k # KoNViD-1k_path is your path to the KoNViD-1k dataset 25 | ln -s CVD2014_path CVD2014 # CVD2014_path is your path to the CVD2014 dataset 26 | ln -s LIVE-Qualcomm_path LIVE-Qualcomm # LIVE-Qualcomm_path is your path to the LIVE-Qualcomm dataset 27 | ln -s LIVE-VQC_path LIVE-VQC # LIVE-VQC_path is your path to the LIVE-VQC dataset 28 | ``` 29 | 30 | ### Training and Evaluating on Multiple Datasets 31 | 32 | ```bash 33 | # Feature extraction 34 | CUDA_VISIBLE_DEVICES=0 python CNNfeatures.py --database=KoNViD-1k --frame_batch_size=64 35 | CUDA_VISIBLE_DEVICES=1 python CNNfeatures.py --database=CVD2014 --frame_batch_size=32 36 | CUDA_VISIBLE_DEVICES=0 python CNNfeatures.py --database=LIVE-Qualcomm --frame_batch_size=8 37 | CUDA_VISIBLE_DEVICES=1 python CNNfeatures.py --database=LIVE-VQC --frame_batch_size=8 38 | # Training, intra-dataset evaluation, for example 39 | chmod 777 job.sh 40 | ./job.sh -g 0 -d K -d C -d L > KCL-mixed-exp-0-10-1e-4-32-40.log 2>&1 & 41 | # Cross-dataset evaluation (after training), for example 42 | chmod 777 cross_job.sh 43 | ./cross_job.sh -g 1 -d K -d C -d L -c N -l mixed > KCLtrained-crossN-mixed-exp-0-10.log 2>&1 & 44 | ``` 45 | 46 | ### Test Demo 47 | 48 | The model weights provided in `models/MDTVSFA.pt` are the saved weights when running the 9-th split of KoNViD-1k, CVD2014, and LIVE-Qualcomm. 49 | ```bash 50 | python test_demo.py --model_path=models/MDTVSFA.pt --video_path=data/test.mp4 51 | ``` 52 | 53 | ### Contact 54 | Dingquan Li, dingquanli AT pku DOT edu DOT cn. 55 | -------------------------------------------------------------------------------- /VQAdataset.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import torch 3 | from torch.utils.data import Dataset 4 | import numpy as np 5 | 6 | 7 | class VQADataset(Dataset): 8 | def __init__(self, args, datasets, status='train'): 9 | self.status = status 10 | self.datasets = datasets 11 | self.crop_length = args.crop_length 12 | 13 | max_len = dict() 14 | self.M = dict() 15 | self.m = dict() 16 | self.scale = dict() 17 | self.index = dict() 18 | 19 | for dataset in datasets: 20 | Info = h5py.File(args.data_info[dataset], 'r') 21 | max_len[dataset] = int(Info['max_len'][0]) 22 | 23 | self.M[dataset] = Info['scores'][0, :].max() 24 | self.m[dataset] = Info['scores'][0, :].min() 25 | self.scale[dataset] = self.M[dataset] - self.m[dataset] 26 | 27 | index = Info['index'] 28 | index = index[:, args.exp_id % index.shape[1]] 29 | ref_ids = Info['ref_ids'][0, :] 30 | if status == 'train': 31 | index = index[0:int(args.train_proportion * args.train_ratio * len(index))] 32 | elif status == 'val': 33 | index = index[int(args.train_ratio * len(index)):int((0.5 + args.train_ratio / 2) * len(index))] 34 | elif status == 'test': 35 | index = index[int((0.5 + args.train_ratio / 2) * len(index)):len(index)] 36 | self.index[dataset] = [] 37 | for i in range(len(ref_ids)): 38 | if ref_ids[i] in index: 39 | self.index[dataset].append(i) 40 | print("# {} images from {}: {}".format(status, dataset, len(self.index[dataset]))) 41 | print("Ref Index: ") 42 | print(index.astype(int)) 43 | 44 | max_len_all = max(max_len.values()) 45 | self.features, self.length, self.label, self.KCL, self.N = dict(), dict(), dict(), dict(), dict() 46 | for dataset in datasets: 47 | N = len(self.index[dataset]) 48 | self.N[dataset] = N 49 | self.features[dataset] = np.zeros((N, max_len_all, args.feat_dim), dtype=np.float32) 50 | self.length[dataset] = np.zeros(N, dtype=np.int) 51 | self.label[dataset] = np.zeros((N, 1), dtype=np.float32) 52 | self.KCL[dataset] = [] 53 | for i in range(N): 54 | features = np.load(args.features_dir[dataset] + str(self.index[dataset][i]) + '_' + args.feature_extractor +'_last_conv.npy') 55 | self.length[dataset][i] = features.shape[0] 56 | self.features[dataset][i, :features.shape[0], :] = features 57 | mos = np.load(args.features_dir[dataset] + str(self.index[dataset][i]) + '_score.npy') # 58 | self.label[dataset][i] = mos 59 | self.KCL[dataset].append(dataset) 60 | 61 | def __len__(self): 62 | return max(self.N.values()) 63 | 64 | def __getitem__(self, idx): 65 | data = [(self.features[dataset][idx % self.N[dataset]], 66 | self.length[dataset][idx % self.N[dataset]], 67 | self.KCL[dataset][idx % self.N[dataset]]) for dataset in self.datasets] 68 | label = [self.label[dataset][idx % self.N[dataset]] for dataset in self.datasets] 69 | return data, label 70 | 71 | 72 | def get_data_loaders(args): 73 | """ Prepare the train-val-test data 74 | :param args: related arguments 75 | :return: train_loader, val_loader, test_loader 76 | """ 77 | train_dataset = VQADataset(args, args.datasets['train'], 'train') 78 | train_loader = torch.utils.data.DataLoader(train_dataset, 79 | batch_size=args.batch_size, 80 | shuffle=True, 81 | num_workers=2, 82 | drop_last=True) # 83 | 84 | scale = train_dataset.scale 85 | m = train_dataset.m 86 | 87 | val_loader, test_loader = dict(), dict() 88 | for dataset in args.datasets['val']: 89 | val_dataset = VQADataset(args, [dataset], 'val') 90 | val_loader[dataset] = torch.utils.data.DataLoader(val_dataset) 91 | 92 | for dataset in args.datasets['test']: 93 | test_dataset = VQADataset(args, [dataset], 'test') 94 | if dataset not in args.datasets['train']: 95 | scale[dataset] = test_dataset.scale[dataset] 96 | m[dataset] = test_dataset.m[dataset] 97 | test_loader[dataset] = torch.utils.data.DataLoader(test_dataset) 98 | 99 | return train_loader, val_loader, test_loader, scale, m 100 | -------------------------------------------------------------------------------- /VQAloss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class VQALoss(nn.Module): 7 | def __init__(self, scale, loss_type='mixed', m=None): 8 | super(VQALoss, self).__init__() 9 | self.loss_type = loss_type 10 | self.scale = scale 11 | self.m = m # 12 | 13 | def forward(self, y_pred, y): 14 | relative_score, mapped_score, aligned_score = y_pred 15 | if self.loss_type == 'mixed': 16 | loss = [loss_a(mapped_score[d], y[d]) + loss_m(relative_score[d], y[d]) + 17 | F.l1_loss(aligned_score[d], y[d]) / self.scale[d] for d in range(len(y))] 18 | elif self.loss_type == 'correlation' or self.loss_type == 'rank+plcc': 19 | loss = [loss_a(mapped_score[d], y[d]) + loss_m(relative_score[d], y[d]) for d in range(len(y))] 20 | elif self.loss_type == 'rank': 21 | loss = [loss_m(relative_score[d], y[d]) for d in range(len(y))] 22 | elif self.loss_type == 'plcc': 23 | loss = [loss_a(mapped_score[d], y[d]) for d in range(len(y))] 24 | elif self.loss_type == 'rank+l1': 25 | loss = [loss_m(relative_score[d], y[d]) + F.l1_loss(aligned_score[d], y[d]) / self.scale[d] for d in range(len(y)) for d in range(len(y))] 26 | elif self.loss_type == 'plcc+l1': 27 | loss = [loss_a(relative_score[d], y[d]) + F.l1_loss(aligned_score[d], y[d]) / self.scale[d] for d in range(len(y)) for d in range(len(y))] 28 | elif 'naive' in self.loss_type: 29 | aligned_scores = torch.cat([(aligned_score[d]-self.m[d])/self.scale[d] for d in range(len(y))]) 30 | ys = torch.cat([(y[d]-self.m[d])/self.scale[d] for d in range(len(y))]) 31 | if self.loss_type == 'naive0': 32 | return F.l1_loss(aligned_scores, ys) # 33 | return loss_a(aligned_scores, ys) + loss_m(aligned_scores, ys) + F.l1_loss(aligned_scores, ys) 34 | else: # default l1 35 | loss = [F.l1_loss(aligned_score[d], y[d]) / self.scale[d] for d in range(len(y))] 36 | # print(loss) 37 | # sum_loss = sum([lossi for lossi in loss]) / len(loss) 38 | # sum_loss = len(loss) / sum([1 / lossi for lossi in loss]) 39 | sum_loss = sum([torch.exp(lossi) * lossi for lossi in loss]) / sum([torch.exp(lossi) for lossi in loss]) 40 | return sum_loss 41 | 42 | 43 | def loss_m(y_pred, y): 44 | """prediction monotonicity related loss""" 45 | assert y_pred.size(0) > 1 # 46 | return torch.sum(F.relu((y_pred-y_pred.t()) * torch.sign((y.t()-y)))) / y_pred.size(0) / (y_pred.size(0)-1) 47 | 48 | 49 | def loss_a(y_pred, y): 50 | """prediction accuracy related loss""" 51 | assert y_pred.size(0) > 1 # 52 | return (1 - torch.cosine_similarity(y_pred.t() - torch.mean(y_pred), y.t() - torch.mean(y))[0]) / 2 53 | 54 | -------------------------------------------------------------------------------- /VQAmodel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from collections import OrderedDict 5 | import numpy as np 6 | 7 | 8 | class VQAModel(nn.Module): 9 | def __init__(self, scale={'K': 1, 'C': 1, 'L': 1, 'N': 1}, m={'K': 0, 'C': 0, 'L': 0, 'N': 0}, 10 | simple_linear_scale=False, input_size=4096, reduced_size=128, hidden_size=32): 11 | super(VQAModel, self).__init__() 12 | self.hidden_size = hidden_size 13 | mapping_datasets = scale.keys() 14 | 15 | self.dimemsion_reduction = nn.Linear(input_size, reduced_size) 16 | self.feature_aggregation = nn.GRU(reduced_size, hidden_size, batch_first=True) 17 | self.regression = nn.Linear(hidden_size, 1) 18 | self.bound = nn.Sigmoid() 19 | self.nlm = nn.Sequential(nn.Linear(1, 1), nn.Sigmoid(), nn.Linear(1, 1)) # 4 parameters 20 | # self.nlm = nn.Sequential(nn.Sequential(nn.Linear(1, 1), nn.Sigmoid(), nn.Linear(1, 1, bias=False)), 21 | # nn.Linear(1, 1)) # 5 parameters 22 | self.lm = nn.Sequential(OrderedDict([(dataset, nn.Linear(1, 1)) for dataset in mapping_datasets])) 23 | 24 | torch.nn.init.constant_(self.nlm[0].weight, 2*np.sqrt(3)) 25 | torch.nn.init.constant_(self.nlm[0].bias, -np.sqrt(3)) 26 | torch.nn.init.constant_(self.nlm[2].weight, 1) 27 | torch.nn.init.constant_(self.nlm[2].bias, 0) 28 | for p in self.nlm[2].parameters(): 29 | p.requires_grad = False 30 | for d, dataset in enumerate(mapping_datasets): 31 | torch.nn.init.constant_(self.lm._modules[dataset].weight, scale[dataset]) 32 | torch.nn.init.constant_(self.lm._modules[dataset].bias, m[dataset]) 33 | 34 | 35 | # torch.nn.init.constant_(self.nlm[0][0].weight, 2*np.sqrt(3)) 36 | # torch.nn.init.constant_(self.nlm[0][0].bias, -np.sqrt(3)) 37 | # torch.nn.init.constant_(self.nlm[0][2].weight, 0) 38 | 39 | # torch.nn.init.constant_(self.nlm[1].weight, 1) 40 | # torch.nn.init.constant_(self.nlm[1].bias, 0) 41 | # for d, dataset in enumerate(mapping_datasets): 42 | # torch.nn.init.constant_(self.lm._modules[dataset].weight, scale[dataset]) 43 | # torch.nn.init.constant_(self.lm._modules[dataset].bias, m[dataset]) 44 | 45 | # for d, dataset in enumerate(mapping_datasets): 46 | # if d == 0: 47 | # dataset0 = dataset 48 | # torch.nn.init.constant_(self.nlm[1].weight, scale[dataset0]) 49 | # torch.nn.init.constant_(self.nlm[1].bias, m[dataset0]) 50 | # torch.nn.init.constant_(self.lm._modules[dataset0].weight, 1) 51 | # torch.nn.init.constant_(self.lm._modules[dataset0].bias, 0) 52 | # for p in self.lm._modules[dataset0].parameters(): 53 | # p.requires_grad = False 54 | # else: 55 | # torch.nn.init.constant_(self.lm._modules[dataset].weight, scale[dataset] / scale[dataset0]) 56 | # torch.nn.init.constant_(self.lm._modules[dataset].bias, 57 | # m[dataset] - m[dataset0] * scale[dataset] / scale[dataset0]) 58 | 59 | if simple_linear_scale: 60 | for p in self.lm.parameters(): 61 | p.requires_grad = False 62 | 63 | def forward(self, input): 64 | relative_score, mapped_score, aligned_score = [], [], [] 65 | for d, (x, x_len, KCL) in enumerate(input): 66 | x = self.dimemsion_reduction(x) # dimension reduction 67 | x, _ = self.feature_aggregation(x, self._get_initial_state(x.size(0), x.device)) 68 | q = self.regression(x) # frame quality 69 | relative_score.append(torch.zeros_like(q[:, 0])) # 70 | mapped_score.append(torch.zeros_like(q[:, 0])) # 71 | aligned_score.append(torch.zeros_like(q[:, 0])) # 72 | for i in range(q.shape[0]): # 73 | relative_score[d][i] = self._sitp(q[i, :x_len[i].item()]) # video overall quality 74 | relative_score[d] = self.bound(relative_score[d]) 75 | # mapped_score[d] = relative_score[d] # The nonlinear mapping module is embedded into the RQA. 76 | mapped_score[d] = self.nlm(relative_score[d]) # 4 parameters 77 | # mapped_score[d] = self.nlm[0](relative_score[d]) + self.nlm[1](relative_score[d]) # 5 parameters 78 | for i in range(q.shape[0]): 79 | aligned_score[d][i] = self.lm._modules[KCL[i]](mapped_score[d][i]) 80 | 81 | return relative_score, mapped_score, aligned_score 82 | 83 | def _sitp(self, q, tau=12, beta=0.5): 84 | """subjectively-inspired temporal pooling""" 85 | q = torch.unsqueeze(torch.t(q), 0) 86 | qm = -float('inf')*torch.ones((1, 1, tau-1)).to(q.device) 87 | qp = 10000.0 * torch.ones((1, 1, tau - 1)).to(q.device) # 88 | l = -F.max_pool1d(torch.cat((qm, -q), 2), tau, stride=1) 89 | m = F.avg_pool1d(torch.cat((q * torch.exp(-q), qp * torch.exp(-qp)), 2), tau, stride=1) 90 | n = F.avg_pool1d(torch.cat((torch.exp(-q), torch.exp(-qp)), 2), tau, stride=1) 91 | m = m / n 92 | q_hat = beta * m + (1 - beta) * l 93 | return torch.mean(q_hat) 94 | 95 | def _get_initial_state(self, batch_size, device): 96 | h0 = torch.zeros(1, batch_size, self.hidden_size, device=device) 97 | return h0 98 | -------------------------------------------------------------------------------- /VQAperformance.py: -------------------------------------------------------------------------------- 1 | from ignite.metrics.metric import Metric 2 | import numpy as np 3 | from scipy import stats 4 | 5 | 6 | class VQAPerformance(Metric): 7 | """ 8 | Evaluation of VQA methods using SROCC, KROCC, PLCC, RMSE. 9 | 10 | `update` must receive output of the form (y_pred, y). 11 | """ 12 | def reset(self): 13 | self._rq = [] 14 | self._mq = [] 15 | self._aq = [] 16 | self._y = [] 17 | 18 | def update(self, output): 19 | y_pred, y = output 20 | self._y.append(y[0].item()) 21 | self._rq.append(y_pred[0][0].item()) 22 | self._mq.append(y_pred[1][0].item()) 23 | self._aq.append(y_pred[2][0].item()) 24 | 25 | def compute(self): 26 | sq = np.reshape(np.asarray(self._y), (-1,)) 27 | rq = np.reshape(np.asarray(self._rq), (-1,)) 28 | mq = np.reshape(np.asarray(self._mq), (-1,)) 29 | aq = np.reshape(np.asarray(self._aq), (-1,)) 30 | 31 | SROCC = stats.spearmanr(sq, rq)[0] 32 | KROCC = stats.stats.kendalltau(sq, rq)[0] 33 | PLCC = stats.pearsonr(sq, mq)[0] 34 | RMSE = np.sqrt(np.power(sq-aq, 2).mean()) 35 | return {'SROCC': SROCC, 36 | 'KROCC': KROCC, 37 | 'PLCC': PLCC, 38 | 'RMSE': RMSE, 39 | 'sq': sq, 40 | 'rq': rq, 41 | 'mq': mq, 42 | 'aq': aq} 43 | -------------------------------------------------------------------------------- /cross_dataset_evaluation.py: -------------------------------------------------------------------------------- 1 | # Author: Dingquan Li 2 | # Email: dingquanli AT pku DOT edu DOT cn 3 | # Date: 2019/11/8 4 | # 5 | 6 | import torch 7 | from torch.utils.data import Dataset 8 | from ignite.engine import create_supervised_evaluator 9 | from VQAmodel import VQAModel 10 | from VQAloss import VQALoss 11 | from VQAperformance import VQAPerformance 12 | import datetime 13 | import os 14 | import numpy as np 15 | import random 16 | from argparse import ArgumentParser 17 | import h5py 18 | 19 | 20 | class VQADataset(Dataset): 21 | def __init__(self, args, datasets): 22 | self.datasets = datasets 23 | 24 | self.index = dict() 25 | max_len = dict() 26 | 27 | for dataset in datasets: 28 | Info = h5py.File(args.data_info[dataset], 'r') 29 | max_len[dataset] = int(Info['max_len'][0]) 30 | index = Info['index'] 31 | index = index[:, args.exp_id % index.shape[1]] 32 | ref_ids = Info['ref_ids'][0, :] 33 | self.index[dataset] = [] 34 | for i in range(len(ref_ids)): 35 | if ref_ids[i] in index: 36 | self.index[dataset].append(i) 37 | 38 | max_len_all = max(max_len.values()) 39 | self.features, self.length, self.label, self.KCL, self.N = dict(), dict(), dict(), dict(), dict() 40 | for dataset in datasets: 41 | N = len(self.index[dataset]) 42 | self.N[dataset] = N 43 | self.features[dataset] = np.zeros((N, max_len_all, args.feat_dim), dtype=np.float32) 44 | self.length[dataset] = np.zeros(N, dtype=np.int) 45 | self.label[dataset] = np.zeros((N, 1), dtype=np.float32) 46 | self.KCL[dataset] = [] 47 | for i in range(N): 48 | features = np.load(args.features_dir[dataset] + str(self.index[dataset][i]) + '_' + args.feature_extractor +'_last_conv.npy') 49 | self.length[dataset][i] = features.shape[0] 50 | self.features[dataset][i, :features.shape[0], :] = features 51 | mos = np.load(args.features_dir[dataset] + str(self.index[dataset][i]) + '_score.npy') # 52 | self.label[dataset][i] = mos 53 | self.KCL[dataset].append(dataset) 54 | 55 | def __len__(self): 56 | return max(self.N.values()) 57 | 58 | def __getitem__(self, idx): 59 | data = [(self.features[dataset][idx % self.N[dataset]], 60 | self.length[dataset][idx % self.N[dataset]], 61 | self.KCL[dataset][idx % self.N[dataset]]) for dataset in self.datasets] 62 | label = [self.label[dataset][idx % self.N[dataset]] for dataset in self.datasets] 63 | return data, label 64 | 65 | 66 | def run(args): 67 | device = torch.device("cuda" if not args.disable_gpu and torch.cuda.is_available() else "cpu") 68 | test_loader = dict() 69 | for dataset in args.cross_datasets: 70 | test_dataset = VQADataset(args, [dataset]) 71 | test_loader[dataset] = torch.utils.data.DataLoader(test_dataset) 72 | 73 | model = VQAModel(simple_linear_scale=args.simple_linear_scale).to(device) # 74 | model.load_state_dict(torch.load(args.trained_model_file)) 75 | 76 | evaluator = create_supervised_evaluator(model, metrics={'VQA_performance': VQAPerformance()}, device=device) 77 | 78 | performance = dict() 79 | for dataset in args.cross_datasets: 80 | evaluator.run(test_loader[dataset]) 81 | performance[dataset] = evaluator.state.metrics['VQA_performance'] 82 | print('{}, SROCC: {}'.format(dataset, performance[dataset]['SROCC'])) 83 | np.save(args.save_result_file, performance) 84 | 85 | 86 | if __name__ == "__main__": 87 | parser = ArgumentParser(description='MDTVSFA Cross-dataset evaluation') 88 | parser.add_argument("--seed", type=int, default=19920517) 89 | parser.add_argument('--lr', type=float, default=1e-4, 90 | help='learning rate (default: 1e-4)') 91 | parser.add_argument('--batch_size', type=int, default=32, 92 | help='input batch size for training (default: 32)') 93 | parser.add_argument('--epochs', type=int, default=40, 94 | help='number of epochs to train (default: 40)') 95 | parser.add_argument('--weight_decay', type=float, default=0.0, 96 | help='weight decay (default: 0.0)') 97 | 98 | parser.add_argument('--model', default='MDTVSFA', type=str, 99 | help='model name (default: MDTVSFA)') 100 | parser.add_argument('--loss', default='mixed', type=str, 101 | help='loss type (default: mixed)') 102 | parser.add_argument('--feature_extractor', default='ResNet-50', type=str, 103 | help='feature_extractor backbone (default: ResNet-50)') 104 | # parser.add_argument('--feat_dim', type=int, default=4096, 105 | # help='feature dimension (default: 4096)') 106 | 107 | parser.add_argument('--trained_datasets', nargs='+', type=str, default=['K'], 108 | help="trained datasets (default: ['K'])") 109 | 110 | parser.add_argument('--cross_datasets', nargs='+', type=str, default=['C', 'L', 'N'], 111 | help="cross datasets (default: ['C', 'L', 'N'])") 112 | 113 | parser.add_argument('--exp_id', default=0, type=int, 114 | help='exp id for train-val-test splits (default: 0)') 115 | parser.add_argument('--train_proportion', type=float, default=6, 116 | help='the number of proportions (#total 6) used in the training set (default: 6)') 117 | 118 | parser.add_argument('--disable_gpu', action='store_true', 119 | help='flag whether to disable GPU') 120 | args = parser.parse_args() 121 | args.train_proportion /= 6 122 | if args.feature_extractor == 'AlexNet': 123 | args.feat_dim = 256 * 2 124 | else: 125 | args.feat_dim = 2048 * 2 126 | 127 | 128 | args.simple_linear_scale = False # 129 | if 'naive' in args.loss: 130 | args.simple_linear_scale = True # 131 | 132 | args.decay_interval = int(args.epochs / 20) 133 | args.decay_ratio = 0.8 134 | 135 | args.datasets = {'train': args.trained_datasets, 136 | 'val': args.trained_datasets, 137 | 'test': ['K', 'C', 'L', 'N']} 138 | args.features_dir = {'K': 'CNN_features_KoNViD-1k/', 139 | 'C': 'CNN_features_CVD2014/', 140 | 'L': 'CNN_features_LIVE-Qualcomm/', 141 | 'N': 'CNN_features_LIVE-VQC/'} 142 | args.data_info = {'K': 'data/KoNViD-1kinfo.mat', 143 | 'C': 'data/CVD2014info.mat', 144 | 'L': 'data/LIVE-Qualcomminfo.mat', 145 | 'N': 'data/LIVE-VQCinfo.mat'} 146 | 147 | torch.manual_seed(args.seed) # 148 | torch.backends.cudnn.deterministic = True 149 | torch.backends.cudnn.benchmark = False 150 | np.random.seed(args.seed) 151 | random.seed(args.seed) 152 | 153 | torch.utils.backcompat.broadcast_warning.enabled = True 154 | 155 | args.trained_model_file = 'checkpoints/{}-{}-{}-{}-{}-{}-{}-{}-EXP{}'.format(args.model, args.feature_extractor, args.loss, args.train_proportion, args.trained_datasets, args.lr, args.batch_size, args.epochs, args.exp_id) 156 | if not os.path.exists('results'): 157 | os.makedirs('results') 158 | args.save_result_file = 'results/cross-dataset-{}-{}-{}-{}-{}-{}-{}-{}-EXP{}'.format(args.model, args.feature_extractor, args.loss, args.train_proportion, args.trained_datasets, args.lr, args.batch_size, args.epochs, args.exp_id) 159 | print(args) 160 | run(args) 161 | -------------------------------------------------------------------------------- /cross_job.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Using K80 GPU 4 | 5 | # ./cross_job.sh -g 0 -d K -c C -c L -c N -l mixed > Ktrained-crossCLN-mixed-exp-0-10.log 2>&1 & 6 | # ./cross_job.sh -g 1 -c K -d C -c L -c N -l mixed > Ctrained-crossKLN-mixed-exp-0-10.log 2>&1 & 7 | # ./cross_job.sh -g 2 -c K -c C -d L -c N -l mixed > Ltrained-crossKCN-mixed-exp-0-10.log 2>&1 & 8 | # ./cross_job.sh -g 3 -c K -c C -c L -d N -l mixed > Ntrained-crossKCL-mixed-exp-0-10.log 2>&1 & 9 | # ./cross_job.sh -g 4 -d K -d C -c L -c N -l mixed > KCtrained-crossLN-mixed-exp-0-10.log 2>&1 & 10 | # ./cross_job.sh -g 5 -d K -c C -d L -c N -l mixed > KLtrained-crossCN-mixed-exp-0-10.log 2>&1 & 11 | # ./cross_job.sh -g 6 -d K -c C -c L -d N -l mixed > KNtrained-crossCL-mixed-exp-0-10.log 2>&1 & 12 | # ./cross_job.sh -g 7 -c K -d C -d L -c N -l mixed > CLtrained-crossKN-mixed-exp-0-10.log 2>&1 & 13 | # ./cross_job.sh -g 8 -c K -d C -c L -d N -l mixed > Ntrained-crossKL-mixed-exp-0-10.log 2>&1 & 14 | # ./cross_job.sh -g 9 -c K -c C -d L -d N -l mixed > Ltrained-crossKL-mixed-exp-0-10.log 2>&1 & 15 | # ./cross_job.sh -g 10 -d K -d C -d L -c N -l mixed > KCLtrained-crossN-mixed-exp-0-10.log 2>&1 & 16 | # ./cross_job.sh -g 11 -d K -d C -c L -d N -l mixed > KCNtrained-crossL-mixed-exp-0-10.log 2>&1 & 17 | # ./cross_job.sh -g 12 -c K -d C -d L -d N -l mixed > CLNtrained-crossK-mixed-exp-0-10.log 2>&1 & 18 | # ./cross_job.sh -g 13 -d K -c C -d L -d N -l mixed > KLNtrained-crossC-mixed-exp-0-10.log 2>&1 & 19 | 20 | # ./cross_job.sh -g 0 -d K -c C -c L -c N -l naive > Ktrained-crossCLN-naive-exp-0-10.log 2>&1 & 21 | # ./cross_job.sh -g 1 -c K -d C -c L -c N -l naive > Ctrained-crossKLN-naive-exp-0-10.log 2>&1 & 22 | # ./cross_job.sh -g 2 -c K -c C -d L -c N -l naive > Ltrained-crossKCN-naive-exp-0-10.log 2>&1 & 23 | # ./cross_job.sh -g 3 -c K -c C -c L -d N -l naive > Ntrained-crossKCL-naive-exp-0-10.log 2>&1 & 24 | # ./cross_job.sh -g 4 -d K -d C -c L -c N -l naive > KCtrained-crossLN-naive-exp-0-10.log 2>&1 & 25 | # ./cross_job.sh -g 5 -d K -c C -d L -c N -l naive > KLtrained-crossCN-naive-exp-0-10.log 2>&1 & 26 | # ./cross_job.sh -g 6 -d K -c C -c L -d N -l naive > KNtrained-crossCL-naive-exp-0-10.log 2>&1 & 27 | # ./cross_job.sh -g 7 -c K -d C -d L -c N -l naive > CLtrained-crossKN-naive-exp-0-10.log 2>&1 & 28 | # ./cross_job.sh -g 8 -c K -d C -c L -d N -l naive > Ntrained-crossKL-naive-exp-0-10.log 2>&1 & 29 | # ./cross_job.sh -g 9 -c K -c C -d L -d N -l naive > Ltrained-crossKL-naive-exp-0-10.log 2>&1 & 30 | # ./cross_job.sh -g 10 -d K -d C -d L -c N -l naive > KCLtrained-crossN-naive-exp-0-10.log 2>&1 & 31 | # ./cross_job.sh -g 11 -d K -d C -c L -d N -l naive > KCNtrained-crossL-naive-exp-0-10.log 2>&1 & 32 | # ./cross_job.sh -g 12 -c K -d C -d L -d N -l naive > CLNtrained-crossK-naive-exp-0-10.log 2>&1 & 33 | # ./cross_job.sh -g 13 -d K -c C -d L -d N -l naive > KLNtrained-crossC-naive-exp-0-10.log 2>&1 & 34 | 35 | # ./cross_job.sh -g 0 -d K -c C -c L -c N -l naive0 > Ktrained-crossCLN-naive0-exp-0-10.log 2>&1 & 36 | # ./cross_job.sh -g 1 -c K -d C -c L -c N -l naive0 > Ctrained-crossKLN-naive0-exp-0-10.log 2>&1 & 37 | # ./cross_job.sh -g 2 -c K -c C -d L -c N -l naive0 > Ltrained-crossKCN-naive0-exp-0-10.log 2>&1 & 38 | # ./cross_job.sh -g 3 -c K -c C -c L -d N -l naive0 > Ntrained-crossKCL-naive0-exp-0-10.log 2>&1 & 39 | # ./cross_job.sh -g 4 -d K -d C -c L -c N -l naive0 > KCtrained-crossLN-naive0-exp-0-10.log 2>&1 & 40 | # ./cross_job.sh -g 5 -d K -c C -d L -c N -l naive0 > KLtrained-crossCN-naive0-exp-0-10.log 2>&1 & 41 | # ./cross_job.sh -g 6 -d K -c C -c L -d N -l naive0 > KNtrained-crossCL-naive0-exp-0-10.log 2>&1 & 42 | # ./cross_job.sh -g 7 -c K -d C -d L -c N -l naive0 > CLtrained-crossKN-naive0-exp-0-10.log 2>&1 & 43 | # ./cross_job.sh -g 8 -c K -d C -c L -d N -l naive0 > Ntrained-crossKL-naive0-exp-0-10.log 2>&1 & 44 | # ./cross_job.sh -g 9 -c K -c C -d L -d N -l naive0 > Ltrained-crossKL-naive0-exp-0-10.log 2>&1 & 45 | # ./cross_job.sh -g 10 -d K -d C -d L -c N -l naive0 > KCLtrained-crossN-naive0-exp-0-10.log 2>&1 & 46 | # ./cross_job.sh -g 11 -d K -d C -c L -d N -l naive0 > KCNtrained-crossL-naive0-exp-0-10.log 2>&1 & 47 | # ./cross_job.sh -g 12 -c K -d C -d L -d N -l naive0 > CLNtrained-crossK-naive0-exp-0-10.log 2>&1 & 48 | # ./cross_job.sh -g 13 -d K -c C -d L -d N -l naive0 > KLNtrained-crossC-naive0-exp-0-10.log 2>&1 & 49 | 50 | # ./cross_job.sh -g 0 -d K -c C -c L -c N -l l1 > Ktrained-crossCLN-l1-exp-0-10.log 2>&1 & 51 | # ./cross_job.sh -g 1 -c K -d C -c L -c N -l l1 > Ctrained-crossKLN-l1-exp-0-10.log 2>&1 & 52 | # ./cross_job.sh -g 2 -c K -c C -d L -c N -l l1 > Ltrained-crossKCN-l1-exp-0-10.log 2>&1 & 53 | # ./cross_job.sh -g 3 -c K -c C -c L -d N -l l1 > Ntrained-crossKCL-l1-exp-0-10.log 2>&1 & 54 | # ./cross_job.sh -g 4 -d K -d C -c L -c N -l l1 > KCtrained-crossLN-l1-exp-0-10.log 2>&1 & 55 | # ./cross_job.sh -g 5 -d K -c C -d L -c N -l l1 > KLtrained-crossCN-l1-exp-0-10.log 2>&1 & 56 | # ./cross_job.sh -g 6 -d K -c C -c L -d N -l l1 > KNtrained-crossCL-l1-exp-0-10.log 2>&1 & 57 | # ./cross_job.sh -g 7 -c K -d C -d L -c N -l l1 > CLtrained-crossKN-l1-exp-0-10.log 2>&1 & 58 | # ./cross_job.sh -g 8 -c K -d C -c L -d N -l l1 > Ntrained-crossKL-l1-exp-0-10.log 2>&1 & 59 | # ./cross_job.sh -g 9 -c K -c C -d L -d N -l l1 > Ltrained-crossKL-l1-exp-0-10.log 2>&1 & 60 | # ./cross_job.sh -g 10 -d K -d C -d L -c N -l l1 > KCLtrained-crossN-l1-exp-0-10.log 2>&1 & 61 | # ./cross_job.sh -g 11 -d K -d C -c L -d N -l l1 > KCNtrained-crossL-l1-exp-0-10.log 2>&1 & 62 | # ./cross_job.sh -g 12 -c K -d C -d L -d N -l l1 > CLNtrained-crossK-l1-exp-0-10.log 2>&1 & 63 | # ./cross_job.sh -g 13 -d K -c C -d L -d N -l l1 > KLNtrained-crossC-l1-exp-0-10.log 2>&1 & 64 | 65 | # ./cross_job.sh -g 0 -d K -c C -c L -c N -l rank > Ktrained-crossCLN-rank-exp-0-10.log 2>&1 & 66 | # ./cross_job.sh -g 1 -c K -d C -c L -c N -l rank > Ctrained-crossKLN-rank-exp-0-10.log 2>&1 & 67 | # ./cross_job.sh -g 2 -c K -c C -d L -c N -l rank > Ltrained-crossKCN-rank-exp-0-10.log 2>&1 & 68 | # ./cross_job.sh -g 3 -c K -c C -c L -d N -l rank > Ntrained-crossKCL-rank-exp-0-10.log 2>&1 & 69 | # ./cross_job.sh -g 4 -d K -d C -c L -c N -l rank > KCtrained-crossLN-rank-exp-0-10.log 2>&1 & 70 | # ./cross_job.sh -g 5 -d K -c C -d L -c N -l rank > KLtrained-crossCN-rank-exp-0-10.log 2>&1 & 71 | # ./cross_job.sh -g 6 -d K -c C -c L -d N -l rank > KNtrained-crossCL-rank-exp-0-10.log 2>&1 & 72 | # ./cross_job.sh -g 7 -c K -d C -d L -c N -l rank > CLtrained-crossKN-rank-exp-0-10.log 2>&1 & 73 | # ./cross_job.sh -g 8 -c K -d C -c L -d N -l rank > Ntrained-crossKL-rank-exp-0-10.log 2>&1 & 74 | # ./cross_job.sh -g 9 -c K -c C -d L -d N -l rank > Ltrained-crossKL-rank-exp-0-10.log 2>&1 & 75 | # ./cross_job.sh -g 10 -d K -d C -d L -c N -l rank > KCLtrained-crossN-rank-exp-0-10.log 2>&1 & 76 | # ./cross_job.sh -g 11 -d K -d C -c L -d N -l rank > KCNtrained-crossL-rank-exp-0-10.log 2>&1 & 77 | # ./cross_job.sh -g 12 -c K -d C -d L -d N -l rank > CLNtrained-crossK-rank-exp-0-10.log 2>&1 & 78 | # ./cross_job.sh -g 13 -d K -c C -d L -d N -l rank > KLNtrained-crossC-rank-exp-0-10.log 2>&1 & 79 | 80 | # ./cross_job.sh -g 0 -d K -c C -c L -c N -l plcc > Ktrained-crossCLN-plcc-exp-0-10.log 2>&1 & 81 | # ./cross_job.sh -g 1 -c K -d C -c L -c N -l plcc > Ctrained-crossKLN-plcc-exp-0-10.log 2>&1 & 82 | # ./cross_job.sh -g 2 -c K -c C -d L -c N -l plcc > Ltrained-crossKCN-plcc-exp-0-10.log 2>&1 & 83 | # ./cross_job.sh -g 3 -c K -c C -c L -d N -l plcc > Ntrained-crossKCL-plcc-exp-0-10.log 2>&1 & 84 | # ./cross_job.sh -g 4 -d K -d C -c L -c N -l plcc > KCtrained-crossLN-plcc-exp-0-10.log 2>&1 & 85 | # ./cross_job.sh -g 5 -d K -c C -d L -c N -l plcc > KLtrained-crossCN-plcc-exp-0-10.log 2>&1 & 86 | # ./cross_job.sh -g 6 -d K -c C -c L -d N -l plcc > KNtrained-crossCL-plcc-exp-0-10.log 2>&1 & 87 | # ./cross_job.sh -g 7 -c K -d C -d L -c N -l plcc > CLtrained-crossKN-plcc-exp-0-10.log 2>&1 & 88 | # ./cross_job.sh -g 8 -c K -d C -c L -d N -l plcc > Ntrained-crossKL-plcc-exp-0-10.log 2>&1 & 89 | # ./cross_job.sh -g 9 -c K -c C -d L -d N -l plcc > Ltrained-crossKL-plcc-exp-0-10.log 2>&1 & 90 | # ./cross_job.sh -g 10 -d K -d C -d L -c N -l plcc > KCLtrained-crossN-plcc-exp-0-10.log 2>&1 & 91 | # ./cross_job.sh -g 11 -d K -d C -c L -d N -l plcc > KCNtrained-crossL-plcc-exp-0-10.log 2>&1 & 92 | # ./cross_job.sh -g 12 -c K -d C -d L -d N -l plcc > CLNtrained-crossK-plcc-exp-0-10.log 2>&1 & 93 | # ./cross_job.sh -g 13 -d K -c C -d L -d N -l plcc > KLNtrained-crossC-plcc-exp-0-10.log 2>&1 & 94 | 95 | # ./cross_job.sh -g 0 -d K -c C -c L -c N -l rank+l1 > Ktrained-crossCLN-rank+l1-exp-0-10.log 2>&1 & 96 | # ./cross_job.sh -g 1 -c K -d C -c L -c N -l rank+l1 > Ctrained-crossKLN-rank+l1-exp-0-10.log 2>&1 & 97 | # ./cross_job.sh -g 2 -c K -c C -d L -c N -l rank+l1 > Ltrained-crossKCN-rank+l1-exp-0-10.log 2>&1 & 98 | # ./cross_job.sh -g 3 -c K -c C -c L -d N -l rank+l1 > Ntrained-crossKCL-rank+l1-exp-0-10.log 2>&1 & 99 | # ./cross_job.sh -g 4 -d K -d C -c L -c N -l rank+l1 > KCtrained-crossLN-rank+l1-exp-0-10.log 2>&1 & 100 | # ./cross_job.sh -g 5 -d K -c C -d L -c N -l rank+l1 > KLtrained-crossCN-rank+l1-exp-0-10.log 2>&1 & 101 | # ./cross_job.sh -g 6 -d K -c C -c L -d N -l rank+l1 > KNtrained-crossCL-rank+l1-exp-0-10.log 2>&1 & 102 | # ./cross_job.sh -g 7 -c K -d C -d L -c N -l rank+l1 > CLtrained-crossKN-rank+l1-exp-0-10.log 2>&1 & 103 | # ./cross_job.sh -g 8 -c K -d C -c L -d N -l rank+l1 > Ntrained-crossKL-rank+l1-exp-0-10.log 2>&1 & 104 | # ./cross_job.sh -g 9 -c K -c C -d L -d N -l rank+l1 > Ltrained-crossKL-rank+l1-exp-0-10.log 2>&1 & 105 | # ./cross_job.sh -g 10 -d K -d C -d L -c N -l rank+l1 > KCLtrained-crossN-rank+l1-exp-0-10.log 2>&1 & 106 | # ./cross_job.sh -g 11 -d K -d C -c L -d N -l rank+l1 > KCNtrained-crossL-rank+l1-exp-0-10.log 2>&1 & 107 | # ./cross_job.sh -g 12 -c K -d C -d L -d N -l rank+l1 > CLNtrained-crossK-rank+l1-exp-0-10.log 2>&1 & 108 | # ./cross_job.sh -g 13 -d K -c C -d L -d N -l rank+l1 > KLNtrained-crossC-rank+l1-exp-0-10.log 2>&1 & 109 | 110 | # ./cross_job.sh -g 15 -d K -c C -c L -c N -l plcc+l1 > Ktrained-crossCLN-plcc+l1-exp-0-10.log 2>&1 & 111 | # ./cross_job.sh -g 14 -c K -d C -c L -c N -l plcc+l1 > Ctrained-crossKLN-plcc+l1-exp-0-10.log 2>&1 & 112 | # ./cross_job.sh -g 13 -c K -c C -d L -c N -l plcc+l1 > Ltrained-crossKCN-plcc+l1-exp-0-10.log 2>&1 & 113 | # ./cross_job.sh -g 12 -c K -c C -c L -d N -l plcc+l1 > Ntrained-crossKCL-plcc+l1-exp-0-10.log 2>&1 & 114 | # ./cross_job.sh -g 11 -d K -d C -c L -c N -l plcc+l1 > KCtrained-crossLN-plcc+l1-exp-0-10.log 2>&1 & 115 | # ./cross_job.sh -g 10 -d K -c C -d L -c N -l plcc+l1 > KLtrained-crossCN-plcc+l1-exp-0-10.log 2>&1 & 116 | # ./cross_job.sh -g 9 -d K -c C -c L -d N -l plcc+l1 > KNtrained-crossCL-plcc+l1-exp-0-10.log 2>&1 & 117 | # ./cross_job.sh -g 8 -c K -d C -d L -c N -l plcc+l1 > CLtrained-crossKN-plcc+l1-exp-0-10.log 2>&1 & 118 | # ./cross_job.sh -g 7 -c K -d C -c L -d N -l plcc+l1 > Ntrained-crossKL-plcc+l1-exp-0-10.log 2>&1 & 119 | # ./cross_job.sh -g 6 -c K -c C -d L -d N -l plcc+l1 > Ltrained-crossKL-plcc+l1-exp-0-10.log 2>&1 & 120 | # ./cross_job.sh -g 5 -d K -d C -d L -c N -l plcc+l1 > KCLtrained-crossN-plcc+l1-exp-0-10.log 2>&1 & 121 | # ./cross_job.sh -g 4 -d K -d C -c L -d N -l plcc+l1 > KCNtrained-crossL-plcc+l1-exp-0-10.log 2>&1 & 122 | # ./cross_job.sh -g 3 -c K -d C -d L -d N -l plcc+l1 > CLNtrained-crossK-plcc+l1-exp-0-10.log 2>&1 & 123 | # ./cross_job.sh -g 2 -d K -c C -d L -d N -l plcc+l1 > KLNtrained-crossC-plcc+l1-exp-0-10.log 2>&1 & 124 | 125 | # ./cross_job.sh -g 0 -d K -c C -c L -c N -l correlation > Ktrained-crossCLN-correlation-exp-0-10.log 2>&1 & 126 | # ./cross_job.sh -g 1 -c K -d C -c L -c N -l correlation > Ctrained-crossKLN-correlation-exp-0-10.log 2>&1 & 127 | # ./cross_job.sh -g 2 -c K -c C -d L -c N -l correlation > Ltrained-crossKCN-correlation-exp-0-10.log 2>&1 & 128 | # ./cross_job.sh -g 3 -c K -c C -c L -d N -l correlation > Ntrained-crossKCL-correlation-exp-0-10.log 2>&1 & 129 | # ./cross_job.sh -g 4 -d K -d C -c L -c N -l correlation > KCtrained-crossLN-correlation-exp-0-10.log 2>&1 & 130 | # ./cross_job.sh -g 5 -d K -c C -d L -c N -l correlation > KLtrained-crossCN-correlation-exp-0-10.log 2>&1 & 131 | # ./cross_job.sh -g 6 -d K -c C -c L -d N -l correlation > KNtrained-crossCL-correlation-exp-0-10.log 2>&1 & 132 | # ./cross_job.sh -g 7 -c K -d C -d L -c N -l correlation > CLtrained-crossKN-correlation-exp-0-10.log 2>&1 & 133 | # ./cross_job.sh -g 8 -c K -d C -c L -d N -l correlation > Ntrained-crossKL-correlation-exp-0-10.log 2>&1 & 134 | # ./cross_job.sh -g 9 -c K -c C -d L -d N -l correlation > Ltrained-crossKL-correlation-exp-0-10.log 2>&1 & 135 | # ./cross_job.sh -g 10 -d K -d C -d L -c N -l correlation > KCLtrained-crossN-correlation-exp-0-10.log 2>&1 & 136 | # ./cross_job.sh -g 11 -d K -d C -c L -d N -l correlation > KCNtrained-crossL-correlation-exp-0-10.log 2>&1 & 137 | # ./cross_job.sh -g 12 -c K -d C -d L -d N -l correlation > CLNtrained-crossK-correlation-exp-0-10.log 2>&1 & 138 | # ./cross_job.sh -g 13 -d K -c C -d L -d N -l correlation > KLNtrained-crossC-correlation-exp-0-10.log 2>&1 & 139 | 140 | loss=mixed 141 | start_id=0 142 | end_id=10 143 | while getopts "g:d:l:s:e:c:" opt; do 144 | case $opt in 145 | g) gpu_id=("$OPTARG");; # gpu_id 146 | d) datasets+=("$OPTARG");; # trained datasets 147 | l) loss=("$OPTARG");; # loss 148 | s) start_id=("$OPTARG");; 149 | e) end_id=("$OPTARG");; 150 | c) cross_datasets+=("$OPTARG");; # trained datasets 151 | esac 152 | done 153 | shift $((OPTIND -1)) 154 | 155 | source activate reproducibleresearch 156 | for ((i=$start_id; i<$end_id; i++)); do 157 | CUDA_VISIBLE_DEVICES=$gpu_id python cross_dataset_evaluation.py --exp_id=$i --loss=$loss --trained_datasets ${datasets[@]} --cross_datasets ${cross_datasets[@]} 158 | done 159 | source deactivate -------------------------------------------------------------------------------- /data/CVD2014info.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lidq92/MDTVSFA/1460fb21d8e8cf1493331edee5a0082dd6bbf2ff/data/CVD2014info.mat -------------------------------------------------------------------------------- /data/KoNViD-1kinfo.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lidq92/MDTVSFA/1460fb21d8e8cf1493331edee5a0082dd6bbf2ff/data/KoNViD-1kinfo.mat -------------------------------------------------------------------------------- /data/LIVE-Qualcomminfo.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lidq92/MDTVSFA/1460fb21d8e8cf1493331edee5a0082dd6bbf2ff/data/LIVE-Qualcomminfo.mat -------------------------------------------------------------------------------- /data/LIVE-VQCinfo.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lidq92/MDTVSFA/1460fb21d8e8cf1493331edee5a0082dd6bbf2ff/data/LIVE-VQCinfo.mat -------------------------------------------------------------------------------- /data/data_info_maker.m: -------------------------------------------------------------------------------- 1 | clear,clc; 2 | 3 | %% KoNViD-1k 4 | data_path = '/media/ldq/Research/Data/KoNViD-1k/KoNViD_1k_attributes.csv'; 5 | data = readtable(data_path); 6 | video_names = data.file_name; % video names 7 | scores = data.MOS; % subjective scores 8 | clear data_path data 9 | 10 | height = 540; % video height 11 | width = 960; % video width 12 | max_len = 240; % maximum video length in the dataset 13 | video_format = 'RGB'; % video format 14 | ref_ids = [1:length(scores)]'; % video content ids 15 | % `random` train-val-test split index, 1000 runs 16 | index = cell2mat(arrayfun(@(i)randperm(length(scores)), ... 17 | 1:1000,'UniformOutput', false)'); 18 | save('KoNViD-1kinfo') 19 | 20 | %% CVD2014 21 | data_path = '/media/ldq/Research/Data/CVD2014/CVD2014_ratings/Realignment_MOS.csv'; 22 | data = readtable(data_path); 23 | video_names = arrayfun(@(i) ['Test' data.File_name{i}(6) '/' ... 24 | data.Content{i} '/' data.File_name{i} '.avi'], 1:234, ... 25 | 'UniformOutput', false)'; % video names, remove '', add dir 26 | scores = arrayfun(@(i) str2double(data.RealignmentMOS{i})/100, 1:234)'; % subjective scores 27 | clear data_path data 28 | 29 | height = [720 480]; 30 | width = [1280 640]; 31 | max_len = 830; 32 | video_format = 'RGB'; 33 | ref_ids = [1:length(scores)]'; 34 | % `random` train-val-test split index, 1000 runs 35 | index = cell2mat(arrayfun(@(i)randperm(length(scores)), ... 36 | 1:1000,'UniformOutput', false)'); 37 | save('CVD2014info') 38 | % LIVE-Qualcomm 39 | data_path = '/media/ldq/Others/Data/12.LIVE-Qualcomm Mobile In-Capture Video Quality Database/qualcommSubjectiveData.mat'; 40 | data = load(data_path); 41 | scores = data.qualcommSubjectiveData.unBiasedMOS; % subjective scores 42 | video_names = data.qualcommVideoData; 43 | video_names = arrayfun(@(i) [video_names.distortionNames{video_names.distortionType(i)} ... 44 | '/' video_names.vidNames{i}], 1:length(scores), ... 45 | 'UniformOutput', false)'; % video names 46 | clear data_path data 47 | 48 | height = 1080; 49 | width = 1920; 50 | max_len = 526; 51 | video_format = 'YUV420'; 52 | ref_ids = [1:length(scores)]'; 53 | % `random` train-val-test split index, 1000 runs 54 | index = cell2mat(arrayfun(@(i)randperm(length(scores)), ... 55 | 1:1000,'UniformOutput', false)'); 56 | save('LIVE-Qualcomminfo') -------------------------------------------------------------------------------- /data/test.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lidq92/MDTVSFA/1460fb21d8e8cf1493331edee5a0082dd6bbf2ff/data/test.mp4 -------------------------------------------------------------------------------- /job.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Using K80 GPU 4 | 5 | # ./job.sh -g 0 -d K > K-mixed-exp-0-10-1e-4-32-40.log 2>&1 & 6 | # ./job.sh -g 1 -d K -d C > KC-mixed-exp-0-10-1e-4-32-40.log 2>&1 & 7 | # ./job.sh -g 2 -d K -d L > KL-mixed-exp-0-10-1e-4-32-40.log 2>&1 & 8 | # ./job.sh -g 3 -d K -d N > KN-mixed-exp-0-10-1e-4-32-40.log 2>&1 & 9 | # ./job.sh -g 4 -d K -d C -d L > KCL-mixed-exp-0-10-1e-4-32-40.log 2>&1 & 10 | # ./job.sh -g 5 -d K -d C -d N > KCN-mixed-exp-0-10-1e-4-32-40.log 2>&1 & 11 | # ./job.sh -g 6 -d K -d L -d N > KLN-mixed-exp-0-10-1e-4-32-40.log 2>&1 & 12 | # ./job.sh -g 7 -d K -d C -d L -d N > KCLN-mixed-exp-0-10-1e-4-32-40.log 2>&1 & 13 | # ./job.sh -g 8 -d C > C-mixed-exp-0-10-1e-4-32-40.log 2>&1 & 14 | # ./job.sh -g 9 -d L > L-mixed-exp-0-10-1e-4-32-40.log 2>&1 & 15 | # ./job.sh -g 10 -d N > N-mixed-exp-0-10-1e-4-32-40.log 2>&1 & 16 | # ./job.sh -g 11 -d C -d L > CL-mixed-exp-0-10-1e-4-32-40.log 2>&1 & 17 | # ./job.sh -g 12 -d C -d N > CN-mixed-exp-0-10-1e-4-32-40.log 2>&1 & 18 | # ./job.sh -g 13 -d L -d N > LN-mixed-exp-0-10-1e-4-32-40.log 2>&1 & 19 | # ./job.sh -g 14 -d C -d L -d N > CLN-mixed-exp-0-10-1e-4-32-40.log 2>&1 & 20 | 21 | # ./job.sh -g 15 -d K -d C -d L -d N -l naive > KCLN-naive-exp-0-10-1e-4-32-40.log 2>&1 & 22 | # ./job.sh -g 14 -d C -l naive > C-naive-exp-0-10-1e-4-32-40.log 2>&1 & 23 | # ./job.sh -g 13 -d L -l naive > L-naive-exp-0-10-1e-4-32-40.log 2>&1 & 24 | # ./job.sh -g 12 -d N -l naive > N-naive-exp-0-10-1e-4-32-40.log 2>&1 & 25 | # ./job.sh -g 11 -d C -d L -l naive > CL-naive-exp-0-10-1e-4-32-40.log 2>&1 & 26 | # ./job.sh -g 10 -d C -d N -l naive > CN-naive-exp-0-10-1e-4-32-40.log 2>&1 & 27 | # ./job.sh -g 9 -d L -d N -l naive > LN-naive-exp-0-10-1e-4-32-40.log 2>&1 & 28 | # ./job.sh -g 8 -d C -d L -d N -l naive > CLN-naive-exp-0-10-1e-4-32-40.log 2>&1 & 29 | # ./job.sh -g 7 -d K -l naive > K-naive-exp-0-10-1e-4-32-40.log 2>&1 & 30 | # ./job.sh -g 6 -d K -d C -l naive > KC-naive-exp-0-10-1e-4-32-40.log 2>&1 & 31 | # ./job.sh -g 5 -d K -d L -l naive > KL-naive-exp-0-10-1e-4-32-40.log 2>&1 & 32 | # ./job.sh -g 4 -d K -d N -l naive > KN-naive-exp-0-10-1e-4-32-40.log 2>&1 & 33 | # ./job.sh -g 3 -d K -d C -d L -l naive > KCL-naive-exp-0-10-1e-4-32-40.log 2>&1 & 34 | # ./job.sh -g 2 -d K -d C -d N -l naive > KCN-naive-exp-0-10-1e-4-32-40.log 2>&1 & 35 | # ./job.sh -g 1 -d K -d L -d N -l naive > KLN-naive-exp-0-10-1e-4-32-40.log 2>&1 & 36 | 37 | # ./job.sh -g 15 -d K -d C -d L -d N -l naive0 > KCLN-naive0-exp-0-10-1e-4-32-40.log 2>&1 & 38 | # ./job.sh -g 14 -d C -l naive0 > C-naive0-exp-0-10-1e-4-32-40.log 2>&1 & 39 | # ./job.sh -g 13 -d L -l naive0 > L-naive0-exp-0-10-1e-4-32-40.log 2>&1 & 40 | # ./job.sh -g 12 -d N -l naive0 > N-naive0-exp-0-10-1e-4-32-40.log 2>&1 & 41 | # ./job.sh -g 11 -d C -d L -l naive0 > CL-naive0-exp-0-10-1e-4-32-40.log 2>&1 & 42 | # ./job.sh -g 10 -d C -d N -l naive0 > CN-naive0-exp-0-10-1e-4-32-40.log 2>&1 & 43 | # ./job.sh -g 9 -d L -d N -l naive0 > LN-naive0-exp-0-10-1e-4-32-40.log 2>&1 & 44 | # ./job.sh -g 8 -d C -d L -d N -l naive0 > CLN-naive0-exp-0-10-1e-4-32-40.log 2>&1 & 45 | # ./job.sh -g 7 -d K -l naive0 > K-naive0-exp-0-10-1e-4-32-40.log 2>&1 & 46 | # ./job.sh -g 6 -d K -d C -l naive0 > KC-naive0-exp-0-10-1e-4-32-40.log 2>&1 & 47 | # ./job.sh -g 5 -d K -d L -l naive0 > KL-naive0-exp-0-10-1e-4-32-40.log 2>&1 & 48 | # ./job.sh -g 4 -d K -d N -l naive0 > KN-naive0-exp-0-10-1e-4-32-40.log 2>&1 & 49 | # ./job.sh -g 3 -d K -d C -d L -l naive0 > KCL-naive0-exp-0-10-1e-4-32-40.log 2>&1 & 50 | # ./job.sh -g 2 -d K -d C -d N -l naive0 > KCN-naive0-exp-0-10-1e-4-32-40.log 2>&1 & 51 | # ./job.sh -g 1 -d K -d L -d N -l naive0 > KLN-naive0-exp-0-10-1e-4-32-40.log 2>&1 & 52 | 53 | # ./job.sh -g 0 -d K -d C -d L -d N -l plcc > KCLN-plcc-exp-0-10-1e-4-32-40.log 2>&1 & 54 | # ./job.sh -g 1 -d C -l plcc > C-plcc-exp-0-10-1e-4-32-40.log 2>&1 & 55 | # ./job.sh -g 2 -d L -l plcc > L-plcc-exp-0-10-1e-4-32-40.log 2>&1 & 56 | # ./job.sh -g 3 -d N -l plcc > N-plcc-exp-0-10-1e-4-32-40.log 2>&1 & 57 | # ./job.sh -g 4 -d C -d L -l plcc > CL-plcc-exp-0-10-1e-4-32-40.log 2>&1 & 58 | # ./job.sh -g 5 -d C -d N -l plcc > CN-plcc-exp-0-10-1e-4-32-40.log 2>&1 & 59 | # ./job.sh -g 6 -d L -d N -l plcc > LN-plcc-exp-0-10-1e-4-32-40.log 2>&1 & 60 | # ./job.sh -g 7 -d C -d L -d N -l plcc > CLN-plcc-exp-0-10-1e-4-32-40.log 2>&1 & 61 | # ./job.sh -g 8 -d K -l plcc > K-plcc-exp-0-10-1e-4-32-40.log 2>&1 & 62 | # ./job.sh -g 9 -d K -d C -l plcc > KC-plcc-exp-0-10-1e-4-32-40.log 2>&1 & 63 | # ./job.sh -g 10 -d K -d L -l plcc > KL-plcc-exp-0-10-1e-4-32-40.log 2>&1 & 64 | # ./job.sh -g 11 -d K -d N -l plcc > KN-plcc-exp-0-10-1e-4-32-40.log 2>&1 & 65 | # ./job.sh -g 12 -d K -d C -d L -l plcc > KCL-plcc-exp-0-10-1e-4-32-40.log 2>&1 & 66 | # ./job.sh -g 13 -d K -d C -d N -l plcc > KCN-plcc-exp-0-10-1e-4-32-40.log 2>&1 & 67 | # ./job.sh -g 14 -d K -d L -d N -l plcc > KLN-plcc-exp-0-10-1e-4-32-40.log 2>&1 & 68 | 69 | # ./job.sh -g 15 -d K -d C -d L -d N -l rank > KCLN-rank-exp-0-10-1e-4-32-40.log 2>&1 & 70 | # ./job.sh -g 14 -d C -l rank > C-rank-exp-0-10-1e-4-32-40.log 2>&1 & 71 | # ./job.sh -g 13 -d L -l rank > L-rank-exp-0-10-1e-4-32-40.log 2>&1 & 72 | # ./job.sh -g 12 -d N -l rank > N-rank-exp-0-10-1e-4-32-40.log 2>&1 & 73 | # ./job.sh -g 11 -d C -d L -l rank > CL-rank-exp-0-10-1e-4-32-40.log 2>&1 & 74 | # ./job.sh -g 10 -d C -d N -l rank > CN-rank-exp-0-10-1e-4-32-40.log 2>&1 & 75 | # ./job.sh -g 9 -d L -d N -l rank > LN-rank-exp-0-10-1e-4-32-40.log 2>&1 & 76 | # ./job.sh -g 8 -d C -d L -d N -l rank > CLN-rank-exp-0-10-1e-4-32-40.log 2>&1 & 77 | # ./job.sh -g 7 -d K -l rank > K-rank-exp-0-10-1e-4-32-40.log 2>&1 & 78 | # ./job.sh -g 6 -d K -d C -l rank > KC-rank-exp-0-10-1e-4-32-40.log 2>&1 & 79 | # ./job.sh -g 5 -d K -d L -l rank > KL-rank-exp-0-10-1e-4-32-40.log 2>&1 & 80 | # ./job.sh -g 4 -d K -d N -l rank > KN-rank-exp-0-10-1e-4-32-40.log 2>&1 & 81 | # ./job.sh -g 3 -d K -d C -d L -l rank > KCL-rank-exp-0-10-1e-4-32-40.log 2>&1 & 82 | # ./job.sh -g 2 -d K -d C -d N -l rank > KCN-rank-exp-0-10-1e-4-32-40.log 2>&1 & 83 | # ./job.sh -g 1 -d K -d L -d N -l rank > KLN-rank-exp-0-10-1e-4-32-40.log 2>&1 & 84 | 85 | # ./job.sh -g 0 -d K -d C -d L -d N -l l1 > KCLN-l1-exp-0-10-1e-4-32-40.log 2>&1 & 86 | # ./job.sh -g 1 -d C -l l1 > C-l1-exp-0-10-1e-4-32-40.log 2>&1 & 87 | # ./job.sh -g 2 -d L -l l1 > L-l1-exp-0-10-1e-4-32-40.log 2>&1 & 88 | # ./job.sh -g 3 -d N -l l1 > N-l1-exp-0-10-1e-4-32-40.log 2>&1 & 89 | # ./job.sh -g 4 -d C -d L -l l1 > CL-l1-exp-0-10-1e-4-32-40.log 2>&1 & 90 | # ./job.sh -g 5 -d C -d N -l l1 > CN-l1-exp-0-10-1e-4-32-40.log 2>&1 & 91 | # ./job.sh -g 6 -d L -d N -l l1 > LN-l1-exp-0-10-1e-4-32-40.log 2>&1 & 92 | # ./job.sh -g 7 -d C -d L -d N -l l1 > CLN-l1-exp-0-10-1e-4-32-40.log 2>&1 & 93 | # ./job.sh -g 8 -d K -l l1 > K-l1-exp-0-10-1e-4-32-40.log 2>&1 & 94 | # ./job.sh -g 9 -d K -d C -l l1 > KC-l1-exp-0-10-1e-4-32-40.log 2>&1 & 95 | # ./job.sh -g 10 -d K -d L -l l1 > KL-l1-exp-0-10-1e-4-32-40.log 2>&1 & 96 | # ./job.sh -g 11 -d K -d N -l l1 > KN-l1-exp-0-10-1e-4-32-40.log 2>&1 & 97 | # ./job.sh -g 12 -d K -d C -d L -l l1 > KCL-l1-exp-0-10-1e-4-32-40.log 2>&1 & 98 | # ./job.sh -g 13 -d K -d C -d N -l l1 > KCN-l1-exp-0-10-1e-4-32-40.log 2>&1 & 99 | # ./job.sh -g 14 -d K -d L -d N -l l1 > KLN-l1-exp-0-10-1e-4-32-40.log 2>&1 & 100 | 101 | # ./job.sh -g 15 -d K -d C -d L -d N -l correlation > KCLN-correlation-exp-0-10-1e-4-32-40.log 2>&1 & 102 | # ./job.sh -g 14 -d C -l correlation > C-correlation-exp-0-10-1e-4-32-40.log 2>&1 & 103 | # ./job.sh -g 13 -d L -l correlation > L-correlation-exp-0-10-1e-4-32-40.log 2>&1 & 104 | # ./job.sh -g 12 -d N -l correlation > N-correlation-exp-0-10-1e-4-32-40.log 2>&1 & 105 | # ./job.sh -g 11 -d C -d L -l correlation > CL-correlation-exp-0-10-1e-4-32-40.log 2>&1 & 106 | # ./job.sh -g 10 -d C -d N -l correlation > CN-correlation-exp-0-10-1e-4-32-40.log 2>&1 & 107 | # ./job.sh -g 9 -d L -d N -l correlation > LN-correlation-exp-0-10-1e-4-32-40.log 2>&1 & 108 | # ./job.sh -g 8 -d C -d L -d N -l correlation > CLN-correlation-exp-0-10-1e-4-32-40.log 2>&1 & 109 | # ./job.sh -g 7 -d K -l correlation > K-correlation-exp-0-10-1e-4-32-40.log 2>&1 & 110 | # ./job.sh -g 6 -d K -d C -l correlation > KC-correlation-exp-0-10-1e-4-32-40.log 2>&1 & 111 | # ./job.sh -g 5 -d K -d L -l correlation > KL-correlation-exp-0-10-1e-4-32-40.log 2>&1 & 112 | # ./job.sh -g 4 -d K -d N -l correlation > KN-correlation-exp-0-10-1e-4-32-40.log 2>&1 & 113 | # ./job.sh -g 3 -d K -d C -d L -l correlation > KCL-correlation-exp-0-10-1e-4-32-40.log 2>&1 & 114 | # ./job.sh -g 2 -d K -d C -d N -l correlation > KCN-correlation-exp-0-10-1e-4-32-40.log 2>&1 & 115 | # ./job.sh -g 1 -d K -d L -d N -l correlation > KLN-correlation-exp-0-10-1e-4-32-40.log 2>&1 & 116 | 117 | # ./job.sh -g 0 -d K -d C -d L -d N -l rank+l1 > KCLN-rank+l1-exp-0-10-1e-4-32-40.log 2>&1 & 118 | # ./job.sh -g 1 -d C -l rank+l1 > C-rank+l1-exp-0-10-1e-4-32-40.log 2>&1 & 119 | # ./job.sh -g 2 -d L -l rank+l1 > L-rank+l1-exp-0-10-1e-4-32-40.log 2>&1 & 120 | # ./job.sh -g 3 -d N -l rank+l1 > N-rank+l1-exp-0-10-1e-4-32-40.log 2>&1 & 121 | # ./job.sh -g 4 -d C -d L -l rank+l1 > CL-rank+l1-exp-0-10-1e-4-32-40.log 2>&1 & 122 | # ./job.sh -g 5 -d C -d N -l rank+l1 > CN-rank+l1-exp-0-10-1e-4-32-40.log 2>&1 & 123 | # ./job.sh -g 6 -d L -d N -l rank+l1 > LN-rank+l1-exp-0-10-1e-4-32-40.log 2>&1 & 124 | # ./job.sh -g 7 -d C -d L -d N -l rank+l1 > CLN-rank+l1-exp-0-10-1e-4-32-40.log 2>&1 & 125 | # ./job.sh -g 8 -d K -l rank+l1 > K-rank+l1-exp-0-10-1e-4-32-40.log 2>&1 & 126 | # ./job.sh -g 9 -d K -d C -l rank+l1 > KC-rank+l1-exp-0-10-1e-4-32-40.log 2>&1 & 127 | # ./job.sh -g 10 -d K -d L -l rank+l1 > KL-rank+l1-exp-0-10-1e-4-32-40.log 2>&1 & 128 | # ./job.sh -g 11 -d K -d N -l rank+l1 > KN-rank+l1-exp-0-10-1e-4-32-40.log 2>&1 & 129 | # ./job.sh -g 12 -d K -d C -d L -l rank+l1 > KCL-rank+l1-exp-0-10-1e-4-32-40.log 2>&1 & 130 | # ./job.sh -g 13 -d K -d C -d N -l rank+l1 > KCN-rank+l1-exp-0-10-1e-4-32-40.log 2>&1 & 131 | # ./job.sh -g 14 -d K -d L -d N -l rank+l1 > KLN-rank+l1-exp-0-10-1e-4-32-40.log 2>&1 & 132 | 133 | # ./job.sh -g 15 -d K -d C -d L -d N -l plcc+l1 > KCLN-plcc+l1-exp-0-10-1e-4-32-40.log 2>&1 & 134 | # ./job.sh -g 14 -d C -l plcc+l1 > C-plcc+l1-exp-0-10-1e-4-32-40.log 2>&1 & 135 | # ./job.sh -g 13 -d L -l plcc+l1 > L-plcc+l1-exp-0-10-1e-4-32-40.log 2>&1 & 136 | # ./job.sh -g 12 -d N -l plcc+l1 > N-plcc+l1-exp-0-10-1e-4-32-40.log 2>&1 & 137 | # ./job.sh -g 11 -d C -d L -l plcc+l1 > CL-plcc+l1-exp-0-10-1e-4-32-40.log 2>&1 & 138 | # ./job.sh -g 10 -d C -d N -l plcc+l1 > CN-plcc+l1-exp-0-10-1e-4-32-40.log 2>&1 & 139 | # ./job.sh -g 9 -d L -d N -l plcc+l1 > LN-plcc+l1-exp-0-10-1e-4-32-40.log 2>&1 & 140 | # ./job.sh -g 8 -d C -d L -d N -l plcc+l1 > CLN-plcc+l1-exp-0-10-1e-4-32-40.log 2>&1 & 141 | # ./job.sh -g 7 -d K -l plcc+l1 > K-plcc+l1-exp-0-10-1e-4-32-40.log 2>&1 & 142 | # ./job.sh -g 6 -d K -d C -l plcc+l1 > KC-plcc+l1-exp-0-10-1e-4-32-40.log 2>&1 & 143 | # ./job.sh -g 5 -d K -d L -l plcc+l1 > KL-plcc+l1-exp-0-10-1e-4-32-40.log 2>&1 & 144 | # ./job.sh -g 4 -d K -d N -l plcc+l1 > KN-plcc+l1-exp-0-10-1e-4-32-40.log 2>&1 & 145 | # ./job.sh -g 3 -d K -d C -d L -l plcc+l1 > KCL-plcc+l1-exp-0-10-1e-4-32-40.log 2>&1 & 146 | # ./job.sh -g 2 -d K -d C -d N -l plcc+l1 > KCN-plcc+l1-exp-0-10-1e-4-32-40.log 2>&1 & 147 | # ./job.sh -g 1 -d K -d L -d N -l plcc+l1 > KLN-plcc+l1-exp-0-10-1e-4-32-40.log 2>&1 & 148 | 149 | # ./job.sh -g 1 -d K -d C -d L -d N -p 1 > KCLN-mixed-train_proportion=1-exp-0-10-1e-4-32-40.log 2>&1 & 150 | # ./job.sh -g 2 -d K -d C -d L -d N -p 2 > KCLN-mixed-train_proportion=2-exp-0-10-1e-4-32-40.log 2>&1 & 151 | # ./job.sh -g 3 -d K -d C -d L -d N -p 3 > KCLN-mixed-train_proportion=3-exp-0-10-1e-4-32-40.log 2>&1 & 152 | # ./job.sh -g 4 -d K -d C -d L -d N -p 4 > KCLN-mixed-train_proportion=4-exp-0-10-1e-4-32-40.log 2>&1 & 153 | # ./job.sh -g 0 -d K -d C -d L -d N -p 5 > KCLN-mixed-train_proportion=5-exp-0-10-1e-4-32-40.log 2>&1 & 154 | 155 | loss=mixed 156 | start_id=0 157 | end_id=10 158 | train_proportion=6 159 | while getopts "g:d:l:s:e:p:" opt; do 160 | case $opt in 161 | g) gpu_id=("$OPTARG");; # gpu_id 162 | d) datasets+=("$OPTARG");; # trained datasets 163 | l) loss=("$OPTARG");; # loss 164 | s) start_id=("$OPTARG");; 165 | e) end_id=("$OPTARG");; 166 | p) train_proportion=("$OPTARG");; 167 | esac 168 | done 169 | shift $((OPTIND -1)) 170 | # if [ ! $loss ]; then 171 | # loss=mixed 172 | # fi 173 | # echo $loss 174 | source activate reproducibleresearch 175 | for ((i=$start_id; i<$end_id; i++)); do 176 | CUDA_VISIBLE_DEVICES=$gpu_id python main.py --exp_id=$i --train_proportion $train_proportion --loss=$loss --lr=1e-4 --batch_size=32 --epochs=40 --trained_datasets ${datasets[@]} 177 | done 178 | source deactivate -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # Author: Dingquan Li 2 | # Email: dingquanli AT pku DOT edu DOT cn 3 | # Date: 2019/11/8 4 | # 5 | # source activate reproducibleresearch 6 | # tensorboard --logdir=runs --port=6006 7 | 8 | import torch 9 | from torch.optim import Adam, lr_scheduler 10 | from torch.utils.data import Dataset 11 | from ignite.engine import create_supervised_evaluator, create_supervised_trainer, Events 12 | from VQAdataset import get_data_loaders 13 | from VQAmodel import VQAModel 14 | from VQAloss import VQALoss 15 | from VQAperformance import VQAPerformance 16 | from tensorboardX import SummaryWriter 17 | import datetime 18 | import os 19 | import numpy as np 20 | import random 21 | from argparse import ArgumentParser 22 | 23 | 24 | def writer_add_scalar(writer, status, dataset, scalars, iter): 25 | writer.add_scalar("{}/{}/SROCC".format(status, dataset), scalars['SROCC'], iter) 26 | writer.add_scalar("{}/{}/KROCC".format(status, dataset), scalars['KROCC'], iter) 27 | writer.add_scalar("{}/{}/PLCC".format(status, dataset), scalars['PLCC'], iter) 28 | writer.add_scalar("{}/{}/RMSE".format(status, dataset), scalars['RMSE'], iter) 29 | 30 | 31 | def run(args): 32 | device = torch.device("cuda" if not args.disable_gpu and torch.cuda.is_available() else "cpu") 33 | train_loader, val_loader, test_loader, scale, m = get_data_loaders(args) 34 | model = VQAModel(scale, m, args.simple_linear_scale).to(device) # 35 | print(model) 36 | 37 | optimizer = Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 38 | scheduler = lr_scheduler.StepLR(optimizer, step_size=args.decay_interval, gamma=args.decay_ratio) 39 | loss_func = VQALoss([scale[dataset] for dataset in args.datasets['train']], args.loss, [m[dataset] for dataset in args.datasets['train']]) 40 | trainer = create_supervised_trainer(model, optimizer, loss_func, device=device) 41 | evaluator = create_supervised_evaluator(model, metrics={'VQA_performance': VQAPerformance()}, device=device) 42 | 43 | if args.inference: 44 | model.load_state_dict(torch.load(args.trained_model_file)) 45 | performance = dict() 46 | for dataset in args.datasets['test']: 47 | evaluator.run(test_loader[dataset]) 48 | performance[dataset] = evaluator.state.metrics['VQA_performance'] 49 | print('{}, SROCC: {}'.format(dataset, performance[dataset]['SROCC'])) 50 | np.save(args.save_result_file, performance) 51 | return 52 | 53 | writer = SummaryWriter(log_dir='{}/EXP{}-{}-{}-{}-{}-{}-{}-{}-{}-{}' 54 | .format(args.log_dir, args.exp_id, args.model, args.feature_extractor, args.loss, args.train_proportion, args.datasets['train'], 55 | args.lr, args.batch_size, args.epochs, 56 | datetime.datetime.now().strftime("%I:%M%p on %B %d, %Y"))) 57 | 58 | global best_val_criterion, best_epoch 59 | best_val_criterion, best_epoch = -100, -1 # larger, better, e.g., SROCC/KROCC/PLCC 60 | 61 | @trainer.on(Events.ITERATION_COMPLETED) 62 | def iter_event_function(engine): 63 | writer.add_scalar("train/loss", engine.state.output, engine.state.iteration) 64 | 65 | @trainer.on(Events.EPOCH_COMPLETED) 66 | def epoch_event_function(engine): 67 | val_criterion = 0 68 | for dataset in args.datasets['val']: 69 | evaluator.run(val_loader[dataset]) 70 | performance = evaluator.state.metrics['VQA_performance'] 71 | writer_add_scalar(writer, 'val', dataset, performance, engine.state.epoch) 72 | if dataset in args.datasets['train']: 73 | val_criterion += performance['SROCC'] 74 | 75 | for dataset in args.datasets['test']: 76 | evaluator.run(test_loader[dataset]) 77 | performance = evaluator.state.metrics['VQA_performance'] 78 | writer_add_scalar(writer, 'test', dataset, performance, engine.state.epoch) 79 | 80 | global best_val_criterion, best_epoch 81 | if val_criterion > best_val_criterion: 82 | torch.save(model.state_dict(), args.trained_model_file) 83 | best_val_criterion = val_criterion 84 | best_epoch = engine.state.epoch 85 | print('Save current best model @best_val_criterion: {} @epoch: {}'.format(best_val_criterion, best_epoch)) 86 | 87 | scheduler.step(engine.state.epoch) 88 | 89 | @trainer.on(Events.COMPLETED) 90 | def final_testing_results(engine): 91 | print('best epoch: {}'.format(best_epoch)) 92 | model.load_state_dict(torch.load(args.trained_model_file)) 93 | performance = dict() 94 | for dataset in args.datasets['test']: 95 | evaluator.run(test_loader[dataset]) 96 | performance[dataset] = evaluator.state.metrics['VQA_performance'] 97 | print('{}, SROCC: {}'.format(dataset, performance[dataset]['SROCC'])) 98 | np.save(args.save_result_file, performance) 99 | 100 | trainer.run(train_loader, max_epochs=args.epochs) 101 | 102 | 103 | if __name__ == "__main__": 104 | parser = ArgumentParser(description='Mixed Dataset Training for Quality Assessment of In-the-Wild Videos') 105 | parser.add_argument("--seed", type=int, default=19920517) 106 | parser.add_argument('--lr', type=float, default=1e-4, 107 | help='learning rate (default: 1e-4)') 108 | parser.add_argument('--batch_size', type=int, default=32, 109 | help='input batch size for training (default: 32)') 110 | parser.add_argument('--epochs', type=int, default=40, 111 | help='number of epochs to train (default: 40)') 112 | parser.add_argument('--weight_decay', type=float, default=0.0, 113 | help='weight decay (default: 0.0)') 114 | 115 | parser.add_argument('--model', default='MDTVSFA', type=str, 116 | help='model name (default: MDTVSFA)') 117 | parser.add_argument('--loss', default='mixed', type=str, 118 | help='loss type (default: mixed)') 119 | parser.add_argument('--feature_extractor', default='ResNet-50', type=str, 120 | help='feature_extractor backbone (default: ResNet-50)') 121 | # parser.add_argument('--feat_dim', type=int, default=4096, 122 | # help='feature dimension (default: 4096)') 123 | 124 | parser.add_argument('--trained_datasets', nargs='+', type=str, default=['K', 'C', 'L', 'N'], 125 | help="trained datasets (default: ['K', 'C', 'L', 'N'])") 126 | 127 | parser.add_argument('--exp_id', default=0, type=int, 128 | help='exp id for train-val-test splits (default: 0)') 129 | parser.add_argument('--crop_length', type=int, default=180, 130 | help='Crop video length (<=max_len=1202, default: 180)') 131 | parser.add_argument('--train_ratio', type=float, default=0.6, 132 | help='train ratio (default: 0.6)') 133 | parser.add_argument('--train_proportion', type=float, default=6, 134 | help='the number of proportions (#total 6) used in the training set (default: 6)') 135 | 136 | parser.add_argument("--log_dir", type=str, default="runs", 137 | help="log directory for Tensorboard log output") 138 | parser.add_argument('--disable_gpu', action='store_true', 139 | help='flag whether to disable GPU') 140 | parser.add_argument('--inference', action='store_true', 141 | help='Inference?') 142 | args = parser.parse_args() 143 | args.train_proportion /= 6 144 | if args.feature_extractor == 'AlexNet': 145 | args.feat_dim = 256 * 2 146 | else: 147 | args.feat_dim = 2048 * 2 148 | 149 | 150 | args.simple_linear_scale = False # 151 | if 'naive' in args.loss: 152 | args.simple_linear_scale = True # 153 | 154 | args.decay_interval = int(args.epochs / 20) 155 | args.decay_ratio = 0.8 156 | 157 | args.datasets = {'train': args.trained_datasets, 158 | 'val': args.trained_datasets, 159 | 'test': ['K', 'C', 'L', 'N']} 160 | args.features_dir = {'K': 'CNN_features_KoNViD-1k/', 161 | 'C': 'CNN_features_CVD2014/', 162 | 'L': 'CNN_features_LIVE-Qualcomm/', 163 | 'N': 'CNN_features_LIVE-VQC/'} 164 | args.data_info = {'K': 'data/KoNViD-1kinfo.mat', 165 | 'C': 'data/CVD2014info.mat', 166 | 'L': 'data/LIVE-Qualcomminfo.mat', 167 | 'N': 'data/LIVE-VQCinfo.mat'} 168 | 169 | torch.manual_seed(args.seed) # 170 | torch.backends.cudnn.deterministic = True 171 | torch.backends.cudnn.benchmark = False 172 | np.random.seed(args.seed) 173 | random.seed(args.seed) 174 | 175 | torch.utils.backcompat.broadcast_warning.enabled = True 176 | 177 | if not os.path.exists('checkpoints'): 178 | os.makedirs('checkpoints') 179 | args.trained_model_file = 'checkpoints/{}-{}-{}-{}-{}-{}-{}-{}-EXP{}'.format(args.model, args.feature_extractor, args.loss, args.train_proportion, args.datasets['train'], args.lr, args.batch_size, args.epochs, args.exp_id) 180 | if not os.path.exists('results'): 181 | os.makedirs('results') 182 | args.save_result_file = 'results/{}-{}-{}-{}-{}-{}-{}-{}-EXP{}'.format(args.model, args.feature_extractor, args.loss, args.train_proportion, args.datasets['train'], args.lr, args.batch_size, args.epochs, args.exp_id) 183 | print(args) 184 | run(args) 185 | -------------------------------------------------------------------------------- /models/MDTVSFA.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lidq92/MDTVSFA/1460fb21d8e8cf1493331edee5a0082dd6bbf2ff/models/MDTVSFA.pt -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | argparse==1.4.0 2 | h5py==2.10.0 3 | PyYAML==5.1.2 4 | Pillow==7.1.2 5 | scikit-video==1.1.11 6 | numpy==1.17.3 7 | scipy==1.0.1 8 | torch==1.3.0 9 | torchvision==0.3.0 10 | pytorch-ignite==0.4.1 11 | tensorflow-gpu==2.0.0 12 | tensorboardX==1.9 13 | -------------------------------------------------------------------------------- /test_demo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import transforms 3 | import skvideo.io 4 | from PIL import Image 5 | import numpy as np 6 | from CNNfeatures import get_features 7 | from VQAmodel import VQAModel 8 | from argparse import ArgumentParser 9 | import time 10 | 11 | 12 | if __name__ == "__main__": 13 | parser = ArgumentParser(description='"Test Demo of MDTVSFA') 14 | parser.add_argument('--model_path', default='models/MDTVSFA.pt', type=str, 15 | help='model path (default: models/MDTVSFA.pt)') 16 | parser.add_argument('--video_path', default='./test.mp4', type=str, 17 | help='video path (default: ./test.mp4)') 18 | parser.add_argument('--video_format', default='RGB', type=str, 19 | help='video format: RGB or YUV420 (default: RGB)') 20 | parser.add_argument('--video_width', type=int, default=None, 21 | help='video width') 22 | parser.add_argument('--video_height', type=int, default=None, 23 | help='video height') 24 | 25 | parser.add_argument('--frame_batch_size', type=int, default=32, 26 | help='frame batch size for feature extraction (default: 32)') 27 | args = parser.parse_args() 28 | 29 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 30 | 31 | start = time.time() 32 | 33 | # data preparation 34 | assert args.video_format == 'YUV420' or args.video_format == 'RGB' 35 | if args.video_format == 'YUV420': 36 | video_data = skvideo.io.vread(args.video_path, args.video_height, args.video_width, inputdict={'-pix_fmt': 'yuvj420p'}) 37 | else: 38 | video_data = skvideo.io.vread(args.video_path) 39 | 40 | video_length = video_data.shape[0] 41 | video_channel = video_data.shape[3] 42 | video_height = video_data.shape[1] 43 | video_width = video_data.shape[2] 44 | transformed_video = torch.zeros([video_length, video_channel, video_height, video_width]) 45 | transform = transforms.Compose([ 46 | transforms.ToTensor(), 47 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 48 | ]) 49 | 50 | for frame_idx in range(video_length): 51 | frame = video_data[frame_idx] 52 | frame = Image.fromarray(frame) 53 | frame = transform(frame) 54 | transformed_video[frame_idx] = frame 55 | 56 | print('Video length: {}'.format(transformed_video.shape[0])) 57 | 58 | # feature extraction 59 | features = get_features(transformed_video, frame_batch_size=args.frame_batch_size, device=device) 60 | features = torch.unsqueeze(features, 0) # batch size 1 61 | 62 | # quality prediction 63 | model = VQAModel().to(device) 64 | model.load_state_dict(torch.load(args.model_path)) # 65 | 66 | model.eval() 67 | with torch.no_grad(): 68 | input_length = features.shape[1] * torch.ones(1, 1, dtype=torch.long) 69 | relative_score, mapped_score, aligned_score = model([(features, input_length, ['K'])]) 70 | y_pred = mapped_score[0][0].to('cpu').numpy() 71 | print("Predicted perceptual quality: {}".format(y_pred)) 72 | 73 | end = time.time() 74 | 75 | print('Time: {} s'.format(end-start)) --------------------------------------------------------------------------------