├── Figures └── architecture.png ├── LICENSE ├── README.md ├── data └── pipal.py ├── download.sh ├── model └── deform_regressor.py ├── options ├── __init__.py ├── base_options.py ├── test_options.py └── train_options.py ├── requirements.txt ├── script └── extract_feature.py ├── test.py ├── test.sh ├── train.py ├── train.sh └── utils ├── process_image.py └── util.py /Figures/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IIGROUP/AHIQ/9cc6a4eed821bb8f16a1aae00d420b24625bc07a/Figures/architecture.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 ShanShan Lao 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Attention Helps CNN See Better: Hybrid Image Quality Assessment Network 2 | [CVPRW 2022] Code for Hybrid Image Quality Assessment Network 3 | 4 | [[paper]](https://arxiv.org/abs/2204.10485) [[code](https://github.com/IIGROUP/AHIQ)] 5 | 6 | *This is the official repository for NTIRE2022 Perceptual Image Quality Assessment Challenge Track 1 Full-Reference competition. 7 | **We won first place in the competition and the codes have been released now.*** 8 | 9 | > **Abstract:** *Image quality assessment (IQA) algorithm aims to quantify the human perception of image quality. Unfortunately, there is a performance drop when assessing the distortion images generated by generative adversarial network (GAN) with seemingly realistic texture. In this work, we conjecture that this maladaptation lies in the backbone of IQA models, where patch-level prediction methods use independent image patches as input to calculate their scores separately, but lack spatial relationship modeling among image patches. Therefore, we propose an Attention-based Hybrid Image Quality Assessment Network (AHIQ) to deal with the challenge and get better performance on the GAN-based IQA task. Firstly, we adopt a two-branch architecture, including a vision transformer (ViT) branch and a convolutional neural network (CNN) branch for feature extraction. The hybrid architecture combines interaction information among image patches captured by ViT and local texture details from CNN. To make the features from shallow CNN more focused on the visually salient region, a deformable convolution is applied with the help of semantic information from the ViT branch. Finally, we use a patch-wise score prediction module to obtain the final score. The experiments show that our model outperforms the state-of-the-art methods on four standard IQA datasets and AHIQ ranked first on the Full Reference (FR) track of the NTIRE 2022 Perceptual Image Quality Assessment Challenge.* 10 | 11 | ## Overview 12 |

13 | 14 | ## Getting Started 15 | 16 | ### Prerequisites 17 | - Linux 18 | - NVIDIA GPU + CUDA CuDNN 19 | - Python 3.7 20 | 21 | ### Dependencies 22 | 23 | We recommend running this repository using [Anaconda](https://docs.anaconda.com/anaconda/install/). All dependencies for defining the environment are provided in `requirements.txt`. 24 | 25 | ### Pretrained Models 26 | You may manually download the pretrained models from 27 | [Google Drive](https://drive.google.com/drive/folders/1-8LKOEDYt-RzmM9IDV_oW73uRBLqeRB6?usp=sharing) and put them into `checkpoints/ahiq_pipal/`, or simply use 28 | ``` 29 | sh download.sh 30 | ``` 31 | 32 | ### Instruction 33 | use `sh train.sh` or `sh test.sh` to train or test the model. You can also change the options in the `options/` as you like. 34 | 35 | ## Acknowledgment 36 | The codes borrow heavily from IQT implemented by [anse3832](https://github.com/anse3832/IQT) and we really appreciate it. 37 | 38 | ## Citation 39 | If you find our work or code helpful for your research, please consider to cite: 40 | ```bibtex 41 | @article{lao2022attentions, 42 | title = {Attentions Help CNNs See Better: Attention-based Hybrid Image Quality Assessment Network}, 43 | author = {Lao, Shanshan and Gong, Yuan and Shi, Shuwei and Yang, Sidi and Wu, Tianhe and Wang, Jiahao and Xia, Weihao and Yang, Yujiu}, 44 | journal = {arXiv preprint arXiv:2204.10485}, 45 | year = {2022} 46 | } 47 | ``` 48 | -------------------------------------------------------------------------------- /data/pipal.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import cv2 5 | 6 | 7 | class PIPAL(torch.utils.data.Dataset): 8 | def __init__(self, ref_path, dis_path, txt_file_name, transform, resize=False, size=None, flip=False): 9 | super(PIPAL, self).__init__() 10 | self.ref_path = ref_path 11 | self.dis_path = dis_path 12 | self.txt_file_name = txt_file_name 13 | self.transform = transform 14 | self.flip = flip 15 | self.resize = resize 16 | self.size = size 17 | ref_files_data, dis_files_data, score_data = [], [], [] 18 | with open(self.txt_file_name, 'r') as listFile: 19 | for line in listFile: 20 | dis, score = line[:-1].split(',') 21 | #dis = dis[:-1] 22 | ref = dis[:5] + '.bmp' 23 | score = float(score) 24 | ref_files_data.append(ref) 25 | dis_files_data.append(dis) 26 | score_data.append(score) 27 | 28 | # reshape score_list (1xn -> nx1) 29 | score_data = np.array(score_data) 30 | score_data = self.normalization(score_data) 31 | score_data = score_data.astype('float').reshape(-1, 1) 32 | 33 | self.data_dict = {'r_img_list': ref_files_data, 'd_img_list': dis_files_data, 'score_list': score_data} 34 | 35 | def normalization(self, data): 36 | range = np.max(data) - np.min(data) 37 | return (data - np.min(data)) / range 38 | 39 | def __len__(self): 40 | return len(self.data_dict['r_img_list']) 41 | 42 | def __getitem__(self, idx): 43 | # r_img: H x W x C -> C x H x W 44 | r_img_name = self.data_dict['r_img_list'][idx] 45 | r_img = cv2.imread(os.path.join(self.ref_path, r_img_name), cv2.IMREAD_COLOR) 46 | r_img = cv2.cvtColor(r_img, cv2.COLOR_BGR2RGB) 47 | if self.flip: 48 | r_img = np.fliplr(r_img).copy() 49 | if self.resize: 50 | r_img = cv2.resize(r_img, self.size) 51 | r_img = np.array(r_img).astype('float32') / 255 52 | r_img = (r_img - 0.5) / 0.5 53 | r_img = np.transpose(r_img, (2, 0, 1)) 54 | 55 | d_img_name = self.data_dict['d_img_list'][idx] 56 | d_img = cv2.imread(os.path.join(self.dis_path, d_img_name), cv2.IMREAD_COLOR) 57 | d_img = cv2.cvtColor(d_img, cv2.COLOR_BGR2RGB) 58 | if self.flip: 59 | d_img = np.fliplr(d_img).copy() 60 | if self.resize: 61 | d_img = cv2.resize(d_img, self.size) 62 | d_img = np.array(d_img).astype('float32') / 255 63 | d_img = (d_img - 0.5) / 0.5 64 | d_img = np.transpose(d_img, (2, 0, 1)) 65 | 66 | score = self.data_dict['score_list'][idx] 67 | sample = { 68 | 'r_img_org': r_img, 69 | 'd_img_org': d_img, 70 | 'score': score, 'd_img_name':d_img_name 71 | } 72 | if self.transform: 73 | sample = self.transform(sample) 74 | return sample -------------------------------------------------------------------------------- /download.sh: -------------------------------------------------------------------------------- 1 | # pip install gdown 2 | 3 | save_path='checkpoints/ahiq_pipal/' 4 | 5 | # download the pretrained models 6 | gdown https://drive.google.com/uc?id=1Nk-IpjnDNXbWacoh3T69wkSYhYYWip2W -O $save_path 7 | gdown https://drive.google.com/uc?id=1Jr2nLnhMA0f0uPEjMG7sH-T4WfIEasXn -O $save_path -------------------------------------------------------------------------------- /model/deform_regressor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchvision.ops.deform_conv import DeformConv2d 5 | 6 | class deform_fusion(nn.Module): 7 | def __init__(self, opt, in_channels=768*5, cnn_channels=256*3, out_channels=256*3): 8 | super().__init__() 9 | #in_channels, out_channels, kernel_size, stride, padding 10 | self.d_hidn = 512 11 | if opt.patch_size == 8: 12 | stride = 1 13 | else: 14 | stride = 2 15 | self.conv_offset = nn.Conv2d(in_channels, 2*3*3, 3, 1, 1) 16 | self.deform = DeformConv2d(cnn_channels, out_channels, 3, 1, 1) 17 | self.conv1 = nn.Sequential( 18 | nn.Conv2d(in_channels=out_channels, out_channels=self.d_hidn, kernel_size=3,padding=1,stride=2), 19 | nn.ReLU(), 20 | nn.Conv2d(in_channels=self.d_hidn, out_channels=out_channels, kernel_size=3, padding=1,stride=stride) 21 | ) 22 | 23 | def forward(self, cnn_feat, vit_feat): 24 | vit_feat = F.interpolate(vit_feat, size=cnn_feat.shape[-2:], mode="nearest") 25 | offset = self.conv_offset(vit_feat) 26 | deform_feat = self.deform(cnn_feat, offset) 27 | deform_feat = self.conv1(deform_feat) 28 | 29 | return deform_feat 30 | 31 | class Pixel_Prediction(nn.Module): 32 | def __init__(self, inchannels=768*5+256*3, outchannels=256, d_hidn=1024): 33 | super().__init__() 34 | self.d_hidn = d_hidn 35 | self.down_channel = nn.Conv2d(inchannels, outchannels, kernel_size=1) 36 | self.feat_smoothing = nn.Sequential( 37 | nn.Conv2d(in_channels=256*3, out_channels=self.d_hidn, kernel_size=3,padding=1), 38 | nn.ReLU(), 39 | nn.Conv2d(in_channels=self.d_hidn, out_channels=512, kernel_size=3, padding=1) 40 | ) 41 | 42 | self.conv1 = nn.Sequential( 43 | nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3,padding=1), 44 | nn.ReLU() 45 | ) 46 | self.conv_attent = nn.Sequential( 47 | nn.Conv2d(in_channels=256, out_channels=1, kernel_size=1), 48 | nn.Sigmoid() 49 | ) 50 | self.conv = nn.Sequential( 51 | nn.Conv2d(in_channels=256, out_channels=1, kernel_size=1), 52 | ) 53 | 54 | def forward(self,f_dis, f_ref, cnn_dis, cnn_ref): 55 | f_dis = torch.cat((f_dis,cnn_dis),1) 56 | f_ref = torch.cat((f_ref,cnn_ref),1) 57 | f_dis = self.down_channel(f_dis) 58 | f_ref = self.down_channel(f_ref) 59 | 60 | f_cat = torch.cat((f_dis - f_ref, f_dis, f_ref), 1) 61 | 62 | feat_fused = self.feat_smoothing(f_cat) 63 | feat = self.conv1(feat_fused) 64 | f = self.conv(feat) 65 | w = self.conv_attent(feat) 66 | pred = (f*w).sum(dim=2).sum(dim=2)/w.sum(dim=2).sum(dim=2) 67 | 68 | return pred -------------------------------------------------------------------------------- /options/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IIGROUP/AHIQ/9cc6a4eed821bb8f16a1aae00d420b24625bc07a/options/__init__.py -------------------------------------------------------------------------------- /options/base_options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from utils import util 4 | import torch 5 | 6 | class BaseOptions(): 7 | def __init__(self): 8 | self._parser = argparse.ArgumentParser() 9 | self._initialized = False 10 | 11 | def initialize(self): 12 | # dataset path 13 | self._parser.add_argument('--train_ref_path', type=str, default='Train_Ref/', help='path to reference images') 14 | self._parser.add_argument('--train_dis_path', type=str, default='Train_Dis/', help='path to distortion images') 15 | self._parser.add_argument('--val_ref_path', type=str, default='NTIRE2021/Ref/', help='path to reference images') 16 | self._parser.add_argument('--val_dis_path', type=str, default='NTIRE2021/Dis/', help='path to distortion images') 17 | self._parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') 18 | self._parser.add_argument('--train_list', type=str, default='PIPAL.txt', help='training data') 19 | self._parser.add_argument('--val_list', type=str, default='PIPAL_NTIRE_Valid_MOS.txt', help='testing data') 20 | # experiment 21 | self._parser.add_argument('--name', type=str, default='ahiq_pipal', 22 | help='name of the experiment. It decides where to store samples and models') 23 | # device 24 | self._parser.add_argument('--num_workers', type=int, default=8, help='total workers') 25 | # model 26 | self._parser.add_argument('--patch_size', type=int, default=8, help='patch size of Vision Transformer') 27 | self._parser.add_argument('--load_epoch', type=int, default=-1, help='which epoch to load? set to -1 to use latest cached model') 28 | self._parser.add_argument('--ckpt', type=str, default='./checkpoints', help='models to be loaded') 29 | self._parser.add_argument('--seed', type=int, default=1919, help='random seed') 30 | #data process 31 | self._parser.add_argument('--crop_size', type=int, default=224, help='image size') 32 | self._parser.add_argument('--num_crop', type=int, default=1, help='random crop times') 33 | self._parser.add_argument('--num_avg_val', type=int, default=5, help='ensemble ways of validation') 34 | 35 | self._parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids') 36 | self._initialized = True 37 | 38 | def parse(self): 39 | if not self._initialized: 40 | self.initialize() 41 | #self._opt = self._parser.parse_args() 42 | self._opt = self._parser.parse_known_args()[0] 43 | 44 | # set is train or set 45 | self._opt.is_train = self.is_train 46 | 47 | # set and check load_epoch 48 | self._set_and_check_load_epoch() 49 | 50 | # get and set gpus 51 | self._get_set_gpus() 52 | 53 | args = vars(self._opt) 54 | 55 | # print in terminal args 56 | self._print(args) 57 | 58 | # save args to file 59 | self._save(args) 60 | 61 | return self._opt 62 | 63 | def _set_and_check_load_epoch(self): 64 | models_dir = os.path.join(self._opt.checkpoints_dir, self._opt.name) 65 | if os.path.exists(models_dir): 66 | if self._opt.load_epoch == -1: 67 | load_epoch = 0 68 | for file in os.listdir(models_dir): 69 | if file.startswith("epoch"): 70 | load_epoch = max(load_epoch, int(file.split('.')[0].split('_')[1])) 71 | self._opt.load_epoch = load_epoch 72 | else: 73 | found = False 74 | for file in os.listdir(models_dir): 75 | if file.startswith("epoch"): 76 | found = int(file.split('_')[2]) == self._opt.load_epoch 77 | if found: break 78 | assert found, 'Model for epoch %i not found' % self._opt.load_epoch 79 | else: 80 | assert self._opt.load_epoch < 1, 'Model for epoch %i not found' % self._opt.load_epoch 81 | self._opt.load_epoch = 0 82 | 83 | def _get_set_gpus(self): 84 | # get gpu ids 85 | str_ids = self._opt.gpu_ids.split(',') 86 | self._opt.gpu_ids = [] 87 | for str_id in str_ids: 88 | id = int(str_id) 89 | if id >= 0: 90 | self._opt.gpu_ids.append(id) 91 | 92 | # set gpu ids 93 | if len(self._opt.gpu_ids) > 0: 94 | torch.cuda.set_device(self._opt.gpu_ids[0]) 95 | 96 | def _print(self, args): 97 | print('------------ Options -------------') 98 | for k, v in sorted(args.items()): 99 | print('%s: %s' % (str(k), str(v))) 100 | print('-------------- End ----------------') 101 | 102 | def _save(self, args): 103 | expr_dir = os.path.join(self._opt.checkpoints_dir, self._opt.name) 104 | print(expr_dir) 105 | util.mkdirs(expr_dir) 106 | file_name = os.path.join(expr_dir, 'opt_%s.txt' % ('train' if self.is_train else 'test')) 107 | with open(file_name, 'wt') as opt_file: 108 | opt_file.write('------------ Options -------------\n') 109 | for k, v in sorted(args.items()): 110 | opt_file.write('%s: %s\n' % (str(k), str(v))) 111 | opt_file.write('-------------- End ----------------\n') -------------------------------------------------------------------------------- /options/test_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | class TestOptions(BaseOptions): 4 | def initialize(self): 5 | BaseOptions.initialize(self) 6 | self._parser.add_argument('--test_ref_path', type=str, default='Test_Ref/', help='path to reference images') 7 | self._parser.add_argument('--test_dis_path', type=str, default='Test_Dis/', help='path to distortion images') 8 | self._parser.add_argument('--test_list', type=str, default='test.txt', help='training data') 9 | 10 | self._parser.add_argument('--batch_size', type=int, default=10, help='input batch size') 11 | self._parser.add_argument('--test_file_name', type=str, default='results.txt', help='txt path to save results') 12 | self._parser.add_argument('--n_ensemble', type=int, default=20, help='crop method for test: five points crop or nine points crop or random crop for several times') 13 | self._parser.add_argument('--flip', type=bool, default=False, help='if flip images when testing') 14 | self._parser.add_argument('--resize', type=bool, default=False, help='if resize images when testing') 15 | self._parser.add_argument('--size', type=int, default=224, help='the resize shape') 16 | self.is_train = False 17 | -------------------------------------------------------------------------------- /options/train_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TrainOptions(BaseOptions): 5 | def initialize(self): 6 | BaseOptions.initialize(self) 7 | self._parser.add_argument('--n_epoch', type=int, default=200, help='total epoch for training') 8 | self._parser.add_argument('--save_interval', type=int, default=5, help='interval for saving models') 9 | self._parser.add_argument('--learning_rate', type=float, default=1e-4, help='initial learning rate') 10 | self._parser.add_argument('--weight_decay', type=float, default=1e-5, help='weight decay') 11 | self._parser.add_argument('--batch_size', type=int, default=4, help='input batch size') 12 | self._parser.add_argument('--val_freq', type=int, default=1, help='validation frequency') 13 | self._parser.add_argument('--T_max', type=int, default=50, help="cosine learning rate period (iteration)") 14 | self._parser.add_argument('--eta_min', type=int, default=0, help="mininum learning rate") 15 | 16 | self.is_train = True -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm 2 | opencv-python 3 | torchvision==0.11.1 4 | timm==0.5.4 5 | pytorch==1.10.0 6 | scipy==1.7.3 7 | numpy==1.21.2 -------------------------------------------------------------------------------- /script/extract_feature.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def get_resnet_feature(save_output): 4 | feat = torch.cat( 5 | ( 6 | save_output.outputs[0], 7 | save_output.outputs[1], 8 | save_output.outputs[2] 9 | ), 10 | dim=1 11 | ) 12 | return feat 13 | 14 | def get_vit_feature(save_output): 15 | feat = torch.cat( 16 | ( 17 | save_output.outputs[0][:,1:,:], 18 | save_output.outputs[1][:,1:,:], 19 | save_output.outputs[2][:,1:,:], 20 | save_output.outputs[3][:,1:,:], 21 | save_output.outputs[4][:,1:,:], 22 | ), 23 | dim=2 24 | ) 25 | return feat 26 | 27 | def get_inception_feature(save_output): 28 | feat = torch.cat( 29 | ( 30 | save_output.outputs[0], 31 | save_output.outputs[2], 32 | save_output.outputs[4], 33 | save_output.outputs[6], 34 | save_output.outputs[8], 35 | save_output.outputs[10] 36 | ), 37 | dim=1 38 | ) 39 | return feat 40 | 41 | def get_resnet152_feature(save_output): 42 | feat = torch.cat( 43 | ( 44 | save_output.outputs[3], 45 | save_output.outputs[4], 46 | save_output.outputs[6], 47 | save_output.outputs[7], 48 | save_output.outputs[8], 49 | save_output.outputs[10] 50 | ), 51 | dim=1 52 | ) 53 | return feat -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import os 3 | import torch 4 | import numpy as np 5 | import logging 6 | from scipy.stats import spearmanr, pearsonr 7 | import timm 8 | from timm.models.vision_transformer import Block 9 | from timm.models.resnet import BasicBlock,Bottleneck 10 | import time 11 | from torch.utils.data import DataLoader 12 | 13 | from utils.util import setup_seed,set_logging,SaveOutput 14 | from script.extract_feature import get_resnet_feature, get_vit_feature 15 | from options.test_options import TestOptions 16 | from model.deform_regressor import deform_fusion, Pixel_Prediction 17 | from data.pipal import PIPAL 18 | from utils.process_image import ToTensor, RandHorizontalFlip, RandCrop, crop_image, Normalize, five_point_crop 19 | from torchvision import transforms 20 | 21 | class Test: 22 | def __init__(self, config): 23 | self.opt = config 24 | self.create_model() 25 | self.init_saveoutput() 26 | self.init_data() 27 | self.load_model() 28 | self.test() 29 | 30 | def create_model(self): 31 | self.resnet50 = timm.create_model('resnet50',pretrained=True).cuda() 32 | if self.opt.patch_size == 8: 33 | self.vit = timm.create_model('vit_base_patch8_224',pretrained=True).cuda() 34 | else: 35 | self.vit = timm.create_model('vit_base_patch16_224',pretrained=True).cuda() 36 | self.deform_net = deform_fusion(self.opt).cuda() 37 | self.regressor = Pixel_Prediction().cuda() 38 | 39 | def init_saveoutput(self): 40 | self.save_output = SaveOutput() 41 | hook_handles = [] 42 | for layer in self.resnet50.modules(): 43 | if isinstance(layer, Bottleneck): 44 | handle = layer.register_forward_hook(self.save_output) 45 | hook_handles.append(handle) 46 | for layer in self.vit.modules(): 47 | if isinstance(layer, Block): 48 | handle = layer.register_forward_hook(self.save_output) 49 | hook_handles.append(handle) 50 | 51 | def init_data(self): 52 | test_dataset = PIPAL( 53 | ref_path=self.opt.test_ref_path, 54 | dis_path=self.opt.test_dis_path, 55 | txt_file_name=self.opt.test_list, 56 | resize=self.opt.resize, 57 | size=(self.opt.size,self.opt.size), 58 | flip=self.opt.flip, 59 | transform=ToTensor(), 60 | ) 61 | logging.info('number of test scenes: {}'.format(len(test_dataset))) 62 | 63 | self.test_loader = DataLoader( 64 | dataset=test_dataset, 65 | batch_size=self.opt.batch_size, 66 | num_workers=self.opt.num_workers, 67 | drop_last=True, 68 | shuffle=False 69 | ) 70 | 71 | def load_model(self): 72 | models_dir = self.opt.checkpoints_dir 73 | if os.path.exists(models_dir): 74 | if self.opt.load_epoch == -1: 75 | load_epoch = 0 76 | for file in os.listdir(models_dir): 77 | if file.startswith("epoch"): 78 | load_epoch = max(load_epoch, int(file.split('.')[0].split('_')[1])) 79 | self.opt.load_epoch = load_epoch 80 | checkpoint = torch.load(os.path.join(models_dir,"epoch"+str(self.opt.load_epoch)+".pth")) 81 | self.regressor.load_state_dict(checkpoint['regressor_model_state_dict']) 82 | self.deform_net.load_state_dict(checkpoint['deform_net_model_state_dict']) 83 | self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 84 | self.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) 85 | self.start_epoch = checkpoint['epoch']+1 86 | loss = checkpoint['loss'] 87 | else: 88 | found = False 89 | for file in os.listdir(models_dir): 90 | if file.startswith("epoch"): 91 | found = int(file.split('.')[0].split('_')[1]) == self.opt.load_epoch 92 | if found: break 93 | assert found, 'Model for epoch %i not found' % self.opt.load_epoch 94 | else: 95 | assert self.opt.load_epoch < 1, 'Model for epoch %i not found' % self.opt.load_epoch 96 | self.opt.load_epoch = 0 97 | 98 | def test(self): 99 | f = open(os.path.join(self.opt.checkpoints_dir,self.opt.test_file_name), 'w') 100 | with torch.no_grad(): 101 | for data in tqdm(self.test_loader): 102 | d_img_org = data['d_img_org'].cuda() 103 | r_img_org = data['r_img_org'].cuda() 104 | d_img_name = data['d_img_name'] 105 | pred = 0 106 | for i in range(self.opt.n_ensemble): 107 | b, c, h, w = r_img_org.size() 108 | if self.opt.n_ensemble > 9: 109 | new_h = config.crop_size 110 | new_w = config.crop_size 111 | top = np.random.randint(0, h - new_h) 112 | left = np.random.randint(0, w - new_w) 113 | r_img = r_img_org[:,:, top: top+new_h, left: left+new_w] 114 | d_img = d_img_org[:,:, top: top+new_h, left: left+new_w] 115 | elif self.opt.n_ensemble == 1: 116 | r_img = r_img_org 117 | d_img = d_img_org 118 | else: 119 | d_img, r_img = five_point_crop(i, d_img=d_img_org, r_img=r_img_org, config=self.opt) 120 | d_img = d_img.cuda() 121 | r_img = r_img.cuda() 122 | _x = self.vit(d_img) 123 | vit_dis = get_vit_feature(self.save_output) 124 | self.save_output.outputs.clear() 125 | 126 | _y = self.vit(r_img) 127 | vit_ref = get_vit_feature(self.save_output) 128 | self.save_output.outputs.clear() 129 | B, N, C = vit_ref.shape 130 | if self.opt.patch_size == 8: 131 | H,W = 28,28 132 | else: 133 | H,W = 14,14 134 | assert H*W==N 135 | vit_ref = vit_ref.transpose(1, 2).view(B, C, H, W) 136 | vit_dis = vit_dis.transpose(1, 2).view(B, C, H, W) 137 | 138 | _ = self.resnet50(d_img) 139 | cnn_dis = get_resnet_feature(self.save_output) 140 | self.save_output.outputs.clear() 141 | cnn_dis = self.deform_net(cnn_dis,vit_ref) 142 | 143 | _ = self.resnet50(r_img) 144 | cnn_ref = get_resnet_feature(self.save_output) 145 | self.save_output.outputs.clear() 146 | cnn_ref = self.deform_net(cnn_ref,vit_ref) 147 | pred += self.regressor(vit_dis, vit_ref, cnn_dis, cnn_ref) 148 | 149 | pred /= self.opt.n_ensemble 150 | for i in range(len(d_img_name)): 151 | line = "%s,%f\n" % (d_img_name[i], float(pred.squeeze()[i])) 152 | f.write(line) 153 | 154 | f.close() 155 | 156 | 157 | 158 | if __name__ == '__main__': 159 | config = TestOptions().parse() 160 | config.checkpoints_dir = os.path.join(config.checkpoints_dir, config.name) 161 | setup_seed(config.seed) 162 | set_logging(config) 163 | Test(config) 164 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | python test.py 2 | --test_ref_path /path/to/test 3 | --test_dis_path /path/to/test 4 | --test_list /txt/for/test -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import os 3 | import torch 4 | import numpy as np 5 | import logging 6 | from scipy.stats import spearmanr, pearsonr 7 | import timm 8 | from timm.models.vision_transformer import Block 9 | from timm.models.resnet import BasicBlock,Bottleneck 10 | import time 11 | 12 | from torch.utils.data import DataLoader 13 | 14 | from utils.util import setup_seed,set_logging,SaveOutput 15 | from script.extract_feature import get_resnet_feature, get_vit_feature 16 | from options.train_options import TrainOptions 17 | from model.deform_regressor import deform_fusion, Pixel_Prediction 18 | from data.pipal import PIPAL 19 | from utils.process_image import ToTensor, RandHorizontalFlip, RandCrop, crop_image, Normalize, five_point_crop 20 | from torchvision import transforms 21 | 22 | class Train: 23 | def __init__(self, config): 24 | self.opt = config 25 | self.create_model() 26 | self.init_saveoutput() 27 | self.init_data() 28 | self.criterion = torch.nn.MSELoss() 29 | self.optimizer = torch.optim.Adam([ 30 | {'params': self.regressor.parameters(), 'lr': self.opt.learning_rate,'weight_decay':self.opt.weight_decay}, 31 | {'params': self.deform_net.parameters(),'lr': self.opt.learning_rate,'weight_decay':self.opt.weight_decay} 32 | ]) 33 | self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=self.opt.T_max, eta_min=self.opt.eta_min) 34 | self.load_model() 35 | self.train() 36 | 37 | def create_model(self): 38 | self.resnet50 = timm.create_model('resnet50',pretrained=True).cuda() 39 | if self.opt.patch_size == 8: 40 | self.vit = timm.create_model('vit_base_patch8_224',pretrained=True).cuda() 41 | else: 42 | self.vit = timm.create_model('vit_base_patch16_224',pretrained=True).cuda() 43 | self.deform_net = deform_fusion(self.opt).cuda() 44 | self.regressor = Pixel_Prediction().cuda() 45 | 46 | def init_saveoutput(self): 47 | self.save_output = SaveOutput() 48 | hook_handles = [] 49 | for layer in self.resnet50.modules(): 50 | if isinstance(layer, Bottleneck): 51 | handle = layer.register_forward_hook(self.save_output) 52 | hook_handles.append(handle) 53 | for layer in self.vit.modules(): 54 | if isinstance(layer, Block): 55 | handle = layer.register_forward_hook(self.save_output) 56 | hook_handles.append(handle) 57 | 58 | def init_data(self): 59 | train_dataset = PIPAL( 60 | ref_path=self.opt.train_ref_path, 61 | dis_path=self.opt.train_dis_path, 62 | txt_file_name=self.opt.train_list, 63 | transform=transforms.Compose( 64 | [ 65 | RandCrop(self.opt.crop_size, self.opt.num_crop), 66 | #Normalize(0.5, 0.5), 67 | RandHorizontalFlip(), 68 | ToTensor(), 69 | ] 70 | ), 71 | ) 72 | val_dataset = PIPAL( 73 | ref_path=self.opt.val_ref_path, 74 | dis_path=self.opt.val_dis_path, 75 | txt_file_name=self.opt.val_list, 76 | transform=ToTensor(), 77 | ) 78 | logging.info('number of train scenes: {}'.format(len(train_dataset))) 79 | logging.info('number of val scenes: {}'.format(len(val_dataset))) 80 | 81 | self.train_loader = DataLoader( 82 | dataset=train_dataset, 83 | batch_size=self.opt.batch_size, 84 | num_workers=self.opt.num_workers, 85 | drop_last=True, 86 | shuffle=True 87 | ) 88 | self.val_loader = DataLoader( 89 | dataset=val_dataset, 90 | batch_size=self.opt.batch_size, 91 | num_workers=self.opt.num_workers, 92 | drop_last=True, 93 | shuffle=False 94 | ) 95 | 96 | def load_model(self): 97 | models_dir = self.opt.checkpoints_dir 98 | if os.path.exists(models_dir): 99 | if self.opt.load_epoch == -1: 100 | load_epoch = 0 101 | for file in os.listdir(models_dir): 102 | if file.startswith("epoch_"): 103 | load_epoch = max(load_epoch, int(file.split('.')[0].split('_')[1])) 104 | self.opt.load_epoch = load_epoch 105 | checkpoint = torch.load(os.path.join(models_dir,"epoch_"+str(self.opt.load_epoch)+".pth")) 106 | self.regressor.load_state_dict(checkpoint['regressor_model_state_dict']) 107 | self.deform_net.load_state_dict(checkpoint['deform_net_model_state_dict']) 108 | self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 109 | self.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) 110 | self.start_epoch = checkpoint['epoch']+1 111 | loss = checkpoint['loss'] 112 | else: 113 | found = False 114 | for file in os.listdir(models_dir): 115 | if file.startswith("epoch_"): 116 | found = int(file.split('.')[0].split('_')[1]) == self.opt.load_epoch 117 | if found: break 118 | assert found, 'Model for epoch %i not found' % self.opt.load_epoch 119 | else: 120 | assert self.opt.load_epoch < 1, 'Model for epoch %i not found' % self.opt.load_epoch 121 | self.opt.load_epoch = 0 122 | 123 | def train_epoch(self, epoch): 124 | losses = [] 125 | self.regressor.train() 126 | self.deform_net.train() 127 | self.vit.eval() 128 | self.resnet50.eval() 129 | # save data for one epoch 130 | pred_epoch = [] 131 | labels_epoch = [] 132 | 133 | for data in tqdm(self.train_loader): 134 | d_img_org = data['d_img_org'].cuda() 135 | r_img_org = data['r_img_org'].cuda() 136 | labels = data['score'] 137 | labels = torch.squeeze(labels.type(torch.FloatTensor)).cuda() 138 | 139 | _x = self.vit(d_img_org) 140 | vit_dis = get_vit_feature(self.save_output) 141 | self.save_output.outputs.clear() 142 | 143 | _y = self.vit(r_img_org) 144 | vit_ref = get_vit_feature(self.save_output) 145 | self.save_output.outputs.clear() 146 | B, N, C = vit_ref.shape 147 | if self.opt.patch_size == 8: 148 | H,W = 28,28 149 | else: 150 | H,W = 14,14 151 | assert H*W==N 152 | vit_ref = vit_ref.transpose(1, 2).view(B, C, H, W) 153 | vit_dis = vit_dis.transpose(1, 2).view(B, C, H, W) 154 | 155 | _ = self.resnet50(d_img_org) 156 | cnn_dis = get_resnet_feature(self.save_output) #0,1,2都是[B,256,56,56] 157 | self.save_output.outputs.clear() 158 | cnn_dis = self.deform_net(cnn_dis,vit_ref) 159 | 160 | _ = self.resnet50(r_img_org) 161 | cnn_ref = get_resnet_feature(self.save_output) 162 | self.save_output.outputs.clear() 163 | cnn_ref = self.deform_net(cnn_ref,vit_ref) 164 | 165 | pred = self.regressor(vit_dis, vit_ref, cnn_dis, cnn_ref) 166 | 167 | self.optimizer.zero_grad() 168 | loss = self.criterion(torch.squeeze(pred), labels) 169 | losses.append(loss.item()) 170 | 171 | loss.backward() 172 | self.optimizer.step() 173 | self.scheduler.step() 174 | 175 | # save results in one epoch 176 | pred_batch_numpy = pred.data.cpu().numpy() 177 | labels_batch_numpy = labels.data.cpu().numpy() 178 | pred_epoch = np.append(pred_epoch, pred_batch_numpy) 179 | labels_epoch = np.append(labels_epoch, labels_batch_numpy) 180 | 181 | # compute correlation coefficient 182 | rho_s, _ = spearmanr(np.squeeze(pred_epoch), np.squeeze(labels_epoch)) 183 | rho_p, _ = pearsonr(np.squeeze(pred_epoch), np.squeeze(labels_epoch)) 184 | 185 | ret_loss = np.mean(losses) 186 | print('train epoch:{} / loss:{:.4} / SRCC:{:.4} / PLCC:{:.4}'.format(epoch + 1, ret_loss, rho_s, rho_p)) 187 | logging.info('train epoch:{} / loss:{:.4} / SRCC:{:.4} / PLCC:{:.4}'.format(epoch + 1, ret_loss, rho_s, rho_p)) 188 | 189 | return ret_loss, rho_s, rho_p 190 | 191 | def train(self): 192 | best_srocc = 0 193 | best_plcc = 0 194 | for epoch in range(self.opt.load_epoch, self.opt.n_epoch): 195 | start_time = time.time() 196 | logging.info('Running training epoch {}'.format(epoch + 1)) 197 | loss_val, rho_s, rho_p = self.train_epoch(epoch) 198 | if (epoch + 1) % self.opt.val_freq == 0: 199 | logging.info('Starting eval...') 200 | logging.info('Running testing in epoch {}'.format(epoch + 1)) 201 | loss, rho_s, rho_p = self.eval_epoch(epoch) 202 | logging.info('Eval done...') 203 | 204 | if rho_s > best_srocc or rho_p > best_plcc: 205 | best_srocc = rho_s 206 | best_plcc = rho_p 207 | print('Best now') 208 | logging.info('Best now') 209 | self.save_model( epoch, "best.pth", loss, rho_s, rho_p) 210 | if epoch % self.opt.save_interval == 0: 211 | weights_file_name = "epoch_%d.pth" % (epoch+1) 212 | self.save_model( epoch, weights_file_name, loss, rho_s, rho_p) 213 | logging.info('Epoch {} done. Time: {:.2}min'.format(epoch + 1, (time.time() - start_time) / 60)) 214 | 215 | def eval_epoch(self, epoch): 216 | with torch.no_grad(): 217 | losses = [] 218 | self.regressor.train() 219 | self.deform_net.train() 220 | self.vit.eval() 221 | self.resnet50.eval() 222 | # save data for one epoch 223 | pred_epoch = [] 224 | labels_epoch = [] 225 | 226 | for data in tqdm(self.val_loader): 227 | pred = 0 228 | for i in range(self.opt.num_avg_val): 229 | d_img_org = data['d_img_org'].cuda() 230 | r_img_org = data['r_img_org'].cuda() 231 | labels = data['score'] 232 | labels = torch.squeeze(labels.type(torch.FloatTensor)).cuda() 233 | 234 | d_img_org, r_img_org = five_point_crop(i, d_img=d_img_org, r_img=r_img_org, config=self.opt) 235 | 236 | _x = self.vit(d_img_org) 237 | vit_dis = get_vit_feature(self.save_output) 238 | self.save_output.outputs.clear() 239 | 240 | _y = self.vit(r_img_org) 241 | vit_ref = get_vit_feature(self.save_output) 242 | self.save_output.outputs.clear() 243 | B, N, C = vit_ref.shape 244 | if self.opt.patch_size == 8: 245 | H,W = 28,28 246 | else: 247 | H,W = 14,14 248 | assert H*W==N 249 | vit_ref = vit_ref.transpose(1, 2).view(B, C, H, W) 250 | vit_dis = vit_dis.transpose(1, 2).view(B, C, H, W) 251 | 252 | _ = self.resnet50(d_img_org) 253 | cnn_dis = get_resnet_feature(self.save_output) #0,1,2都是[B,256,56,56] 254 | self.save_output.outputs.clear() 255 | cnn_dis = self.deform_net(cnn_dis,vit_ref) 256 | 257 | _ = self.resnet50(r_img_org) 258 | cnn_ref = get_resnet_feature(self.save_output) 259 | self.save_output.outputs.clear() 260 | cnn_ref = self.deform_net(cnn_ref,vit_ref) 261 | 262 | pred += self.regressor(vit_dis, vit_ref, cnn_dis, cnn_ref) 263 | 264 | pred /= self.opt.num_avg_val 265 | # compute loss 266 | loss = self.criterion(torch.squeeze(pred), labels) 267 | loss_val = loss.item() 268 | losses.append(loss_val) 269 | 270 | # save results in one epoch 271 | pred_batch_numpy = pred.data.cpu().numpy() 272 | labels_batch_numpy = labels.data.cpu().numpy() 273 | pred_epoch = np.append(pred_epoch, pred_batch_numpy) 274 | labels_epoch = np.append(labels_epoch, labels_batch_numpy) 275 | 276 | # compute correlation coefficient 277 | rho_s, _ = spearmanr(np.squeeze(pred_epoch), np.squeeze(labels_epoch)) 278 | rho_p, _ = pearsonr(np.squeeze(pred_epoch), np.squeeze(labels_epoch)) 279 | print('Epoch:{} ===== loss:{:.4} ===== SRCC:{:.4} ===== PLCC:{:.4}'.format(epoch + 1, np.mean(losses), rho_s, rho_p)) 280 | logging.info('Epoch:{} ===== loss:{:.4} ===== SRCC:{:.4} ===== PLCC:{:.4}'.format(epoch + 1, np.mean(losses), rho_s, rho_p)) 281 | return np.mean(losses), rho_s, rho_p 282 | 283 | def save_model(self, epoch, weights_file_name, loss, rho_s, rho_p): 284 | print('-------------saving weights---------') 285 | weights_file = os.path.join(self.opt.checkpoints_dir, weights_file_name) 286 | torch.save({ 287 | 'epoch': epoch, 288 | 'regressor_model_state_dict': self.regressor.state_dict(), 289 | 'deform_net_model_state_dict': self.deform_net.state_dict(), 290 | 'optimizer_state_dict': self.optimizer.state_dict(), 291 | 'scheduler_state_dict': self.scheduler.state_dict(), 292 | 'loss': loss 293 | }, weights_file) 294 | logging.info('Saving weights and model of epoch{}, SRCC:{}, PLCC:{}'.format(epoch, rho_s, rho_p)) 295 | 296 | if __name__ == '__main__': 297 | config = TrainOptions().parse() 298 | config.checkpoints_dir = os.path.join(config.checkpoints_dir, config.name) 299 | setup_seed(config.seed) 300 | set_logging(config) 301 | # logging.info(config) 302 | Train(config) 303 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | python train.py 2 | --train_ref_path /path/to/train 3 | --train_dis_path /path/to/train 4 | --train_list /txt/for/train 5 | --val_ref_path /path/to/val 6 | --val_dis_path /path/to/val 7 | --val_list /txt/for/val 8 | --checkpoints_dir /path/for/save/ckpt -------------------------------------------------------------------------------- /utils/process_image.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def crop_image(top, left, new_h, new_w, img=None): 6 | b, c, h, w = img.shape 7 | tmp_img = img[ : , : , top: top + new_h, left: left + new_w] 8 | return tmp_img 9 | 10 | class RandCrop(object): 11 | def __init__(self, patch_size, num_crop): 12 | self.patch_size = patch_size 13 | self.num_crop = num_crop 14 | 15 | def __call__(self, sample): 16 | # r_img : C x H x W (numpy) 17 | r_img, d_img = sample['r_img_org'], sample['d_img_org'] 18 | score = sample['score'] 19 | d_img_name = sample['d_img_name'] 20 | 21 | c, h, w = d_img.shape 22 | new_h = self.patch_size 23 | new_w = self.patch_size 24 | ret_r_img = np.zeros((c, self.patch_size, self.patch_size)) 25 | ret_d_img = np.zeros((c, self.patch_size, self.patch_size)) 26 | for _ in range(self.num_crop): 27 | top = np.random.randint(0, h - new_h) 28 | left = np.random.randint(0, w - new_w) 29 | tmp_r_img = r_img[:, top: top + new_h, left: left + new_w] 30 | tmp_d_img = d_img[:, top: top + new_h, left: left + new_w] 31 | ret_r_img += tmp_r_img 32 | ret_d_img += tmp_d_img 33 | ret_r_img /= self.num_crop 34 | ret_d_img /= self.num_crop 35 | 36 | sample = { 37 | 'r_img_org': ret_r_img, 38 | 'd_img_org': ret_d_img, 39 | 'score': score, 'd_img_name':d_img_name 40 | } 41 | 42 | return sample 43 | 44 | def five_point_crop(idx, d_img, r_img, config): 45 | new_h = config.crop_size 46 | new_w = config.crop_size 47 | if len(d_img.shape) == 3: 48 | c, h, w = d_img.shape 49 | else: 50 | b, c, h, w = d_img.shape 51 | center_h = h // 2 52 | center_w = w // 2 53 | if idx == 0: 54 | top = 0 55 | left = 0 56 | elif idx == 1: 57 | top = 0 58 | left = w - new_w 59 | elif idx == 2: 60 | top = h - new_h 61 | left = 0 62 | elif idx == 3: 63 | top = h - new_h 64 | left = w - new_w 65 | elif idx == 4: 66 | top = center_h - new_h // 2 67 | left = center_w - new_w // 2 68 | elif idx == 5: 69 | left = 0 70 | top = center_h - new_h // 2 71 | elif idx == 6: 72 | left = w - new_w 73 | top = center_h - new_h // 2 74 | elif idx == 7: 75 | top = 0 76 | left = center_w - new_w // 2 77 | elif idx == 8: 78 | top = h - new_h 79 | left = center_w - new_w // 2 80 | if len(d_img.shape) == 3: 81 | d_img_org = d_img[: , top: top + new_h, left: left + new_w] 82 | r_img_org = r_img[: , top: top + new_h, left: left + new_w] 83 | else: 84 | d_img_org = d_img[ :,: , top: top + new_h, left: left + new_w] 85 | r_img_org = r_img[ :,: , top: top + new_h, left: left + new_w] 86 | return d_img_org, r_img_org 87 | 88 | class RandCrop_fivepoints(object): 89 | def __init__(self, patch_size, num_crop, config): 90 | self.patch_size = patch_size 91 | self.num_crop = num_crop 92 | self.config = config 93 | def __call__(self, sample): 94 | # r_img : C x H x W (numpy) 95 | r_img, d_img = sample['r_img_org'], sample['d_img_org'] 96 | score = sample['score'] 97 | 98 | for idx in range(5): 99 | r_img, d_img = five_point_crop(idx, d_img, r_img, self.config) 100 | if idx == 0: 101 | ret_r_img = r_img.unsqueeze(0) 102 | ret_d_img = d_img.unsqueeze(0) 103 | else: 104 | ret_r_img = torch.cat((ret_r_img, r_img.unsqueeze(0)),dim=0) 105 | ret_d_img = torch.cat((ret_d_img, d_img.unsqueeze(0)),dim=0) 106 | 107 | sample = { 108 | 'r_img_org': ret_r_img, 109 | 'd_img_org': ret_d_img, 110 | 'score': score 111 | } 112 | 113 | return sample 114 | 115 | class RandCrop_points(object): 116 | def __init__(self, patch_size, num_crop, config): 117 | self.patch_size = patch_size 118 | self.num_crop = num_crop 119 | self.config = config 120 | def __call__(self, sample): 121 | # r_img : C x H x W (numpy) 122 | r_img, d_img = sample['r_img_org'], sample['d_img_org'] 123 | score = sample['score'] 124 | d_img_name = sample['d_img_name'] 125 | 126 | c, h, w = d_img.shape 127 | new_h = self.patch_size 128 | new_w = self.patch_size 129 | 130 | for idx in range(self.num_crop): 131 | if self.num_crop == 5 or self.num_crop == 9: 132 | r_img_org, d_img_org = five_point_crop(idx, d_img, r_img, self.config) 133 | else: 134 | top = np.random.randint(0, h - new_h) 135 | left = np.random.randint(0, w - new_w) 136 | r_img_org = r_img[:, top: top + new_h, left: left + new_w] 137 | d_img_org = d_img[:, top: top + new_h, left: left + new_w] 138 | if idx == 0: 139 | ret_r_img = r_img_org.unsqueeze(0) 140 | ret_d_img = d_img_org.unsqueeze(0) 141 | else: 142 | ret_r_img = torch.cat((ret_r_img, r_img_org.unsqueeze(0)),dim=0) 143 | ret_d_img = torch.cat((ret_d_img, d_img_org.unsqueeze(0)),dim=0) 144 | 145 | sample = { 146 | 'r_img_org': ret_r_img, 147 | 'd_img_org': ret_d_img, 148 | 'score': score, 'd_img_name':d_img_name 149 | } 150 | 151 | return sample 152 | 153 | class Normalize(object): 154 | def __init__(self, mean, var): 155 | self.mean = mean 156 | self.var = var 157 | 158 | def __call__(self, sample): 159 | # r_img: C x H x W (numpy) 160 | r_img, d_img = sample['r_img_org'], sample['d_img_org'] 161 | score = sample['score'] 162 | d_img_name = sample['d_img_name'] 163 | 164 | r_img = (r_img - self.mean) / self.var 165 | d_img = (d_img - self.mean) / self.var 166 | 167 | sample = {'r_img_org': r_img, 'd_img_org': d_img, 'score': score, 'd_img_name':d_img_name} 168 | return sample 169 | 170 | 171 | class RandHorizontalFlip(object): 172 | def __init__(self): 173 | pass 174 | 175 | def __call__(self, sample): 176 | r_img, d_img = sample['r_img_org'], sample['d_img_org'] 177 | score = sample['score'] 178 | d_img_name = sample['d_img_name'] 179 | prob_lr = np.random.random() 180 | # np.fliplr needs HxWxC 181 | if prob_lr > 0.5: 182 | d_img = np.fliplr(d_img).copy() 183 | r_img = np.fliplr(r_img).copy() 184 | 185 | sample = { 186 | 'r_img_org': r_img, 187 | 'd_img_org': d_img, 188 | 'score': score, 'd_img_name':d_img_name 189 | } 190 | return sample 191 | 192 | 193 | class ToTensor(object): 194 | def __init__(self): 195 | pass 196 | 197 | def __call__(self, sample): 198 | r_img, d_img = sample['r_img_org'], sample['d_img_org'] 199 | score = sample['score'] 200 | d_img_name = sample['d_img_name'] 201 | d_img = torch.from_numpy(d_img).type(torch.FloatTensor) 202 | r_img = torch.from_numpy(r_img).type(torch.FloatTensor) 203 | score = torch.from_numpy(score).type(torch.FloatTensor) 204 | sample = { 205 | 'r_img_org': r_img, 206 | 'd_img_org': d_img, 207 | 'score': score, 'd_img_name':d_img_name 208 | } 209 | return sample -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | import torch 5 | import logging 6 | 7 | def mkdirs(paths): 8 | if isinstance(paths, list) and not isinstance(paths, str): 9 | for path in paths: 10 | mkdir(path) 11 | else: 12 | mkdir(paths) 13 | 14 | def mkdir(path): 15 | if not os.path.exists(path): 16 | os.makedirs(path) 17 | 18 | def setup_seed(seed): 19 | random.seed(seed) 20 | os.environ['PYTHONHASHSEED'] = str(seed) 21 | np.random.seed(seed) 22 | torch.manual_seed(seed) 23 | torch.cuda.manual_seed(seed) 24 | torch.cuda.manual_seed_all(seed) 25 | torch.backends.cudnn.benchmark = False 26 | torch.backends.cudnn.deterministic = True 27 | 28 | def set_logging(config): 29 | filename = os.path.join(config.checkpoints_dir, "log.txt") 30 | logging.basicConfig( 31 | level=logging.INFO, 32 | filename=filename, 33 | filemode='w', 34 | format='[%(asctime)s %(levelname)-8s] %(message)s', 35 | datefmt='%Y%m%d %H:%M:%S' 36 | ) 37 | 38 | class SaveOutput: 39 | def __init__(self): 40 | self.outputs = [] 41 | 42 | def __call__(self, module, module_in, module_out): 43 | self.outputs.append(module_out) 44 | 45 | def clear(self): 46 | self.outputs = [] --------------------------------------------------------------------------------