├── Figures
└── architecture.png
├── LICENSE
├── README.md
├── data
└── pipal.py
├── download.sh
├── model
└── deform_regressor.py
├── options
├── __init__.py
├── base_options.py
├── test_options.py
└── train_options.py
├── requirements.txt
├── script
└── extract_feature.py
├── test.py
├── test.sh
├── train.py
├── train.sh
└── utils
├── process_image.py
└── util.py
/Figures/architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IIGROUP/AHIQ/9cc6a4eed821bb8f16a1aae00d420b24625bc07a/Figures/architecture.png
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 ShanShan Lao
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Attention Helps CNN See Better: Hybrid Image Quality Assessment Network
2 | [CVPRW 2022] Code for Hybrid Image Quality Assessment Network
3 |
4 | [[paper]](https://arxiv.org/abs/2204.10485) [[code](https://github.com/IIGROUP/AHIQ)]
5 |
6 | *This is the official repository for NTIRE2022 Perceptual Image Quality Assessment Challenge Track 1 Full-Reference competition.
7 | **We won first place in the competition and the codes have been released now.***
8 |
9 | > **Abstract:** *Image quality assessment (IQA) algorithm aims to quantify the human perception of image quality. Unfortunately, there is a performance drop when assessing the distortion images generated by generative adversarial network (GAN) with seemingly realistic texture. In this work, we conjecture that this maladaptation lies in the backbone of IQA models, where patch-level prediction methods use independent image patches as input to calculate their scores separately, but lack spatial relationship modeling among image patches. Therefore, we propose an Attention-based Hybrid Image Quality Assessment Network (AHIQ) to deal with the challenge and get better performance on the GAN-based IQA task. Firstly, we adopt a two-branch architecture, including a vision transformer (ViT) branch and a convolutional neural network (CNN) branch for feature extraction. The hybrid architecture combines interaction information among image patches captured by ViT and local texture details from CNN. To make the features from shallow CNN more focused on the visually salient region, a deformable convolution is applied with the help of semantic information from the ViT branch. Finally, we use a patch-wise score prediction module to obtain the final score. The experiments show that our model outperforms the state-of-the-art methods on four standard IQA datasets and AHIQ ranked first on the Full Reference (FR) track of the NTIRE 2022 Perceptual Image Quality Assessment Challenge.*
10 |
11 | ## Overview
12 |
13 |
14 | ## Getting Started
15 |
16 | ### Prerequisites
17 | - Linux
18 | - NVIDIA GPU + CUDA CuDNN
19 | - Python 3.7
20 |
21 | ### Dependencies
22 |
23 | We recommend running this repository using [Anaconda](https://docs.anaconda.com/anaconda/install/). All dependencies for defining the environment are provided in `requirements.txt`.
24 |
25 | ### Pretrained Models
26 | You may manually download the pretrained models from
27 | [Google Drive](https://drive.google.com/drive/folders/1-8LKOEDYt-RzmM9IDV_oW73uRBLqeRB6?usp=sharing) and put them into `checkpoints/ahiq_pipal/`, or simply use
28 | ```
29 | sh download.sh
30 | ```
31 |
32 | ### Instruction
33 | use `sh train.sh` or `sh test.sh` to train or test the model. You can also change the options in the `options/` as you like.
34 |
35 | ## Acknowledgment
36 | The codes borrow heavily from IQT implemented by [anse3832](https://github.com/anse3832/IQT) and we really appreciate it.
37 |
38 | ## Citation
39 | If you find our work or code helpful for your research, please consider to cite:
40 | ```bibtex
41 | @article{lao2022attentions,
42 | title = {Attentions Help CNNs See Better: Attention-based Hybrid Image Quality Assessment Network},
43 | author = {Lao, Shanshan and Gong, Yuan and Shi, Shuwei and Yang, Sidi and Wu, Tianhe and Wang, Jiahao and Xia, Weihao and Yang, Yujiu},
44 | journal = {arXiv preprint arXiv:2204.10485},
45 | year = {2022}
46 | }
47 | ```
48 |
--------------------------------------------------------------------------------
/data/pipal.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | import cv2
5 |
6 |
7 | class PIPAL(torch.utils.data.Dataset):
8 | def __init__(self, ref_path, dis_path, txt_file_name, transform, resize=False, size=None, flip=False):
9 | super(PIPAL, self).__init__()
10 | self.ref_path = ref_path
11 | self.dis_path = dis_path
12 | self.txt_file_name = txt_file_name
13 | self.transform = transform
14 | self.flip = flip
15 | self.resize = resize
16 | self.size = size
17 | ref_files_data, dis_files_data, score_data = [], [], []
18 | with open(self.txt_file_name, 'r') as listFile:
19 | for line in listFile:
20 | dis, score = line[:-1].split(',')
21 | #dis = dis[:-1]
22 | ref = dis[:5] + '.bmp'
23 | score = float(score)
24 | ref_files_data.append(ref)
25 | dis_files_data.append(dis)
26 | score_data.append(score)
27 |
28 | # reshape score_list (1xn -> nx1)
29 | score_data = np.array(score_data)
30 | score_data = self.normalization(score_data)
31 | score_data = score_data.astype('float').reshape(-1, 1)
32 |
33 | self.data_dict = {'r_img_list': ref_files_data, 'd_img_list': dis_files_data, 'score_list': score_data}
34 |
35 | def normalization(self, data):
36 | range = np.max(data) - np.min(data)
37 | return (data - np.min(data)) / range
38 |
39 | def __len__(self):
40 | return len(self.data_dict['r_img_list'])
41 |
42 | def __getitem__(self, idx):
43 | # r_img: H x W x C -> C x H x W
44 | r_img_name = self.data_dict['r_img_list'][idx]
45 | r_img = cv2.imread(os.path.join(self.ref_path, r_img_name), cv2.IMREAD_COLOR)
46 | r_img = cv2.cvtColor(r_img, cv2.COLOR_BGR2RGB)
47 | if self.flip:
48 | r_img = np.fliplr(r_img).copy()
49 | if self.resize:
50 | r_img = cv2.resize(r_img, self.size)
51 | r_img = np.array(r_img).astype('float32') / 255
52 | r_img = (r_img - 0.5) / 0.5
53 | r_img = np.transpose(r_img, (2, 0, 1))
54 |
55 | d_img_name = self.data_dict['d_img_list'][idx]
56 | d_img = cv2.imread(os.path.join(self.dis_path, d_img_name), cv2.IMREAD_COLOR)
57 | d_img = cv2.cvtColor(d_img, cv2.COLOR_BGR2RGB)
58 | if self.flip:
59 | d_img = np.fliplr(d_img).copy()
60 | if self.resize:
61 | d_img = cv2.resize(d_img, self.size)
62 | d_img = np.array(d_img).astype('float32') / 255
63 | d_img = (d_img - 0.5) / 0.5
64 | d_img = np.transpose(d_img, (2, 0, 1))
65 |
66 | score = self.data_dict['score_list'][idx]
67 | sample = {
68 | 'r_img_org': r_img,
69 | 'd_img_org': d_img,
70 | 'score': score, 'd_img_name':d_img_name
71 | }
72 | if self.transform:
73 | sample = self.transform(sample)
74 | return sample
--------------------------------------------------------------------------------
/download.sh:
--------------------------------------------------------------------------------
1 | # pip install gdown
2 |
3 | save_path='checkpoints/ahiq_pipal/'
4 |
5 | # download the pretrained models
6 | gdown https://drive.google.com/uc?id=1Nk-IpjnDNXbWacoh3T69wkSYhYYWip2W -O $save_path
7 | gdown https://drive.google.com/uc?id=1Jr2nLnhMA0f0uPEjMG7sH-T4WfIEasXn -O $save_path
--------------------------------------------------------------------------------
/model/deform_regressor.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torchvision.ops.deform_conv import DeformConv2d
5 |
6 | class deform_fusion(nn.Module):
7 | def __init__(self, opt, in_channels=768*5, cnn_channels=256*3, out_channels=256*3):
8 | super().__init__()
9 | #in_channels, out_channels, kernel_size, stride, padding
10 | self.d_hidn = 512
11 | if opt.patch_size == 8:
12 | stride = 1
13 | else:
14 | stride = 2
15 | self.conv_offset = nn.Conv2d(in_channels, 2*3*3, 3, 1, 1)
16 | self.deform = DeformConv2d(cnn_channels, out_channels, 3, 1, 1)
17 | self.conv1 = nn.Sequential(
18 | nn.Conv2d(in_channels=out_channels, out_channels=self.d_hidn, kernel_size=3,padding=1,stride=2),
19 | nn.ReLU(),
20 | nn.Conv2d(in_channels=self.d_hidn, out_channels=out_channels, kernel_size=3, padding=1,stride=stride)
21 | )
22 |
23 | def forward(self, cnn_feat, vit_feat):
24 | vit_feat = F.interpolate(vit_feat, size=cnn_feat.shape[-2:], mode="nearest")
25 | offset = self.conv_offset(vit_feat)
26 | deform_feat = self.deform(cnn_feat, offset)
27 | deform_feat = self.conv1(deform_feat)
28 |
29 | return deform_feat
30 |
31 | class Pixel_Prediction(nn.Module):
32 | def __init__(self, inchannels=768*5+256*3, outchannels=256, d_hidn=1024):
33 | super().__init__()
34 | self.d_hidn = d_hidn
35 | self.down_channel = nn.Conv2d(inchannels, outchannels, kernel_size=1)
36 | self.feat_smoothing = nn.Sequential(
37 | nn.Conv2d(in_channels=256*3, out_channels=self.d_hidn, kernel_size=3,padding=1),
38 | nn.ReLU(),
39 | nn.Conv2d(in_channels=self.d_hidn, out_channels=512, kernel_size=3, padding=1)
40 | )
41 |
42 | self.conv1 = nn.Sequential(
43 | nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3,padding=1),
44 | nn.ReLU()
45 | )
46 | self.conv_attent = nn.Sequential(
47 | nn.Conv2d(in_channels=256, out_channels=1, kernel_size=1),
48 | nn.Sigmoid()
49 | )
50 | self.conv = nn.Sequential(
51 | nn.Conv2d(in_channels=256, out_channels=1, kernel_size=1),
52 | )
53 |
54 | def forward(self,f_dis, f_ref, cnn_dis, cnn_ref):
55 | f_dis = torch.cat((f_dis,cnn_dis),1)
56 | f_ref = torch.cat((f_ref,cnn_ref),1)
57 | f_dis = self.down_channel(f_dis)
58 | f_ref = self.down_channel(f_ref)
59 |
60 | f_cat = torch.cat((f_dis - f_ref, f_dis, f_ref), 1)
61 |
62 | feat_fused = self.feat_smoothing(f_cat)
63 | feat = self.conv1(feat_fused)
64 | f = self.conv(feat)
65 | w = self.conv_attent(feat)
66 | pred = (f*w).sum(dim=2).sum(dim=2)/w.sum(dim=2).sum(dim=2)
67 |
68 | return pred
--------------------------------------------------------------------------------
/options/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IIGROUP/AHIQ/9cc6a4eed821bb8f16a1aae00d420b24625bc07a/options/__init__.py
--------------------------------------------------------------------------------
/options/base_options.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from utils import util
4 | import torch
5 |
6 | class BaseOptions():
7 | def __init__(self):
8 | self._parser = argparse.ArgumentParser()
9 | self._initialized = False
10 |
11 | def initialize(self):
12 | # dataset path
13 | self._parser.add_argument('--train_ref_path', type=str, default='Train_Ref/', help='path to reference images')
14 | self._parser.add_argument('--train_dis_path', type=str, default='Train_Dis/', help='path to distortion images')
15 | self._parser.add_argument('--val_ref_path', type=str, default='NTIRE2021/Ref/', help='path to reference images')
16 | self._parser.add_argument('--val_dis_path', type=str, default='NTIRE2021/Dis/', help='path to distortion images')
17 | self._parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
18 | self._parser.add_argument('--train_list', type=str, default='PIPAL.txt', help='training data')
19 | self._parser.add_argument('--val_list', type=str, default='PIPAL_NTIRE_Valid_MOS.txt', help='testing data')
20 | # experiment
21 | self._parser.add_argument('--name', type=str, default='ahiq_pipal',
22 | help='name of the experiment. It decides where to store samples and models')
23 | # device
24 | self._parser.add_argument('--num_workers', type=int, default=8, help='total workers')
25 | # model
26 | self._parser.add_argument('--patch_size', type=int, default=8, help='patch size of Vision Transformer')
27 | self._parser.add_argument('--load_epoch', type=int, default=-1, help='which epoch to load? set to -1 to use latest cached model')
28 | self._parser.add_argument('--ckpt', type=str, default='./checkpoints', help='models to be loaded')
29 | self._parser.add_argument('--seed', type=int, default=1919, help='random seed')
30 | #data process
31 | self._parser.add_argument('--crop_size', type=int, default=224, help='image size')
32 | self._parser.add_argument('--num_crop', type=int, default=1, help='random crop times')
33 | self._parser.add_argument('--num_avg_val', type=int, default=5, help='ensemble ways of validation')
34 |
35 | self._parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids')
36 | self._initialized = True
37 |
38 | def parse(self):
39 | if not self._initialized:
40 | self.initialize()
41 | #self._opt = self._parser.parse_args()
42 | self._opt = self._parser.parse_known_args()[0]
43 |
44 | # set is train or set
45 | self._opt.is_train = self.is_train
46 |
47 | # set and check load_epoch
48 | self._set_and_check_load_epoch()
49 |
50 | # get and set gpus
51 | self._get_set_gpus()
52 |
53 | args = vars(self._opt)
54 |
55 | # print in terminal args
56 | self._print(args)
57 |
58 | # save args to file
59 | self._save(args)
60 |
61 | return self._opt
62 |
63 | def _set_and_check_load_epoch(self):
64 | models_dir = os.path.join(self._opt.checkpoints_dir, self._opt.name)
65 | if os.path.exists(models_dir):
66 | if self._opt.load_epoch == -1:
67 | load_epoch = 0
68 | for file in os.listdir(models_dir):
69 | if file.startswith("epoch"):
70 | load_epoch = max(load_epoch, int(file.split('.')[0].split('_')[1]))
71 | self._opt.load_epoch = load_epoch
72 | else:
73 | found = False
74 | for file in os.listdir(models_dir):
75 | if file.startswith("epoch"):
76 | found = int(file.split('_')[2]) == self._opt.load_epoch
77 | if found: break
78 | assert found, 'Model for epoch %i not found' % self._opt.load_epoch
79 | else:
80 | assert self._opt.load_epoch < 1, 'Model for epoch %i not found' % self._opt.load_epoch
81 | self._opt.load_epoch = 0
82 |
83 | def _get_set_gpus(self):
84 | # get gpu ids
85 | str_ids = self._opt.gpu_ids.split(',')
86 | self._opt.gpu_ids = []
87 | for str_id in str_ids:
88 | id = int(str_id)
89 | if id >= 0:
90 | self._opt.gpu_ids.append(id)
91 |
92 | # set gpu ids
93 | if len(self._opt.gpu_ids) > 0:
94 | torch.cuda.set_device(self._opt.gpu_ids[0])
95 |
96 | def _print(self, args):
97 | print('------------ Options -------------')
98 | for k, v in sorted(args.items()):
99 | print('%s: %s' % (str(k), str(v)))
100 | print('-------------- End ----------------')
101 |
102 | def _save(self, args):
103 | expr_dir = os.path.join(self._opt.checkpoints_dir, self._opt.name)
104 | print(expr_dir)
105 | util.mkdirs(expr_dir)
106 | file_name = os.path.join(expr_dir, 'opt_%s.txt' % ('train' if self.is_train else 'test'))
107 | with open(file_name, 'wt') as opt_file:
108 | opt_file.write('------------ Options -------------\n')
109 | for k, v in sorted(args.items()):
110 | opt_file.write('%s: %s\n' % (str(k), str(v)))
111 | opt_file.write('-------------- End ----------------\n')
--------------------------------------------------------------------------------
/options/test_options.py:
--------------------------------------------------------------------------------
1 | from .base_options import BaseOptions
2 |
3 | class TestOptions(BaseOptions):
4 | def initialize(self):
5 | BaseOptions.initialize(self)
6 | self._parser.add_argument('--test_ref_path', type=str, default='Test_Ref/', help='path to reference images')
7 | self._parser.add_argument('--test_dis_path', type=str, default='Test_Dis/', help='path to distortion images')
8 | self._parser.add_argument('--test_list', type=str, default='test.txt', help='training data')
9 |
10 | self._parser.add_argument('--batch_size', type=int, default=10, help='input batch size')
11 | self._parser.add_argument('--test_file_name', type=str, default='results.txt', help='txt path to save results')
12 | self._parser.add_argument('--n_ensemble', type=int, default=20, help='crop method for test: five points crop or nine points crop or random crop for several times')
13 | self._parser.add_argument('--flip', type=bool, default=False, help='if flip images when testing')
14 | self._parser.add_argument('--resize', type=bool, default=False, help='if resize images when testing')
15 | self._parser.add_argument('--size', type=int, default=224, help='the resize shape')
16 | self.is_train = False
17 |
--------------------------------------------------------------------------------
/options/train_options.py:
--------------------------------------------------------------------------------
1 | from .base_options import BaseOptions
2 |
3 |
4 | class TrainOptions(BaseOptions):
5 | def initialize(self):
6 | BaseOptions.initialize(self)
7 | self._parser.add_argument('--n_epoch', type=int, default=200, help='total epoch for training')
8 | self._parser.add_argument('--save_interval', type=int, default=5, help='interval for saving models')
9 | self._parser.add_argument('--learning_rate', type=float, default=1e-4, help='initial learning rate')
10 | self._parser.add_argument('--weight_decay', type=float, default=1e-5, help='weight decay')
11 | self._parser.add_argument('--batch_size', type=int, default=4, help='input batch size')
12 | self._parser.add_argument('--val_freq', type=int, default=1, help='validation frequency')
13 | self._parser.add_argument('--T_max', type=int, default=50, help="cosine learning rate period (iteration)")
14 | self._parser.add_argument('--eta_min', type=int, default=0, help="mininum learning rate")
15 |
16 | self.is_train = True
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | tqdm
2 | opencv-python
3 | torchvision==0.11.1
4 | timm==0.5.4
5 | pytorch==1.10.0
6 | scipy==1.7.3
7 | numpy==1.21.2
--------------------------------------------------------------------------------
/script/extract_feature.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | def get_resnet_feature(save_output):
4 | feat = torch.cat(
5 | (
6 | save_output.outputs[0],
7 | save_output.outputs[1],
8 | save_output.outputs[2]
9 | ),
10 | dim=1
11 | )
12 | return feat
13 |
14 | def get_vit_feature(save_output):
15 | feat = torch.cat(
16 | (
17 | save_output.outputs[0][:,1:,:],
18 | save_output.outputs[1][:,1:,:],
19 | save_output.outputs[2][:,1:,:],
20 | save_output.outputs[3][:,1:,:],
21 | save_output.outputs[4][:,1:,:],
22 | ),
23 | dim=2
24 | )
25 | return feat
26 |
27 | def get_inception_feature(save_output):
28 | feat = torch.cat(
29 | (
30 | save_output.outputs[0],
31 | save_output.outputs[2],
32 | save_output.outputs[4],
33 | save_output.outputs[6],
34 | save_output.outputs[8],
35 | save_output.outputs[10]
36 | ),
37 | dim=1
38 | )
39 | return feat
40 |
41 | def get_resnet152_feature(save_output):
42 | feat = torch.cat(
43 | (
44 | save_output.outputs[3],
45 | save_output.outputs[4],
46 | save_output.outputs[6],
47 | save_output.outputs[7],
48 | save_output.outputs[8],
49 | save_output.outputs[10]
50 | ),
51 | dim=1
52 | )
53 | return feat
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | from tqdm import tqdm
2 | import os
3 | import torch
4 | import numpy as np
5 | import logging
6 | from scipy.stats import spearmanr, pearsonr
7 | import timm
8 | from timm.models.vision_transformer import Block
9 | from timm.models.resnet import BasicBlock,Bottleneck
10 | import time
11 | from torch.utils.data import DataLoader
12 |
13 | from utils.util import setup_seed,set_logging,SaveOutput
14 | from script.extract_feature import get_resnet_feature, get_vit_feature
15 | from options.test_options import TestOptions
16 | from model.deform_regressor import deform_fusion, Pixel_Prediction
17 | from data.pipal import PIPAL
18 | from utils.process_image import ToTensor, RandHorizontalFlip, RandCrop, crop_image, Normalize, five_point_crop
19 | from torchvision import transforms
20 |
21 | class Test:
22 | def __init__(self, config):
23 | self.opt = config
24 | self.create_model()
25 | self.init_saveoutput()
26 | self.init_data()
27 | self.load_model()
28 | self.test()
29 |
30 | def create_model(self):
31 | self.resnet50 = timm.create_model('resnet50',pretrained=True).cuda()
32 | if self.opt.patch_size == 8:
33 | self.vit = timm.create_model('vit_base_patch8_224',pretrained=True).cuda()
34 | else:
35 | self.vit = timm.create_model('vit_base_patch16_224',pretrained=True).cuda()
36 | self.deform_net = deform_fusion(self.opt).cuda()
37 | self.regressor = Pixel_Prediction().cuda()
38 |
39 | def init_saveoutput(self):
40 | self.save_output = SaveOutput()
41 | hook_handles = []
42 | for layer in self.resnet50.modules():
43 | if isinstance(layer, Bottleneck):
44 | handle = layer.register_forward_hook(self.save_output)
45 | hook_handles.append(handle)
46 | for layer in self.vit.modules():
47 | if isinstance(layer, Block):
48 | handle = layer.register_forward_hook(self.save_output)
49 | hook_handles.append(handle)
50 |
51 | def init_data(self):
52 | test_dataset = PIPAL(
53 | ref_path=self.opt.test_ref_path,
54 | dis_path=self.opt.test_dis_path,
55 | txt_file_name=self.opt.test_list,
56 | resize=self.opt.resize,
57 | size=(self.opt.size,self.opt.size),
58 | flip=self.opt.flip,
59 | transform=ToTensor(),
60 | )
61 | logging.info('number of test scenes: {}'.format(len(test_dataset)))
62 |
63 | self.test_loader = DataLoader(
64 | dataset=test_dataset,
65 | batch_size=self.opt.batch_size,
66 | num_workers=self.opt.num_workers,
67 | drop_last=True,
68 | shuffle=False
69 | )
70 |
71 | def load_model(self):
72 | models_dir = self.opt.checkpoints_dir
73 | if os.path.exists(models_dir):
74 | if self.opt.load_epoch == -1:
75 | load_epoch = 0
76 | for file in os.listdir(models_dir):
77 | if file.startswith("epoch"):
78 | load_epoch = max(load_epoch, int(file.split('.')[0].split('_')[1]))
79 | self.opt.load_epoch = load_epoch
80 | checkpoint = torch.load(os.path.join(models_dir,"epoch"+str(self.opt.load_epoch)+".pth"))
81 | self.regressor.load_state_dict(checkpoint['regressor_model_state_dict'])
82 | self.deform_net.load_state_dict(checkpoint['deform_net_model_state_dict'])
83 | self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
84 | self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
85 | self.start_epoch = checkpoint['epoch']+1
86 | loss = checkpoint['loss']
87 | else:
88 | found = False
89 | for file in os.listdir(models_dir):
90 | if file.startswith("epoch"):
91 | found = int(file.split('.')[0].split('_')[1]) == self.opt.load_epoch
92 | if found: break
93 | assert found, 'Model for epoch %i not found' % self.opt.load_epoch
94 | else:
95 | assert self.opt.load_epoch < 1, 'Model for epoch %i not found' % self.opt.load_epoch
96 | self.opt.load_epoch = 0
97 |
98 | def test(self):
99 | f = open(os.path.join(self.opt.checkpoints_dir,self.opt.test_file_name), 'w')
100 | with torch.no_grad():
101 | for data in tqdm(self.test_loader):
102 | d_img_org = data['d_img_org'].cuda()
103 | r_img_org = data['r_img_org'].cuda()
104 | d_img_name = data['d_img_name']
105 | pred = 0
106 | for i in range(self.opt.n_ensemble):
107 | b, c, h, w = r_img_org.size()
108 | if self.opt.n_ensemble > 9:
109 | new_h = config.crop_size
110 | new_w = config.crop_size
111 | top = np.random.randint(0, h - new_h)
112 | left = np.random.randint(0, w - new_w)
113 | r_img = r_img_org[:,:, top: top+new_h, left: left+new_w]
114 | d_img = d_img_org[:,:, top: top+new_h, left: left+new_w]
115 | elif self.opt.n_ensemble == 1:
116 | r_img = r_img_org
117 | d_img = d_img_org
118 | else:
119 | d_img, r_img = five_point_crop(i, d_img=d_img_org, r_img=r_img_org, config=self.opt)
120 | d_img = d_img.cuda()
121 | r_img = r_img.cuda()
122 | _x = self.vit(d_img)
123 | vit_dis = get_vit_feature(self.save_output)
124 | self.save_output.outputs.clear()
125 |
126 | _y = self.vit(r_img)
127 | vit_ref = get_vit_feature(self.save_output)
128 | self.save_output.outputs.clear()
129 | B, N, C = vit_ref.shape
130 | if self.opt.patch_size == 8:
131 | H,W = 28,28
132 | else:
133 | H,W = 14,14
134 | assert H*W==N
135 | vit_ref = vit_ref.transpose(1, 2).view(B, C, H, W)
136 | vit_dis = vit_dis.transpose(1, 2).view(B, C, H, W)
137 |
138 | _ = self.resnet50(d_img)
139 | cnn_dis = get_resnet_feature(self.save_output)
140 | self.save_output.outputs.clear()
141 | cnn_dis = self.deform_net(cnn_dis,vit_ref)
142 |
143 | _ = self.resnet50(r_img)
144 | cnn_ref = get_resnet_feature(self.save_output)
145 | self.save_output.outputs.clear()
146 | cnn_ref = self.deform_net(cnn_ref,vit_ref)
147 | pred += self.regressor(vit_dis, vit_ref, cnn_dis, cnn_ref)
148 |
149 | pred /= self.opt.n_ensemble
150 | for i in range(len(d_img_name)):
151 | line = "%s,%f\n" % (d_img_name[i], float(pred.squeeze()[i]))
152 | f.write(line)
153 |
154 | f.close()
155 |
156 |
157 |
158 | if __name__ == '__main__':
159 | config = TestOptions().parse()
160 | config.checkpoints_dir = os.path.join(config.checkpoints_dir, config.name)
161 | setup_seed(config.seed)
162 | set_logging(config)
163 | Test(config)
164 |
--------------------------------------------------------------------------------
/test.sh:
--------------------------------------------------------------------------------
1 | python test.py
2 | --test_ref_path /path/to/test
3 | --test_dis_path /path/to/test
4 | --test_list /txt/for/test
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from tqdm import tqdm
2 | import os
3 | import torch
4 | import numpy as np
5 | import logging
6 | from scipy.stats import spearmanr, pearsonr
7 | import timm
8 | from timm.models.vision_transformer import Block
9 | from timm.models.resnet import BasicBlock,Bottleneck
10 | import time
11 |
12 | from torch.utils.data import DataLoader
13 |
14 | from utils.util import setup_seed,set_logging,SaveOutput
15 | from script.extract_feature import get_resnet_feature, get_vit_feature
16 | from options.train_options import TrainOptions
17 | from model.deform_regressor import deform_fusion, Pixel_Prediction
18 | from data.pipal import PIPAL
19 | from utils.process_image import ToTensor, RandHorizontalFlip, RandCrop, crop_image, Normalize, five_point_crop
20 | from torchvision import transforms
21 |
22 | class Train:
23 | def __init__(self, config):
24 | self.opt = config
25 | self.create_model()
26 | self.init_saveoutput()
27 | self.init_data()
28 | self.criterion = torch.nn.MSELoss()
29 | self.optimizer = torch.optim.Adam([
30 | {'params': self.regressor.parameters(), 'lr': self.opt.learning_rate,'weight_decay':self.opt.weight_decay},
31 | {'params': self.deform_net.parameters(),'lr': self.opt.learning_rate,'weight_decay':self.opt.weight_decay}
32 | ])
33 | self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=self.opt.T_max, eta_min=self.opt.eta_min)
34 | self.load_model()
35 | self.train()
36 |
37 | def create_model(self):
38 | self.resnet50 = timm.create_model('resnet50',pretrained=True).cuda()
39 | if self.opt.patch_size == 8:
40 | self.vit = timm.create_model('vit_base_patch8_224',pretrained=True).cuda()
41 | else:
42 | self.vit = timm.create_model('vit_base_patch16_224',pretrained=True).cuda()
43 | self.deform_net = deform_fusion(self.opt).cuda()
44 | self.regressor = Pixel_Prediction().cuda()
45 |
46 | def init_saveoutput(self):
47 | self.save_output = SaveOutput()
48 | hook_handles = []
49 | for layer in self.resnet50.modules():
50 | if isinstance(layer, Bottleneck):
51 | handle = layer.register_forward_hook(self.save_output)
52 | hook_handles.append(handle)
53 | for layer in self.vit.modules():
54 | if isinstance(layer, Block):
55 | handle = layer.register_forward_hook(self.save_output)
56 | hook_handles.append(handle)
57 |
58 | def init_data(self):
59 | train_dataset = PIPAL(
60 | ref_path=self.opt.train_ref_path,
61 | dis_path=self.opt.train_dis_path,
62 | txt_file_name=self.opt.train_list,
63 | transform=transforms.Compose(
64 | [
65 | RandCrop(self.opt.crop_size, self.opt.num_crop),
66 | #Normalize(0.5, 0.5),
67 | RandHorizontalFlip(),
68 | ToTensor(),
69 | ]
70 | ),
71 | )
72 | val_dataset = PIPAL(
73 | ref_path=self.opt.val_ref_path,
74 | dis_path=self.opt.val_dis_path,
75 | txt_file_name=self.opt.val_list,
76 | transform=ToTensor(),
77 | )
78 | logging.info('number of train scenes: {}'.format(len(train_dataset)))
79 | logging.info('number of val scenes: {}'.format(len(val_dataset)))
80 |
81 | self.train_loader = DataLoader(
82 | dataset=train_dataset,
83 | batch_size=self.opt.batch_size,
84 | num_workers=self.opt.num_workers,
85 | drop_last=True,
86 | shuffle=True
87 | )
88 | self.val_loader = DataLoader(
89 | dataset=val_dataset,
90 | batch_size=self.opt.batch_size,
91 | num_workers=self.opt.num_workers,
92 | drop_last=True,
93 | shuffle=False
94 | )
95 |
96 | def load_model(self):
97 | models_dir = self.opt.checkpoints_dir
98 | if os.path.exists(models_dir):
99 | if self.opt.load_epoch == -1:
100 | load_epoch = 0
101 | for file in os.listdir(models_dir):
102 | if file.startswith("epoch_"):
103 | load_epoch = max(load_epoch, int(file.split('.')[0].split('_')[1]))
104 | self.opt.load_epoch = load_epoch
105 | checkpoint = torch.load(os.path.join(models_dir,"epoch_"+str(self.opt.load_epoch)+".pth"))
106 | self.regressor.load_state_dict(checkpoint['regressor_model_state_dict'])
107 | self.deform_net.load_state_dict(checkpoint['deform_net_model_state_dict'])
108 | self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
109 | self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
110 | self.start_epoch = checkpoint['epoch']+1
111 | loss = checkpoint['loss']
112 | else:
113 | found = False
114 | for file in os.listdir(models_dir):
115 | if file.startswith("epoch_"):
116 | found = int(file.split('.')[0].split('_')[1]) == self.opt.load_epoch
117 | if found: break
118 | assert found, 'Model for epoch %i not found' % self.opt.load_epoch
119 | else:
120 | assert self.opt.load_epoch < 1, 'Model for epoch %i not found' % self.opt.load_epoch
121 | self.opt.load_epoch = 0
122 |
123 | def train_epoch(self, epoch):
124 | losses = []
125 | self.regressor.train()
126 | self.deform_net.train()
127 | self.vit.eval()
128 | self.resnet50.eval()
129 | # save data for one epoch
130 | pred_epoch = []
131 | labels_epoch = []
132 |
133 | for data in tqdm(self.train_loader):
134 | d_img_org = data['d_img_org'].cuda()
135 | r_img_org = data['r_img_org'].cuda()
136 | labels = data['score']
137 | labels = torch.squeeze(labels.type(torch.FloatTensor)).cuda()
138 |
139 | _x = self.vit(d_img_org)
140 | vit_dis = get_vit_feature(self.save_output)
141 | self.save_output.outputs.clear()
142 |
143 | _y = self.vit(r_img_org)
144 | vit_ref = get_vit_feature(self.save_output)
145 | self.save_output.outputs.clear()
146 | B, N, C = vit_ref.shape
147 | if self.opt.patch_size == 8:
148 | H,W = 28,28
149 | else:
150 | H,W = 14,14
151 | assert H*W==N
152 | vit_ref = vit_ref.transpose(1, 2).view(B, C, H, W)
153 | vit_dis = vit_dis.transpose(1, 2).view(B, C, H, W)
154 |
155 | _ = self.resnet50(d_img_org)
156 | cnn_dis = get_resnet_feature(self.save_output) #0,1,2都是[B,256,56,56]
157 | self.save_output.outputs.clear()
158 | cnn_dis = self.deform_net(cnn_dis,vit_ref)
159 |
160 | _ = self.resnet50(r_img_org)
161 | cnn_ref = get_resnet_feature(self.save_output)
162 | self.save_output.outputs.clear()
163 | cnn_ref = self.deform_net(cnn_ref,vit_ref)
164 |
165 | pred = self.regressor(vit_dis, vit_ref, cnn_dis, cnn_ref)
166 |
167 | self.optimizer.zero_grad()
168 | loss = self.criterion(torch.squeeze(pred), labels)
169 | losses.append(loss.item())
170 |
171 | loss.backward()
172 | self.optimizer.step()
173 | self.scheduler.step()
174 |
175 | # save results in one epoch
176 | pred_batch_numpy = pred.data.cpu().numpy()
177 | labels_batch_numpy = labels.data.cpu().numpy()
178 | pred_epoch = np.append(pred_epoch, pred_batch_numpy)
179 | labels_epoch = np.append(labels_epoch, labels_batch_numpy)
180 |
181 | # compute correlation coefficient
182 | rho_s, _ = spearmanr(np.squeeze(pred_epoch), np.squeeze(labels_epoch))
183 | rho_p, _ = pearsonr(np.squeeze(pred_epoch), np.squeeze(labels_epoch))
184 |
185 | ret_loss = np.mean(losses)
186 | print('train epoch:{} / loss:{:.4} / SRCC:{:.4} / PLCC:{:.4}'.format(epoch + 1, ret_loss, rho_s, rho_p))
187 | logging.info('train epoch:{} / loss:{:.4} / SRCC:{:.4} / PLCC:{:.4}'.format(epoch + 1, ret_loss, rho_s, rho_p))
188 |
189 | return ret_loss, rho_s, rho_p
190 |
191 | def train(self):
192 | best_srocc = 0
193 | best_plcc = 0
194 | for epoch in range(self.opt.load_epoch, self.opt.n_epoch):
195 | start_time = time.time()
196 | logging.info('Running training epoch {}'.format(epoch + 1))
197 | loss_val, rho_s, rho_p = self.train_epoch(epoch)
198 | if (epoch + 1) % self.opt.val_freq == 0:
199 | logging.info('Starting eval...')
200 | logging.info('Running testing in epoch {}'.format(epoch + 1))
201 | loss, rho_s, rho_p = self.eval_epoch(epoch)
202 | logging.info('Eval done...')
203 |
204 | if rho_s > best_srocc or rho_p > best_plcc:
205 | best_srocc = rho_s
206 | best_plcc = rho_p
207 | print('Best now')
208 | logging.info('Best now')
209 | self.save_model( epoch, "best.pth", loss, rho_s, rho_p)
210 | if epoch % self.opt.save_interval == 0:
211 | weights_file_name = "epoch_%d.pth" % (epoch+1)
212 | self.save_model( epoch, weights_file_name, loss, rho_s, rho_p)
213 | logging.info('Epoch {} done. Time: {:.2}min'.format(epoch + 1, (time.time() - start_time) / 60))
214 |
215 | def eval_epoch(self, epoch):
216 | with torch.no_grad():
217 | losses = []
218 | self.regressor.train()
219 | self.deform_net.train()
220 | self.vit.eval()
221 | self.resnet50.eval()
222 | # save data for one epoch
223 | pred_epoch = []
224 | labels_epoch = []
225 |
226 | for data in tqdm(self.val_loader):
227 | pred = 0
228 | for i in range(self.opt.num_avg_val):
229 | d_img_org = data['d_img_org'].cuda()
230 | r_img_org = data['r_img_org'].cuda()
231 | labels = data['score']
232 | labels = torch.squeeze(labels.type(torch.FloatTensor)).cuda()
233 |
234 | d_img_org, r_img_org = five_point_crop(i, d_img=d_img_org, r_img=r_img_org, config=self.opt)
235 |
236 | _x = self.vit(d_img_org)
237 | vit_dis = get_vit_feature(self.save_output)
238 | self.save_output.outputs.clear()
239 |
240 | _y = self.vit(r_img_org)
241 | vit_ref = get_vit_feature(self.save_output)
242 | self.save_output.outputs.clear()
243 | B, N, C = vit_ref.shape
244 | if self.opt.patch_size == 8:
245 | H,W = 28,28
246 | else:
247 | H,W = 14,14
248 | assert H*W==N
249 | vit_ref = vit_ref.transpose(1, 2).view(B, C, H, W)
250 | vit_dis = vit_dis.transpose(1, 2).view(B, C, H, W)
251 |
252 | _ = self.resnet50(d_img_org)
253 | cnn_dis = get_resnet_feature(self.save_output) #0,1,2都是[B,256,56,56]
254 | self.save_output.outputs.clear()
255 | cnn_dis = self.deform_net(cnn_dis,vit_ref)
256 |
257 | _ = self.resnet50(r_img_org)
258 | cnn_ref = get_resnet_feature(self.save_output)
259 | self.save_output.outputs.clear()
260 | cnn_ref = self.deform_net(cnn_ref,vit_ref)
261 |
262 | pred += self.regressor(vit_dis, vit_ref, cnn_dis, cnn_ref)
263 |
264 | pred /= self.opt.num_avg_val
265 | # compute loss
266 | loss = self.criterion(torch.squeeze(pred), labels)
267 | loss_val = loss.item()
268 | losses.append(loss_val)
269 |
270 | # save results in one epoch
271 | pred_batch_numpy = pred.data.cpu().numpy()
272 | labels_batch_numpy = labels.data.cpu().numpy()
273 | pred_epoch = np.append(pred_epoch, pred_batch_numpy)
274 | labels_epoch = np.append(labels_epoch, labels_batch_numpy)
275 |
276 | # compute correlation coefficient
277 | rho_s, _ = spearmanr(np.squeeze(pred_epoch), np.squeeze(labels_epoch))
278 | rho_p, _ = pearsonr(np.squeeze(pred_epoch), np.squeeze(labels_epoch))
279 | print('Epoch:{} ===== loss:{:.4} ===== SRCC:{:.4} ===== PLCC:{:.4}'.format(epoch + 1, np.mean(losses), rho_s, rho_p))
280 | logging.info('Epoch:{} ===== loss:{:.4} ===== SRCC:{:.4} ===== PLCC:{:.4}'.format(epoch + 1, np.mean(losses), rho_s, rho_p))
281 | return np.mean(losses), rho_s, rho_p
282 |
283 | def save_model(self, epoch, weights_file_name, loss, rho_s, rho_p):
284 | print('-------------saving weights---------')
285 | weights_file = os.path.join(self.opt.checkpoints_dir, weights_file_name)
286 | torch.save({
287 | 'epoch': epoch,
288 | 'regressor_model_state_dict': self.regressor.state_dict(),
289 | 'deform_net_model_state_dict': self.deform_net.state_dict(),
290 | 'optimizer_state_dict': self.optimizer.state_dict(),
291 | 'scheduler_state_dict': self.scheduler.state_dict(),
292 | 'loss': loss
293 | }, weights_file)
294 | logging.info('Saving weights and model of epoch{}, SRCC:{}, PLCC:{}'.format(epoch, rho_s, rho_p))
295 |
296 | if __name__ == '__main__':
297 | config = TrainOptions().parse()
298 | config.checkpoints_dir = os.path.join(config.checkpoints_dir, config.name)
299 | setup_seed(config.seed)
300 | set_logging(config)
301 | # logging.info(config)
302 | Train(config)
303 |
--------------------------------------------------------------------------------
/train.sh:
--------------------------------------------------------------------------------
1 | python train.py
2 | --train_ref_path /path/to/train
3 | --train_dis_path /path/to/train
4 | --train_list /txt/for/train
5 | --val_ref_path /path/to/val
6 | --val_dis_path /path/to/val
7 | --val_list /txt/for/val
8 | --checkpoints_dir /path/for/save/ckpt
--------------------------------------------------------------------------------
/utils/process_image.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | def crop_image(top, left, new_h, new_w, img=None):
6 | b, c, h, w = img.shape
7 | tmp_img = img[ : , : , top: top + new_h, left: left + new_w]
8 | return tmp_img
9 |
10 | class RandCrop(object):
11 | def __init__(self, patch_size, num_crop):
12 | self.patch_size = patch_size
13 | self.num_crop = num_crop
14 |
15 | def __call__(self, sample):
16 | # r_img : C x H x W (numpy)
17 | r_img, d_img = sample['r_img_org'], sample['d_img_org']
18 | score = sample['score']
19 | d_img_name = sample['d_img_name']
20 |
21 | c, h, w = d_img.shape
22 | new_h = self.patch_size
23 | new_w = self.patch_size
24 | ret_r_img = np.zeros((c, self.patch_size, self.patch_size))
25 | ret_d_img = np.zeros((c, self.patch_size, self.patch_size))
26 | for _ in range(self.num_crop):
27 | top = np.random.randint(0, h - new_h)
28 | left = np.random.randint(0, w - new_w)
29 | tmp_r_img = r_img[:, top: top + new_h, left: left + new_w]
30 | tmp_d_img = d_img[:, top: top + new_h, left: left + new_w]
31 | ret_r_img += tmp_r_img
32 | ret_d_img += tmp_d_img
33 | ret_r_img /= self.num_crop
34 | ret_d_img /= self.num_crop
35 |
36 | sample = {
37 | 'r_img_org': ret_r_img,
38 | 'd_img_org': ret_d_img,
39 | 'score': score, 'd_img_name':d_img_name
40 | }
41 |
42 | return sample
43 |
44 | def five_point_crop(idx, d_img, r_img, config):
45 | new_h = config.crop_size
46 | new_w = config.crop_size
47 | if len(d_img.shape) == 3:
48 | c, h, w = d_img.shape
49 | else:
50 | b, c, h, w = d_img.shape
51 | center_h = h // 2
52 | center_w = w // 2
53 | if idx == 0:
54 | top = 0
55 | left = 0
56 | elif idx == 1:
57 | top = 0
58 | left = w - new_w
59 | elif idx == 2:
60 | top = h - new_h
61 | left = 0
62 | elif idx == 3:
63 | top = h - new_h
64 | left = w - new_w
65 | elif idx == 4:
66 | top = center_h - new_h // 2
67 | left = center_w - new_w // 2
68 | elif idx == 5:
69 | left = 0
70 | top = center_h - new_h // 2
71 | elif idx == 6:
72 | left = w - new_w
73 | top = center_h - new_h // 2
74 | elif idx == 7:
75 | top = 0
76 | left = center_w - new_w // 2
77 | elif idx == 8:
78 | top = h - new_h
79 | left = center_w - new_w // 2
80 | if len(d_img.shape) == 3:
81 | d_img_org = d_img[: , top: top + new_h, left: left + new_w]
82 | r_img_org = r_img[: , top: top + new_h, left: left + new_w]
83 | else:
84 | d_img_org = d_img[ :,: , top: top + new_h, left: left + new_w]
85 | r_img_org = r_img[ :,: , top: top + new_h, left: left + new_w]
86 | return d_img_org, r_img_org
87 |
88 | class RandCrop_fivepoints(object):
89 | def __init__(self, patch_size, num_crop, config):
90 | self.patch_size = patch_size
91 | self.num_crop = num_crop
92 | self.config = config
93 | def __call__(self, sample):
94 | # r_img : C x H x W (numpy)
95 | r_img, d_img = sample['r_img_org'], sample['d_img_org']
96 | score = sample['score']
97 |
98 | for idx in range(5):
99 | r_img, d_img = five_point_crop(idx, d_img, r_img, self.config)
100 | if idx == 0:
101 | ret_r_img = r_img.unsqueeze(0)
102 | ret_d_img = d_img.unsqueeze(0)
103 | else:
104 | ret_r_img = torch.cat((ret_r_img, r_img.unsqueeze(0)),dim=0)
105 | ret_d_img = torch.cat((ret_d_img, d_img.unsqueeze(0)),dim=0)
106 |
107 | sample = {
108 | 'r_img_org': ret_r_img,
109 | 'd_img_org': ret_d_img,
110 | 'score': score
111 | }
112 |
113 | return sample
114 |
115 | class RandCrop_points(object):
116 | def __init__(self, patch_size, num_crop, config):
117 | self.patch_size = patch_size
118 | self.num_crop = num_crop
119 | self.config = config
120 | def __call__(self, sample):
121 | # r_img : C x H x W (numpy)
122 | r_img, d_img = sample['r_img_org'], sample['d_img_org']
123 | score = sample['score']
124 | d_img_name = sample['d_img_name']
125 |
126 | c, h, w = d_img.shape
127 | new_h = self.patch_size
128 | new_w = self.patch_size
129 |
130 | for idx in range(self.num_crop):
131 | if self.num_crop == 5 or self.num_crop == 9:
132 | r_img_org, d_img_org = five_point_crop(idx, d_img, r_img, self.config)
133 | else:
134 | top = np.random.randint(0, h - new_h)
135 | left = np.random.randint(0, w - new_w)
136 | r_img_org = r_img[:, top: top + new_h, left: left + new_w]
137 | d_img_org = d_img[:, top: top + new_h, left: left + new_w]
138 | if idx == 0:
139 | ret_r_img = r_img_org.unsqueeze(0)
140 | ret_d_img = d_img_org.unsqueeze(0)
141 | else:
142 | ret_r_img = torch.cat((ret_r_img, r_img_org.unsqueeze(0)),dim=0)
143 | ret_d_img = torch.cat((ret_d_img, d_img_org.unsqueeze(0)),dim=0)
144 |
145 | sample = {
146 | 'r_img_org': ret_r_img,
147 | 'd_img_org': ret_d_img,
148 | 'score': score, 'd_img_name':d_img_name
149 | }
150 |
151 | return sample
152 |
153 | class Normalize(object):
154 | def __init__(self, mean, var):
155 | self.mean = mean
156 | self.var = var
157 |
158 | def __call__(self, sample):
159 | # r_img: C x H x W (numpy)
160 | r_img, d_img = sample['r_img_org'], sample['d_img_org']
161 | score = sample['score']
162 | d_img_name = sample['d_img_name']
163 |
164 | r_img = (r_img - self.mean) / self.var
165 | d_img = (d_img - self.mean) / self.var
166 |
167 | sample = {'r_img_org': r_img, 'd_img_org': d_img, 'score': score, 'd_img_name':d_img_name}
168 | return sample
169 |
170 |
171 | class RandHorizontalFlip(object):
172 | def __init__(self):
173 | pass
174 |
175 | def __call__(self, sample):
176 | r_img, d_img = sample['r_img_org'], sample['d_img_org']
177 | score = sample['score']
178 | d_img_name = sample['d_img_name']
179 | prob_lr = np.random.random()
180 | # np.fliplr needs HxWxC
181 | if prob_lr > 0.5:
182 | d_img = np.fliplr(d_img).copy()
183 | r_img = np.fliplr(r_img).copy()
184 |
185 | sample = {
186 | 'r_img_org': r_img,
187 | 'd_img_org': d_img,
188 | 'score': score, 'd_img_name':d_img_name
189 | }
190 | return sample
191 |
192 |
193 | class ToTensor(object):
194 | def __init__(self):
195 | pass
196 |
197 | def __call__(self, sample):
198 | r_img, d_img = sample['r_img_org'], sample['d_img_org']
199 | score = sample['score']
200 | d_img_name = sample['d_img_name']
201 | d_img = torch.from_numpy(d_img).type(torch.FloatTensor)
202 | r_img = torch.from_numpy(r_img).type(torch.FloatTensor)
203 | score = torch.from_numpy(score).type(torch.FloatTensor)
204 | sample = {
205 | 'r_img_org': r_img,
206 | 'd_img_org': d_img,
207 | 'score': score, 'd_img_name':d_img_name
208 | }
209 | return sample
--------------------------------------------------------------------------------
/utils/util.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import numpy as np
4 | import torch
5 | import logging
6 |
7 | def mkdirs(paths):
8 | if isinstance(paths, list) and not isinstance(paths, str):
9 | for path in paths:
10 | mkdir(path)
11 | else:
12 | mkdir(paths)
13 |
14 | def mkdir(path):
15 | if not os.path.exists(path):
16 | os.makedirs(path)
17 |
18 | def setup_seed(seed):
19 | random.seed(seed)
20 | os.environ['PYTHONHASHSEED'] = str(seed)
21 | np.random.seed(seed)
22 | torch.manual_seed(seed)
23 | torch.cuda.manual_seed(seed)
24 | torch.cuda.manual_seed_all(seed)
25 | torch.backends.cudnn.benchmark = False
26 | torch.backends.cudnn.deterministic = True
27 |
28 | def set_logging(config):
29 | filename = os.path.join(config.checkpoints_dir, "log.txt")
30 | logging.basicConfig(
31 | level=logging.INFO,
32 | filename=filename,
33 | filemode='w',
34 | format='[%(asctime)s %(levelname)-8s] %(message)s',
35 | datefmt='%Y%m%d %H:%M:%S'
36 | )
37 |
38 | class SaveOutput:
39 | def __init__(self):
40 | self.outputs = []
41 |
42 | def __call__(self, module, module_in, module_out):
43 | self.outputs.append(module_out)
44 |
45 | def clear(self):
46 | self.outputs = []
--------------------------------------------------------------------------------