├── models ├── __init__.py ├── VQ_net.py └── VP_net.py ├── utils ├── __init__.py ├── yuv.py ├── sphere.py └── proposal.py ├── dataset ├── __init__.py └── dataset_VQA_ODV.py ├── .gitmodules ├── LICENSE ├── .gitignore ├── README.md └── test.py /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "flownet2"] 2 | path = flownet2 3 | url = https://github.com/NVIDIA/flownet2-pytorch.git 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Archer タツ 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # V-CNN 2 | Viewport-based CNN for visual quality assessment on 360° video. 3 | 4 | Note that this is an updated version of the approach in our [CVPR2019 paper](http://openaccess.thecvf.com/content_CVPR_2019/html/Li_Viewport_Proposal_CNN_for_360deg_Video_Quality_Assessment_CVPR_2019_paper.html), and thus the results are further improved. 5 | There are several differences between the CVPR2019 paper and this code. 6 | 7 | Dataloader and the corresponding files for our [VQA-ODV](https://github.com/Archer-Tatsu/VQA-ODV) dataset are also provided. 8 | 9 | At least 1 GPU is required by FlowNet2. 10 | 11 | 12 | ## Dependencies 13 | 14 | * python3 15 | * PyTorch == 1.0.1 (CUDA 9.0 is required for compilation of FlowNet2) 16 | * s2cnn: https://github.com/jonas-koehler/s2cnn 17 | * FlowNet2: https://github.com/NVIDIA/flownet2-pytorch 18 | * numpy 19 | * scipy 20 | * scikit-image 21 | * tqdm 22 | 23 | ## Binaries 24 | 25 | The binaries including pre-trained model, as well as the list files for VQA-ODV in inference can be obtained [HERE](https://www.dropbox.com/sh/zblm9bnmc3dksti/AAC2zJB45WtAh4s9psVjKDIRa?dl=0). 26 | 27 | Please put all these files under the log directory. 28 | 29 | ## Usage 30 | 31 | ``` 32 | python test.py --log_dir /path/to/log/directory --flownet_ckpt /path/to/flownet2/pre-trained/model [--batch_size 1] [--num_workers 4] [--test_start_frame 21] [--test_interval 45] 33 | ``` 34 | Note that this released version only supports `batch_size` of 1 in inference. The `num_workers` should be set according to the condition of the computer. 35 | 36 | It may spend a lot of time to test on all frames for each sequence. Therefore, frame drop can be set via `test_start_frame` and `test_interval`. 37 | The default settings are to test every 45 frames for each sequence, beginning with the 22 frame. 38 | 39 | ## Reference 40 | If you find this code useful in your work, please acknowledge it appropriately and cite the paper: 41 | ``` 42 | @inproceedings{Li_2019_CVPR, 43 | author = {Li, Chen and Xu, Mai and Jiang, Lai and Zhang, Shanyi and Tao, Xiaoming}, 44 | title = {Viewport Proposal CNN for 360deg Video Quality Assessment}, 45 | booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 46 | pages = {10177--10186}, 47 | month = {June}, 48 | year = {2019} 49 | } 50 | ``` -------------------------------------------------------------------------------- /utils/yuv.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def yuv_import(file_path, dims, num_frames=1, start_frame=0, frames=None, yuv444=False): 5 | """ 6 | Import frame images from a YUV file. 7 | :param file_path: Path of the file. 8 | :param dims: (height, width) of the frames. 9 | :param num_frames: Number of the consecutive frames to be imported. 10 | :param start_frame: Index of the frame to be started. The first frame is indexed as 0. 11 | :param frames: Indexes of the frames to be imported. Inconsecutive frames are supported. 12 | :param yuv444: Whether the YUV file is in YUV444 mode. 13 | :return: Y, U, V, all as the numpy ndarray. 14 | """ 15 | 16 | fp = open(file_path, 'rb') 17 | ratio = 3 if yuv444 else 1.5 18 | blk_size = int(np.prod(dims) * ratio) 19 | if frames is None: 20 | assert num_frames > 0 21 | fp.seek(blk_size * start_frame, 0) 22 | 23 | height, width = dims 24 | Y = [] 25 | U = [] 26 | V = [] 27 | if yuv444: 28 | height_half = height 29 | width_half = width 30 | else: 31 | height_half = height // 2 32 | width_half = width // 2 33 | 34 | if frames is not None: 35 | previous_frame = -1 36 | for frame in frames: 37 | fp.seek(blk_size * (frame - previous_frame - 1), 1) 38 | Yt = np.fromfile(fp, dtype=np.uint8, count=width * height).reshape((height, width)) 39 | Ut = np.fromfile(fp, dtype=np.uint8, count=width_half * height_half).reshape((height_half, width_half)) 40 | Vt = np.fromfile(fp, dtype=np.uint8, count=width_half * height_half).reshape((height_half, width_half)) 41 | previous_frame = frame 42 | Y = Y + [Yt] 43 | U = U + [Ut] 44 | V = V + [Vt] 45 | 46 | else: 47 | for i in range(num_frames): 48 | Yt = np.fromfile(fp, dtype=np.uint8, count=width * height).reshape((height, width)) 49 | Ut = np.fromfile(fp, dtype=np.uint8, count=width_half * height_half).reshape((height_half, width_half)) 50 | Vt = np.fromfile(fp, dtype=np.uint8, count=width_half * height_half).reshape((height_half, width_half)) 51 | Y = Y + [Yt] 52 | U = U + [Ut] 53 | V = V + [Vt] 54 | 55 | fp.close() 56 | return np.array(Y), np.array(U), np.array(V) 57 | 58 | 59 | def yuv2rgb(Y, U, V): 60 | """ 61 | Convert YUV to RGB. 62 | """ 63 | 64 | if not Y.shape == U.shape: 65 | U = U.repeat(2, axis=1).repeat(2, axis=2).astype(np.float64) 66 | V = V.repeat(2, axis=1).repeat(2, axis=2).astype(np.float64) 67 | 68 | Y = Y.astype(np.float64) 69 | U = U.astype(np.float64) 70 | V = V.astype(np.float64) 71 | U -= 128.0 72 | V -= 128.0 73 | 74 | rr = 1.001574765442552 * Y + 0.002770649292941 * U + 1.574765442551769 * V 75 | gg = 0.999531875325065 * Y - 0.188148872370914 * U - 0.468124674935631 * V 76 | bb = 1.000000105739993 * Y + 1.855609881994441 * U + 1.057399924810358e-04 * V 77 | 78 | rr = rr.clip(0, 255).round().astype(np.uint8) 79 | gg = gg.clip(0, 255).round().astype(np.uint8) 80 | bb = bb.clip(0, 255).round().astype(np.uint8) 81 | 82 | return np.stack((rr, gg, bb), axis=1) 83 | -------------------------------------------------------------------------------- /models/VQ_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.cuda 4 | import torch.nn.functional as F 5 | from torchvision.models import densenet 6 | 7 | 8 | class Model(nn.Module): 9 | def __init__(self): 10 | super(Model, self).__init__() 11 | self.leaky = 0.1 12 | 13 | self.group_layers = nn.Sequential( 14 | nn.Conv2d(6, 32, 3, stride=1, padding=1, groups=2), 15 | nn.ReLU(inplace=True), 16 | nn.Conv2d(32, 64, 3, stride=1, padding=1, groups=2), 17 | nn.ReLU(inplace=True) 18 | ) 19 | densenet_layers = densenet.DenseNet(num_init_features=96, growth_rate=16, block_config=(4, 8, 16), 20 | drop_rate=0.5) 21 | densenet_layers.features[0] = nn.Conv2d(64, 96, kernel_size=7, stride=2, padding=3, bias=False) 22 | 23 | self.shared_layers = nn.Sequential( 24 | densenet_layers.features, 25 | nn.Conv2d(360, 512, 3, stride=1, padding=1, bias=False), 26 | nn.BatchNorm2d(512), 27 | nn.LeakyReLU(self.leaky, inplace=True), 28 | nn.Conv2d(512, 256, 3, stride=1, padding=1, bias=False), 29 | nn.BatchNorm2d(256), 30 | nn.LeakyReLU(self.leaky, inplace=True), 31 | ) 32 | self.score_layers = nn.Sequential( 33 | nn.Conv2d(448, 256, 3, stride=2, padding=1, bias=False), 34 | nn.BatchNorm2d(256), 35 | nn.MaxPool2d(2, 2), 36 | nn.LeakyReLU(self.leaky, inplace=True), 37 | nn.Conv2d(256, 128, 3, stride=2, padding=1, bias=False), 38 | nn.BatchNorm2d(128), 39 | nn.MaxPool2d(2, 2), 40 | nn.AdaptiveAvgPool2d(1) 41 | ) 42 | self.em_layers = nn.Sequential( 43 | nn.ConvTranspose2d(256, 128, 4, stride=2, padding=2, bias=False), 44 | nn.BatchNorm2d(128), 45 | nn.LeakyReLU(self.leaky, inplace=True), 46 | nn.ConvTranspose2d(128, 32, 4, stride=2, padding=2, bias=False), 47 | nn.BatchNorm2d(32), 48 | nn.ConvTranspose2d(32, 1, 2, stride=2, padding=1, bias=False), 49 | ) 50 | self.fc = nn.Sequential( 51 | nn.Linear(128, 16, bias=False), 52 | nn.BatchNorm1d(16), 53 | nn.Linear(16, 1) 54 | ) 55 | self.softmax = nn.Softmax(dim=-1) 56 | 57 | def forward(self, x, y): 58 | """ 59 | :param x: Impaired viewports. shape: (batch_size, channels, height, width) 60 | :param y: Viewport error map with the same shape of x. 61 | """ 62 | 63 | x = torch.cat((x, y), dim=1) 64 | x = F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=False) 65 | 66 | batch_size = x.shape[0] 67 | x = self.group_layers(x) 68 | x = self.shared_layers(x) 69 | 70 | em = self.em_layers(x) 71 | 72 | size = em.size() 73 | em = em.view(size[0], size[1], -1) 74 | em = self.softmax(em) 75 | em = em.view(size) 76 | 77 | z = F.interpolate(y, x.shape[-2:], mode='bilinear', align_corners=False) 78 | z = z.repeat(1, 64, 1, 1) 79 | x = torch.cat((x, z), dim=1) 80 | x = self.score_layers(x) 81 | x = x.view(batch_size, -1) 82 | x = self.fc(x) 83 | return x, em 84 | -------------------------------------------------------------------------------- /utils/sphere.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | from functools import lru_cache 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | 8 | @lru_cache(maxsize=1) 9 | def viewport2sph_coord(port_w, port_h, fov_x, fov_y): 10 | """ 11 | Generate the meshgrid of the viewport and transform it to sphere coordinate. 12 | :param port_w: The width of the viewport in pixel. 13 | :param port_h: The height of the viewport in pixel. 14 | :param fov_x: Horizontal FoV in degree. 15 | :param fov_y: Vertical FoV in degree. 16 | :return: Sphere 3d coordinates of points of the meshgrid. Numpy ndarray with shape of (N, 3). 17 | """ 18 | 19 | u_mesh, v_mesh = np.meshgrid(range(port_w), range(port_h)) 20 | u_mesh, v_mesh = u_mesh.flatten(), v_mesh.flatten() 21 | 22 | u_mesh = u_mesh.astype(np.float64) + 0.5 23 | v_mesh = v_mesh.astype(np.float64) + 0.5 24 | 25 | fov_x_rad = math.pi * fov_x / 180 26 | fov_y_rad = math.pi * fov_y / 180 27 | fx = port_w / (2 * math.tan(fov_x_rad / 2)) 28 | fy = port_h / (2 * math.tan(fov_y_rad / 2)) 29 | 30 | K = np.asmatrix([[fx, 0, port_w / 2], [0, - fy, port_h / 2], [0, 0, 1]]) 31 | 32 | e = np.asmatrix([u_mesh, v_mesh, np.ones_like(u_mesh)]) 33 | q = K.I * e 34 | q_normed = q / np.linalg.norm(q, axis=0, keepdims=True) 35 | P = np.diag([1, 1, -1]) * q_normed 36 | return np.asarray(P) 37 | 38 | 39 | @lru_cache(maxsize=1) 40 | def cal_alignment_grid(viewport_resolution, lat, lon, P): 41 | """ 42 | Calculate the grid for viewport alignment according to the center of the viewport. 43 | :param viewport_resolution: Tuple. (height_of_viewport, width_of_viewport) 44 | :param lat: Latitude of the center of the viewport (i.e., head movement position) in degree. 1-D array 45 | :param lon: Longitude of the center of the viewport (i.e., head movement position) in degree. 1-D array 46 | :param P: Viewport meshgrid in sphere cooordinate. Numpy ndarray with shape of (N, 3). 47 | :return: Grid for interpolatation. Tensor in (viewport_num, *viewport_resolution). 48 | """ 49 | viewport_num = lat.shape[0] 50 | 51 | phi = lat * math.pi / 180 52 | tht = -lon * math.pi / 180 53 | 54 | # Rotation matrix 55 | R = torch.stack( 56 | (torch.stack((torch.cos(tht), torch.sin(tht) * torch.sin(phi), torch.sin(tht) * torch.cos(phi))), 57 | torch.stack((torch.zeros_like(phi), torch.cos(phi), - torch.sin(phi))), 58 | torch.stack((-torch.sin(tht), torch.cos(tht) * torch.sin(phi), torch.cos(tht) * torch.cos(phi))))) 59 | 60 | P = P.to(R) 61 | E = torch.matmul(R.permute(0, 2, 1), P) 62 | 63 | lat = 90 - torch.acos(E[1, :]) * 180 / math.pi 64 | lon = torch.atan2(E[0, :], -E[2, :]) * 180 / math.pi 65 | lat = lat.view((viewport_num, *viewport_resolution)) 66 | lon = lon.view((viewport_num, *viewport_resolution)) 67 | 68 | pix_height = -lat / 90 69 | pix_width = lon / 180 70 | grid = torch.stack((pix_width, pix_height)) 71 | grid = grid.permute(1, 2, 3, 0).to(torch.float) 72 | 73 | return grid 74 | 75 | 76 | def viewport_alignment(img, p_lat, t_lon, viewport_resolution=(600, 540)): 77 | """ 78 | Apply viewport alignment. 79 | :param img: Tensor of the frame. 80 | :param p_lat: Latitude of the center of the viewport (i.e., head movement position) in degree. 1-D array. 81 | :param t_lon: Longitude of the center of the viewport (i.e., head movement position) in degree. 1-D array. 82 | :param viewport_resolution: Tuple. (height_of_viewport, width_of_viewport). 83 | :return viewports. (viewport_num, 3, *viewport_resolution). 84 | """ 85 | viewport_num = p_lat.shape[0] 86 | port_h, port_w = viewport_resolution 87 | 88 | P = torch.tensor(viewport2sph_coord(port_w, port_h, 71, 74).astype(np.float32)) 89 | 90 | grid = cal_alignment_grid(viewport_resolution, p_lat, t_lon, P) 91 | viewport = F.grid_sample(img.expand(viewport_num, -1, -1, -1), grid) 92 | 93 | return viewport 94 | -------------------------------------------------------------------------------- /utils/proposal.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from math import pi 3 | import torch 4 | 5 | 6 | def generate_anchors(shape): 7 | """ 8 | Generate anchors on the sphere. 9 | :param shape: Spatial shape of the feature map: (height, width). 10 | """ 11 | 12 | # Enumerate shifts in feature space 13 | stride = np.array((180, 360)) / shape 14 | shifts_lat = np.arange(0, shape[0]) * stride[0] + 0.5 * stride[0] - 90 15 | shifts_lon = np.arange(0, shape[1]) * stride[1] + 0.5 * stride[1] - 180 16 | shifts_lon, shifts_lat = np.meshgrid(shifts_lon, np.flip(shifts_lat)) 17 | 18 | # Reshape to get a list 19 | anchors = np.stack([shifts_lat, shifts_lon], axis=2).reshape([-1, 2]) 20 | 21 | return anchors.astype(np.float32) 22 | 23 | 24 | def softer_nms(hm_points, weights, threshold=13.75, top_k=256, proposal_count=20): 25 | """ 26 | Apply non-maximum suppression at test time to avoid proposing too many overlapping viewports. 27 | :param hm_points: All predicted HM points on the sphere, Shape: (N, 2). 28 | :param weights: Predicted weights corresponding to the HM points, Shape: (N, ). 29 | :param threshold: (float) The threshold for suppressing close HM in degree. 30 | :param top_k: (int) The maximum number of points to consider. 31 | :param proposal_count: (int) The maximum number of points to propose. 32 | :return The proposed HM points with weights. 33 | """ 34 | 35 | proposed_points = [] 36 | proposed_weights = [] 37 | if hm_points.numel() == 0: 38 | return proposed_points 39 | 40 | lat = hm_points[:, 0] * pi / 180 41 | lon = hm_points[:, 1] * pi / 180 42 | _, idx = weights.sort(0) # sort in ascending order 43 | idx = idx[-top_k:] # indices of the top-k largest values 44 | 45 | lat1, lat2 = torch.meshgrid((lat[idx], lat[idx])) 46 | lon1, lon2 = torch.meshgrid((lon[idx], lon[idx])) 47 | 48 | # Great-circle distance 49 | distances = torch.acos( 50 | torch.cos(lat1) * torch.cos(lat2) * torch.cos(lon1 - lon2) + torch.sin(lat1) * torch.sin(lat2)) 51 | 52 | distances[torch.eye(distances.shape[0], dtype=torch.uint8)] = 0 53 | 54 | while idx.numel() > 0: 55 | lt_mask = distances[-1].lt(threshold * pi / 180) 56 | lt_weights = weights[idx[lt_mask]] 57 | lt_points = hm_points[idx[lt_mask], :] 58 | 59 | new_point = torch.sum(lt_points * lt_weights[:, None], dim=0) / torch.sum(lt_weights) 60 | proposed_points.append(new_point) 61 | proposed_weights.append(torch.sum(lt_weights)) 62 | 63 | ge_mask = 1 - lt_mask 64 | idx = idx[ge_mask] 65 | 66 | if torch.any(ge_mask): 67 | distances = distances[ge_mask] 68 | distances = distances[:, ge_mask] 69 | else: 70 | break 71 | 72 | if idx.size(0) == 1 or len(proposed_points) == proposal_count: 73 | break 74 | 75 | return proposed_points, proposed_weights 76 | 77 | 78 | def proposal_layer(weights, offsets, proposal_count, nms_threshold, anchors, mask): 79 | """ 80 | Receives anchor weights and selects a subset to pass as proposals to the second stage. 81 | Filtering is done based on anchor weights and non-maximum suppression to remove overlaps. 82 | It also applies HM refinement offsets to anchors. 83 | :param weights: Predicted weights corresponding to all anchors. 84 | :param offsets: Predicted offsets corresponding to all anchors. 85 | :param proposal_count: (int) The maximum number of points to propose. 86 | :param nms_threshold: (float) The threshold for suppressing close HM in degree. 87 | :param anchors: All anchor points. 88 | :param mask: The mask to down sample anchors near the polars. 89 | :return The proposed HM points with normalized weights. 90 | """ 91 | 92 | # Currently only supports batch size 1 93 | weights = weights.squeeze(0) 94 | offsets = offsets.squeeze(0) 95 | 96 | weights = weights.view(-1) 97 | offsets = offsets * 180 / pi 98 | hm_points = anchors + offsets 99 | if mask is not None: 100 | weights = weights[mask] 101 | hm_points = hm_points[mask] 102 | 103 | # Fix boundary at +-180 104 | ids = hm_points[:, 1] > 180 105 | hm_points[ids, 1] = hm_points[ids, 1] - 360 106 | ids = hm_points[:, 1] < -180 107 | hm_points[ids, 1] = hm_points[ids, 1] + 360 108 | 109 | # Non-max suppression 110 | if proposal_count is not None: 111 | hm_points, weights = softer_nms(hm_points, weights, nms_threshold, proposal_count=proposal_count) 112 | hm_points = torch.stack(hm_points) 113 | weights = torch.stack(weights) 114 | if float(weights.sum()) > 0: 115 | weights /= weights.sum() 116 | 117 | return hm_points, weights 118 | -------------------------------------------------------------------------------- /models/VP_net.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from s2cnn import s2_near_identity_grid, S2Convolution, SO3Convolution, \ 7 | so3_near_identity_grid 8 | 9 | 10 | class Model(nn.Module): 11 | def __init__(self): 12 | super(Model, self).__init__() 13 | 14 | self.leaky_alpha = 0.1 15 | 16 | # S2 layer 17 | grid = s2_near_identity_grid(max_beta=np.pi / 64, n_alpha=4, n_beta=2) 18 | self.layer0 = nn.Sequential( 19 | S2Convolution(3, 16, 128, 64, grid), 20 | nn.GroupNorm(1, 16), 21 | nn.LeakyReLU(self.leaky_alpha, inplace=True), 22 | ) 23 | 24 | self.flow_layer0 = nn.Sequential( 25 | S2Convolution(2, 16, 128, 64, grid), 26 | nn.GroupNorm(1, 16), 27 | nn.LeakyReLU(self.leaky_alpha, inplace=True), 28 | ) 29 | 30 | grid = so3_near_identity_grid(max_beta=np.pi / 32, max_gamma=0, n_alpha=4, n_beta=2, n_gamma=1) 31 | self.layer1, self.flow_layer1 = ( 32 | nn.Sequential( 33 | SO3Convolution(16, 16, 64, 32, grid), 34 | nn.GroupNorm(1, 16), 35 | nn.LeakyReLU(self.leaky_alpha, inplace=True), 36 | SO3Convolution(16, 32, 32, 32, grid), 37 | nn.GroupNorm(2, 32), 38 | nn.LeakyReLU(self.leaky_alpha, inplace=True), 39 | ) 40 | for _ in range(2) 41 | ) 42 | 43 | grid = so3_near_identity_grid(max_beta=np.pi / 16, max_gamma=0, n_alpha=4, n_beta=2, n_gamma=1) 44 | self.layer2, self.flow_layer2 = ( 45 | nn.Sequential( 46 | SO3Convolution(32, 32, 32, 16, grid), 47 | nn.GroupNorm(2, 32), 48 | nn.LeakyReLU(self.leaky_alpha, inplace=True), 49 | SO3Convolution(32, 64, 16, 16, grid), 50 | nn.GroupNorm(4, 64), 51 | nn.LeakyReLU(self.leaky_alpha, inplace=True), 52 | ) 53 | for _ in range(2) 54 | ) 55 | 56 | grid = so3_near_identity_grid(max_beta=np.pi / 8, max_gamma=0, n_alpha=4, n_beta=2, n_gamma=1) 57 | self.layer3, self.flow_layer3 = ( 58 | nn.Sequential( 59 | SO3Convolution(64, 64, 16, 8, grid), 60 | nn.GroupNorm(4, 64), 61 | nn.LeakyReLU(self.leaky_alpha, inplace=True), 62 | SO3Convolution(64, 128, 8, 8, grid), 63 | nn.GroupNorm(8, 128), 64 | nn.LeakyReLU(self.leaky_alpha, inplace=True), 65 | ) 66 | for _ in range(2) 67 | ) 68 | 69 | grid = so3_near_identity_grid(max_beta=np.pi / 16, max_gamma=0, n_alpha=4, n_beta=2, n_gamma=1) 70 | self.layer4 = nn.Sequential( 71 | SO3Convolution(256, 128, 8, 8, grid), 72 | nn.GroupNorm(8, 128), 73 | nn.LeakyReLU(self.leaky_alpha, inplace=True), 74 | ) 75 | 76 | self.weight_layer = nn.Sequential( 77 | nn.Conv2d(129, 1, kernel_size=1, stride=1, bias=False), 78 | ) 79 | 80 | self.refine_layer = nn.Sequential( 81 | nn.Conv2d(129, 2, kernel_size=1, stride=1, bias=False), 82 | ) 83 | 84 | self.motion_layer1 = nn.Sequential( 85 | nn.Conv2d(256, 32, 3, stride=2, padding=1, bias=True), 86 | nn.LeakyReLU(self.leaky_alpha, inplace=True), 87 | nn.Conv2d(32, 8, 3, stride=2, padding=1, bias=False), 88 | ) 89 | self.motion_layer2 = nn.Linear(128, 2, bias=False) 90 | self.control_layer = nn.Sequential( 91 | nn.Conv2d(128, 129, 1, 1, 0, bias=False), 92 | nn.Sigmoid() 93 | ) 94 | 95 | self.softmax = nn.Softmax(dim=-1) 96 | 97 | def forward(self, img, flow, cb): 98 | batch_size = img.shape[0] 99 | 100 | for layer in (self.layer0, self.layer1, self.layer2, self.layer3): 101 | img = layer(img) 102 | 103 | for layer in (self.flow_layer0, self.flow_layer1, self.flow_layer2, self.flow_layer3): 104 | flow = layer(flow) 105 | 106 | spatial_feat = self.layer4(torch.cat((img, flow), dim=1)) 107 | spatial_feat = spatial_feat.mean(-1) 108 | 109 | motion = torch.cat((spatial_feat, flow.mean(-1)), dim=1) 110 | 111 | motion = self.motion_layer1(motion) 112 | motion = motion.reshape(batch_size, -1) 113 | 114 | m_control = motion.detach().unsqueeze(-1).unsqueeze(-1) 115 | m_control = self.control_layer(m_control) 116 | 117 | cb = F.adaptive_avg_pool2d(cb, spatial_feat.shape[-2:]) 118 | spatial_feat = torch.cat((spatial_feat, cb), dim=1) 119 | spatial_feat = spatial_feat * m_control 120 | 121 | motion = self.motion_layer2(motion) 122 | 123 | # HM refinement. 124 | pred_offset = self.refine_layer(spatial_feat) 125 | pred_offset = pred_offset.permute(0, 2, 3, 1).contiguous().view(batch_size, -1, 2) 126 | 127 | motion = F.softmax(motion, dim=1) 128 | 129 | pred_weight = self.weight_layer(spatial_feat) 130 | 131 | size = pred_weight.size() 132 | pred_weight = pred_weight.view(size[0], size[1], -1) 133 | pred_weight = self.softmax(pred_weight) 134 | pred_weight = pred_weight.view(size) 135 | 136 | return pred_weight, pred_offset, motion 137 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import torch.cuda 5 | import torch.utils.data 6 | import logging 7 | import copy 8 | import skimage.transform 9 | import scipy.stats 10 | from tqdm import tqdm 11 | 12 | from utils.proposal import generate_anchors, proposal_layer 13 | from utils.sphere import viewport_alignment 14 | from flownet2.models import FlowNet2 15 | from dataset.dataset_VQA_ODV import DS_VQA_ODV, VQA_ODV_Transform 16 | from models import VP_net, VQ_net 17 | 18 | 19 | def main(log_dir, batch_size, num_workers, flownet_ckpt, test_start_frame, test_interval): 20 | arguments = copy.deepcopy(locals()) 21 | 22 | if not torch.cuda.is_available(): 23 | raise RuntimeError('At least 1 GPU is needed by FlowNet2.') 24 | device_main = torch.device('cuda:0') 25 | 26 | # For viewport alignment on 8K frame, more than 6 GB GPU memory is needed, 27 | # and thus it needs a different GPU device or fallback to CPU 28 | if torch.cuda.device_count() > 1: 29 | device_alignment = torch.device('cuda:1') 30 | else: 31 | device_alignment = torch.device('cpu') 32 | torch.backends.cudnn.benchmark = True 33 | 34 | logger = logging.getLogger("test") 35 | logger.handlers = [] 36 | logger.setLevel(logging.DEBUG) 37 | ch = logging.StreamHandler() 38 | ch.setLevel(logging.INFO) 39 | logger.addHandler(ch) 40 | 41 | logger.info("%s", repr(arguments)) 42 | 43 | bandwidth = 128 44 | test_set = DS_VQA_ODV(root=os.path.join(log_dir, "VQA_ODV"), dataset_type='test', tr_te_file='tr_te_VQA_ODV.txt', 45 | ds_list_file='VQA_ODV.txt', test_interval=test_interval, test_start_frame=test_start_frame, 46 | transform=VQA_ODV_Transform(bandwidth=bandwidth, down_resolution=(1024, 2048), to_rgb=True)) 47 | 48 | anchor_shape = (16, 16) 49 | anchors = torch.tensor(generate_anchors(np.array(anchor_shape))) 50 | 51 | # Gaussian center bias 52 | cb = np.load(os.path.join(log_dir, 'cb256.npy')).astype(np.float32)[np.newaxis, np.newaxis, ...] 53 | cb = torch.tensor(cb).to(device_main) 54 | # Mask for anchors 55 | anchor_mask = np.load(os.path.join(log_dir, 'anchor_mask.npy')).astype(np.int64) 56 | anchor_mask = torch.tensor(anchor_mask) 57 | 58 | vpnet = VP_net.Model() 59 | vpnet.to(device_main) 60 | vpnet.load_state_dict(torch.load(os.path.join(log_dir, 'vp_state.pkl'))) 61 | logger.info("Successfully loaded VP-net pre-trained model.") 62 | 63 | vqnet = VQ_net.Model() 64 | vqnet.to(device_main) 65 | vqnet.load_state_dict(torch.load(os.path.join(log_dir, 'vq_state.pkl'))) 66 | logger.info("Successfully loaded VQ-net pre-trained model.") 67 | 68 | class FlowNetParams: 69 | rgb_max = 255.0 70 | fp16 = False 71 | 72 | flownet = FlowNet2(args=FlowNetParams()) 73 | flownet.to(device_main) 74 | 75 | if isinstance(flownet_ckpt, str): 76 | flownet_ckpt = torch.load(flownet_ckpt) 77 | flownet.load_state_dict(flownet_ckpt['state_dict']) 78 | logger.info("Successfully loaded FlowNet2 pre-trained model.") 79 | flownet.eval() 80 | 81 | test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, num_workers=num_workers, 82 | shuffle=False, pin_memory=True, drop_last=False) 83 | 84 | pred = [] 85 | targets = [] 86 | 87 | vpnet.eval() 88 | vqnet.eval() 89 | 90 | for batch_idx, img_tuple in enumerate(tqdm(test_loader)): 91 | with torch.no_grad(): 92 | img_s2, img_original, img_down, img_gap_s2, gap_down, ref_original, target = img_tuple 93 | 94 | gap_down = gap_down.to(device_main) 95 | img_down = img_down.to(device_main) 96 | gap_down = gap_down.view((-1, *gap_down.shape[-3:])) 97 | img_down = img_down.view((-1, *img_down.shape[-3:])) 98 | 99 | # Optical flow 100 | flow = torch.stack((gap_down, img_down), dim=0).permute(1, 2, 0, 3, 4) 101 | flow = flownet(flow) 102 | flow = flow.cpu().numpy().transpose((2, 3, 1, 0)) 103 | flow = skimage.transform.resize(flow, (bandwidth * 2, bandwidth * 2) + flow.shape[-2:], order=1, 104 | anti_aliasing=True, mode='reflect', preserve_range=True).astype(np.float32) 105 | flow_s2 = torch.tensor(flow.transpose((3, 2, 0, 1))) 106 | flow_s2 = flow_s2.to(device_main) 107 | 108 | # VP net 109 | img_s2 = img_s2.to(device_main) 110 | img_gap_s2 = img_gap_s2.to(device_main) 111 | img_s2 = img_s2.view((-1, *img_s2.shape[-3:])) 112 | img_gap_s2 = img_gap_s2.view((-1, *img_gap_s2.shape[-3:])) 113 | 114 | vp_hm_weight, vp_hm_offset, _ = vpnet(img_s2, flow_s2, cb) 115 | 116 | # Viewport softer NMS 117 | hm_after_nms, hm_weight = proposal_layer(vp_hm_weight, vp_hm_offset, 20, 7.5, anchors.to(vp_hm_offset), 118 | mask=anchor_mask) 119 | 120 | # Viewport alignment 121 | hm_after_nms = hm_after_nms.to(device_alignment) 122 | 123 | img_original = img_original.to(device_alignment) 124 | img_original = img_original.view((-1, *img_original.shape[-3:])) 125 | img_viewport = viewport_alignment(img_original, hm_after_nms[:, 0], hm_after_nms[:, 1]) 126 | del img_original 127 | img_viewport = img_viewport.to(device_main) 128 | 129 | ref_original = ref_original.to(device_alignment) 130 | ref_original = ref_original.view((-1, *ref_original.shape[-3:])) 131 | ref_viewport = viewport_alignment(ref_original, hm_after_nms[:, 0], hm_after_nms[:, 1]) 132 | del ref_original 133 | ref_viewport = ref_viewport.to(device_main) 134 | 135 | # VQ net 136 | vq_score, _ = vqnet(img_viewport, ref_viewport - img_viewport) 137 | vq_score = vq_score.flatten() 138 | vq_score = (vq_score * hm_weight).sum(dim=0, keepdim=True) 139 | 140 | pred.append(float(vq_score)) 141 | 142 | target = target.mean(dim=1).reshape((-1,)) 143 | targets.append(target.numpy()) 144 | 145 | pred = np.array(pred) 146 | targets = np.concatenate(targets, 0) 147 | video_cnt = len(test_set.cum_frame_num) 148 | pred = [pred[test_set.cum_frame_num_prev[i]:test_set.cum_frame_num[i]].mean() for i in range(video_cnt)] 149 | targets = [targets[test_set.cum_frame_num_prev[i]:test_set.cum_frame_num[i]].mean() for i in range(video_cnt)] 150 | np.savetxt(os.path.join(log_dir, 'test_pred_scores.txt'), np.array(pred)) 151 | np.savetxt(os.path.join(log_dir, 'test_targets.txt'), np.array(targets)) 152 | srocc, _ = scipy.stats.spearmanr(pred, targets) 153 | 154 | logger.info("SROCC:{:.4}".format(srocc)) 155 | 156 | 157 | if __name__ == "__main__": 158 | import argparse 159 | 160 | parser = argparse.ArgumentParser() 161 | 162 | parser.add_argument("--log_dir", type=str, required=True) 163 | parser.add_argument("--flownet_ckpt", type=str, required=True) 164 | parser.add_argument("--batch_size", type=int, default=1) 165 | parser.add_argument("--num_workers", type=int, default=4) 166 | parser.add_argument("--test_start_frame", type=int, default=21) 167 | parser.add_argument("--test_interval", type=int, default=45) 168 | 169 | args = parser.parse_args() 170 | main(**args.__dict__) 171 | -------------------------------------------------------------------------------- /dataset/dataset_VQA_ODV.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import numpy as np 4 | import torch.utils.data 5 | import scipy.ndimage.interpolation as interp 6 | import skimage.transform 7 | import warnings 8 | from utils import yuv 9 | 10 | 11 | class DownSample: 12 | def __init__(self, down_resolution): 13 | self.down_resolution = down_resolution 14 | 15 | def __call__(self, Y, U, V): 16 | half_resolution = [i / 2 for i in self.down_resolution] 17 | Y_d = skimage.transform.resize(Y.transpose((1, 2, 0)), self.down_resolution, order=1, anti_aliasing=True, 18 | mode='reflect', preserve_range=True) 19 | U_d = skimage.transform.resize(U.transpose((1, 2, 0)), half_resolution, order=1, anti_aliasing=True, 20 | mode='reflect', preserve_range=True) 21 | V_d = skimage.transform.resize(V.transpose((1, 2, 0)), half_resolution, order=1, anti_aliasing=True, 22 | mode='reflect', preserve_range=True) 23 | 24 | return Y_d.transpose((2, 0, 1)).round().astype(np.uint8), U_d.transpose((2, 0, 1)).round().astype(np.uint8), \ 25 | V_d.transpose((2, 0, 1)).round().astype(np.uint8) 26 | 27 | def __repr__(self): 28 | return self.__class__.__name__ + '(down_resolution={0})'.format(self.down_resolution) 29 | 30 | 31 | class SampleSGrid: 32 | def __init__(self, bandwidth): 33 | self.bandwidth = bandwidth 34 | self.euler_grid, _ = self.make_sgrid(bandwidth) 35 | 36 | def __call__(self, Y, U, V, resolution): 37 | height, width = resolution 38 | height_half, width_half = height // 2, width // 2 39 | theta, phi = self.euler_grid 40 | 41 | pix_height = theta[:, 0] / np.pi * height 42 | pix_width = phi[0, :] / (np.pi * 2) * width 43 | pix_width, pix_height = np.meshgrid(pix_width, pix_height) 44 | pix_height_half = theta[:, 0] / np.pi * height_half 45 | pix_width_half = phi[0, :] / (np.pi * 2) * width_half 46 | pix_width_half, pix_height_half = np.meshgrid(pix_width_half, pix_height_half) 47 | 48 | Y_im = interp.map_coordinates(Y[0, ...], [pix_height, pix_width], order=1) 49 | U_im = interp.map_coordinates(U[0, ...], [pix_height_half, pix_width_half], order=1) 50 | V_im = interp.map_coordinates(V[0, ...], [pix_height_half, pix_width_half], order=1) 51 | 52 | return Y_im[np.newaxis, ...], U_im[np.newaxis, ...], V_im[np.newaxis, ...] 53 | 54 | @staticmethod 55 | def make_sgrid(b): 56 | from lie_learn.spaces import S2 57 | 58 | theta, phi = S2.meshgrid(b=b, grid_type='SOFT') 59 | sgrid = S2.change_coordinates(np.c_[theta[..., None], phi[..., None]], p_from='S', p_to='C') 60 | sgrid = sgrid.reshape((-1, 3)) 61 | 62 | return (theta, phi), sgrid 63 | 64 | def __repr__(self): 65 | return self.__class__.__name__ + '(bandwidth={0})'.format(self.bandwidth) 66 | 67 | 68 | class VQA_ODV_Transform: 69 | def __init__(self, bandwidth, down_resolution, to_rgb=True): 70 | self.to_rgb = to_rgb 71 | self.sgrid_transform = SampleSGrid(bandwidth) 72 | self.down_transform = DownSample(down_resolution) 73 | 74 | def __call__(self, file_path, resolution, frame_index): 75 | ori = yuv.yuv_import(file_path, resolution, 1, frame_index) 76 | im = self.sgrid_transform(*ori, resolution) 77 | down = self.down_transform(*ori) 78 | if self.to_rgb: 79 | im = yuv.yuv2rgb(*im) 80 | ori = yuv.yuv2rgb(*ori) 81 | down = yuv.yuv2rgb(*down) 82 | return im, ori, down 83 | 84 | 85 | class DS_VQA_ODV(torch.utils.data.Dataset): 86 | 87 | def __init__(self, root, dataset_type, ds_list_file, transform, tr_te_file=None, flow_gap=2, test_start_frame=21, 88 | test_interval=45): 89 | """ 90 | VQA-ODV initialization. 91 | :param root: Directory of the dataset information files. 92 | :param dataset_type: Training set or test set. NOTE: ONLY TEST SET IS SUPPORTED AT PRESENT. 93 | :param ds_list_file: Name of the list file for all impaired sequences. 94 | :param transform: The class for transformation. Should be an instance of VQA_ODV_Transform. 95 | :param tr_te_file: Name of the file for splitting training and test scenes. 96 | :param flow_gap: Frame gap for optical flow extraction. 97 | :param test_start_frame: The start frame for each sequence in test (for dropping frames and saving time). 98 | :param test_interval: The interval for each sequence in test (for dropping frames and saving time). 99 | """ 100 | 101 | self.type = dataset_type 102 | if self.type not in ("train", "test"): 103 | raise ValueError("Invalid dataset") 104 | 105 | self.logger = logging.getLogger("{}.dataset".format(self.type)) 106 | self.root = os.path.expanduser(root) 107 | self.ds_list_file = ds_list_file 108 | 109 | assert isinstance(transform, VQA_ODV_Transform) 110 | self.transform = transform 111 | self.flow_gap = flow_gap 112 | self.test_start_frame = test_start_frame 113 | self.test_interval = test_interval 114 | if self.test_start_frame < self.flow_gap: 115 | warnings.warn( 116 | "The value of test_start_frame should not be less than flow_gap. Set test_start_frame equal to flow_gap.", 117 | Warning) 118 | self.test_start_frame = self.flow_gap 119 | if self.test_interval < 1: 120 | warnings.warn( 121 | "The value of test_interval should not be less than 1. Set test_interval equal to 1.", 122 | Warning) 123 | self.test_start_frame = 1 124 | 125 | self.scenes = list(range(60)) 126 | self.train_scenes, self.test_scenes = self.divide_tr_te_wrt_ref(self.scenes, 127 | tr_te_file=os.path.join(self.root, tr_te_file)) 128 | 129 | if self.type == 'train': 130 | self.data_dict = self.make_video_list(self.train_scenes) 131 | else: 132 | self.data_dict = self.make_video_list(self.test_scenes) 133 | self.frame_num_list = np.array(self.data_dict['frame_num_list'], dtype=np.float32) 134 | self.frame_num_list = np.ceil((self.frame_num_list - self.test_start_frame) / self.test_interval).astype( 135 | np.int) 136 | self.cum_frame_num = np.cumsum(self.frame_num_list) 137 | self.cum_frame_num_prev = np.zeros_like(self.cum_frame_num) 138 | self.cum_frame_num_prev[1:] = self.cum_frame_num[:-1] 139 | self.scores = self.data_dict['score_list'] 140 | 141 | self.files = self.data_dict['distort_yuv_list'] 142 | self.ref_files = self.data_dict['ref_yuv_list'] 143 | 144 | def __getitem__(self, index): 145 | if self.type == 'train': 146 | raise NotImplementedError 147 | else: 148 | video_index = np.searchsorted(self.cum_frame_num_prev, index, side='right') - 1 149 | frame_index = (index - self.cum_frame_num_prev[video_index]) * self.test_interval + self.test_start_frame 150 | 151 | video_path = self.files[video_index] 152 | ref_path = self.ref_files[video_index] 153 | resolution = self.data_dict['resolution_list'][video_index] 154 | 155 | img_gap, _, gap_down = self.transform(file_path=video_path, resolution=resolution, 156 | frame_index=frame_index - self.flow_gap) 157 | img, img_original, img_down = self.transform(file_path=video_path, resolution=resolution, 158 | frame_index=frame_index) 159 | _, ref, _ = self.transform(file_path=ref_path, resolution=resolution, frame_index=frame_index) 160 | 161 | target = np.array([self.scores[video_index]]) 162 | self.logger.debug('[DATA] {}, REF:{}, FRAME:{}, SCORE:{}'.format(video_path, ref_path, frame_index, target)) 163 | 164 | return img.astype(np.float32), img_original.astype(np.float32), img_down.astype(np.float32), \ 165 | img_gap.astype(np.float32), gap_down.astype(np.float32), ref.astype(np.float32), \ 166 | target.astype(np.float32) 167 | 168 | def __len__(self): 169 | if self.type == 'train': 170 | return len(self.files) 171 | else: 172 | return self.cum_frame_num[-1] 173 | 174 | def divide_tr_te_wrt_ref(self, scenes, train_size=0.8, tr_te_file=None): 175 | """ 176 | Divide data with respect to scenes. 177 | """ 178 | tr_te_file_loaded = False 179 | if tr_te_file is not None and os.path.isfile(tr_te_file): 180 | # Load tr_te_file and divide scenes accordingly 181 | tr_te_file_loaded = True 182 | with open(tr_te_file, 'r') as f: 183 | train_scenes = f.readline().strip().split() 184 | train_scenes = [int(elem) for elem in train_scenes] 185 | test_scenes = f.readline().strip().split() 186 | test_scenes = [int(elem) for elem in test_scenes] 187 | 188 | n_train_refs = len(train_scenes) 189 | n_test_refs = len(test_scenes) 190 | train_size = (len(train_scenes) / 191 | (len(train_scenes) + len(test_scenes))) 192 | else: 193 | # Divide scenes randomly 194 | # Get the numbers of training and testing scenes 195 | n_scenes = len(scenes) 196 | n_train_refs = int(np.ceil(n_scenes * train_size)) 197 | n_test_refs = n_scenes - n_train_refs 198 | 199 | # Randomly divide scenes 200 | rand_seq = np.random.permutation(n_scenes) 201 | scenes_sh = [scenes[elem] for elem in rand_seq] 202 | train_scenes = sorted(scenes_sh[:n_train_refs]) 203 | test_scenes = sorted(scenes_sh[n_train_refs:]) 204 | 205 | # Write train-test idx list into file 206 | if tr_te_file is not None: 207 | fpath, fname = os.path.split(tr_te_file) 208 | if not os.path.isdir(fpath): 209 | os.makedirs(fpath) 210 | with open(tr_te_file, 'w') as f: 211 | for idx in range(n_train_refs): 212 | f.write('%d ' % train_scenes[idx]) 213 | f.write('\n') 214 | for idx in range(n_scenes - n_train_refs): 215 | f.write('%d ' % test_scenes[idx]) 216 | f.write('\n') 217 | 218 | self.logger.debug( 219 | ' - Refs.: training = {} / testing = {} (Ratio = {:.2f})'.format(n_train_refs, n_test_refs, train_size, 220 | end='')) 221 | if tr_te_file_loaded: 222 | self.logger.debug(' (Loaded {})'.format(tr_te_file)) 223 | else: 224 | self.logger.debug('') 225 | 226 | return train_scenes, test_scenes 227 | 228 | def make_video_list(self, scenes, show_info=True): 229 | # Get reference / distorted image file lists: 230 | # distort_yuv_list and score_list 231 | distort_yuv_list, ref_yuv_list, ref_index_list, score_list, resolution_list, frame_num_list = [], [], [], [], [], [] 232 | with open(os.path.join(self.root, self.ds_list_file), 'r') as listFile: 233 | for line in listFile: 234 | # ref_idx ref_path dist_path DMOS width height frame_number 235 | scn_idx, _, ref, dis, score, width, height, frame_num = line.split() 236 | scn_idx = int(scn_idx) 237 | if scn_idx in scenes: 238 | distort_yuv_list.append(dis) 239 | ref_yuv_list.append(ref) 240 | ref_index_list.append(scn_idx) 241 | score_list.append(float(score)) 242 | resolution_list.append((int(height), int(width))) 243 | frame_num_list.append(int(frame_num)) 244 | 245 | score_list = np.array(score_list, dtype='float32') 246 | # DMOS -> reverse subjecive scores by default 247 | score_list = 1.0 - score_list 248 | n_videos = len(distort_yuv_list) 249 | 250 | if show_info: 251 | scenes.sort() 252 | self.logger.debug(' - Scenes: %s'.format(', '.join([str(i) for i in scenes]))) 253 | self.logger.debug(' - Number of videos: {:,}'.format(n_videos)) 254 | self.logger.debug( 255 | ' - DMOS range: [{:.2f}, {:.2f}]'.format(np.min(score_list), np.max(score_list))) # , end='') 256 | self.logger.debug(' (Scale reversed)') 257 | 258 | return { 259 | 'scenes': scenes, 260 | 'n_videos': n_videos, 261 | 'distort_yuv_list': distort_yuv_list, 262 | 'ref_yuv_list': ref_yuv_list, 263 | 'ref_index_list': ref_index_list, 264 | 'score_list': score_list, 265 | 'resolution_list': resolution_list, 266 | 'frame_num_list': frame_num_list 267 | } 268 | --------------------------------------------------------------------------------