├── ACMMM_2019_supplementary_44 ├── in-focus image pair │ ├── better_in-focus.JPG │ └── worse_in-focus.JPG ├── out-of-focus image pair │ ├── better_out-of-focus.JPG │ └── worse_out-of-focus.JPG └── video pair │ ├── better-MAH_02693_8s.mp4 │ └── worse-MAH_02762_8s.mp4 ├── CNNfeatures.py ├── Framework.jpg ├── License ├── Readme.md ├── VSFA.py ├── _config.yml ├── data ├── CVD2014info.mat ├── KoNViD-1kinfo.mat ├── LIVE-Qualcomminfo.mat └── data_info_maker.m ├── models └── VSFA.pt ├── requirements.txt ├── test.mp4 └── test_demo.py /ACMMM_2019_supplementary_44/in-focus image pair/better_in-focus.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lidq92/VSFA/b78751b58b05321aea0f13ee913b4bee1531a233/ACMMM_2019_supplementary_44/in-focus image pair/better_in-focus.JPG -------------------------------------------------------------------------------- /ACMMM_2019_supplementary_44/in-focus image pair/worse_in-focus.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lidq92/VSFA/b78751b58b05321aea0f13ee913b4bee1531a233/ACMMM_2019_supplementary_44/in-focus image pair/worse_in-focus.JPG -------------------------------------------------------------------------------- /ACMMM_2019_supplementary_44/out-of-focus image pair/better_out-of-focus.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lidq92/VSFA/b78751b58b05321aea0f13ee913b4bee1531a233/ACMMM_2019_supplementary_44/out-of-focus image pair/better_out-of-focus.JPG -------------------------------------------------------------------------------- /ACMMM_2019_supplementary_44/out-of-focus image pair/worse_out-of-focus.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lidq92/VSFA/b78751b58b05321aea0f13ee913b4bee1531a233/ACMMM_2019_supplementary_44/out-of-focus image pair/worse_out-of-focus.JPG -------------------------------------------------------------------------------- /ACMMM_2019_supplementary_44/video pair/better-MAH_02693_8s.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lidq92/VSFA/b78751b58b05321aea0f13ee913b4bee1531a233/ACMMM_2019_supplementary_44/video pair/better-MAH_02693_8s.mp4 -------------------------------------------------------------------------------- /ACMMM_2019_supplementary_44/video pair/worse-MAH_02762_8s.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lidq92/VSFA/b78751b58b05321aea0f13ee913b4bee1531a233/ACMMM_2019_supplementary_44/video pair/worse-MAH_02762_8s.mp4 -------------------------------------------------------------------------------- /CNNfeatures.py: -------------------------------------------------------------------------------- 1 | """Extracting Content-Aware Perceptual Features using Pre-Trained ResNet-50""" 2 | # Author: Dingquan Li 3 | # Email: dingquanli AT pku DOT edu DOT cn 4 | # Date: 2018/3/27 5 | # 6 | # CUDA_VISIBLE_DEVICES=0 python CNNfeatures.py --database=KoNViD-1k --frame_batch_size=64 7 | # CUDA_VISIBLE_DEVICES=1 python CNNfeatures.py --database=CVD2014 --frame_batch_size=32 8 | # CUDA_VISIBLE_DEVICES=0 python CNNfeatures.py --database=LIVE-Qualcomm --frame_batch_size=8 9 | 10 | import torch 11 | from torchvision import transforms, models 12 | import torch.nn as nn 13 | from torch.utils.data import Dataset 14 | import skvideo.io 15 | from PIL import Image 16 | import os 17 | import h5py 18 | import numpy as np 19 | import random 20 | from argparse import ArgumentParser 21 | 22 | 23 | class VideoDataset(Dataset): 24 | """Read data from the original dataset for feature extraction""" 25 | def __init__(self, videos_dir, video_names, score, video_format='RGB', width=None, height=None): 26 | 27 | super(VideoDataset, self).__init__() 28 | self.videos_dir = videos_dir 29 | self.video_names = video_names 30 | self.score = score 31 | self.format = video_format 32 | self.width = width 33 | self.height = height 34 | 35 | def __len__(self): 36 | return len(self.video_names) 37 | 38 | def __getitem__(self, idx): 39 | video_name = self.video_names[idx] 40 | assert self.format == 'YUV420' or self.format == 'RGB' 41 | if self.format == 'YUV420': 42 | video_data = skvideo.io.vread(os.path.join(self.videos_dir, video_name), self.height, self.width, inputdict={'-pix_fmt':'yuvj420p'}) 43 | else: 44 | video_data = skvideo.io.vread(os.path.join(self.videos_dir, video_name)) 45 | video_score = self.score[idx] 46 | 47 | transform = transforms.Compose([ 48 | transforms.ToTensor(), 49 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 50 | ]) 51 | 52 | video_length = video_data.shape[0] 53 | video_channel = video_data.shape[3] 54 | video_height = video_data.shape[1] 55 | video_width = video_data.shape[2] 56 | transformed_video = torch.zeros([video_length, video_channel, video_height, video_width]) 57 | for frame_idx in range(video_length): 58 | frame = video_data[frame_idx] 59 | frame = Image.fromarray(frame) 60 | frame = transform(frame) 61 | transformed_video[frame_idx] = frame 62 | 63 | sample = {'video': transformed_video, 64 | 'score': video_score} 65 | 66 | return sample 67 | 68 | 69 | class ResNet50(torch.nn.Module): 70 | """Modified ResNet50 for feature extraction""" 71 | def __init__(self): 72 | super(ResNet50, self).__init__() 73 | self.features = nn.Sequential(*list(models.resnet50(pretrained=True).children())[:-2]) 74 | for p in self.features.parameters(): 75 | p.requires_grad = False 76 | 77 | def forward(self, x): 78 | # features@: 7->res5c 79 | for ii, model in enumerate(self.features): 80 | x = model(x) 81 | if ii == 7: 82 | features_mean = nn.functional.adaptive_avg_pool2d(x, 1) 83 | features_std = global_std_pool2d(x) 84 | return features_mean, features_std 85 | 86 | 87 | def global_std_pool2d(x): 88 | """2D global standard variation pooling""" 89 | return torch.std(x.view(x.size()[0], x.size()[1], -1, 1), 90 | dim=2, keepdim=True) 91 | 92 | 93 | def get_features(video_data, frame_batch_size=64, device='cuda'): 94 | """feature extraction""" 95 | extractor = ResNet50().to(device) 96 | video_length = video_data.shape[0] 97 | frame_start = 0 98 | frame_end = frame_start + frame_batch_size 99 | output1 = torch.Tensor().to(device) 100 | output2 = torch.Tensor().to(device) 101 | extractor.eval() 102 | with torch.no_grad(): 103 | while frame_end < video_length: 104 | batch = video_data[frame_start:frame_end].to(device) 105 | features_mean, features_std = extractor(batch) 106 | output1 = torch.cat((output1, features_mean), 0) 107 | output2 = torch.cat((output2, features_std), 0) 108 | frame_end += frame_batch_size 109 | frame_start += frame_batch_size 110 | 111 | last_batch = video_data[frame_start:video_length].to(device) 112 | features_mean, features_std = extractor(last_batch) 113 | output1 = torch.cat((output1, features_mean), 0) 114 | output2 = torch.cat((output2, features_std), 0) 115 | output = torch.cat((output1, output2), 1).squeeze() 116 | 117 | return output 118 | 119 | 120 | if __name__ == "__main__": 121 | parser = ArgumentParser(description='"Extracting Content-Aware Perceptual Features using Pre-Trained ResNet-50') 122 | parser.add_argument("--seed", type=int, default=19920517) 123 | parser.add_argument('--database', default='KoNViD-1k', type=str, 124 | help='database name (default: KoNViD-1k)') 125 | parser.add_argument('--frame_batch_size', type=int, default=64, 126 | help='frame batch size for feature extraction (default: 64)') 127 | 128 | parser.add_argument('--disable_gpu', action='store_true', 129 | help='flag whether to disable GPU') 130 | args = parser.parse_args() 131 | 132 | torch.manual_seed(args.seed) # 133 | torch.backends.cudnn.deterministic = True 134 | torch.backends.cudnn.benchmark = False 135 | np.random.seed(args.seed) 136 | random.seed(args.seed) 137 | 138 | torch.utils.backcompat.broadcast_warning.enabled = True 139 | 140 | if args.database == 'KoNViD-1k': 141 | videos_dir = '/home/ldq/Downloads/KoNViD-1k/' # videos dir 142 | features_dir = 'CNN_features_KoNViD-1k/' # features dir 143 | datainfo = 'data/KoNViD-1kinfo.mat' # database info: video_names, scores; video format, width, height, index, ref_ids, max_len, etc. 144 | if args.database == 'CVD2014': 145 | videos_dir = '/media/ldq/Research/Data/CVD2014/' 146 | features_dir = 'CNN_features_CVD2014/' 147 | datainfo = 'data/CVD2014info.mat' 148 | if args.database == 'LIVE-Qualcomm': 149 | videos_dir = '/media/ldq/Others/Data/12.LIVE-Qualcomm Mobile In-Capture Video Quality Database/' 150 | features_dir = 'CNN_features_LIVE-Qualcomm/' 151 | datainfo = 'data/LIVE-Qualcomminfo.mat' 152 | 153 | if not os.path.exists(features_dir): 154 | os.makedirs(features_dir) 155 | 156 | device = torch.device("cuda" if not args.disable_gpu and torch.cuda.is_available() else "cpu") 157 | 158 | Info = h5py.File(datainfo, 'r') 159 | video_names = [Info[Info['video_names'][0, :][i]][()].tobytes()[::2].decode() for i in range(len(Info['video_names'][0, :]))] 160 | scores = Info['scores'][0, :] 161 | video_format = Info['video_format'][()].tobytes()[::2].decode() 162 | width = int(Info['width'][0]) 163 | height = int(Info['height'][0]) 164 | dataset = VideoDataset(videos_dir, video_names, scores, video_format, width, height) 165 | 166 | for i in range(len(dataset)): 167 | current_data = dataset[i] 168 | current_video = current_data['video'] 169 | current_score = current_data['score'] 170 | print('Video {}: length {}'.format(i, current_video.shape[0])) 171 | features = get_features(current_video, args.frame_batch_size, device) 172 | np.save(features_dir + str(i) + '_resnet-50_res5c', features.to('cpu').numpy()) 173 | np.save(features_dir + str(i) + '_score', current_score) 174 | -------------------------------------------------------------------------------- /Framework.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lidq92/VSFA/b78751b58b05321aea0f13ee913b4bee1531a233/Framework.jpg -------------------------------------------------------------------------------- /License: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | Copyright © 2019 Dingquan Li 3 | 4 | Permission is hereby granted, free of charge, to any person obtaining a copy 5 | of this software and associated documentation files (the “Software”), to deal 6 | in the Software without restriction, including without limitation the rights 7 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 | copies of the Software, and to permit persons to whom the Software is 9 | furnished to do so, subject to the following conditions: 10 | 11 | The above copyright notice and this permission notice shall be included in 12 | all copies or substantial portions of the Software. 13 | 14 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 20 | THE SOFTWARE. -------------------------------------------------------------------------------- /Readme.md: -------------------------------------------------------------------------------- 1 | # Quality Assessment of In-the-Wild Videos 2 | [![License](https://img.shields.io/github/license/mashape/apistatus.svg?maxAge=2592000)](License) 3 | 4 | ## Description 5 | VSFA code for the following papers: 6 | 7 | - Dingquan Li, Tingting Jiang, and Ming Jiang. [Quality Assessment of In-the-Wild Videos](https://dl.acm.org/citation.cfm?doid=3343031.3351028). In Proceedings of the 27th ACM International Conference on Multimedia (MM ’19), October 21-25, 2019, Nice, France. [[arxiv version]](https://arxiv.org/abs/1908.00375) 8 | ![Framework](Framework.jpg) 9 | 10 | ### Intra-Database Experiments (Training and Evaluating) 11 | #### Feature extraction 12 | 13 | ``` 14 | CUDA_VISIBLE_DEVICES=0 python CNNfeatures.py --database=KoNViD-1k --frame_batch_size=64 15 | ``` 16 | 17 | You need to specify the `database` and change the corresponding `videos_dir`. 18 | 19 | #### Quality prediction 20 | 21 | ``` 22 | CUDA_VISIBLE_DEVICES=0 python VSFA.py --database=KoNViD-1k --exp_id=0 23 | ``` 24 | 25 | You need to specify the `database` and `exp_id`. 26 | 27 | #### Visualization 28 | ```bash 29 | tensorboard --logdir=logs --port=6006 # in the server (host:port) 30 | ssh -p port -L 6006:localhost:6006 user@host # in your PC. See the visualization in your PC 31 | ``` 32 | 33 | #### Reproduced results 34 | We set seeds for the random generators and re-run the experiments on the same ten splits, i.e., the first 10 splits (`exp_id=0~9`). The results may be still not the same among different version of PyTorch. See [randomness@Pytorch Docs](https://pytorch.org/docs/stable/notes/randomness.html) 35 | 36 | The reproduced overall results are better than the previous results published in the paper. 37 | We add learning rate scheduling in the updated code. 38 | Better hyper-parameters may be set, if you "look" at the training loss curve and the curves of validation results. 39 | 40 | The mean (std) values of the first ten index splits (60%:20%:20% train:val:test) 41 | 42 | | | KoNViD-1k | CVD2014 | LIVE-Qualcomm | 43 | | ---- | ---- | ---- | ---- | 44 | | SROCC | 0.7728 (0.0189) | 0.8698 (0.0368) | 0.7726 (0.0611) | 45 | | KROCC | 0.5784 (0.0194) | 0.6950 (0.0465) | 0.5871 (0.0620) | 46 | | PLCC | 0.7754 (0.0192) | 0.8678 (0.0315) | 0.7954 (0.0553) | 47 | | RMSE | 0.4205 (0.0211) | 10.8572 (1.3518)| 7.5495 (0.7017) | 48 | 49 | ### Test Demo 50 | 51 | The model weights provided in `models/VSFA.pt` are the saved weights when running the 9-th split of KoNViD-1k. 52 | ``` 53 | python test_demo.py --video_path=test.mp4 54 | ``` 55 | 56 | ### Requirement 57 | ```bash 58 | conda create -n reproducibleresearch pip python=3.6 59 | source activate reproducibleresearch 60 | pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple 61 | source deactive 62 | ``` 63 | - PyTorch 1.1.0 64 | - TensorboardX 1.2, TensorFlow-TensorBoard 65 | 66 | Note: The codes can also be directly run on PyTorch 1.3. 67 | 68 | ### Contact 69 | Dingquan Li, dingquanli AT pku DOT edu DOT cn. 70 | -------------------------------------------------------------------------------- /VSFA.py: -------------------------------------------------------------------------------- 1 | """Quality Assessment of In-the-Wild Videos, ACM MM 2019""" 2 | # 3 | # Author: Dingquan Li 4 | # Email: dingquanli AT pku DOT edu DOT cn 5 | # Date: 2019/11/8 6 | # 7 | # tensorboard --logdir=logs --port=6006 8 | # CUDA_VISIBLE_DEVICES=1 python VSFA.py --database=KoNViD-1k --exp_id=0 9 | 10 | from argparse import ArgumentParser 11 | import os 12 | import h5py 13 | import torch 14 | from torch.optim import Adam, lr_scheduler 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | from torch.utils.data import Dataset 18 | import numpy as np 19 | import random 20 | from scipy import stats 21 | from tensorboardX import SummaryWriter 22 | import datetime 23 | 24 | 25 | class VQADataset(Dataset): 26 | def __init__(self, features_dir='CNN_features_KoNViD-1k/', index=None, max_len=240, feat_dim=4096, scale=1): 27 | super(VQADataset, self).__init__() 28 | self.features = np.zeros((len(index), max_len, feat_dim)) 29 | self.length = np.zeros((len(index), 1)) 30 | self.mos = np.zeros((len(index), 1)) 31 | for i in range(len(index)): 32 | features = np.load(features_dir + str(index[i]) + '_resnet-50_res5c.npy') 33 | self.length[i] = features.shape[0] 34 | self.features[i, :features.shape[0], :] = features 35 | self.mos[i] = np.load(features_dir + str(index[i]) + '_score.npy') # 36 | self.scale = scale # 37 | self.label = self.mos / self.scale # label normalization 38 | 39 | def __len__(self): 40 | return len(self.mos) 41 | 42 | def __getitem__(self, idx): 43 | sample = self.features[idx], self.length[idx], self.label[idx] 44 | return sample 45 | 46 | 47 | class ANN(nn.Module): 48 | def __init__(self, input_size=4096, reduced_size=128, n_ANNlayers=1, dropout_p=0.5): 49 | super(ANN, self).__init__() 50 | self.n_ANNlayers = n_ANNlayers 51 | self.fc0 = nn.Linear(input_size, reduced_size) # 52 | self.dropout = nn.Dropout(p=dropout_p) # 53 | self.fc = nn.Linear(reduced_size, reduced_size) # 54 | 55 | def forward(self, input): 56 | input = self.fc0(input) # linear 57 | for i in range(self.n_ANNlayers-1): # nonlinear 58 | input = self.fc(self.dropout(F.relu(input))) 59 | return input 60 | 61 | 62 | def TP(q, tau=12, beta=0.5): 63 | """subjectively-inspired temporal pooling""" 64 | q = torch.unsqueeze(torch.t(q), 0) 65 | qm = -float('inf')*torch.ones((1, 1, tau-1)).to(q.device) 66 | qp = 10000.0 * torch.ones((1, 1, tau - 1)).to(q.device) # 67 | l = -F.max_pool1d(torch.cat((qm, -q), 2), tau, stride=1) 68 | m = F.avg_pool1d(torch.cat((q * torch.exp(-q), qp * torch.exp(-qp)), 2), tau, stride=1) 69 | n = F.avg_pool1d(torch.cat((torch.exp(-q), torch.exp(-qp)), 2), tau, stride=1) 70 | m = m / n 71 | return beta * m + (1 - beta) * l 72 | 73 | 74 | class VSFA(nn.Module): 75 | def __init__(self, input_size=4096, reduced_size=128, hidden_size=32): 76 | 77 | super(VSFA, self).__init__() 78 | self.hidden_size = hidden_size 79 | self.ann = ANN(input_size, reduced_size, 1) 80 | self.rnn = nn.GRU(reduced_size, hidden_size, batch_first=True) 81 | self.q = nn.Linear(hidden_size, 1) 82 | 83 | def forward(self, input, input_length): 84 | input = self.ann(input) # dimension reduction 85 | outputs, _ = self.rnn(input, self._get_initial_state(input.size(0), input.device)) 86 | q = self.q(outputs) # frame quality 87 | score = torch.zeros_like(input_length, device=q.device) # 88 | for i in range(input_length.shape[0]): # 89 | qi = q[i, :np.int(input_length[i].numpy())] 90 | qi = TP(qi) 91 | score[i] = torch.mean(qi) # video overall quality 92 | return score 93 | 94 | def _get_initial_state(self, batch_size, device): 95 | h0 = torch.zeros(1, batch_size, self.hidden_size, device=device) 96 | return h0 97 | 98 | 99 | if __name__ == "__main__": 100 | parser = ArgumentParser(description='"VSFA: Quality Assessment of In-the-Wild Videos') 101 | parser.add_argument("--seed", type=int, default=19920517) 102 | parser.add_argument('--lr', type=float, default=0.00001, 103 | help='learning rate (default: 0.00001)') 104 | parser.add_argument('--batch_size', type=int, default=16, 105 | help='input batch size for training (default: 16)') 106 | parser.add_argument('--epochs', type=int, default=2000, 107 | help='number of epochs to train (default: 2000)') 108 | 109 | parser.add_argument('--database', default='CVD2014', type=str, 110 | help='database name (default: CVD2014)') 111 | parser.add_argument('--model', default='VSFA', type=str, 112 | help='model name (default: VSFA)') 113 | parser.add_argument('--exp_id', default=0, type=int, 114 | help='exp id for train-val-test splits (default: 0)') 115 | parser.add_argument('--test_ratio', type=float, default=0.2, 116 | help='test ratio (default: 0.2)') 117 | parser.add_argument('--val_ratio', type=float, default=0.2, 118 | help='val ratio (default: 0.2)') 119 | 120 | parser.add_argument('--weight_decay', type=float, default=0.0, 121 | help='weight decay (default: 0.0)') 122 | 123 | parser.add_argument("--notest_during_training", action='store_true', 124 | help='flag whether to test during training') 125 | parser.add_argument("--disable_visualization", action='store_true', 126 | help='flag whether to enable TensorBoard visualization') 127 | parser.add_argument("--log_dir", type=str, default="logs", 128 | help="log directory for Tensorboard log output") 129 | parser.add_argument('--disable_gpu', action='store_true', 130 | help='flag whether to disable GPU') 131 | args = parser.parse_args() 132 | 133 | args.decay_interval = int(args.epochs/10) 134 | args.decay_ratio = 0.8 135 | 136 | torch.manual_seed(args.seed) # 137 | torch.backends.cudnn.deterministic = True 138 | torch.backends.cudnn.benchmark = False 139 | np.random.seed(args.seed) 140 | random.seed(args.seed) 141 | 142 | torch.utils.backcompat.broadcast_warning.enabled = True 143 | 144 | if args.database == 'KoNViD-1k': 145 | features_dir = 'CNN_features_KoNViD-1k/' # features dir 146 | datainfo = 'data/KoNViD-1kinfo.mat' # database info: video_names, scores; video format, width, height, index, ref_ids, max_len, etc. 147 | if args.database == 'CVD2014': 148 | features_dir = 'CNN_features_CVD2014/' 149 | datainfo = 'data/CVD2014info.mat' 150 | if args.database == 'LIVE-Qualcomm': 151 | features_dir = 'CNN_features_LIVE-Qualcomm/' 152 | datainfo = 'data/LIVE-Qualcomminfo.mat' 153 | 154 | print('EXP ID: {}'.format(args.exp_id)) 155 | print(args.database) 156 | print(args.model) 157 | 158 | device = torch.device("cuda" if not args.disable_gpu and torch.cuda.is_available() else "cpu") 159 | 160 | Info = h5py.File(datainfo, 'r') # index, ref_ids 161 | index = Info['index'] 162 | index = index[:, args.exp_id % index.shape[1]] # np.random.permutation(N) 163 | ref_ids = Info['ref_ids'][0, :] # 164 | max_len = int(Info['max_len'][0]) 165 | trainindex = index[0:int(np.ceil((1 - args.test_ratio - args.val_ratio) * len(index)))] 166 | testindex = index[int(np.ceil((1 - args.test_ratio) * len(index))):len(index)] 167 | train_index, val_index, test_index = [], [], [] 168 | for i in range(len(ref_ids)): 169 | train_index.append(i) if (ref_ids[i] in trainindex) else \ 170 | test_index.append(i) if (ref_ids[i] in testindex) else \ 171 | val_index.append(i) 172 | 173 | scale = Info['scores'][0, :].max() # label normalization factor 174 | train_dataset = VQADataset(features_dir, train_index, max_len, scale=scale) 175 | train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=True) 176 | val_dataset = VQADataset(features_dir, val_index, max_len, scale=scale) 177 | val_loader = torch.utils.data.DataLoader(dataset=val_dataset) 178 | if args.test_ratio > 0: 179 | test_dataset = VQADataset(features_dir, test_index, max_len, scale=scale) 180 | test_loader = torch.utils.data.DataLoader(dataset=test_dataset) 181 | 182 | model = VSFA().to(device) # 183 | 184 | if not os.path.exists('models'): 185 | os.makedirs('models') 186 | trained_model_file = 'models/{}-{}-EXP{}'.format(args.model, args.database, args.exp_id) 187 | if not os.path.exists('results'): 188 | os.makedirs('results') 189 | save_result_file = 'results/{}-{}-EXP{}'.format(args.model, args.database, args.exp_id) 190 | 191 | if not args.disable_visualization: # Tensorboard Visualization 192 | writer = SummaryWriter(log_dir='{}/EXP{}-{}-{}-{}-{}-{}-{}' 193 | .format(args.log_dir, args.exp_id, args.database, args.model, 194 | args.lr, args.batch_size, args.epochs, 195 | datetime.datetime.now().strftime("%I:%M%p on %B %d, %Y"))) 196 | 197 | criterion = nn.L1Loss() # L1 loss 198 | optimizer = Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 199 | scheduler = lr_scheduler.StepLR(optimizer, step_size=args.decay_interval, gamma=args.decay_ratio) 200 | best_val_criterion = -1 # SROCC min 201 | for epoch in range(args.epochs): 202 | # Train 203 | model.train() 204 | L = 0 205 | for i, (features, length, label) in enumerate(train_loader): 206 | features = features.to(device).float() 207 | label = label.to(device).float() 208 | optimizer.zero_grad() # 209 | outputs = model(features, length.float()) 210 | loss = criterion(outputs, label) 211 | loss.backward() 212 | optimizer.step() 213 | L = L + loss.item() 214 | train_loss = L / (i + 1) 215 | 216 | model.eval() 217 | # Val 218 | y_pred = np.zeros(len(val_index)) 219 | y_val = np.zeros(len(val_index)) 220 | L = 0 221 | with torch.no_grad(): 222 | for i, (features, length, label) in enumerate(val_loader): 223 | y_val[i] = scale * label.item() # 224 | features = features.to(device).float() 225 | label = label.to(device).float() 226 | outputs = model(features, length.float()) 227 | y_pred[i] = scale * outputs.item() 228 | loss = criterion(outputs, label) 229 | L = L + loss.item() 230 | val_loss = L / (i + 1) 231 | val_PLCC = stats.pearsonr(y_pred, y_val)[0] 232 | val_SROCC = stats.spearmanr(y_pred, y_val)[0] 233 | val_RMSE = np.sqrt(((y_pred-y_val) ** 2).mean()) 234 | val_KROCC = stats.stats.kendalltau(y_pred, y_val)[0] 235 | 236 | # Test 237 | if args.test_ratio > 0 and not args.notest_during_training: 238 | y_pred = np.zeros(len(test_index)) 239 | y_test = np.zeros(len(test_index)) 240 | L = 0 241 | with torch.no_grad(): 242 | for i, (features, length, label) in enumerate(test_loader): 243 | y_test[i] = scale * label.item() # 244 | features = features.to(device).float() 245 | label = label.to(device).float() 246 | outputs = model(features, length.float()) 247 | y_pred[i] = scale * outputs.item() 248 | loss = criterion(outputs, label) 249 | L = L + loss.item() 250 | test_loss = L / (i + 1) 251 | PLCC = stats.pearsonr(y_pred, y_test)[0] 252 | SROCC = stats.spearmanr(y_pred, y_test)[0] 253 | RMSE = np.sqrt(((y_pred-y_test) ** 2).mean()) 254 | KROCC = stats.stats.kendalltau(y_pred, y_test)[0] 255 | 256 | if not args.disable_visualization: # record training curves 257 | writer.add_scalar("loss/train", train_loss, epoch) # 258 | writer.add_scalar("loss/val", val_loss, epoch) # 259 | writer.add_scalar("SROCC/val", val_SROCC, epoch) # 260 | writer.add_scalar("KROCC/val", val_KROCC, epoch) # 261 | writer.add_scalar("PLCC/val", val_PLCC, epoch) # 262 | writer.add_scalar("RMSE/val", val_RMSE, epoch) # 263 | if args.test_ratio > 0 and not args.notest_during_training: 264 | writer.add_scalar("loss/test", test_loss, epoch) # 265 | writer.add_scalar("SROCC/test", SROCC, epoch) # 266 | writer.add_scalar("KROCC/test", KROCC, epoch) # 267 | writer.add_scalar("PLCC/test", PLCC, epoch) # 268 | writer.add_scalar("RMSE/test", RMSE, epoch) # 269 | 270 | # Update the model with the best val_SROCC 271 | if val_SROCC > best_val_criterion: 272 | print("EXP ID={}: Update best model using best_val_criterion in epoch {}".format(args.exp_id, epoch)) 273 | print("Val results: val loss={:.4f}, SROCC={:.4f}, KROCC={:.4f}, PLCC={:.4f}, RMSE={:.4f}" 274 | .format(val_loss, val_SROCC, val_KROCC, val_PLCC, val_RMSE)) 275 | if args.test_ratio > 0 and not args.notest_during_training: 276 | print("Test results: test loss={:.4f}, SROCC={:.4f}, KROCC={:.4f}, PLCC={:.4f}, RMSE={:.4f}" 277 | .format(test_loss, SROCC, KROCC, PLCC, RMSE)) 278 | np.save(save_result_file, (y_pred, y_test, test_loss, SROCC, KROCC, PLCC, RMSE, test_index)) 279 | torch.save(model.state_dict(), trained_model_file) 280 | best_val_criterion = val_SROCC # update best val SROCC 281 | 282 | # Test 283 | if args.test_ratio > 0: 284 | model.load_state_dict(torch.load(trained_model_file)) # 285 | model.eval() 286 | with torch.no_grad(): 287 | y_pred = np.zeros(len(test_index)) 288 | y_test = np.zeros(len(test_index)) 289 | L = 0 290 | for i, (features, length, label) in enumerate(test_loader): 291 | y_test[i] = scale * label.item() # 292 | features = features.to(device).float() 293 | label = label.to(device).float() 294 | outputs = model(features, length.float()) 295 | y_pred[i] = scale * outputs.item() 296 | loss = criterion(outputs, label) 297 | L = L + loss.item() 298 | test_loss = L / (i + 1) 299 | PLCC = stats.pearsonr(y_pred, y_test)[0] 300 | SROCC = stats.spearmanr(y_pred, y_test)[0] 301 | RMSE = np.sqrt(((y_pred-y_test) ** 2).mean()) 302 | KROCC = stats.stats.kendalltau(y_pred, y_test)[0] 303 | print("Test results: test loss={:.4f}, SROCC={:.4f}, KROCC={:.4f}, PLCC={:.4f}, RMSE={:.4f}" 304 | .format(test_loss, SROCC, KROCC, PLCC, RMSE)) 305 | np.save(save_result_file, (y_pred, y_test, test_loss, SROCC, KROCC, PLCC, RMSE, test_index)) 306 | -------------------------------------------------------------------------------- /_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-cayman -------------------------------------------------------------------------------- /data/CVD2014info.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lidq92/VSFA/b78751b58b05321aea0f13ee913b4bee1531a233/data/CVD2014info.mat -------------------------------------------------------------------------------- /data/KoNViD-1kinfo.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lidq92/VSFA/b78751b58b05321aea0f13ee913b4bee1531a233/data/KoNViD-1kinfo.mat -------------------------------------------------------------------------------- /data/LIVE-Qualcomminfo.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lidq92/VSFA/b78751b58b05321aea0f13ee913b4bee1531a233/data/LIVE-Qualcomminfo.mat -------------------------------------------------------------------------------- /data/data_info_maker.m: -------------------------------------------------------------------------------- 1 | % https://www.mathworks.com/help/matlab/ref/save.html 2 | % save your mat file with v7.3 3 | % To view or set the default version for MAT-files, go to the Home tab and in the Environment section, click Preferences. 4 | % Select MATLAB > General > MAT-Files and then choose a MAT-file save format option. 5 | 6 | clear,clc; 7 | 8 | %% KoNViD-1k 9 | data_path = '/media/ldq/Research/Data/KoNViD-1k/KoNViD_1k_attributes.csv'; 10 | data = readtable(data_path); 11 | video_names = data.file_name; % video names 12 | scores = data.MOS; % subjective scores 13 | clear data_path data 14 | 15 | height = 540; % video height 16 | width = 960; % video width 17 | max_len = 240; % maximum video length in the dataset 18 | video_format = 'RGB'; % video format 19 | ref_ids = [1:length(scores)]'; % video content ids 20 | % `random` train-val-test split index, 1000 runs 21 | index = cell2mat(arrayfun(@(i)randperm(length(scores)), ... 22 | 1:1000,'UniformOutput', false)'); 23 | save('KoNViD-1kinfo','-v7.3') 24 | 25 | %% CVD2014 26 | data_path = '/media/ldq/Research/Data/CVD2014/CVD2014_ratings/Realignment_MOS.csv'; 27 | data = readtable(data_path); 28 | video_names = arrayfun(@(i) ['Test' data.File_name{i}(6) '/' ... 29 | data.Content{i} '/' data.File_name{i} '.avi'], 1:234, ... 30 | 'UniformOutput', false)'; % video names, remove '', add dir 31 | scores = arrayfun(@(i) str2double(data.RealignmentMOS{i})/100, 1:234)'; % subjective scores 32 | clear data_path data 33 | 34 | height = [720 480]; 35 | width = [1280 640]; 36 | max_len = 830; 37 | video_format = 'RGB'; 38 | ref_ids = [1:length(scores)]'; 39 | % `random` train-val-test split index, 1000 runs 40 | index = cell2mat(arrayfun(@(i)randperm(length(scores)), ... 41 | 1:1000,'UniformOutput', false)'); 42 | save('CVD2014info','-v7.3') 43 | % LIVE-Qualcomm 44 | data_path = '/media/ldq/Others/Data/12.LIVE-Qualcomm Mobile In-Capture Video Quality Database/qualcommSubjectiveData.mat'; 45 | data = load(data_path); 46 | scores = data.qualcommSubjectiveData.unBiasedMOS; % subjective scores 47 | video_names = data.qualcommVideoData; 48 | video_names = arrayfun(@(i) [video_names.distortionNames{video_names.distortionType(i)} ... 49 | '/' video_names.vidNames{i}], 1:length(scores), ... 50 | 'UniformOutput', false)'; % video names 51 | clear data_path data 52 | 53 | height = 1080; 54 | width = 1920; 55 | max_len = 526; 56 | video_format = 'YUV420'; 57 | ref_ids = [1:length(scores)]'; 58 | % `random` train-val-test split index, 1000 runs 59 | index = cell2mat(arrayfun(@(i)randperm(length(scores)), ... 60 | 1:1000,'UniformOutput', false)'); 61 | save('LIVE-Qualcomminfo','-v7.3') 62 | -------------------------------------------------------------------------------- /models/VSFA.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lidq92/VSFA/b78751b58b05321aea0f13ee913b4bee1531a233/models/VSFA.pt -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | argparse==1.4.0 2 | h5py==2.10.0 3 | PyYAML==5.1.2 4 | Pillow==6.2.1 5 | scikit-video==1.1.11 6 | numpy==1.17.3 7 | scipy==1.0.1 8 | torch==1.1.0 9 | torchvision==0.3.0 10 | pytorch-ignite==0.2.1 11 | tensorflow-gpu==1.0.0 12 | tensorboardX==1.2 13 | -------------------------------------------------------------------------------- /test.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lidq92/VSFA/b78751b58b05321aea0f13ee913b4bee1531a233/test.mp4 -------------------------------------------------------------------------------- /test_demo.py: -------------------------------------------------------------------------------- 1 | """Test Demo for Quality Assessment of In-the-Wild Videos, ACM MM 2019""" 2 | # 3 | # Author: Dingquan Li 4 | # Email: dingquanli AT pku DOT edu DOT cn 5 | # Date: 2018/3/27 6 | # 7 | import torch 8 | from torchvision import transforms 9 | import skvideo.io 10 | from PIL import Image 11 | import numpy as np 12 | from VSFA import VSFA 13 | from CNNfeatures import get_features 14 | from argparse import ArgumentParser 15 | import time 16 | 17 | 18 | if __name__ == "__main__": 19 | parser = ArgumentParser(description='"Test Demo of VSFA') 20 | parser.add_argument('--model_path', default='models/VSFA.pt', type=str, 21 | help='model path (default: models/VSFA.pt)') 22 | parser.add_argument('--video_path', default='./test.mp4', type=str, 23 | help='video path (default: ./test.mp4)') 24 | parser.add_argument('--video_format', default='RGB', type=str, 25 | help='video format: RGB or YUV420 (default: RGB)') 26 | parser.add_argument('--video_width', type=int, default=None, 27 | help='video width') 28 | parser.add_argument('--video_height', type=int, default=None, 29 | help='video height') 30 | 31 | parser.add_argument('--frame_batch_size', type=int, default=32, 32 | help='frame batch size for feature extraction (default: 32)') 33 | args = parser.parse_args() 34 | 35 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 36 | 37 | start = time.time() 38 | 39 | # data preparation 40 | assert args.video_format == 'YUV420' or args.video_format == 'RGB' 41 | if args.video_format == 'YUV420': 42 | video_data = skvideo.io.vread(args.video_path, args.video_height, args.video_width, inputdict={'-pix_fmt': 'yuvj420p'}) 43 | else: 44 | video_data = skvideo.io.vread(args.video_path) 45 | 46 | video_length = video_data.shape[0] 47 | video_channel = video_data.shape[3] 48 | video_height = video_data.shape[1] 49 | video_width = video_data.shape[2] 50 | transformed_video = torch.zeros([video_length, video_channel, video_height, video_width]) 51 | transform = transforms.Compose([ 52 | transforms.ToTensor(), 53 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 54 | ]) 55 | 56 | for frame_idx in range(video_length): 57 | frame = video_data[frame_idx] 58 | frame = Image.fromarray(frame) 59 | frame = transform(frame) 60 | transformed_video[frame_idx] = frame 61 | 62 | print('Video length: {}'.format(transformed_video.shape[0])) 63 | 64 | # feature extraction 65 | features = get_features(transformed_video, frame_batch_size=args.frame_batch_size, device=device) 66 | features = torch.unsqueeze(features, 0) # batch size 1 67 | 68 | # quality prediction using VSFA 69 | model = VSFA() 70 | model.load_state_dict(torch.load(args.model_path)) # 71 | model.to(device) 72 | model.eval() 73 | with torch.no_grad(): 74 | input_length = features.shape[1] * torch.ones(1, 1) 75 | outputs = model(features, input_length) 76 | y_pred = outputs[0][0].to('cpu').numpy() 77 | print("Predicted quality: {}".format(y_pred)) 78 | 79 | end = time.time() 80 | 81 | print('Time: {} s'.format(end-start)) 82 | --------------------------------------------------------------------------------