├── utils ├── __init__.py ├── __pycache__ │ ├── utils.cpython-36.pyc │ └── __init__.cpython-36.pyc └── utils.py ├── fov_selection ├── 375.png ├── SphereDist.m ├── erp2sph.m ├── demo.m ├── im2fov.m ├── getcoords3.m ├── select_points.m ├── cut_patch.m └── readme ├── datasets ├── __pycache__ │ ├── oiqa_gl.cpython-36.pyc │ └── cviqd_gl.cpython-36.pyc └── cviqd_gl.py ├── model ├── __pycache__ │ └── final_model.cpython-36.pyc └── final_model.py ├── LICENSE ├── README.md └── main.py /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import * 2 | -------------------------------------------------------------------------------- /fov_selection/375.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weizhou-geek/VGCN-PyTorch/HEAD/fov_selection/375.png -------------------------------------------------------------------------------- /fov_selection/SphereDist.m: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weizhou-geek/VGCN-PyTorch/HEAD/fov_selection/SphereDist.m -------------------------------------------------------------------------------- /utils/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weizhou-geek/VGCN-PyTorch/HEAD/utils/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weizhou-geek/VGCN-PyTorch/HEAD/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/oiqa_gl.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weizhou-geek/VGCN-PyTorch/HEAD/datasets/__pycache__/oiqa_gl.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/cviqd_gl.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weizhou-geek/VGCN-PyTorch/HEAD/datasets/__pycache__/cviqd_gl.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/final_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weizhou-geek/VGCN-PyTorch/HEAD/model/__pycache__/final_model.cpython-36.pyc -------------------------------------------------------------------------------- /fov_selection/erp2sph.m: -------------------------------------------------------------------------------- 1 | function [phi theta] = erp2sph(m,n,W,H) 2 | u = (m+0.5)/W; 3 | v = (n+0.5)/H; 4 | phi = (u-0.5)*2*pi; 5 | theta = (0.5 - v)*pi; -------------------------------------------------------------------------------- /fov_selection/demo.m: -------------------------------------------------------------------------------- 1 | clear all; 2 | clc; 3 | 4 | img_dis_rgb=imread('375.png'); 5 | [phi theta]=select_points(img_dis_rgb); 6 | spoint_radian = [phi' theta']; 7 | img_dis_rgb=imresize(img_dis_rgb,[512 1024]); 8 | im2fov(img_dis_rgb,spoint_radian,'375'); -------------------------------------------------------------------------------- /fov_selection/im2fov.m: -------------------------------------------------------------------------------- 1 | function im2fov(img,spoint,i) 2 | [M,~,~]=size(img); 3 | fov_size=round(M/2); 4 | parfor k=1:length(spoint) 5 | img_fov=cut_patch(img,spoint(k,1),spoint(k,2),fov_size); 6 | imwrite(uint8(img_fov),[i,'_fov',num2str(k),'.png']); 7 | end -------------------------------------------------------------------------------- /fov_selection/getcoords3.m: -------------------------------------------------------------------------------- 1 | function coords = getcoords3(lon,lat,anglex,angley) 2 | faceSizex = 2*tan(anglex/2); 3 | faceSizey = 2*tan(angley/2); 4 | x = cos(lat)*cos(lon); 5 | y = cos(lat)*sin(lon); 6 | z = sin(lat); 7 | %thetapoint=[a;b]; 8 | point=[x;y;z]; 9 | %tangentvector = [-b;a]; 10 | 11 | vector1 = [-sin(lon);cos(lon);0]; 12 | vector2 = [sin(lat)*cos(lon);sin(lat)*sin(lon);-cos(lat)]; 13 | coords = zeros(3,4); 14 | coords(:,1)=point -vector1*faceSizex/2-vector2*faceSizey/2; 15 | coords(:,2)=point +vector1*faceSizex/2-vector2*faceSizey/2; 16 | coords(:,3)=point -vector1*faceSizex/2+vector2*faceSizey/2; 17 | coords(:,4)=point +vector1*faceSizex/2+vector2*faceSizey/2; 18 | end -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Wei Zhou 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import torch.nn.functional as F 5 | from functools import partial 6 | import pickle 7 | 8 | 9 | def save_model(model, checkpoint, num, is_epoch=True): 10 | if not os.path.exists(checkpoint): 11 | os.system('mkdir -p '+ checkpoint) 12 | if is_epoch: 13 | torch.save(model.state_dict(), os.path.join(checkpoint, 'epoch_{:04d}.pth'.format(num))) 14 | else: 15 | torch.save(model.state_dict(), os.path.join(checkpoint, 'iteration_{:09d}.pth'.format(num))) 16 | 17 | def load_model(model, resume): 18 | # pickle.load = partial(pickle.load, encoding="latin1") 19 | # pickle.Unpickler = partial(pickle.Unpickler, encoding="latin1") 20 | # # model_dict = torch.load(resume, map_location=lambda storage, loc: storage, pickle_module=pickle)['model'] 21 | # model_dict = torch.load(resume, map_location=lambda storage, loc: storage, pickle_module=pickle) 22 | pretrained_dict = torch.load(resume) 23 | model_dict = model.state_dict() 24 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 25 | model_dict.update(pretrained_dict) 26 | # model.load_state_dict({k.replace('module.', ""): v for k, v in torch.load(model_dict).items()}) 27 | model.load_state_dict(model_dict) 28 | 29 | 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VGCN-PyTorch 2 | 3 | Thanks for your attention. In this repo, we provide the codes for the paper [[Blind Omnidirectional Image Quality Assessment with Viewport Oriented Graph Convolutional Networks]](https://ieeexplore.ieee.org/document/9163077). 4 | 5 | ## Prerequisites 6 | + scipy==1.2.1 7 | + opencv_python==4.1.0.25 8 | + numpy==1.16.4 9 | + torchvision==0.3.0 10 | + torch==1.1.0 11 | + Pillow==6.2.0 12 | 13 | ## Install 14 | To install all the dependencies in prerequisites 15 | 16 | ## Prepare Data 17 | + Obtain [cviqd_local_epoch.pth](https://drive.google.com/file/d/1ROT4InmAEKUisfNbMHwWpWb0nvlDhoSe/view?usp=sharing), [cviqd_global_epoch.pth](https://drive.google.com/file/d/1ggxGi2uvmL3n0BtYLC-HCrWbhna2TkFQ/view?usp=sharing), and [cviqd_model.pth](https://drive.google.com/file/d/19WJHBkogveax0b3IgpWeRco5xXgKQvFl/view?usp=sharing) 18 | + Download [database](https://drive.google.com/drive/folders/1LqQFIms_46s7uybos83-5EgMAH2r6OCy?usp=sharing) 19 | 20 | ## FoV Selection 21 | ``` 22 | matlab fov_selection/demo.m 23 | ``` 24 | 25 | ## Training 26 | ``` 27 | python main.py --root1 cviqd_local_epoch.pth --root2 cviqd_global_epoch.pth --save test 28 | ``` 29 | 30 | ## Testing 31 | ``` 32 | python main.py --resume cviqd_model.pth --skip_training 33 | ``` 34 | 35 | ## Citation 36 | You may cite it in your paper. Thanks a lot. 37 | 38 | ``` 39 | @article{xu2020blind, 40 | title={Blind omnidirectional image quality assessment with viewport oriented graph convolutional networks}, 41 | author={Xu, Jiahua and Zhou, Wei and Chen, Zhibo}, 42 | journal={IEEE Transactions on Circuits and Systems for Video Technology}, 43 | year={2020}, 44 | publisher={IEEE} 45 | } 46 | ``` 47 | 48 | 49 | -------------------------------------------------------------------------------- /fov_selection/select_points.m: -------------------------------------------------------------------------------- 1 | function [phi theta]=select_points(img) 2 | I = rgb2gray(img); 3 | img = double(I); 4 | 5 | %%% auto downsampling %%% 6 | [M,N]=size(I); 7 | f = max(1,round(min(M,N)/256)); 8 | if(f>1) 9 | lpf = ones(f,f); 10 | lpf = lpf/sum(lpf(:)); 11 | img = imfilter(img,lpf,'symmetric','same'); 12 | img = img(1:f:end,1:f:end); 13 | end 14 | 15 | %%% detect keypoints with padding %%% 16 | img = [img(:,end-34:end,:) img(:,:,:) img(:,1:35,:)]; 17 | points = detectSURFFeatures(uint8(img)); 18 | 19 | %%% point map %%% 20 | [m,n] = size(img); 21 | point_map = zeros(m,n); 22 | for i=1:length(points) 23 | point_map(round(points.Location(i,2)'),round(points.Location(i,1)'))=1; 24 | end 25 | 26 | %%% filter %%% 27 | sigma = 10; 28 | gausFilter = fspecial('gaussian', [71,71], sigma); 29 | gaus_point = filter2(gausFilter, point_map, 'same'); 30 | gaus_point = gaus_point(:,36:end-35); 31 | scale=255/max(max(gaus_point)); 32 | gaus_point = gaus_point*scale; 33 | 34 | %%% select 20 point %%% 35 | [H,W]=size(gaus_point); 36 | distance = SphereDist([0;0],[pi/6;0]); 37 | [M(1),idx(1)] = max(gaus_point(:)); 38 | [I_row(1), I_col(1)] = ind2sub(size(gaus_point),idx(1)); 39 | [phi(1) theta(1)] = erp2sph(I_col(1),I_row(1),W,H); 40 | gaus_point(I_row(1), I_col(1))=0; 41 | i=2; 42 | while i<21 43 | [max_num,idx_num] = max(gaus_point(:)); 44 | [idx_row, idx_col] = ind2sub(size(gaus_point),idx_num); 45 | [idx_phi, idx_theta] = erp2sph(idx_col(1),idx_row(1),W,H); 46 | gaus_point(idx_row, idx_col)=0; 47 | dist = min(SphereDist([idx_phi;idx_theta],[phi;theta])); 48 | % dist = min(sqrt((idx_row-I_row).^2 + (idx_col-I_col).^2)); 49 | if dist > distance 50 | I_row(i) = idx_row; 51 | I_col(i) = idx_col; 52 | phi(i) = idx_phi; 53 | theta(i) = idx_theta; 54 | i = i + 1; 55 | end 56 | end 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | -------------------------------------------------------------------------------- /fov_selection/cut_patch.m: -------------------------------------------------------------------------------- 1 | function img_fov=cut_patch(a,longitude,latitude,fovsize) 2 | 3 | [sizeiny,sizeinx,pixelz]=size(a); 4 | sizeoutx=fovsize; 5 | sizeouty=fovsize; 6 | 7 | anglex=(180*fovsize/sizeiny)*pi/180; 8 | angley=(180*fovsize/sizeiny)*pi/180; %angley=pi-angley',angley'=2*atan(sizeinx/sizeiny*tan(anglex/2)) 9 | faceSizex = 2*tan(anglex/2); 10 | faceSizey = 2*tan(angley/2); 11 | img_fov=zeros(sizeouty,sizeoutx,3); 12 | 13 | 14 | lat=latitude; 15 | lon=longitude; 16 | coords = getcoords3(lon,lat,anglex,angley); 17 | for ii=1:sizeoutx 18 | for jj=1:sizeouty 19 | c = 1.0 * ii / sizeoutx; 20 | d = 1.0 * jj / sizeouty; 21 | 22 | x = (1-c)*(1-d)*coords(1,1)+c*(1-d)*coords(1,2)+(1-c)*d*coords(1,3)+c*d*coords(1,4); 23 | y = (1-c)*(1-d)*coords(2,1)+c*(1-d)*coords(2,2)+(1-c)*d*coords(2,3)+c*d*coords(2,4); 24 | z = (1-c)*(1-d)*coords(3,1)+c*(1-d)*coords(3,2)+(1-c)*d*coords(3,3)+c*d*coords(3,4); 25 | 26 | r = sqrt(x^2+y^2+z^2); 27 | 28 | theta=asin(z/r); 29 | if(x<0&&y<=0) 30 | phi=atan(y/x)-pi; 31 | elseif(x<0&&y>0) 32 | phi=atan(y/x)+pi; 33 | else 34 | phi=atan(y/x); 35 | end 36 | theta=(pi/2-theta)*sizeiny/pi; 37 | phi=(phi+pi)*sizeinx/2/pi; 38 | thetaf=floor(theta); 39 | phif=floor(phi); 40 | p=theta-thetaf; 41 | q=phi-phif; 42 | if thetaf==0 43 | thetaf=1; 44 | p=0; 45 | end 46 | if thetaf>=sizeiny 47 | thetaf=sizeiny; 48 | img_fov(jj,ii,:)=(1-q)*a(thetaf,mod(phif-1,sizeinx)+1,:)+q*a(thetaf,mod(phif,sizeinx)+1,:); 49 | else 50 | img_fov(jj,ii,:)=(1-p)*(1-q)*a(thetaf,mod(phif-1,sizeinx)+1,:)+(1-p)*q*a(thetaf,mod(phif,sizeinx)+1,:)+p*(1-q)*a(thetaf+1,mod(phif-1,sizeinx)+1,:)+p*q*a(thetaf+1,mod(phif,sizeinx)+1,:); 51 | end 52 | end 53 | end 54 | 55 | 56 | 57 | 58 | 59 | -------------------------------------------------------------------------------- /fov_selection/readme: -------------------------------------------------------------------------------- 1 | PAD-Net Software release. 2 | ======================================================================= 3 | COPYRIGHT NOTICE STARTS WITH THIS LINE------------ Copyright (c) 2021 University of Science and Technology of China All rights reserved. 4 | 5 | Permission is hereby granted, without written agreement and without license or royalty fees, to use, copy, modify, and distribute this code (the source files) and its documentation for any purpose, provided that the copyright notice in its entirety appear in all copies of this code, and the original source of this code, Immersive Media Computing Lab (IMCL) at University of Science and Technology of China (USTC), is acknowledged in any publication that reports research using this code. The research is to be cited in the bibliography as: 6 | 7 | 1. Jiahua Xu, Wei Zhou, Zhibo Chen*, "Blind Omnidirectional Image Quality Assessment with Viewport Oriented Graph Convolutional Networks." 8 | 9 | IN NO EVENT SHALL UNIVERSITY OF SCIENCE AND TECHNOLOGY OF CHINA BE LIABLE TO ANY PARTY FOR DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE USE OF THIS DATABASE AND ITS DOCUMENTATION, EVEN IF UNIVERSITY OF SCIENCE AND TECHNOLOGY OF CHINA HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 10 | 11 | UNIVERSITY OF SCIENCE AND TECHNOLOGY OF CHINA SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE DATABASE PROVIDED HEREUNDER IS ON AN "AS IS" BASIS, AND UNIVERSITY OF SCIENCE AND TECHNOLOGY OF CHINA HAS NO OBLIGATION TO PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS. 12 | 13 | -----------COPYRIGHT NOTICE ENDS WITH THIS LINE------------% 14 | ======================================================================= 15 | 16 | Author : Jiahua Xu Version : 1.0 17 | 18 | Kindly report any suggestions or corrections to xujiahua@mail.ustc.edu.cn 19 | 20 | ======================================================================= 21 | 22 | This is a demonstration of the stage for viewport selection of Viewport Oriented Graph Convolutional Network (VGCN). The algorithm is described in: 23 | 24 | Jiahua Xu, Wei Zhou, Zhibo Chen*, "Blind Omnidirectional Image Quality Assessment with Viewport Oriented Graph Convolutional Networks." Early Access, IEEE Trans. on Circuits and Systems for Video Technology, 2020. 25 | 26 | You can change this program as you like and use it anywhere, but please refer to its original source (cite our paper and our web page at http://staff.ustc.edu.cn/~chenzhibo/resources.html, 2021). 27 | 28 | ======================================================================== 29 | 30 | Run demo.m 31 | -------------------------------------------------------------------------------- /datasets/cviqd_gl.py: -------------------------------------------------------------------------------- 1 | import os 2 | import six 3 | import torch 4 | import random 5 | import numbers 6 | import numpy as np 7 | import cv2 8 | import math 9 | import scipy.io as scio 10 | 11 | import torch.utils.data as data 12 | from torchvision import transforms 13 | import torchvision.transforms.functional as tF 14 | 15 | import sys 16 | if sys.version_info[0] == 2: 17 | import cPickle as pickle 18 | else: 19 | import pickle 20 | 21 | 22 | def get_dataset(is_training): 23 | img_path_list = [] 24 | img_name_list = [] 25 | datasets_list = [] 26 | sets_path = get_setspath(is_training) 27 | print(sets_path) 28 | labels_path = get_labelspath(is_training) 29 | wholeimg_path = get_imgpath(is_training) 30 | transform = get_transform() 31 | for set_path in sets_path: 32 | subset_names = os.listdir(set_path) 33 | for subset_name in subset_names: 34 | subset_path = os.path.join(set_path, subset_name) 35 | img_name_list.append(subset_name) 36 | img_path_list.append(subset_path) 37 | 38 | datasets_list.append( 39 | ImgDataset( 40 | img_path=img_path_list, 41 | img_name=img_name_list, 42 | transform=transform, 43 | is_training=is_training, 44 | label_path=labels_path[0], 45 | wholeimg_path=wholeimg_path[0] 46 | ) 47 | ) 48 | return data.ConcatDataset(datasets_list) 49 | 50 | 51 | def get_setspath(is_training): 52 | sets_root = './database/' 53 | if is_training: 54 | sets = [ 55 | 'cviqd_resize_imgtrain' 56 | ] 57 | else: 58 | sets = [ 59 | 'cviqd_resize_imgtest' 60 | ] 61 | return [os.path.join(sets_root, set) for set in sets] 62 | 63 | def get_imgpath(is_training): 64 | sets_root = './database/' 65 | if is_training: 66 | sets = [ 67 | 'cviqd_all_imgtrain' 68 | ] 69 | else: 70 | sets = [ 71 | 'cviqd_all_imgtest' 72 | ] 73 | return [os.path.join(sets_root, set) for set in sets] 74 | 75 | def get_labelspath(is_training): 76 | sets_root = './database/' 77 | if is_training: 78 | sets = [ 79 | 'cviqd_fovall_label' 80 | ] 81 | else: 82 | sets = [ 83 | 'cviqd_fovall_label' 84 | ] 85 | return [os.path.join(sets_root, set) for set in sets] 86 | 87 | 88 | def get_transform(): 89 | return transforms.Compose([ 90 | transforms.ToTensor() 91 | ]) 92 | 93 | 94 | class ImgDataset(data.Dataset): 95 | def __init__(self, img_path, img_name, transform, is_training, label_path, wholeimg_path, shuffle=False): 96 | self.img_path = img_path 97 | self.img_name = img_name 98 | self.nSamples = len(self.img_path) 99 | self.indices = range(self.nSamples) 100 | if shuffle: 101 | random.shuffle(self.indices) 102 | self.transform = transform 103 | self.is_training = is_training 104 | self.label_path = label_path 105 | self.wholeimg_path = wholeimg_path 106 | 107 | def __getitem__(self, index): 108 | imgpath = self.img_path[index] 109 | imagename = self.img_name[index] 110 | img_group = [] 111 | sub_names = os.listdir(imgpath) 112 | for sub_name in sub_names: 113 | subimg_path = os.path.join(imgpath, sub_name) 114 | img = cv2.imread(subimg_path) 115 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 116 | img = np.transpose(img, (2, 0, 1)) 117 | img_group.append(img) 118 | img_group = np.array(img_group) 119 | 120 | labelname = imagename + '.mat' 121 | labelname = os.path.join(self.label_path, labelname) 122 | label_content = scio.loadmat(labelname) 123 | label = label_content['score'] 124 | label = label[0] 125 | 126 | wholeimgname = imagename + '.png' 127 | wimgname = os.path.join(self.wholeimg_path, wholeimgname) 128 | wimg = cv2.imread(wimgname) 129 | wimg = cv2.resize(wimg, (512, 1024), interpolation=cv2.INTER_CUBIC) 130 | wimg = cv2.cvtColor(wimg, cv2.COLOR_BGR2RGB) 131 | wimg = np.transpose(wimg, (2, 0, 1)) 132 | A = label_content['A'] 133 | 134 | data = torch.from_numpy(img_group).float() 135 | label = torch.from_numpy(label).float() 136 | wimg = torch.from_numpy(wimg).float() 137 | A = torch.from_numpy(A).float() 138 | 139 | return data, label, imagename, A, wimg 140 | 141 | def __len__(self): 142 | return self.nSamples 143 | 144 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import argparse 4 | import torch 5 | import math 6 | import numpy as np 7 | import cv2 8 | import torch.optim as optim 9 | import torch.optim.lr_scheduler as LS 10 | from torch.autograd import Variable 11 | from torchvision import models 12 | import scipy.io as scio 13 | from scipy import stats 14 | import torch.nn as nn 15 | from torchvision import models 16 | import random 17 | 18 | import utils 19 | from datasets.cviqd_gl import get_dataset 20 | from model.final_model import VGCN 21 | 22 | os.environ['CUDA_VISIBLE_DEVICES'] = "7" 23 | # Training settings 24 | parser = argparse.ArgumentParser(description='VR Image Quality Assessment') 25 | parser.add_argument('--start_epoch', type=int, default=1) 26 | parser.add_argument('--total_epochs', type=int, default=20) 27 | parser.add_argument('--total_iterations', type=int, default=10000) 28 | parser.add_argument('--batch_size', '-b', type=int, default=12, help="Batch size") 29 | parser.add_argument('--lr', type=float, default=1e-2, metavar=' LR', help='learning rate (default: 0.01)') 30 | parser.add_argument('--number_workers', '-nw', '--num_workers', type=int, default=4) 31 | parser.add_argument('--save', '-s', default='work', type=str, help='directory for saving') 32 | parser.add_argument('--skip_training', default=False, action='store_true') 33 | parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)') 34 | parser.add_argument('--root1', default='', type=str, metavar='PATH', help='path to pretrained local branch') 35 | parser.add_argument('--root2', default='', type=str, metavar='PATH', help='path to pretrained global branch') 36 | 37 | main_dir = os.path.dirname(os.path.realpath(__file__)) 38 | os.chdir(main_dir) 39 | 40 | args = parser.parse_args() 41 | 42 | # seed = [random.randint(0, 10000) for _ in range(4)] 43 | seed = [7021, 9042, 9042, 8264] 44 | torch.manual_seed(seed[0]) 45 | torch.cuda.manual_seed_all(seed[1]) 46 | np.random.seed(seed[2]) 47 | random.seed(seed[3]) 48 | # torch.backends.cudnn.benchmark = False 49 | # torch.backends.cudnn.deterministic = True 50 | # print(seed) 51 | 52 | 53 | kwargs = {'num_workers': args.number_workers} 54 | if not args.skip_training: 55 | train_set = get_dataset(is_training=True) 56 | train_loader = torch.utils.data.DataLoader(dataset=train_set, batch_size=args.batch_size, shuffle=True, **kwargs) 57 | 58 | test_set = get_dataset(is_training=False) 59 | test_loader = torch.utils.data.DataLoader(dataset=test_set, batch_size=1, shuffle=False, **kwargs) 60 | 61 | model = VGCN(root1=args.root1, root2=args.root2).cuda() 62 | 63 | 64 | OIQA_params = list(map(id, model.OIQA_branch.parameters())) 65 | DBCNN_params = list(map(id, model.DBCNN_branch.parameters())) 66 | base_params = filter(lambda p: id(p) not in OIQA_params + DBCNN_params, model.parameters()) 67 | optimizer = optim.Adam([ 68 | {'params': base_params}, 69 | {'params': model.OIQA_branch.parameters(), 'lr': 1e-5}, 70 | {'params': model.DBCNN_branch.parameters(), 'lr': 1e-5}], lr=args.lr) 71 | 72 | # scheduler = LS.MultiStepLR(optimizer, milestones=[10, 30, 60], gamma=0.1) 73 | 74 | 75 | def train(epoch, iteration): 76 | model.train() 77 | # scheduler.step() 78 | end = time.time() 79 | log = [0 for _ in range(1)] 80 | for batch_idx, batch in enumerate(train_loader): 81 | data, label, _, A, wimg = batch 82 | data = Variable(data.cuda()) 83 | label = Variable(label.cuda()) 84 | A = Variable(A.cuda()) 85 | wimg = Variable(wimg.cuda()) 86 | optimizer.zero_grad() 87 | _, _, batch_info = model(data, wimg, label, A, requires_loss=True) 88 | batch_info.backward() 89 | optimizer.step() 90 | # print(batch_info) 91 | 92 | log = [log[i] + batch_info.item() * len(data) for i in range(1)] 93 | iteration += 1 94 | 95 | log = [log[i] / len(train_loader.dataset) for i in range(1)] 96 | epoch_time = time.time() - end 97 | end = time.time() 98 | print('Train Epoch: {}, Loss: {:.6f}'.format(epoch, log[0])) 99 | print('LogTime: {:.4f}s'.format(epoch_time)) 100 | return log 101 | 102 | 103 | def eval(): 104 | 105 | model.eval() 106 | log = 0 107 | score_list = [] 108 | label_list = [] 109 | name_list = [] 110 | 111 | for batch_idx, batch in enumerate(test_loader): 112 | data, label, imgname, A, wimg = batch 113 | data = Variable(data.cuda()) 114 | label = Variable(label.cuda()) 115 | A = Variable(A.cuda()) 116 | wimg = Variable(wimg.cuda()) 117 | 118 | score, label = model(data, wimg, label, A, requires_loss=False) 119 | 120 | score = score.cpu().detach().numpy() 121 | label = label.cpu().detach().numpy() 122 | res = (score - label)*(score - label) 123 | score_list.append(score) 124 | label_list.append(label) 125 | name_list.append(imgname[0]) 126 | 127 | ## release memory 128 | torch.cuda.empty_cache() 129 | 130 | log += res 131 | 132 | log = log / len(test_loader) 133 | 134 | print('Average LOSS: %.2f' % (log)) 135 | score_list = np.reshape(np.asarray(score_list), (-1,)) 136 | label_list = np.reshape(np.asarray(label_list), (-1,)) 137 | name_list = np.reshape(np.asarray(name_list), (-1,)) 138 | scio.savemat('cviqd_VGCN.mat', {'score': score_list, 'label': label_list, 'name': name_list}) 139 | srocc = stats.spearmanr(label_list, score_list)[0] 140 | plcc = stats.pearsonr(label_list, score_list)[0] 141 | print('SROCC: %.4f, PLCC: %.4f\n' % (srocc, plcc)) 142 | return srocc, plcc 143 | 144 | 145 | if not args.skip_training: 146 | if args.resume: 147 | utils.load_model(model, args.resume) 148 | print('Train Load pre-trained model!') 149 | best = 0 150 | for epoch in range(args.start_epoch, args.total_epochs+1): 151 | iteration = (epoch-1) * len(train_loader) + 1 152 | log = train(epoch, iteration) 153 | log2 = eval() 154 | 155 | srocc = log2[0] 156 | plcc = log2[1] 157 | current_cc = srocc + plcc 158 | if current_cc > best: 159 | best = current_cc 160 | checkpoint = os.path.join(args.save, 'checkpoint') 161 | utils.save_model(model, checkpoint, epoch, is_epoch=True) 162 | 163 | else: 164 | print('Test Load pre-trained model!') 165 | utils.load_model(model, args.resume) 166 | eval() 167 | -------------------------------------------------------------------------------- /model/final_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | from collections import OrderedDict 7 | from torch.nn import Parameter 8 | import math 9 | from torchvision import models 10 | 11 | class GraphConvolution(nn.Module): 12 | """ 13 | Simple GCN layer, similar to https://arxiv.org/abs/1609.02907 14 | """ 15 | 16 | def __init__(self, in_features, out_features, bias=False): 17 | super(GraphConvolution, self).__init__() 18 | self.in_features = in_features 19 | self.out_features = out_features 20 | self.weight = Parameter(torch.Tensor(in_features, out_features)) 21 | if bias: 22 | self.bias = Parameter(torch.Tensor(1, 1, out_features)) 23 | else: 24 | self.register_parameter('bias', None) 25 | self.reset_parameters() 26 | 27 | def reset_parameters(self): 28 | nn.init.xavier_normal_(self.weight.data) 29 | if self.bias is not None: 30 | init.constant_(self.bias.data, 0.1) 31 | 32 | def forward(self, input, adj): 33 | support = torch.matmul(input, self.weight) 34 | output = torch.matmul(adj, support) 35 | if self.bias is not None: 36 | return output + self.bias 37 | else: 38 | return output 39 | 40 | def __repr__(self): 41 | return self.__class__.__name__ + ' (' \ 42 | + str(self.in_features) + ' -> ' \ 43 | + str(self.out_features) + ')' 44 | 45 | class GCNNet(nn.Module): 46 | def __init__(self): 47 | super(GCNNet, self).__init__() 48 | 49 | self.gc1 = GraphConvolution(512, 256) 50 | self.bn1 = nn.BatchNorm1d(20, eps=1e-05, momentum=0.1, affine=True) 51 | self.gc2 = GraphConvolution(256, 128) 52 | self.bn2 = nn.BatchNorm1d(20, eps=1e-05, momentum=0.1, affine=True) 53 | self.gc3 = GraphConvolution(128, 64) 54 | self.bn3 = nn.BatchNorm1d(20, eps=1e-05, momentum=0.1, affine=True) 55 | self.gc4 = GraphConvolution(64, 32) 56 | self.bn4 = nn.BatchNorm1d(20, eps=1e-05, momentum=0.1, affine=True) 57 | self.gc5 = GraphConvolution(32, 1) 58 | self.relu = nn.Softplus() 59 | 60 | def para_init(self): 61 | for m in self.modules(): 62 | if isinstance(m, nn.BatchNorm1d): 63 | init.constant(m.weight, 1) 64 | init.constant(m.bias, 0) 65 | 66 | def norm_adj(self, matrix): 67 | D = torch.diag_embed(matrix.sum(2)) 68 | D = D ** 0.5 69 | D = D.inverse() 70 | # D(-1/2) * A * D(-1/2) 71 | normal = D.bmm(matrix).bmm(D) 72 | return normal.detach() 73 | 74 | def forward(self, feature, A): 75 | adj = self.norm_adj(A) 76 | gc1 = self.gc1(feature, adj) 77 | gc1 = self.bn1(gc1) 78 | gc1 = self.relu(gc1) 79 | gc2 = self.gc2(gc1, adj) 80 | gc2 = self.bn2(gc2) 81 | gc2 = self.relu(gc2) 82 | gc3 = self.gc3(gc2, adj) 83 | gc3 = self.bn3(gc3) 84 | gc3 = self.relu(gc3) 85 | gc4 = self.gc4(gc3, adj) 86 | gc4 = self.bn4(gc4) 87 | gc4 = self.relu(gc4) 88 | gc5 = self.gc5(gc4, adj) 89 | gc5 = self.relu(gc5) 90 | return gc5 91 | 92 | class OIQANet(nn.Module): 93 | def __init__(self, model): 94 | super(OIQANet, self).__init__() 95 | 96 | self.resnet = nn.Sequential(*list(model.children())[:-2]) 97 | self.maxpool = nn.MaxPool2d(8) 98 | self.GCN = GCNNet() 99 | self.fc = nn.Linear(20, 1) 100 | 101 | def para_init(self): 102 | for m in self.modules(): 103 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 104 | nn.init.xavier_normal_(m.weight.data) 105 | if m.bias is not None: 106 | m.bias.data.zero_() 107 | 108 | def loss_build(self, x_hat, x): 109 | distortion = F.mse_loss(x_hat, x, size_average=True) 110 | return distortion 111 | 112 | def forward(self, x, label, A, requires_loss): 113 | batch_size = x.size(0) 114 | y = x.view(-1, 3, 256, 256) 115 | all_feature = self.resnet(y) 116 | all_feature = self.maxpool(all_feature) 117 | feature = all_feature.view(batch_size, 20, -1, 1, 1) 118 | feature = feature.squeeze(3) 119 | feature = feature.squeeze(3) 120 | 121 | gc5 = self.GCN(feature, A) 122 | fc_in = gc5.view(gc5.size()[0], -1) 123 | score = torch.mean(fc_in, dim=1).unsqueeze(1) 124 | 125 | if requires_loss: 126 | return score, label, self.loss_build(score, label) 127 | else: 128 | return score 129 | 130 | 131 | def weight_init(net): 132 | for m in net.modules(): 133 | if isinstance(m, nn.Conv2d): 134 | nn.init.kaiming_normal_(m.weight.data,nonlinearity='relu') 135 | m.bias.data.zero_() 136 | elif isinstance(m, nn.Linear): 137 | nn.init.kaiming_normal_(m.weight.data,nonlinearity='relu') 138 | m.bias.data.zero_() 139 | elif isinstance(m, nn.BatchNorm2d): 140 | m.weight.data.fill_(1) 141 | m.bias.data.zero_() 142 | 143 | class SCNN(nn.Module): 144 | 145 | def __init__(self): 146 | """Declare all needed layers.""" 147 | super(SCNN, self).__init__() 148 | 149 | # Linear classifier. 150 | 151 | self.num_class = 39 152 | 153 | self.features = nn.Sequential(nn.Conv2d(3,48,3,1,1),nn.BatchNorm2d(48),nn.ReLU(inplace=True), 154 | nn.Conv2d(48,48,3,2,1),nn.BatchNorm2d(48),nn.ReLU(inplace=True), 155 | nn.Conv2d(48,64,3,1,1),nn.BatchNorm2d(64),nn.ReLU(inplace=True), 156 | nn.Conv2d(64,64,3,2,1),nn.BatchNorm2d(64),nn.ReLU(inplace=True), 157 | nn.Conv2d(64,64,3,1,1),nn.BatchNorm2d(64),nn.ReLU(inplace=True), 158 | nn.Conv2d(64,64,3,2,1),nn.BatchNorm2d(64),nn.ReLU(inplace=True), 159 | nn.Conv2d(64,128,3,1,1),nn.BatchNorm2d(128),nn.ReLU(inplace=True), 160 | nn.Conv2d(128,128,3,1,1),nn.BatchNorm2d(128),nn.ReLU(inplace=True), 161 | nn.Conv2d(128,128,3,2,1),nn.BatchNorm2d(128),nn.ReLU(inplace=True)) 162 | weight_init(self.features) 163 | self.pooling = nn.AvgPool2d(14,1) 164 | self.projection = nn.Sequential(nn.Conv2d(128,256,1,1,0), nn.BatchNorm2d(256), nn.ReLU(inplace=True), 165 | nn.Conv2d(256,256,1,1,0), nn.BatchNorm2d(256), nn.ReLU(inplace=True)) 166 | weight_init(self.projection) 167 | self.classifier = nn.Linear(256,self.num_class) 168 | weight_init(self.classifier) 169 | 170 | def forward(self, X): 171 | # return X 172 | N = X.size()[0] 173 | assert X.size() == (N, 3, 224, 224) 174 | X = self.features(X) 175 | assert X.size() == (N, 128, 14, 14) 176 | X = self.pooling(X) 177 | assert X.size() == (N, 128, 1, 1) 178 | X = self.projection(X) 179 | X = X.view(X.size(0), -1) 180 | X = self.classifier(X) 181 | assert X.size() == (N, self.num_class) 182 | return X 183 | 184 | class DBCNN(torch.nn.Module): 185 | 186 | def __init__(self, options): 187 | """Declare all needed layers.""" 188 | nn.Module.__init__(self) 189 | # Convolution and pooling layers of VGG-16. 190 | # self.features1 = torchvision.models.vgg16(pretrained=True).features 191 | self.features1 = models.vgg16(pretrained=False).features 192 | self.features1 = nn.Sequential(*list(self.features1.children())[:-1]) 193 | scnn = SCNN() 194 | # scnn = torch.nn.DataParallel(scnn).cuda() 195 | # scnn.load_state_dict({k.replace('module.', ""): v for k, v in torch.load(scnn_root).items()}) 196 | # scnn.load_state_dict(torch.load(scnn_root)) 197 | # print('load scnn model!') 198 | self.features2 = scnn.features 199 | 200 | # Linear classifier. 201 | self.fc = torch.nn.Linear(512 * 128, 1) 202 | 203 | if options['fc'] == True: 204 | # Freeze all previous layers. 205 | for param in self.features1.parameters(): 206 | param.requires_grad = False 207 | for param in self.features2.parameters(): 208 | param.requires_grad = False 209 | # Initialize the fc layers. 210 | nn.init.kaiming_normal_(self.fc.weight.data) 211 | if self.fc.bias is not None: 212 | nn.init.constant_(self.fc.bias.data, val=0) 213 | 214 | def loss_build(self, x_hat, x): 215 | distortion = F.mse_loss(x_hat, x, size_average=True) 216 | return distortion 217 | 218 | def forward(self, X, label, requires_loss): 219 | """Forward pass of the network. 220 | """ 221 | N = X.size()[0] 222 | X1 = self.features1(X) 223 | H = X1.size()[2] 224 | W = X1.size()[3] 225 | assert X1.size()[1] == 512 226 | X2 = self.features2(X) 227 | H2 = X2.size()[2] 228 | W2 = X2.size()[3] 229 | assert X2.size()[1] == 128 230 | 231 | if (H != H2) | (W != W2): 232 | X2 = F.upsample_bilinear(X2, (H, W)) 233 | 234 | X1 = X1.view(N, 512, H * W) 235 | X2 = X2.view(N, 128, H * W) 236 | X = torch.bmm(X1, torch.transpose(X2, 1, 2)) / (H * W) # Bilinear 237 | assert X.size() == (N, 512, 128) 238 | X = X.view(N, 512 * 128) 239 | X = torch.sqrt(X + 1e-8) 240 | X = torch.nn.functional.normalize(X) 241 | X = self.fc(X) 242 | assert X.size() == (N, 1) 243 | if requires_loss: 244 | return X, label, self.loss_build(X, label) 245 | else: 246 | return X 247 | 248 | class VGCN(torch.nn.Module): 249 | 250 | def __init__(self, root1, root2): 251 | super(VGCN, self).__init__() 252 | 253 | res_net = models.resnet18(pretrained=False) 254 | self.OIQA_branch = OIQANet(res_net) 255 | if root1: 256 | pretrained_dict1 = torch.load(root1) 257 | oiqa_dict = self.OIQA_branch.state_dict() 258 | pretrained_dict1 = {k: v for k, v in pretrained_dict1.items() if k in oiqa_dict} 259 | oiqa_dict.update(pretrained_dict1) 260 | self.OIQA_branch.load_state_dict(oiqa_dict) 261 | print('OIQA_branch model load!') 262 | 263 | options = {'fc': []} 264 | options['fc'] = True 265 | self.DBCNN_branch = DBCNN(options=options) 266 | if root2: 267 | pretrained_dict2 = torch.load(root2) 268 | dbcnn_dict = self.DBCNN_branch.state_dict() 269 | pretrained_dict2 = {k: v for k, v in pretrained_dict2.items() if k in dbcnn_dict} 270 | dbcnn_dict.update(pretrained_dict2) 271 | self.DBCNN_branch.load_state_dict(dbcnn_dict) 272 | print('DBCNN_branch model load!') 273 | 274 | self.fc = nn.Linear(2, 1) 275 | 276 | def para_init(self): 277 | for m in self.modules(): 278 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 279 | nn.init.xavier_normal_(m.weight.data) 280 | if m.bias is not None: 281 | m.bias.data.zero_() 282 | 283 | def loss_build(self, x_hat, x): 284 | distortion = F.mse_loss(x_hat, x, size_average=True) 285 | return distortion 286 | 287 | def forward(self, fov, whole, label, A, requires_loss): 288 | 289 | score1 = self.OIQA_branch(fov, label, A, requires_loss=False) 290 | score2 = self.DBCNN_branch(whole, label, requires_loss=False) 291 | score_fuse = torch.cat((score1, score2), dim=1) 292 | score = self.fc(score_fuse) 293 | 294 | if requires_loss: 295 | return score, label, self.loss_build(score, label) 296 | else: 297 | return score, label 298 | 299 | --------------------------------------------------------------------------------