├── LICENSE ├── README.md ├── SAMPNet ├── cadb_dataset.py ├── config.py ├── requirements.txt ├── samp_module.py ├── samp_net.py ├── test.py └── train.py ├── annotations ├── composition_elements.json ├── composition_scores.json ├── scene_categories.json └── visualize_cadb_annotation.py └── examples ├── annotation_example.jpg ├── element_examples.jpg ├── example.jpg ├── interpretability.jpg └── samp_net.jpg /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 BCMI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | > # Image Composition Assessment Dataset 4 | > Welcomes to the offical homepage of Image Composition Assessment DataBase (**CADB**). 5 | > 6 | > Image composition assessment aims to assess the overall composition quality of a given image, which is crucial in aesthetic assessment. 7 | > To support the research on this task, we contribute the first image composition assessment dataset. Furthermore, we propose a composition assessment network **SAMP-Net** 8 | > with a novel Saliency-Augmented Multi-pattern Pooling (**SAMP**) module, which can perform more favorably than previous aesthetic assessment approaches. 9 | > This work has been accepted by BMVC 2021 ([**paper**](https://arxiv.org/pdf/2104.03133.pdf)). 10 | 11 | **Update 2022-05-05**: We release the annotations of scene categories, composition classes as well as composition elements for more fine-grained analysis of image composition quality. 12 | 13 | **Table of Contents** 14 | 15 | 16 | 17 | - [Dataset](#dataset) 18 | - [Introduction](#introduction) 19 | - [Download](#download) 20 | - [Visualizing Annotations](#visualizing-annotations) 21 | - [Method Overview](#method-overview) 22 | - [Motivation](#motivation) 23 | - [SAMP-Net](#samp-net) 24 | - [Results](#results) 25 | - [Code Usage](#code-usage) 26 | - [Requirements](#requirements) 27 | - [Training](#training) 28 | - [Testing](#testing) 29 | - [Citation](#citation) 30 | 31 | # Dataset 32 | 33 | ## Introduction 34 | 35 | We built the CADB dataset upon the existing Aesthetics and Attributes DataBase ([AADB](https://github.com/aimerykong/deepImageAestheticsAnalysis)). CADB dataset contains **9,497** images with each image rated by **5 individual raters** who specialize in fine art for the overall composition quality, in which we provide a **composition rating scale from 1 to 5**, where a larger score indicates better composition. Some example images with annotations in CADB dataset are illustrated in the figure below, in which we show five composition scores provided by five raters in blue and the calculated composition mean score in red. We also show the aesthetic scores annotated by AADB dataset on a scale from 1 to 5 in green. 36 |
37 | 38 |
39 | 40 | To facilitate the study of image composition assessment, apart from the composition score, we also annotate scene categories, composition classes as well as elements for each image. Specifically, we carefully select 9 frequently appeared scenes (including ***animal, plant, human, static, architecture, landscape, cityscape, indoor, night***) and 1 *other* class specially refers to images without obvious meaning. As for composition classes, we categorize the common photographic composition rules into 13 classes: ***center, rule of thirds, golden ratio, triangle, horizontal, vertical, diagonal, symmetric, curved, radial, vanishing point, pattern, fill the frame***, and assign *none* class to the images without obvious composition rules, in which each image is annotated with one or more composition classes. Moreover, we annotate the dominant composition elements for each composition class except pattern and fill the frame, as illustrated in the figure below. We mark composition elements in yellow and add white auxiliary gridlines to some composition classes for better viewing. 41 |
42 | 43 |
44 | 45 | ## Download 46 | Download ``CADB_Dataset.zip`` (~2GB) from 47 | [[Google Drive]](https://drive.google.com/file/d/1fpZoo5exRfoarqDvdLDpQVXVOKFW63vz/view?usp=sharing) | [[Baidu Cloud]](https://pan.baidu.com/s/1o3AktNB-kmOIanJtzEx98g)(access code: *rmnb*). 48 | The annotations of scene categories, composition classes as well as elements can be found in ``annotations`` folder of this repository. 49 | 50 | ## Visualizing Annotations 51 | Put the json files in the ``annotations`` folder into the CADB dataset directory ``CADB_Dataset``. Then we obtain the file structure below: 52 | ``` 53 | CADB_Dataset 54 | ├── composition_elements.json 55 | ├── composition_scores.json 56 | ├── scene_categories.json 57 | └── images 58 |    ├── 10000.jpg 59 |    ├── 10001.jpg 60 |    ├── 10002.jpg 61 |    ├── …… 62 | ``` 63 | Visualizing the annotations of composition score, scene category, and composition element: 64 | ```bash 65 | python annotations/visualize_cadb_annotation.py --data_root ./CADB_Dataset 66 | ``` 67 | The visualized results will be stored in ``CADB_Dataset/Visualization``. 68 | 69 | # Method Overview 70 | 71 | ## Motivation 72 | As shown in the following Figure, each **composition pattern** divides the holistic image into multiple non-overlapping partitions, which can model 73 | human perception of composition quality. By analyzing the **visual layout** (e.g., positions and sizes of visual elements) according to composition pattern, i.e., comparing the 74 | visual elements in various partitions, we can quantify the aesthetics of visual layout in terms of **visual balance** (e.g., symmetrical balance and radial balance), **composition rules** (e.g., rule of thirds, diagonals and triangles), and so on. Different composition patterns offer different perspectives to evaluate composition quality. For example, the composition pattern in the top (resp., bottom) row can help judge the composition quality in terms of symmetrical (resp., radial) balance. 75 |
76 | 77 |
78 | 79 | ## SAMP-Net 80 | To accomplish the composition assessment task, we propose a novel network SAMP-Net, which is named after **Saliency-Augmented Multi-pattern Pooling (SAMP)** module. 81 | The overall pipeline of our method is illustrated in the following Figure, where we first extract the global feature map from input image by backbone and then yield aggregated 82 | pattern feature through our SAMP module, which is followed by **Attentional Attribute Feature Fusion (AAFF)** module to fuse the composition feature and attribute feature. After that, we predict **composition score distribution** based on the fused feature and predict the attribute score based on the attribute feature, which are supervised by **weighted EMD loss** and Mean Squared Error (MSE) loss respectively. 83 |
84 | 85 |
86 | 87 | # Results 88 | We show the input image, its saliency map, its ground-truth/predicted composition mean score, and its pattern weights in below Figure. Moreover, our method predicts the pattern weights which indicate the importance of different patterns on the overall composition quality. For each image, the composition pattern with the largest weight is referred to as its **dominant pattern** and we overly this pattern on the image. The dominant pattern helps to reveal from which perspective the input image is given a high or low score, which further provide constructive suggestions for improving the composition quality. In the left figure of the third row, per the low score under the center pattern, the dog is suggested to be moved to the center. 89 |
90 | 91 |
92 | 93 | # Code Usage 94 | ```bash 95 | # clone this repository 96 | git clone https://github.com/bcmi/Image-Composition-Assessment-with-SAMP.git 97 | cd Image-Composition-Assessment-with-SAMP/SAMPNet 98 | # download CADB data (~2GB), change the default dataset folder and gpu id in config.py. 99 | ``` 100 | ## Requirements 101 | - PyTorch>=1.0 102 | - torchvision 103 | - tensorboardX 104 | - opencv-python 105 | - scipy 106 | - tqdm 107 | 108 | Or you can refer to [``requirement.txt``](./SAMPNet/requirements.txt). 109 | 110 | ## Training 111 | ```bash 112 | python train.py 113 | # track your experiments 114 | tensorboard --logdir=./experiments --bind_all 115 | ``` 116 | During training, the evaluation results of each epoch are recorded in a ``csv`` format file under the produced folder ``./experiments``. 117 | 118 | ## Testing 119 | You can download pretrained model (~180MB) from [[Google Drive]](https://drive.google.com/file/d/1sIcYr5cQGbxm--tCGaASmN0xtE_r-QUg/view?usp=sharing) | [[Baidu Cloud]](https://pan.baidu.com/s/17EzhsbHqwA5aR8ty77fTvw)(access code: *5qgg*). 120 | ```bash 121 | # place the pretrianed model in the folder ``pretrained_model`` and check the path in ``test.py``. 122 | # change the default gpu id in config.py 123 | python test.py 124 | ``` 125 | 126 | # Citation 127 | ``` 128 | @article{zhang2021image, 129 | title={Image Composition Assessment with Saliency-augmented Multi-pattern Pooling}, 130 | author={Zhang, Bo and Niu, Li and Zhang, Liqing}, 131 | journal={arXiv preprint arXiv:2104.03133}, 132 | year={2021} 133 | } 134 | ``` 135 | -------------------------------------------------------------------------------- /SAMPNet/cadb_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset, DataLoader 2 | import torch 3 | import torch.nn.functional as F 4 | from PIL import Image 5 | import os, json 6 | import torchvision.transforms as transforms 7 | import random 8 | import numpy as np 9 | from config import Config 10 | import cv2 11 | 12 | IMAGE_NET_MEAN = [0.485, 0.456, 0.406] 13 | IMAGE_NET_STD = [0.229, 0.224, 0.225] 14 | 15 | random.seed(1) 16 | torch.manual_seed(1) 17 | cv2.setNumThreads(0) 18 | 19 | # Refer to: Saliency detection: A spectral residual approach 20 | def detect_saliency(img, scale=6, q_value=0.95, target_size=(224,224)): 21 | img_gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) 22 | W, H = img_gray.shape 23 | img_resize = cv2.resize(img_gray, (H // scale, W // scale), interpolation=cv2.INTER_AREA) 24 | 25 | myFFT = np.fft.fft2(img_resize) 26 | myPhase = np.angle(myFFT) 27 | myLogAmplitude = np.log(np.abs(myFFT) + 0.000001) 28 | myAvg = cv2.blur(myLogAmplitude, (3, 3)) 29 | mySpectralResidual = myLogAmplitude - myAvg 30 | 31 | m = np.exp(mySpectralResidual) * (np.cos(myPhase) + complex(1j) * np.sin(myPhase)) 32 | saliencyMap = np.abs(np.fft.ifft2(m)) ** 2 33 | saliencyMap = cv2.GaussianBlur(saliencyMap, (9, 9), 2.5) 34 | saliencyMap = cv2.resize(saliencyMap, target_size, interpolation=cv2.INTER_LINEAR) 35 | threshold = np.quantile(saliencyMap.reshape(-1), q_value) 36 | if threshold > 0: 37 | saliencyMap[saliencyMap > threshold] = threshold 38 | saliencyMap = (saliencyMap - saliencyMap.min()) / threshold 39 | # for debugging 40 | # import matplotlib.pyplot as plt 41 | # plt.subplot(1,2,1) 42 | # plt.imshow(img) 43 | # plt.axis('off') 44 | # plt.subplot(1,2,2) 45 | # plt.imshow(saliencyMap, cmap='gray') 46 | # plt.axis('off') 47 | # plt.show() 48 | return saliencyMap 49 | 50 | class CADBDataset(Dataset): 51 | def __init__(self, split, cfg): 52 | self.data_path = cfg.dataset_path 53 | self.image_path = os.path.join(self.data_path, 'images') 54 | self.score_path = os.path.join(self.data_path, 'composition_scores.json') 55 | self.split_path = os.path.join(self.data_path, 'split.json') 56 | self.attr_path = os.path.join(self.data_path, 'composition_attributes.json') 57 | self.weight_path= os.path.join(self.data_path, 'emdloss_weight.json') 58 | self.split = split 59 | self.attr_types = cfg.attribute_types 60 | 61 | self.image_list = json.load(open(self.split_path, 'r'))[split] 62 | self.comp_scores = json.load(open(self.score_path, 'r')) 63 | self.comp_attrs = json.load(open(self.attr_path, 'r')) 64 | if self.split == 'train': 65 | self.image_weight = json.load(open(self.weight_path, 'r')) 66 | else: 67 | self.image_weight = None 68 | 69 | self.image_size = cfg.image_size 70 | self.transformer = transforms.Compose([ 71 | transforms.Resize((self.image_size, self.image_size)), 72 | transforms.ToTensor(), 73 | transforms.Normalize(mean=IMAGE_NET_MEAN, std=IMAGE_NET_STD) 74 | ]) 75 | 76 | def __len__(self): 77 | return len(self.image_list) 78 | 79 | def __getitem__(self, index): 80 | image_name = self.image_list[index] 81 | image_file = os.path.join(self.image_path, image_name) 82 | assert os.path.exists(image_file), image_file + ' not found' 83 | src = Image.open(image_file).convert('RGB') 84 | im = self.transformer(src) 85 | 86 | score_mean = self.comp_scores[image_name]['mean'] 87 | score_mean = torch.Tensor([score_mean]) 88 | score_dist = self.comp_scores[image_name]['dist'] 89 | score_dist = torch.Tensor(score_dist) 90 | 91 | attrs = torch.tensor(self.get_attribute(image_name)) 92 | src_im = np.asarray(src).copy() 93 | sal_map = detect_saliency(src_im, target_size=(self.image_size, self.image_size)) 94 | sal_map = torch.from_numpy(sal_map.astype(np.float32)).unsqueeze(0) 95 | 96 | if self.split == 'train': 97 | emd_weight = torch.tensor(self.image_weight[image_name]) 98 | return im, score_mean, score_dist, sal_map, attrs, emd_weight 99 | else: 100 | return im, score_mean, score_dist, sal_map, attrs 101 | 102 | def get_attribute(self, image_name): 103 | all_attrs = self.comp_attrs[image_name] 104 | attrs = [all_attrs[k] for k in self.attr_types] 105 | return attrs 106 | 107 | def normbboxes(self, bboxes, w, h): 108 | bboxes = bboxes.astype(np.float32) 109 | center_x = (bboxes[:,0] + bboxes[:,2]) / 2. / w 110 | center_y = (bboxes[:,1] + bboxes[:,3]) / 2. / h 111 | norm_w = (bboxes[:,2] - bboxes[:,0]) / w 112 | norm_h = (bboxes[:,3] - bboxes[:,1]) / h 113 | norm_bboxes = np.column_stack((center_x, center_y, norm_w, norm_h)) 114 | norm_bboxes = np.clip(norm_bboxes, 0, 1) 115 | assert norm_bboxes.shape == bboxes.shape, '{} vs. {}'.format(bboxes.shape, norm_bboxes.shape) 116 | # print(w,h,bboxes[0],norm_bboxes[0]) 117 | return norm_bboxes 118 | 119 | def scores2dist(self, scores): 120 | scores = np.array(scores) 121 | count = [(scores == i).sum() for i in range(1,6)] 122 | count = np.array(count) 123 | assert count.sum() == 5, scores 124 | distribution = count.astype(np.float) / count.sum() 125 | distribution = distribution.tolist() 126 | return distribution 127 | 128 | def collate_fn(self, batch): 129 | batch = list(zip(*batch)) 130 | if self.split == 'train': 131 | assert (not self.need_image_path) and (not self.need_mask) \ 132 | and (not self.need_proposal), 'Multi-scale training not implement' 133 | self.image_size = random.choice(range(self.min_image_size, 134 | self.max_image_size+1, 135 | 16)) 136 | batch[0] = [resize(im, self.image_size) for im in batch[0]] 137 | batch = [torch.stack(data, dim=0) for data in batch] 138 | return batch 139 | 140 | 141 | def resize(image, size): 142 | image = F.interpolate(image.unsqueeze(0), size=size, mode="nearest").squeeze(0) 143 | return image 144 | 145 | if __name__ == '__main__': 146 | cfg = Config() 147 | train_dataset = CADBDataset('train', cfg) 148 | train_loader = DataLoader(train_dataset, 149 | batch_size=2, 150 | shuffle=True, 151 | num_workers=8, 152 | drop_last=True, 153 | collate_fn=None) 154 | test_dataset = CADBDataset('test', cfg) 155 | test_loader = DataLoader(test_dataset, 156 | batch_size=2, 157 | shuffle=False, 158 | num_workers=0, 159 | drop_last=False) 160 | 161 | print("training set size {}, test set size {}".format(len(train_dataset), len(test_dataset))) 162 | for batch, data in enumerate(train_loader): 163 | im,score,dist,mask,attrs,weight = data 164 | print('train', im.shape, score.shape, dist.shape, mask.shape, attrs.shape, weight.shape) 165 | 166 | 167 | for batch, data in enumerate(test_loader): 168 | im,score,dist,mask,attrs = data 169 | print('test', im.shape, score.shape, dist.shape, mask.shape, attrs.shape) 170 | break -------------------------------------------------------------------------------- /SAMPNet/config.py: -------------------------------------------------------------------------------- 1 | import os,time 2 | 3 | class Config: 4 | # setting for dataset and dataloader 5 | dataset_path = '/workspace/composition/CADB_Dataset' 6 | assert os.path.exists(dataset_path), dataset_path + 'not found' 7 | 8 | batch_size = 16 9 | gpu_id = 0 10 | num_workers = 8 11 | 12 | # setting for training and optimization 13 | max_epoch = 50 14 | save_epoch = 1 15 | display_steps = 10 16 | score_level = 5 17 | 18 | use_weighted_loss = True 19 | use_attribute = True 20 | use_channel_attention = True 21 | use_saliency = True 22 | use_multipattern = True 23 | use_pattern_weight = True 24 | # AADB attributes: "Light" "Symmetry" "Object" 25 | # "score" "RuleOfThirds" "Repetition" 26 | # "BalacingElements" "ColorHarmony" "MotionBlur" 27 | # "VividColor" "DoF" "Content" 28 | attribute_types = ['RuleOfThirds', 'BalacingElements','DoF', 29 | 'Object', 'Symmetry', 'Repetition'] 30 | num_attributes = len(attribute_types) 31 | attribute_weight = 0.1 32 | 33 | optimizer = 'adam' # or sgd 34 | lr = 1e-4 35 | weight_decay = 5e-5 36 | momentum = 0.9 37 | # setting for cnn 38 | image_size = 224 39 | resnet_layers = 18 40 | dropout = 0.5 41 | pool_dropout = 0.5 42 | pattern_list = [1, 2, 3, 4, 5, 6, 7, 8] 43 | pattern_fuse = 'sum' 44 | 45 | # setting for testing 46 | test_epoch = 1 47 | 48 | # setting for expriments 49 | if len(pattern_list) == 1: 50 | exp_root = os.path.join(os.getcwd(), './experiments/single_pattern') 51 | prefix = 'pattern{}'.format(pattern_list[0]) 52 | else: 53 | exp_root = os.path.join(os.getcwd(), './experiments/') 54 | prefix = 'resnet{}'.format(resnet_layers) 55 | if use_multipattern: 56 | if use_pattern_weight and use_saliency: 57 | prefix += '_samp' 58 | elif use_pattern_weight: 59 | prefix += '_weighted_mpp' 60 | elif use_saliency: 61 | prefix += '_saliency_mpp' 62 | else: 63 | prefix += '_mpp' 64 | if use_attribute: 65 | if use_channel_attention: 66 | prefix += '_aaff' 67 | else: 68 | prefix += '_attr' 69 | if use_weighted_loss: 70 | prefix += '_wemd' 71 | exp_name = prefix 72 | exp_path = os.path.join(exp_root, prefix) 73 | while os.path.exists(exp_path): 74 | index = os.path.basename(exp_path).split(prefix)[-1] 75 | try: 76 | index = int(index) + 1 77 | except: 78 | index = 1 79 | exp_name = prefix + str(index) 80 | exp_path = os.path.join (exp_root, exp_name) 81 | print('Experiment name {} \n'.format(os.path.basename(exp_path))) 82 | checkpoint_dir = os.path.join(exp_path, 'checkpoints') 83 | log_dir = os.path.join(exp_path, 'logs') 84 | 85 | def create_path(self): 86 | print('Create experiment directory: ', self.exp_path) 87 | os.makedirs(self.exp_path) 88 | os.makedirs(self.checkpoint_dir) 89 | os.makedirs(self.log_dir) 90 | 91 | if __name__ == '__main__': 92 | cfg = Config() -------------------------------------------------------------------------------- /SAMPNet/requirements.txt: -------------------------------------------------------------------------------- 1 | einops==0.3.0 2 | numpy==1.19.1 3 | opencv_contrib_python==4.4.0.46 4 | Pillow==8.4.0 5 | scipy==1.5.2 6 | tensorboardX==2.4 7 | torch==1.9.1 8 | torchvision==0.9.0+cu111 9 | tqdm==4.51.0 10 | -------------------------------------------------------------------------------- /SAMPNet/samp_module.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import os 4 | import torch 5 | import torch.nn.functional as F 6 | import torchvision.models as models 7 | from einops import rearrange, repeat 8 | import torch.optim as optim 9 | from torch.nn.modules.utils import _pair 10 | 11 | 12 | class TriangularPattern(nn.Module): 13 | def __init__(self, flip=False, sal_size=56): 14 | super(TriangularPattern,self).__init__() 15 | self.flip = flip 16 | target_size = int(sal_size * 3 / 4) 17 | self.downsample = nn.AdaptiveMaxPool2d(output_size=(target_size, target_size)) 18 | 19 | def forward(self, x, s): 20 | if self.flip: 21 | x = torch.flip(x, dims=(3,)) 22 | 23 | up_indices = torch.triu_indices(x.shape[2], x.shape[3], 24 | device=x.device, 25 | offset=1) 26 | up_feat = x[:,:,up_indices[0], up_indices[1]].mean(dim=2) 27 | 28 | lw_indices = torch.tril_indices(x.shape[2], x.shape[3], 29 | device=x.device, 30 | offset=-1) 31 | lw_feat = x[:,:,lw_indices[0], lw_indices[1]].mean(dim=2) 32 | fused = torch.stack([up_feat, lw_feat], dim=2).unsqueeze(3) 33 | if s is None: 34 | return fused 35 | 36 | if self.flip: 37 | s = torch.flip(s, dims=(3,)) 38 | s = self.downsample(s) 39 | up_sal_indices = torch.triu_indices(s.shape[2], 40 | s.shape[3], 41 | device=s.device, 42 | offset=1) 43 | up_sal = s[:, :, up_sal_indices[0], up_sal_indices[1]].flatten(1) 44 | 45 | lw_sal_indices = torch.tril_indices(s.shape[2], 46 | s.shape[3], 47 | device=s.device, 48 | offset=-1) 49 | lw_sal = s[:, :, lw_sal_indices[0], lw_sal_indices[1]].flatten(1) 50 | 51 | 52 | fused_sal = torch.stack([up_sal, lw_sal], dim=2).unsqueeze(3) 53 | return fused, fused_sal 54 | 55 | class CrossPattern(nn.Module): 56 | def __init__(self): 57 | super(CrossPattern, self).__init__() 58 | 59 | def forward(self, x, s): 60 | ones_vec = torch.ones(x.size(2), x.size(3), requires_grad=False) 61 | up_flip = torch.triu(ones_vec, diagonal=0).flip(dims=(-1,)) 62 | lw_flip = torch.tril(ones_vec, diagonal=0).flip(dims=(-1,)) 63 | up_inds = torch.where(torch.triu(up_flip, diagonal=0) > 0) 64 | lf_inds = torch.where(torch.tril(up_flip, diagonal=0) > 0) 65 | rt_inds = torch.where(torch.triu(lw_flip, diagonal=0) > 0) 66 | lw_inds = torch.where(torch.tril(lw_flip, diagonal=0) > 0) 67 | 68 | up_feat = x[:, :, up_inds[0], up_inds[1]].mean(2) 69 | lf_feat = x[:, :, lf_inds[0], lf_inds[1]].mean(2) 70 | rt_feat = x[:, :, rt_inds[0], rt_inds[1]].mean(2) 71 | lw_feat = x[:, :, lw_inds[0], lw_inds[1]].mean(2) 72 | 73 | fused1 = torch.stack([lf_feat, up_feat], dim=2) 74 | fused2 = torch.stack([lw_feat, rt_feat], dim=2) 75 | fused = torch.stack([fused1, fused2], dim=3) 76 | if s is None: 77 | return fused 78 | 79 | assert s.shape[2] == s.shape[3], \ 80 | 'saliency map should be square, but get a shape {}'.format(s.shape) 81 | ones_vec = torch.ones(s.size(2), s.size(3), requires_grad=False) 82 | up_flip = torch.triu(ones_vec, diagonal=0).flip(dims=(-1,)) 83 | lw_flip = torch.tril(ones_vec, diagonal=0).flip(dims=(-1,)) 84 | up_inds = torch.where(torch.triu(up_flip, diagonal=0) > 0) 85 | lf_inds = torch.where(torch.tril(up_flip, diagonal=0) > 0) 86 | rt_inds = torch.where(torch.triu(lw_flip, diagonal=0) > 0) 87 | lw_inds = torch.where(torch.tril(lw_flip, diagonal=0) > 0) 88 | 89 | up_sal = s[:, :, up_inds[0], up_inds[1]].flatten(1) 90 | lf_sal = s[:, :, lf_inds[0], lf_inds[1]].flatten(1) 91 | rt_sal = s[:, :, rt_inds[0], rt_inds[1]].flatten(1) 92 | lw_sal = s[:, :, lw_inds[0], lw_inds[1]].flatten(1) 93 | 94 | sal1 = torch.stack([lf_sal, up_sal], dim=2) 95 | sal2 = torch.stack([lw_sal, rt_sal], dim=2) 96 | fused_sal = torch.stack([sal1, sal2], dim=3) 97 | 98 | return fused, fused_sal 99 | 100 | class SurroundPattern(nn.Module): 101 | def __init__(self, crop_size=1./2): 102 | super(SurroundPattern, self).__init__() 103 | self.crop_size = crop_size 104 | 105 | def forward(self, x, s): 106 | H,W = x.shape[2:] 107 | crop_h = (int(H / 2 - self.crop_size / 2 * H), int(H / 2 + self.crop_size / 2 * H)) 108 | crop_w = (int(W / 2 - self.crop_size / 2 * W), int(W / 2 + self.crop_size / 2 * W)) 109 | x_mask = torch.zeros(H,W,device=x.device, dtype=torch.bool) 110 | x_mask[crop_h[0] : crop_h[1], crop_w[0] : crop_w[1]] = True 111 | 112 | inside_indices = torch.where(x_mask) 113 | inside_part = x[:, :, inside_indices[0], inside_indices[1]] 114 | inside_feat = inside_part.mean(2) 115 | 116 | outside_indices = torch.where(~x_mask) 117 | outside_part = x[:, :, outside_indices[0], outside_indices[1]] 118 | outside_feat = outside_part.mean(2) 119 | fused = torch.stack([inside_feat, outside_feat], dim=2).unsqueeze(3) 120 | if s is None: 121 | return fused 122 | 123 | SH,SW = s.shape[2:] 124 | crop_sh = (int(SH / 2 - self.crop_size / 2 * SH), int(SH / 2 + self.crop_size / 2 * SH)) 125 | crop_sw = (int(SW / 2 - self.crop_size / 2 * SW), int(SW / 2 + self.crop_size / 2 * SW)) 126 | s_mask = torch.zeros(SH, SW, device=s.device, dtype=torch.bool) 127 | s_mask[crop_sh[0] : crop_sh[1], crop_sw[0] : crop_sw[1]] = True 128 | 129 | s_inside_indices = torch.where(s_mask) 130 | inside_sal = s[:, :, s_inside_indices[0], s_inside_indices[1]].flatten(1) 131 | 132 | s_outside_indices = torch.where(~s_mask) 133 | outside_sal = s[:, :, s_outside_indices[0], s_outside_indices[1]].flatten(1) 134 | if outside_sal.shape != inside_sal.shape: 135 | outside_sal = F.adaptive_max_pool1d(outside_sal.unsqueeze(1), output_size=inside_sal.shape[1]) 136 | outside_sal = outside_sal.squeeze(1) 137 | fused_sal = torch.stack([inside_sal, outside_sal], dim=2).unsqueeze(3) 138 | return fused, fused_sal 139 | 140 | class HorizontalPattern(nn.Module): 141 | def __init__(self): 142 | super(HorizontalPattern, self).__init__() 143 | self.downsample = nn.MaxPool2d(kernel_size=(1,3), stride=(1,2), padding=(0,1)) 144 | 145 | def forward(self, x, s): 146 | H = x.shape[2] 147 | up_part = x[:, :, : H // 2, :] 148 | up_feat = up_part.mean(3).mean(2) 149 | 150 | lw_part = x[:, :, H // 2 :, :] 151 | lw_feat = lw_part.mean(3).mean(2) 152 | fused = torch.stack([up_feat, lw_feat], dim=2).unsqueeze(3) 153 | if s is None: 154 | return fused 155 | SH = s.shape[2] 156 | s = self.downsample(s) 157 | up_sal = s[:, :, : SH // 2, :].flatten(1) 158 | if SH % 2 == 0: 159 | lw_sal = s[:, :, SH // 2 :, :].flatten(1) 160 | else: 161 | lw_sal = s[:, :, SH // 2 + 1 :, :].flatten(1) 162 | fused_sal = torch.stack([up_sal, lw_sal], dim=2).unsqueeze(3) 163 | return fused, fused_sal 164 | 165 | class VerticalPattern(nn.Module): 166 | def __init__(self): 167 | super(VerticalPattern, self).__init__() 168 | self.downsample = nn.MaxPool2d(kernel_size=(3, 1), stride=(2, 1), padding=(1, 0)) 169 | 170 | def forward(self, x, s): 171 | W = x.shape[3] 172 | left_part = x[:, :, :, : W // 2] 173 | left_feat = left_part.mean(3).mean(2) 174 | 175 | right_part = x[:, :, :, W // 2 :] 176 | right_feat = right_part.mean(3).mean(2) 177 | fused = torch.stack([left_feat, right_feat], dim=2).unsqueeze(2) 178 | if s is None: 179 | return fused 180 | 181 | SW = s.shape[3] 182 | s = self.downsample(s) 183 | left_sal = s[:, :, :, : SW // 2].flatten(1) 184 | if SW % 2 == 0: 185 | right_sal = s[:, :, :, SW // 2 :].flatten(1) 186 | else: 187 | right_sal = s[:, :, :, SW // 2 + 1 :].flatten(1) 188 | fused_sal = torch.stack([left_sal, right_sal], dim=2).unsqueeze(2) 189 | return fused, fused_sal 190 | 191 | class GlobalPattern(nn.Module): 192 | def __init__(self): 193 | super(GlobalPattern, self).__init__() 194 | self.downsample = nn.Sequential( 195 | nn.MaxPool2d(kernel_size=3, stride=2, padding=1), 196 | nn.Flatten(1) 197 | ) 198 | self.gap = nn.Sequential( 199 | nn.AdaptiveAvgPool2d(1) 200 | ) 201 | 202 | def forward(self, x, s): 203 | gap_feat = self.gap(x) 204 | if s is None: 205 | return gap_feat 206 | sal_feat = self.downsample(s).unsqueeze(2).unsqueeze(3) 207 | return gap_feat, sal_feat 208 | 209 | class QuarterPattern(nn.Module): 210 | def __init__(self): 211 | super(QuarterPattern, self).__init__() 212 | 213 | def forward(self, x, s): 214 | feat = F.adaptive_avg_pool2d(x, output_size=(2,2)) 215 | if s is None: 216 | return feat 217 | 218 | s_chunks = torch.chunk(s, 2, dim=2) 219 | up_left, up_right = torch.chunk(s_chunks[0], 2, dim=3) 220 | lw_left, lw_right = torch.chunk(s_chunks[1], 2, dim=3) 221 | up_sal = torch.stack([up_left.flatten(1), up_right.flatten(1)],dim=2) 222 | lw_sal = torch.stack([lw_left.flatten(1), lw_right.flatten(1)],dim=2) 223 | feat_sal = torch.stack([up_sal, lw_sal], dim=2) 224 | return feat, feat_sal 225 | 226 | class ThirdOfRulePattern(nn.Module): 227 | def __init__(self): 228 | super(ThirdOfRulePattern, self).__init__() 229 | 230 | def forward(self, x, s): 231 | feat = F.adaptive_avg_pool2d(x, output_size=(3,3)) 232 | if s is None: 233 | return feat 234 | out_size = (s.shape[-1] // 3) * 3 235 | s = F.adaptive_max_pool2d(s, output_size=(out_size,out_size)) 236 | hor_chunks = torch.chunk(s, 3, dim=3) 237 | feat_sal = [] 238 | for h_chunk in hor_chunks: 239 | tmp = [] 240 | for v_chunk in torch.chunk(h_chunk, 3, dim=2): 241 | tmp.append(v_chunk.flatten(1)) 242 | tmp = torch.stack(tmp, dim=2) 243 | feat_sal.append(tmp) 244 | feat_sal = torch.stack(feat_sal, dim=3) 245 | return feat, feat_sal 246 | 247 | class VerticalThirdPattern(nn.Module): 248 | def __init__(self): 249 | super(VerticalThirdPattern, self).__init__() 250 | 251 | def forward(self, x, s): 252 | feat = F.adaptive_avg_pool2d(x, output_size=(3,1)) 253 | if s is None: 254 | return feat 255 | out_size = (s.shape[-2] // 3) * 3 256 | s = F.adaptive_max_pool2d(s, output_size=(out_size,s.shape[-1])) 257 | ver_chunks = torch.chunk(s, 3, dim=2) 258 | feat_sal = [] 259 | for ver in ver_chunks: 260 | feat_sal.append(ver.flatten(1)) 261 | feat_sal = torch.stack(feat_sal, dim=2).unsqueeze(3) 262 | return feat, feat_sal 263 | 264 | class HorThirdPattern(nn.Module): 265 | def __init__(self): 266 | super(HorThirdPattern, self).__init__() 267 | 268 | def forward(self, x, s): 269 | feat = F.adaptive_avg_pool2d(x, output_size=(1,3)) 270 | if s is None: 271 | return feat 272 | out_size = (s.shape[-1] // 3) * 3 273 | s = F.adaptive_max_pool2d(s, output_size=(s.shape[-2], out_size)) 274 | hor_chunks = torch.chunk(s, 3, dim=3) 275 | feat_sal = [] 276 | for hor in hor_chunks: 277 | feat_sal.append(hor.flatten(1)) 278 | feat_sal = torch.stack(feat_sal, dim=2).unsqueeze(2) 279 | return feat, feat_sal 280 | 281 | class MultiDirectionPattern(nn.Module): 282 | def __init__(self): 283 | super(MultiDirectionPattern, self).__init__() 284 | feat_mask = self.generate_multi_direction_mask(7, 7) 285 | sal_mask = self.generate_multi_direction_mask(56, 56) 286 | self.register_buffer('feat_mask', feat_mask) 287 | self.register_buffer('sal_mask', sal_mask) 288 | 289 | def generate_multi_direction_mask(self, w, h): 290 | mask = torch.zeros(8, h, w) 291 | degree_mask = torch.zeros(h, w) 292 | if h % 2 == 0: 293 | cx, cy = float(w-1)/2, float(h-1)/2 294 | else: 295 | cx, cy = w // 2, h // 2 296 | 297 | for i in range(w): 298 | for j in range(h): 299 | degree = math.degrees(math.atan2(cy - j, cx - i)) 300 | degree_mask[j, i] = (degree + 180) % 360 301 | for i in range(8): 302 | if i == 7: 303 | degree_mask[degree_mask == 0] = 360 304 | mask[i, (degree_mask >= i * 45) & (degree_mask <= (i + 1) * 45)] = 1 305 | if h % 2 != 0: 306 | mask[i, cy, cx] = 1 307 | # mask = torch.count_nonzero(mask.flatten(1)) 308 | return mask 309 | 310 | def forward(self, x, s): 311 | # B, C, H, W = x.shape 312 | mask = rearrange(self.feat_mask, 'p h w -> 1 1 p h w') 313 | count = torch.count_nonzero(mask.flatten(-2), dim=-1) 314 | x = x.unsqueeze(2) 315 | part = (x * mask).sum(-2).sum(-1) 316 | feat = part / count 317 | feat = rearrange(feat, 'b c (h w) -> b c h w', h=2, w=4) 318 | if s is None: 319 | return feat 320 | sal_feat = [] 321 | for i in range(self.sal_mask.shape[0]): 322 | sal_feat.append(s[:, :, self.sal_mask[i] > 0].flatten(1)) 323 | sal_feat = torch.stack(sal_feat, dim=2) 324 | sal_feat = rearrange(sal_feat, 'b c (h w) -> b c h w', h=2, w=4) 325 | return feat, sal_feat 326 | 327 | class MultiRectanglePattern(nn.Module): 328 | def __init__(self): 329 | super(MultiRectanglePattern, self).__init__() 330 | feat_mask = self.generate_multi_direction_mask(7, 7) 331 | sal_mask = self.generate_multi_direction_mask(56, 56) 332 | self.register_buffer('feat_mask', feat_mask) 333 | self.register_buffer('sal_mask', sal_mask) 334 | 335 | def generate_multi_direction_mask(self, w, h): 336 | square_part = torch.zeros(4, 4, h, w) 337 | index_y = torch.split(torch.arange(h), (h + 1) // 4) 338 | index_x = torch.split(torch.arange(w), (w + 1) // 4) 339 | for i in range(4): 340 | for j in range(4): 341 | for x in index_x[i]: 342 | for y in index_y[j]: 343 | square_part[i, j, y, x] = 1 344 | mask = torch.zeros(8, h, w) 345 | group_x = [[0, 0, 1], [1], [2, 3, 3], [2], [0, 0, 1], [1], [2, 3, 3], [2]] 346 | group_y = [[1, 0, 0], [1], [0, 0, 1], [1], [2, 3, 3], [2], [3, 3, 2], [2]] 347 | for i in range(len(group_x)): 348 | mask[i] = torch.sum(square_part[group_x[i], group_y[i]], dim=0) 349 | mask = torch.clip(mask, min=0, max=1) 350 | return mask 351 | 352 | def forward(self, x, s): 353 | # B, C, H, W = x.shape 354 | mask = rearrange(self.feat_mask, 'p h w -> 1 1 p h w') 355 | count = torch.count_nonzero(mask.flatten(-2), dim=-1) 356 | x = x.unsqueeze(2) 357 | part = (x * mask).sum(-2).sum(-1) 358 | feat = part / count 359 | feat = rearrange(feat, 'b c (h w) -> b c h w', h=2, w=4) 360 | if s is None: 361 | return feat 362 | sal_count = torch.count_nonzero(self.sal_mask.flatten(1), dim=-1) 363 | target_size = sal_count.min().item() 364 | sal_feat = [] 365 | for i in range(self.sal_mask.shape[0]): 366 | sal = s[:, :, self.sal_mask[i] > 0].flatten(1) 367 | if sal.shape[1] > target_size: 368 | sal = F.adaptive_max_pool1d(sal.unsqueeze(1), output_size=target_size).squeeze(1) 369 | sal_feat.append(sal) 370 | sal_feat = torch.stack(sal_feat, dim=2) 371 | sal_feat = rearrange(sal_feat, 'b c (h w) -> b c h w', h=2, w=4) 372 | return feat, sal_feat 373 | 374 | class MPPModule(nn.Module): 375 | def __init__(self, in_dim, out_dim, dropout=0.5, 376 | pattern_list=[1,2,3,4,5,6,7,8], 377 | fusion='sum'): 378 | super(MPPModule, self).__init__() 379 | self.pattern_list = pattern_list 380 | self.in_dim = in_dim 381 | self.out_dim = out_dim 382 | self.dropout = nn.Dropout(dropout) 383 | self.fusion = fusion 384 | pool_list = [] 385 | conv_list = [] 386 | print('Multi-Pattern Pooling pattern: {}, fusion manner: {}, dropout: {}'.\ 387 | format(pattern_list, fusion, dropout)) 388 | for pattern in pattern_list: 389 | p_fn = getattr(self, 'pattern{}'.format(int(pattern))) 390 | p, c = p_fn() 391 | pool_list.append(p) 392 | conv_list.append(c) 393 | self.pool_list = nn.ModuleList(pool_list) 394 | self.conv_list = nn.ModuleList(conv_list) 395 | 396 | def forward(self, x, weights): 397 | outputs = [] 398 | for pool,conv in zip(self.pool_list, self.conv_list): 399 | feat = self.dropout(pool(x,s=None)) 400 | feat = self.dropout(conv(feat)) 401 | outputs.append(feat) 402 | 403 | if len(outputs) == 1: 404 | return outputs[0] 405 | if self.fusion == 'sum': 406 | outputs = torch.stack(outputs, dim=2) 407 | if weights is None: 408 | outputs = torch.sum(outputs, dim=2) 409 | else: 410 | weights = F.softmax(weights, dim=1) 411 | outputs = torch.sum(outputs * weights.unsqueeze(1), dim=2) 412 | elif self.fusion == 'mean': 413 | outputs = torch.stack(outputs, dim=2) 414 | outputs = torch.mean(outputs, dim=2) 415 | elif self.fusion == 'concat': 416 | outputs = torch.cat(outputs, dim=1) 417 | else: 418 | raise ValueError('Unkown fusion type {}'.format(self.fusion)) 419 | return outputs 420 | 421 | def pattern0(self): 422 | pool = GlobalPattern() 423 | conv = nn.Sequential( 424 | nn.Conv2d(self.in_dim, self.out_dim, (1,1), bias=False), 425 | nn.ReLU(True), 426 | nn.Flatten(1) 427 | ) 428 | return pool, conv 429 | 430 | def pattern1(self): 431 | pool = HorizontalPattern() 432 | conv = nn.Sequential( 433 | nn.Conv2d(self.in_dim, self.out_dim, (2,1), bias=False), 434 | nn.ReLU(True), 435 | nn.Flatten(1) 436 | ) 437 | return pool, conv 438 | 439 | def pattern2(self): 440 | pool = VerticalPattern() 441 | conv = nn.Sequential( 442 | nn.Conv2d(self.in_dim, self.out_dim, (1,2), bias=False), 443 | nn.ReLU(True), 444 | nn.Flatten(1) 445 | ) 446 | return pool, conv 447 | 448 | def pattern3(self): 449 | pool = TriangularPattern(flip=False) 450 | conv = nn.Sequential( 451 | nn.Conv2d(self.in_dim, self.out_dim, (2,1), bias=False), 452 | nn.ReLU(True), 453 | nn.Flatten(1) 454 | ) 455 | return pool, conv 456 | 457 | def pattern4(self): 458 | pool = TriangularPattern(flip=True) 459 | conv = nn.Sequential( 460 | nn.Conv2d(self.in_dim, self.out_dim, (2, 1), bias=False), 461 | nn.ReLU(True), 462 | nn.Flatten(1) 463 | ) 464 | return pool, conv 465 | 466 | def pattern5(self): 467 | pool = SurroundPattern() 468 | conv = nn.Sequential( 469 | nn.Conv2d(self.in_dim, self.out_dim, (2, 1), bias=False), 470 | nn.ReLU(True), 471 | nn.Flatten(1) 472 | ) 473 | return pool, conv 474 | 475 | def pattern6(self): 476 | pool = QuarterPattern() 477 | conv = nn.Sequential( 478 | nn.Conv2d(self.in_dim, self.out_dim, (2, 2), bias=False), 479 | nn.ReLU(True), 480 | nn.Flatten(1) 481 | ) 482 | return pool, conv 483 | 484 | def pattern7(self): 485 | pool = CrossPattern() 486 | conv = nn.Sequential( 487 | nn.Conv2d(self.in_dim, self.out_dim, (2,2), bias=False), 488 | nn.ReLU(True), 489 | nn.Flatten(1) 490 | ) 491 | return pool, conv 492 | 493 | def pattern8(self): 494 | pool = ThirdOfRulePattern() 495 | conv = nn.Sequential( 496 | nn.Conv2d(self.in_dim, self.out_dim, (3,3), bias=False), 497 | nn.ReLU(True), 498 | nn.Flatten(1) 499 | ) 500 | return pool, conv 501 | 502 | def pattern9(self): 503 | pool = HorThirdPattern() 504 | conv = nn.Sequential( 505 | nn.Conv2d(self.in_dim, self.out_dim, (1,3), bias=False), 506 | nn.ReLU(True), 507 | nn.Flatten(1) 508 | ) 509 | return pool, conv 510 | 511 | def pattern10(self): 512 | pool = VerticalThirdPattern() 513 | conv = nn.Sequential( 514 | nn.Conv2d(self.in_dim, self.out_dim, (3,1), bias=False), 515 | nn.ReLU(True), 516 | nn.Flatten(1) 517 | ) 518 | return pool, conv 519 | 520 | def pattern11(self): 521 | pool = MultiDirectionPattern() 522 | conv = nn.Sequential( 523 | nn.Conv2d(self.in_dim, self.out_dim, (2,4), bias=False), 524 | nn.ReLU(True), 525 | nn.Flatten(1) 526 | ) 527 | return pool, conv 528 | 529 | def pattern12(self): 530 | pool = MultiRectanglePattern() 531 | conv = nn.Sequential( 532 | nn.Conv2d(self.in_dim, self.out_dim, (2,4), bias=False), 533 | nn.ReLU(True), 534 | nn.Flatten(1) 535 | ) 536 | return pool, conv 537 | 538 | 539 | class SAMPPModule(nn.Module): 540 | def __init__(self, in_dim, out_dim, 541 | saliency_size, dropout=0.5, 542 | pattern_list=[1,2,3,4,5,6,7,8], 543 | fusion='sum'): 544 | super(SAMPPModule, self).__init__() 545 | self.pattern_list = pattern_list 546 | self.in_dim = in_dim 547 | self.out_dim = out_dim 548 | self.dropout = nn.Dropout(dropout) 549 | self.fusion = fusion 550 | self.saliency_size = saliency_size 551 | pool_list = [] 552 | conv_list = [] 553 | print('Saliency-aware Multi-Pattern Pooling pattern: {}, fusion manner: {}, dropout: {}'.\ 554 | format(pattern_list, fusion, dropout)) 555 | for pattern in pattern_list: 556 | p_fn = getattr(self, 'pattern{}'.format(int(pattern))) 557 | p,c = p_fn() 558 | pool_list.append(p) 559 | conv_list.append(c) 560 | self.pool_list = nn.ModuleList(pool_list) 561 | self.conv_list = nn.ModuleList(conv_list) 562 | 563 | def forward(self, x, s, weights): 564 | outputs = [] 565 | idx = 0 566 | for pool,conv in zip(self.pool_list, self.conv_list): 567 | feat, sal = pool(x,s) 568 | feat = self.dropout(feat) 569 | # print('pattern{}, x {}, s {}, feat_dim {}, sal_dim {}'.format( 570 | # self.pattern_list[idx], x.shape, s.shape, feat.shape, sal.shape)) 571 | fused = torch.cat([feat, sal], dim=1) 572 | fused = self.dropout(conv(fused)) 573 | outputs.append(fused) 574 | idx += 1 575 | 576 | if len(outputs) == 1: 577 | return outputs[0] 578 | if self.fusion == 'sum': 579 | outputs = torch.stack(outputs, dim=2) 580 | if weights is None: 581 | outputs = torch.sum(outputs, dim=2) 582 | else: 583 | weights = F.softmax(weights, dim=1) 584 | outputs = torch.sum(outputs * weights.unsqueeze(1), dim=2) 585 | elif self.fusion == 'mean': 586 | outputs = torch.stack(outputs, dim=2) 587 | outputs = torch.mean(outputs, dim=2) 588 | elif self.fusion == 'concat': 589 | outputs = torch.cat(outputs, dim=1) 590 | else: 591 | raise ValueError('Unkown fusion type {}'.format(self.fusion)) 592 | return outputs 593 | 594 | def pattern0(self): 595 | sal_length = (self.saliency_size // 2) ** 2 596 | pool = GlobalPattern() 597 | conv = nn.Sequential( 598 | nn.Conv2d(self.in_dim + sal_length, self.out_dim, (1, 1), bias=False), 599 | nn.ReLU(True), 600 | nn.Flatten(1) 601 | ) 602 | return pool, conv 603 | 604 | 605 | def pattern1(self): 606 | sal_length = (self.saliency_size // 2) ** 2 607 | pool = HorizontalPattern() 608 | conv = nn.Sequential( 609 | nn.Conv2d(self.in_dim + sal_length, self.out_dim, (2,1), bias=False), 610 | nn.ReLU(True), 611 | nn.Flatten(1) 612 | ) 613 | return pool, conv 614 | 615 | def pattern2(self): 616 | sal_length = (self.saliency_size // 2) ** 2 617 | pool = VerticalPattern() 618 | conv = nn.Sequential( 619 | nn.Conv2d(self.in_dim + sal_length, self.out_dim, (1,2), bias=False), 620 | nn.ReLU(True), 621 | nn.Flatten(1) 622 | ) 623 | return pool, conv 624 | 625 | def pattern3(self): 626 | s_size = int(self.saliency_size * 3 / 4) 627 | sal_length = s_size * (s_size - 1) // 2 628 | pool = TriangularPattern(flip=False, sal_size=self.saliency_size) 629 | conv = nn.Sequential( 630 | nn.Conv2d(self.in_dim + sal_length, self.out_dim, (2,1), bias=False), 631 | nn.ReLU(True), 632 | nn.Flatten(1) 633 | ) 634 | return pool, conv 635 | 636 | def pattern4(self): 637 | s_size = int(self.saliency_size * 3 / 4) 638 | sal_length = s_size * (s_size - 1) // 2 639 | pool = TriangularPattern(flip=True, sal_size=self.saliency_size) 640 | conv = nn.Sequential( 641 | nn.Conv2d(self.in_dim + sal_length, self.out_dim, (2, 1), bias=False), 642 | nn.ReLU(True), 643 | nn.Flatten(1) 644 | ) 645 | return pool, conv 646 | 647 | def pattern5(self): 648 | crop_size = 1./2 649 | sal_length = int(self.saliency_size * crop_size) ** 2 650 | pool = SurroundPattern(crop_size) 651 | conv = nn.Sequential( 652 | nn.Conv2d(self.in_dim + sal_length, self.out_dim, (2, 1), bias=False), 653 | nn.ReLU(True), 654 | nn.Flatten(1) 655 | ) 656 | return pool, conv 657 | 658 | def pattern6(self): 659 | sal_length = (self.saliency_size // 2) ** 2 660 | pool = QuarterPattern() 661 | conv = nn.Sequential( 662 | nn.Conv2d(self.in_dim + sal_length, self.out_dim, (2, 2), bias=False), 663 | nn.ReLU(True), 664 | nn.Flatten(1) 665 | ) 666 | return pool, conv 667 | 668 | def pattern7(self): 669 | sal_length = 0 670 | row_len = self.saliency_size 671 | while row_len > 0: 672 | sal_length += row_len 673 | row_len -= 2 674 | pool = CrossPattern() 675 | conv = nn.Sequential( 676 | nn.Conv2d(self.in_dim + sal_length, self.out_dim, (2,2), bias=False), 677 | nn.ReLU(True), 678 | nn.Flatten(1) 679 | ) 680 | return pool, conv 681 | 682 | def pattern8(self): 683 | sal_length = (self.saliency_size // 3)**2 684 | pool = ThirdOfRulePattern() 685 | conv = nn.Sequential( 686 | nn.Conv2d(self.in_dim + sal_length, self.out_dim, (3,3), bias=False), 687 | nn.ReLU(True), 688 | nn.Flatten(1) 689 | ) 690 | return pool, conv 691 | 692 | def pattern9(self): 693 | sal_length = (self.saliency_size // 3) * self.saliency_size 694 | pool = HorThirdPattern() 695 | conv = nn.Sequential( 696 | nn.Conv2d(self.in_dim + sal_length, self.out_dim, (1,3), bias=False), 697 | nn.ReLU(True), 698 | nn.Flatten(1) 699 | ) 700 | return pool, conv 701 | 702 | def pattern10(self): 703 | sal_length = (self.saliency_size // 3) * self.saliency_size 704 | pool = VerticalThirdPattern() 705 | conv = nn.Sequential( 706 | nn.Conv2d(self.in_dim + sal_length, self.out_dim, (3,1), bias=False), 707 | nn.ReLU(True), 708 | nn.Flatten(1) 709 | ) 710 | return pool, conv 711 | 712 | def pattern11(self): 713 | sal_length = int((self.saliency_size // 2) * (self.saliency_size // 2 + 1) / 2) 714 | pool = MultiDirectionPattern() 715 | conv = nn.Sequential( 716 | nn.Conv2d(self.in_dim + sal_length, self.out_dim, (2,4), bias=False), 717 | nn.ReLU(True), 718 | nn.Flatten(1) 719 | ) 720 | return pool, conv 721 | 722 | def pattern12(self): 723 | sal_length = (self.saliency_size // 4) ** 2 724 | pool = MultiRectanglePattern() 725 | conv = nn.Sequential( 726 | nn.Conv2d(self.in_dim + sal_length, self.out_dim, (2,4), bias=False), 727 | nn.ReLU(True), 728 | nn.Flatten(1) 729 | ) 730 | return pool, conv 731 | 732 | 733 | 734 | if __name__ == '__main__': 735 | pattern_list = list(range(1,13)) 736 | x = torch.randn(2, 512, 7, 7) 737 | s = torch.randn(2, 1, 56, 56) 738 | w = torch.randn(2, len(pattern_list)) 739 | mpp = MPPModule(512, 512, 0.5, pattern_list) 740 | sampp = SAMPPModule(512, 512, 56, 0.5, pattern_list) 741 | mpp_out = mpp(x, weights=w) 742 | sa_out = sampp(x,s, weights=w) 743 | print('mpp_out', mpp_out.shape, 'sampp_out', sa_out.shape) 744 | 745 | # third_pattern = ThirdOfRulePattern() 746 | # print(third_pattern(x,s)[0].shape) 747 | 748 | # h = w = 56 749 | # square_part = torch.zeros(4, 4, h, w) 750 | # index_y = torch.split(torch.arange(h), (h+1)//4) 751 | # index_x = torch.split(torch.arange(w), (w+1)//4) 752 | # for i in range(4): 753 | # for j in range(4): 754 | # for x in index_x[i]: 755 | # for y in index_y[j]: 756 | # square_part[i,j,y,x] = 1 757 | # mask = torch.zeros(8, h, w) 758 | # group_x = [[0, 0, 1], [1], [2, 3, 3], [2], [0, 0, 1], [1], [2, 3, 3], [2]] 759 | # group_y = [[1, 0, 0], [1], [0, 0, 1], [1], [2, 3, 3], [2], [3, 3, 2], [2]] 760 | # for i in range(len(group_x)): 761 | # mask[i] = torch.sum(square_part[group_x[i], group_y[i]], dim=0) 762 | # print(mask[i], torch.count_nonzero(mask[i])) 763 | 764 | 765 | 766 | 767 | 768 | 769 | -------------------------------------------------------------------------------- /SAMPNet/samp_net.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | import torchvision.models as models 5 | import warnings 6 | warnings.filterwarnings('ignore') 7 | from samp_module import MPPModule, SAMPPModule 8 | from config import Config 9 | 10 | class SAPModule(nn.Module): 11 | def __init__(self, input_channel, output_channel, saliency_size, dropout): 12 | super(SAPModule, self).__init__() 13 | self.pooling = nn.Sequential( 14 | nn.AdaptiveAvgPool2d(1), 15 | nn.Flatten(1), 16 | nn.Dropout(dropout)) 17 | sal_feat_len = saliency_size * saliency_size 18 | self.feature_layer = nn.Sequential( 19 | nn.Linear(input_channel + sal_feat_len, 20 | output_channel, bias=False), 21 | nn.Dropout(dropout)) 22 | 23 | def forward(self, x, s): 24 | x = self.pooling(x) 25 | s = s.flatten(1) 26 | f = self.feature_layer(torch.cat([x,s], dim=1)) 27 | return f 28 | 29 | def build_resnet(layers, pretrained=False): 30 | assert layers in [18, 34, 50, 101], f'layers must be one of [18, 34, 50, 101], while layers = {layers}' 31 | if layers == 18: 32 | resnet = models.resnet18(pretrained) 33 | elif layers == 34: 34 | resnet = models.resnet34(pretrained) 35 | elif layers == 50: 36 | resnet = models.resnet50(pretrained) 37 | else: 38 | resnet = models.resnet101(pretrained) 39 | modules = list(resnet.children())[:-2] 40 | resnet = nn.Sequential(*modules) 41 | return resnet 42 | 43 | class SAMPNet(nn.Module): 44 | def __init__(self, cfg, pretrained=True): 45 | super(SAMPNet, self).__init__() 46 | score_level = cfg.score_level 47 | layers = cfg.resnet_layers 48 | dropout = cfg.dropout 49 | num_attributes = cfg.num_attributes 50 | input_channel = 512 if layers in [18,34] else 2048 51 | sal_dim = 512 52 | pool_dropout = cfg.pool_dropout 53 | pattern_list = cfg.pattern_list 54 | pattern_fuse = cfg.pattern_fuse 55 | 56 | self.use_weighted_loss = cfg.use_weighted_loss 57 | self.use_attribute = cfg.use_attribute 58 | self.use_channel_attention = cfg.use_channel_attention 59 | self.use_saliency = cfg.use_saliency 60 | self.use_multipattern = cfg.use_multipattern 61 | self.use_pattern_weight = cfg.use_pattern_weight 62 | 63 | self.backbone = build_resnet(layers, pretrained=pretrained) 64 | self.global_pool = nn.Sequential( 65 | nn.AdaptiveAvgPool2d(1), 66 | nn.Dropout(dropout), 67 | nn.Flatten(1), 68 | ) 69 | 70 | output_channel = input_channel 71 | if self.use_multipattern: 72 | if self.use_saliency: 73 | self.saliency_max = nn.Sequential( 74 | nn.MaxPool2d(kernel_size=3, stride=2, padding=1), 75 | nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) 76 | # multi-pattern pooling module 77 | self.pattern_module = SAMPPModule(input_channel, 78 | input_channel + sal_dim, 79 | saliency_size=56, 80 | dropout=pool_dropout, 81 | pattern_list=pattern_list, 82 | fusion=pattern_fuse) 83 | output_channel = input_channel + sal_dim 84 | else: 85 | self.pattern_module = MPPModule(input_channel, 86 | input_channel, 87 | dropout=pool_dropout, 88 | pattern_list=pattern_list, 89 | fusion=pattern_fuse) 90 | output_channel = input_channel 91 | if self.use_pattern_weight: 92 | self.pattern_weight_layer = nn.Sequential( 93 | nn.AdaptiveAvgPool2d(1), 94 | nn.Dropout(dropout), 95 | nn.Flatten(1), 96 | nn.Linear(input_channel, len(pattern_list), bias=False) 97 | ) 98 | else: 99 | if self.use_saliency: 100 | if self.use_saliency: 101 | self.saliency_max = nn.Sequential( 102 | nn.MaxPool2d(kernel_size=3, stride=2, padding=1), 103 | nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) 104 | self.pattern_module = SAPModule(input_channel, 105 | input_channel + sal_dim, 106 | saliency_size=56, 107 | dropout=pool_dropout) 108 | output_channel = input_channel + sal_dim 109 | 110 | if self.use_attribute: 111 | # multi-task structure 112 | concat_dim = output_channel 113 | att_dim = 512 if concat_dim >= 1024 else concat_dim // 2 114 | com_dim = concat_dim - att_dim 115 | self.att_feature_layer = nn.Sequential( 116 | nn.Linear(concat_dim, att_dim, bias=False), 117 | nn.ReLU(True), 118 | nn.Dropout(dropout) 119 | ) 120 | self.att_pred_layer = nn.Sequential( 121 | nn.Linear(att_dim, num_attributes, bias=False) 122 | ) 123 | self.com_feature_layer = nn.Sequential( 124 | nn.Linear(concat_dim, com_dim, bias=False), 125 | nn.ReLU(True), 126 | nn.Dropout(dropout) 127 | ) 128 | if self.use_channel_attention: 129 | self.alpha_predict_layer = nn.Sequential( 130 | nn.Linear(concat_dim, 2, bias=False), 131 | nn.Sigmoid() 132 | ) 133 | self.com_pred_layer = nn.Sequential( 134 | nn.Linear(output_channel, output_channel, bias=False), 135 | nn.ReLU(True), 136 | nn.Dropout(dropout), 137 | nn.Linear(output_channel, input_channel, bias=False), 138 | nn.ReLU(True), 139 | nn.Linear(input_channel, score_level, bias=False), 140 | nn.Softmax(dim=1)) 141 | 142 | def forward(self, x, s): 143 | feature_map = self.backbone(x) 144 | weight = None 145 | attribute = None 146 | 147 | if self.use_multipattern: 148 | if self.use_pattern_weight: 149 | weight = self.pattern_weight_layer(feature_map) 150 | if self.use_saliency: 151 | sal_map = self.saliency_max(s) 152 | pattern_feat = self.pattern_module(feature_map, sal_map, weight) 153 | else: 154 | pattern_feat = self.pattern_module(feature_map, weight) 155 | else: 156 | if self.use_saliency: 157 | sal_map = self.saliency_max(s) 158 | pattern_feat = self.pattern_module(feature_map, sal_map) 159 | else: 160 | pattern_feat = self.global_pool(feature_map) 161 | if self.use_attribute: 162 | att_feat = self.att_feature_layer(pattern_feat) 163 | com_feat = self.com_feature_layer(pattern_feat) 164 | attribute = self.att_pred_layer(att_feat) 165 | fused_feat = torch.cat([att_feat, com_feat], dim=1) 166 | if self.use_channel_attention: 167 | alpha = self.alpha_predict_layer(fused_feat) 168 | fused_feat = torch.cat([alpha[:,0:1] * att_feat, alpha[:,1:] * com_feat], dim=1) 169 | scores = self.com_pred_layer(fused_feat) 170 | else: 171 | scores = self.com_pred_layer(pattern_feat) 172 | return weight, attribute, scores 173 | 174 | class AttributeLoss(nn.Module): 175 | def __init__(self, scalar=0.1): 176 | super(AttributeLoss, self).__init__() 177 | self.scalar = scalar 178 | 179 | def forward(self, att_target, att_estimate): 180 | assert att_target.shape == att_estimate.shape, \ 181 | 'target {} vs. predict {}'.format(att_target.shape, att_estimate.shape) 182 | diff = att_target - att_estimate 183 | diff[att_target == 0] = 0. 184 | diff = torch.sum(torch.pow(diff, 2), dim=0) 185 | num_samples = torch.count_nonzero(att_target, dim=0) 186 | diff_mean = torch.where(num_samples > 0, diff / num_samples, diff) 187 | loss_att = torch.sum(diff_mean) * self.scalar 188 | return loss_att 189 | 190 | class EMDLoss(nn.Module): 191 | def __init__(self, reduction='mean', r=2): 192 | super(EMDLoss, self).__init__() 193 | self.reduction = reduction 194 | self.r = r 195 | 196 | def forward(self, p_target, p_estimate, weight=None): 197 | assert p_target.shape == p_estimate.shape, \ 198 | 'target {} vs. predict {}'.format(p_target.shape, p_estimate.shape) 199 | # cdf for values [1, 2, ..., 5] 200 | cdf_target = torch.cumsum(p_target, dim=-1) 201 | # cdf for values [1, 2, ..., 5] 202 | cdf_estimate = torch.cumsum(p_estimate, dim=-1) 203 | cdf_diff = cdf_estimate - cdf_target 204 | if self.r == 1: 205 | samplewise_emd = torch.mean(cdf_diff.abs(), dim=-1) 206 | else: 207 | abs_diff = torch.where(cdf_diff.abs() > 1e-15, cdf_diff.abs(), torch.ones_like(cdf_diff) * 1e-15) 208 | sqaures = torch.pow(abs_diff, self.r) 209 | average = torch.mean(sqaures, dim=-1) 210 | samplewise_emd = torch.pow(average, 1. / self.r) 211 | if weight is not None: 212 | samplewise_emd = samplewise_emd * weight 213 | if self.reduction == 'sum': 214 | return samplewise_emd.sum() 215 | elif self.reduction is None: 216 | return samplewise_emd 217 | else: 218 | return samplewise_emd.mean() 219 | 220 | if __name__ == '__main__': 221 | x = torch.randn(2,3,224,224) 222 | s = torch.randn(2,1,224,224) 223 | cfg = Config() 224 | model = SAMPNet(cfg) 225 | weight, attribute, score = model(x,s) 226 | if weight is not None: 227 | print('weight', weight.shape, F.softmax(weight,dim=1)) 228 | if attribute is not None: 229 | print('attribute', attribute.shape, attribute) 230 | print('score', score.shape, score) -------------------------------------------------------------------------------- /SAMPNet/test.py: -------------------------------------------------------------------------------- 1 | from samp_net import EMDLoss, SAMPNet 2 | from cadb_dataset import CADBDataset 3 | import torch 4 | from torch.utils.data import DataLoader 5 | import scipy.stats as stats 6 | import numpy as np 7 | from tqdm import tqdm 8 | from config import Config 9 | 10 | def calculate_accuracy(predict, target, threhold=2.6): 11 | assert target.shape == predict.shape, '{} vs. {}'.format(target.shape, predict.shape) 12 | bin_tar = target > threhold 13 | bin_pre = predict > threhold 14 | correct = (bin_tar == bin_pre).sum() 15 | acc = correct.float() / target.size(0) 16 | return correct,acc 17 | 18 | def calculate_lcc(target, predict): 19 | if len(target.shape) > 1: 20 | target = target.view(-1) 21 | if len(predict.shape) > 1: 22 | predict = predict.view(-1) 23 | predict = predict.cpu().numpy() 24 | target = target.cpu().numpy() 25 | lcc = np.corrcoef(predict, target)[0,1] 26 | return lcc 27 | 28 | def calculate_spearmanr(target, predict): 29 | if len(target.shape) > 1: 30 | target = target.view(-1) 31 | if len(predict.shape) > 1: 32 | predict = predict.view(-1) 33 | target_list = target.cpu().numpy().tolist() 34 | predict_list = predict.cpu().numpy().tolist() 35 | # sort_target = np.sort(target_list).tolist() 36 | # sort_predict = np.sort(predict_list).tolist() 37 | # pre_rank = [] 38 | # for i in predict_list: 39 | # pre_rank.append(sort_predict.index(i)) 40 | # tar_rank = [] 41 | # for i in target_list: 42 | # tar_rank.append(sort_target.index(i)) 43 | # rho,pval = stats.spearmanr(pre_rank, tar_rank) 44 | rho,_ = stats.spearmanr(predict_list, target_list) 45 | return rho 46 | 47 | def dist2ave(pred_dist): 48 | pred_score = torch.sum(pred_dist* torch.Tensor(range(1,6)).to(pred_dist.device), dim=-1, keepdim=True) 49 | return pred_score 50 | 51 | def evaluation_on_cadb(model, cfg): 52 | model.eval() 53 | device = next(model.parameters()).device 54 | testdataset = CADBDataset('test', cfg) 55 | testloader = DataLoader(testdataset, 56 | batch_size=cfg.batch_size, 57 | shuffle=False, 58 | num_workers=cfg.num_workers, 59 | drop_last=False) 60 | emd_r2_fn = EMDLoss(reduction='sum', r=2) 61 | emd_r1_fn = EMDLoss(reduction='sum', r=1) 62 | emd_r2_error = 0.0 63 | emd_r1_error = 0.0 64 | correct = 0. 65 | tar_scores = None 66 | pre_scores = None 67 | print() 68 | print('Evaluation begining...') 69 | with torch.no_grad(): 70 | for (im,score,dist,saliency,attributes) in tqdm(testloader): 71 | image = im.to(device) 72 | score = score.to(device) 73 | dist = dist.to(device) 74 | saliency = saliency.to(device) 75 | weight, atts, output = model(image, saliency) 76 | 77 | pred_score = dist2ave(output) 78 | emd_r1_error += emd_r1_fn(dist, output).item() 79 | emd_r2_error += emd_r2_fn(dist, output).item() 80 | correct += calculate_accuracy(pred_score, score)[0].item() 81 | if tar_scores is None: 82 | tar_scores = score 83 | pre_scores = pred_score 84 | else: 85 | tar_scores = torch.cat([tar_scores, score], dim=0) 86 | pre_scores = torch.cat([pre_scores, pred_score], dim=0) 87 | print('Evaluation result...') 88 | # print('Scores shape', pre_scores.shape, tar_scores.shape) 89 | avg_mse = torch.nn.MSELoss()(pre_scores.view(-1), tar_scores.view(-1)).item() 90 | SRCC = calculate_spearmanr(tar_scores, pre_scores) 91 | LCC = calculate_lcc(tar_scores, pre_scores) 92 | avg_r1_emd = emd_r1_error / len(testdataset) 93 | avg_r2_emd = emd_r2_error / len(testdataset) 94 | avg_acc = correct / len(testdataset) 95 | ss = "Test on {} images, Accuracy={:.2%}, EMD(r=1)={:.4f}, EMD(r=2)={:.4f},". \ 96 | format(len(testdataset), avg_acc, avg_r1_emd, avg_r2_emd) 97 | ss += " MSE_loss={:.4f}, SRCC={:.4f}, LCC={:.4f}". \ 98 | format(avg_mse, SRCC, LCC) 99 | print(ss) 100 | return avg_acc, avg_r1_emd, avg_r2_emd, avg_mse, SRCC, LCC 101 | 102 | if __name__ == '__main__': 103 | cfg = Config() 104 | device = torch.device('cuda:{}'.format(cfg.gpu_id)) 105 | model = SAMPNet(cfg,pretrained=False).to(device) 106 | weight_file = './pretrained_model/samp_net.pth' 107 | model.load_state_dict(torch.load(weight_file)) 108 | evaluation_on_cadb(model, cfg) -------------------------------------------------------------------------------- /SAMPNet/train.py: -------------------------------------------------------------------------------- 1 | import sys,os 2 | from torch.autograd import Variable 3 | import torch.optim as optim 4 | from tensorboardX import SummaryWriter 5 | import torch 6 | import time 7 | import shutil 8 | from torch.utils.data import DataLoader 9 | import csv 10 | 11 | from samp_net import EMDLoss, AttributeLoss, SAMPNet 12 | from config import Config 13 | from cadb_dataset import CADBDataset 14 | from test import evaluation_on_cadb 15 | 16 | def calculate_accuracy(predict, target, threhold=2.6): 17 | assert target.shape == predict.shape, '{} vs. {}'.format(target.shape, predict.shape) 18 | bin_tar = target > threhold 19 | bin_pre = predict > threhold 20 | correct = (bin_tar == bin_pre).sum() 21 | acc = correct.float() / target.size(0) 22 | return correct,acc 23 | 24 | def build_dataloader(cfg): 25 | trainset = CADBDataset('train', cfg) 26 | trainloader = DataLoader(trainset, 27 | batch_size=cfg.batch_size, 28 | shuffle=True, 29 | num_workers=cfg.num_workers, 30 | drop_last=False) 31 | return trainloader 32 | 33 | class Trainer(object): 34 | def __init__(self, model, cfg): 35 | self.cfg = cfg 36 | self.model = model 37 | self.device = torch.device('cuda:{}'.format(self.cfg.gpu_id)) 38 | self.trainloader = build_dataloader(cfg) 39 | self.optimizer = self.create_optimizer() 40 | self.scheduler = optim.lr_scheduler.ReduceLROnPlateau( 41 | self.optimizer, mode='min', patience=5) 42 | self.epoch = 0 43 | self.iters = 0 44 | 45 | self.avg_mse = 0. 46 | self.avg_emd = 0. 47 | self.avg_acc = 0. 48 | self.avg_att = 0. 49 | 50 | self.smooth_coe = 0.4 51 | self.smooth_mse = None 52 | self.smooth_emd = None 53 | self.smooth_acc = None 54 | self.smooth_att = None 55 | 56 | self.mse_loss = torch.nn.MSELoss() 57 | self.emd_loss = EMDLoss() 58 | 59 | self.test_acc = [] 60 | self.test_emd1 = [] 61 | self.test_emd2 = [] 62 | self.test_mse = [] 63 | self.test_srcc = [] 64 | self.test_lcc = [] 65 | 66 | if cfg.use_attribute: 67 | self.att_loss = AttributeLoss(cfg.attribute_weight) 68 | 69 | self.least_metric = 1. 70 | self.writer = self.create_writer() 71 | 72 | def create_optimizer(self): 73 | # for param in self.model.backbone.parameters(): 74 | # param.requires_grad = False 75 | bb_params = list(map(id, self.model.backbone.parameters())) 76 | lr_params = filter(lambda p:id(p) not in bb_params, self.model.parameters()) 77 | params = [ 78 | {'params': lr_params, 'lr': self.cfg.lr}, 79 | {'params': self.model.backbone.parameters(), 'lr': self.cfg.lr * 0.01} 80 | ] 81 | if self.cfg.optimizer == 'adam': 82 | optimizer = optim.Adam(params, 83 | weight_decay=self.cfg.weight_decay) 84 | elif self.cfg.optimizer == 'sgd': 85 | optimizer = optim.SGD(params, 86 | momentum=self.cfg.momentum, 87 | weight_decay=self.cfg.weight_decay) 88 | else: 89 | raise ValueError(f"not such optimizer {self.cfg.optimizer}") 90 | return optimizer 91 | 92 | def create_writer(self): 93 | print('Create tensorboardX writer...', self.cfg.log_dir) 94 | writer = SummaryWriter(log_dir=self.cfg.log_dir) 95 | return writer 96 | 97 | def run(self): 98 | for epoch in range(self.cfg.max_epoch): 99 | self.run_epoch() 100 | self.epoch += 1 101 | self.scheduler.step(metrics=self.least_metric) 102 | self.writer.add_scalar('Train/lr', self.optimizer.param_groups[0]['lr'], self.epoch) 103 | if self.epoch % self.cfg.save_epoch == 0: 104 | checkpoint_path = os.path.join(self.cfg.checkpoint_dir, 'model-{epoch}.pth') 105 | torch.save(self.model.state_dict(), checkpoint_path.format(epoch=self.epoch)) 106 | print('Save checkpoint...') 107 | if self.epoch % self.cfg.test_epoch == 0: 108 | test_emd = self.eval_training() 109 | if test_emd < self.least_metric: 110 | self.least_metric = test_emd 111 | checkpoint_path = os.path.join(self.cfg.checkpoint_dir, 'model-best.pth') 112 | torch.save(self.model.state_dict(), checkpoint_path) 113 | print('Update best checkpoint...') 114 | self.writer.add_scalar('Test/Least EMD', self.least_metric, self.epoch) 115 | 116 | 117 | def eval_training(self): 118 | avg_acc, avg_r1_emd, avg_r2_emd, avg_mse, SRCC, LCC = \ 119 | evaluation_on_cadb(self.model, self.cfg) 120 | self.writer.add_scalar('Test/Average EMD(r=2)', avg_r2_emd, self.epoch) 121 | self.writer.add_scalar('Test/Average EMD(r=1)', avg_r1_emd, self.epoch) 122 | self.writer.add_scalar('Test/Average MSE', avg_mse, self.epoch) 123 | self.writer.add_scalar('Test/Accuracy', avg_acc, self.epoch) 124 | self.writer.add_scalar('Test/SRCC', SRCC, self.epoch) 125 | self.writer.add_scalar('Test/LCC', LCC, self.epoch) 126 | error = avg_r1_emd 127 | 128 | self.test_acc.append(avg_acc) 129 | self.test_emd1.append(avg_r1_emd) 130 | self.test_emd2.append(avg_r2_emd) 131 | self.test_mse.append(avg_mse) 132 | self.test_srcc.append(SRCC) 133 | self.test_lcc.append(LCC) 134 | self.write2csv() 135 | return error 136 | 137 | def write2csv(self): 138 | csv_path = os.path.join(self.cfg.exp_path, '..', '{}.csv'.format(self.cfg.exp_name)) 139 | header = ['epoch', 'Accuracy', 'EMD r=1', 'EMD r=2', 'MSE', 'SRCC', 'LCC'] 140 | epoches = list(range(len(self.test_acc))) 141 | metrics = [epoches, self.test_acc, self.test_emd1, self.test_emd2, 142 | self.test_mse, self.test_srcc, self.test_lcc] 143 | rows = [header] 144 | for i in range(len(epoches)): 145 | row = [m[i] for m in metrics] 146 | rows.append(row) 147 | for name, m in zip(header, metrics): 148 | if name == 'epoch': 149 | continue 150 | index = m.index(min(m)) 151 | if name in ['Accuracy', 'SRCC', 'LCC']: 152 | index = m.index(max(m)) 153 | title = 'best {} (epoch-{})'.format(name, index) 154 | row = [l[index] for l in metrics] 155 | row[0] = title 156 | rows.append(row) 157 | with open(csv_path, 'w') as f: 158 | cw = csv.writer(f) 159 | cw.writerows(rows) 160 | print('Save result to ', csv_path) 161 | 162 | def dist2ave(self, pred_dist): 163 | pred_score = torch.sum(pred_dist* torch.Tensor(range(1,6)).to(pred_dist.device), dim=-1, keepdim=True) 164 | return pred_score 165 | 166 | def run_epoch(self): 167 | self.model.train() 168 | for batch, data in enumerate(self.trainloader): 169 | self.iters += 1 170 | image = data[0].to(self.device) 171 | score = data[1].to(self.device) 172 | score_dist = data[2].to(self.device) 173 | saliency = data[3].to(self.device) 174 | attributes = data[4].to(self.device) 175 | weight = data[5].to(self.device) 176 | 177 | pred_weight, pred_atts, pred_dist = self.model(image, saliency) 178 | 179 | if self.cfg.use_weighted_loss: 180 | dist_loss = self.emd_loss(score_dist, pred_dist, weight) 181 | else: 182 | dist_loss = self.emd_loss(score_dist, pred_dist) 183 | 184 | if self.cfg.use_attribute: 185 | att_loss = self.att_loss(attributes, pred_atts) 186 | loss = dist_loss + att_loss 187 | else: 188 | loss = dist_loss 189 | self.optimizer.zero_grad() 190 | loss.backward() 191 | self.optimizer.step() 192 | 193 | self.avg_emd += dist_loss.item() 194 | self.avg_att += att_loss.item() 195 | pred_score = self.dist2ave(pred_dist) 196 | correct, accuracy = calculate_accuracy(pred_score, score) 197 | self.avg_acc += accuracy.item() 198 | if (self.iters+1) % self.cfg.display_steps == 0: 199 | print('ground truth: average={}'.format(score.view(-1))) 200 | print('prediction: average={}'.format(pred_score.view(-1))) 201 | 202 | self.avg_emd = self.avg_emd / self.cfg.display_steps 203 | self.avg_acc = self.avg_acc / self.cfg.display_steps 204 | if self.cfg.use_attribute: 205 | self.avg_att = self.avg_att / self.cfg.display_steps 206 | 207 | if self.smooth_emd != None: 208 | self.avg_emd = (1-self.smooth_coe) * self.avg_emd + self.smooth_coe * self.smooth_emd 209 | self.avg_acc = (1-self.smooth_coe) * self.avg_acc + self.smooth_coe * self.smooth_acc 210 | if self.cfg.use_attribute: 211 | self.avg_att = (1-self.smooth_coe) * self.avg_att + self.smooth_coe * self.smooth_att 212 | self.writer.add_scalar('Train/AttributeLoss', self.avg_att, self.iters) 213 | 214 | self.writer.add_scalar('Train/EMD_Loss', self.avg_emd, self.iters) 215 | self.writer.add_scalar('Train/Accuracy', self.avg_acc, self.iters) 216 | 217 | if self.cfg.use_attribute: 218 | print('Traning Epoch:{}/{} Current Batch: {}/{} EMD_Loss:{:.4f} Attribute_Loss:{:.4f} ACC:{:.2%} lr:{:.6f} '. 219 | format( 220 | self.epoch, self.cfg.max_epoch, 221 | batch, len(self.trainloader), 222 | self.avg_emd, self.avg_att, 223 | self.avg_acc, 224 | self.optimizer.param_groups[0]['lr'])) 225 | else: 226 | print( 227 | 'Traning Epoch:{}/{} Current Batch: {}/{} EMD_Loss:{:.4f} ACC:{:.2%} lr:{:.6f} '. 228 | format( 229 | self.epoch, self.cfg.max_epoch, 230 | batch, len(self.trainloader), 231 | self.avg_emd, self.avg_acc, 232 | self.optimizer.param_groups[0]['lr'])) 233 | 234 | self.smooth_emd = self.avg_emd 235 | self.smooth_acc = self.avg_acc 236 | 237 | self.avg_mse = 0. 238 | self.avg_emd = 0. 239 | self.avg_acc = 0. 240 | if self.cfg.use_attribute: 241 | self.smooth_att = self.avg_att 242 | self.avg_att = 0. 243 | print() 244 | 245 | if __name__ == '__main__': 246 | cfg = Config() 247 | cfg.create_path() 248 | device = torch.device('cuda:{}'.format(cfg.gpu_id)) 249 | # evaluate(cfg) 250 | for file in os.listdir('./'): 251 | if file.endswith('.py'): 252 | shutil.copy(file, cfg.exp_path) 253 | print('Backup ', file) 254 | 255 | model = SAMPNet(cfg) 256 | model = model.train().to(device) 257 | trainer = Trainer(model, cfg) 258 | trainer.run() -------------------------------------------------------------------------------- /annotations/visualize_cadb_annotation.py: -------------------------------------------------------------------------------- 1 | import json 2 | import cv2 3 | import os 4 | import numpy as np 5 | import shutil 6 | import argparse 7 | from tqdm import tqdm 8 | import math 9 | 10 | score_levels = ['[1,2)', '[2,3)', '[3,4)', '[4,5]'] 11 | element_categories = ["center", "rule_of_thirds", "golden_ratio", "horizontal", "vertical", "diagonal", 12 | "curved", "fill_the_frame", "radial", "vanishing_point", "symmetric", "triangle", 13 | "pattern", 'none'] 14 | scene_categories = ['animal', 'plant', 'human', 'static', 'architecture', 15 | 'landscape', 'cityscape', 'indoor', 'night', 'other'] 16 | 17 | def draw_auxiliary_line(image, com_type): 18 | im_h, im_w, _ = image.shape 19 | if com_type in ['rule_of_thirds','center']: 20 | x1, x2 = int(im_w / 3), int(im_w / 3 * 2) 21 | y1, y2 = int(im_h / 3), int(im_h / 3 * 2) 22 | else: 23 | x1, x2 = int(im_w * 1. / 2.618), int(im_w * 1.618 / 2.618) 24 | y1, y2 = int(im_h * 1. / 2.618), int(im_h * 1.618 / 2.618) 25 | color = (255, 255, 255) 26 | line_width = 3 27 | cv2.line(image, (0, y1), (im_w, y1), color, line_width) 28 | cv2.line(image, (0, y2), (im_w, y2), color, line_width) 29 | cv2.line(image, (x1, 0), (x1, im_h), color, line_width) 30 | cv2.line(image, (x2, 0), (x2, im_h), color, line_width) 31 | return image 32 | 33 | def draw_element_on_image(image, com_type, element): 34 | pd_color = (0, 255, 255) 35 | if len(element) > 0: 36 | if com_type in ['center', 'rule_of_thirds', 'golden_ratio']: 37 | image = draw_auxiliary_line(image.copy(), com_type) 38 | for rect in element: 39 | x1,y1,x2,y2 = map(int, rect) 40 | cv2.rectangle(image, (x1,y1), (x2,y2), pd_color, 5) 41 | elif com_type in ['horizontal', 'diagonal', 'vertical']: 42 | image = draw_line_elements(image.copy(), com_type, element) 43 | else: 44 | element = np.array(element).astype(np.int32).reshape((len(element), -1, 2)) 45 | for i in range(element.shape[0]): 46 | for j in range(element[i].shape[0] - 1): 47 | cv2.line(image, (element[i, j, 0], element[i, j, 1]), (element[i, j + 1, 0], element[i, j + 1, 1]), 48 | pd_color, 5) 49 | 50 | text = '{}'.format(com_type) 51 | cv2.putText(image, text, (20, 70), cv2.FONT_HERSHEY_COMPLEX, 1.5, pd_color, 3) 52 | return image 53 | 54 | def compute_angle(lines): 55 | lines = np.array(lines).reshape((-1,4)) 56 | reverse_lines = np.concatenate([lines[:,2:], lines[:,0:2]], axis=1) 57 | l2r_points = np.where(lines[:,0:1] <= lines[:,2:3], lines, reverse_lines) 58 | angle = np.rad2deg(np.arctan2(l2r_points[:,3] - l2r_points[:,1], l2r_points[:,2] - l2r_points[:,0])) 59 | return np.abs(angle) 60 | 61 | def draw_line_elements(src, comp, element, vis_angle=False): 62 | im_h, im_w, _ = src.shape 63 | angle = compute_angle(element) 64 | element = np.array(element).astype(np.int32).reshape((-1, 4)) 65 | color = (0,255,255) 66 | for i in range(element.shape[0]): 67 | cv2.line(src, (element[i, 0], element[i, 1]), 68 | (element[i, 2], element[i, 3]), color, 5) 69 | if vis_angle: 70 | text = '{:.1f}'.format(angle[i]) 71 | tl_point = (element[i,0], element[i,1]) 72 | br_point = (element[i,2], element[i,3]) 73 | if element[i,0] > element[i,2] or \ 74 | (element[i,0] == element[i,2] and element[i,1] > element[i,3]): 75 | tl_point = (element[i,2], element[i,3]) 76 | br_point = (element[i,0], element[i,1]) 77 | if tl_point[1] <= br_point[1]: 78 | pos_x = max(min(tl_point[0] + 20, im_w - 50), 0) 79 | pos_y = max(tl_point[1] - 10, 30) 80 | else: 81 | pos_x = max(min(tl_point[0] - 100, im_w - 50), 0) 82 | pos_y = max(min(tl_point[1] + 50, im_h - 50), 0) 83 | cv2.putText(src, text, (pos_x, pos_y), cv2.FONT_HERSHEY_COMPLEX, 1.2, color, 2) 84 | return src 85 | 86 | def get_parser(): 87 | parser = argparse.ArgumentParser() 88 | parser.add_argument('--data_root',default='./CADB_Dataset/', 89 | help='path to images (should have subfolder images/)') 90 | parser.add_argument('--save_folder', default='./CADB_Dataset/Visualization/',type=str) 91 | return parser 92 | 93 | if __name__ == '__main__': 94 | parser = get_parser() 95 | opt, _ = parser.parse_known_args() 96 | image_dir = os.path.join(opt.data_root, 'images') 97 | assert os.path.exists(image_dir), image_dir 98 | element_file = os.path.join(opt.data_root, 'composition_elements.json') 99 | assert os.path.exists(element_file), element_file 100 | element_anno = json.load(open(element_file, 'r')) 101 | 102 | score_file = os.path.join(opt.data_root, 'composition_scores.json') 103 | assert os.path.exists(score_file), score_file 104 | score_anno = json.load(open(score_file, 'r')) 105 | 106 | scene_file = os.path.join(opt.data_root, 'scene_categories.json') 107 | assert os.path.exists(scene_file), scene_file 108 | scene_anno = json.load(open(scene_file, 'r')) 109 | 110 | score_dir = os.path.join(opt.save_folder, 'composition_scores') 111 | os.makedirs(score_dir, exist_ok=True) 112 | 113 | element_dir = os.path.join(opt.save_folder, 'composition_elements') 114 | os.makedirs(element_dir, exist_ok=True) 115 | 116 | scene_dir = os.path.join(opt.save_folder, 'scene_classification') 117 | os.makedirs(scene_dir, exist_ok=True) 118 | 119 | score_stats = {} 120 | for level in score_levels: 121 | score_stats[level] = [] 122 | 123 | element_stats = {} 124 | for comp in element_categories: 125 | element_stats[comp] = [] 126 | 127 | scene_stats = {} 128 | for scene in scene_categories: 129 | scene_stats[scene] = [] 130 | 131 | total_num = 0 132 | for image_name, anno in tqdm(element_anno.items()): 133 | # read source image 134 | image_file = os.path.join(image_dir, image_name) 135 | assert os.path.exists(image_file), image_file 136 | src = cv2.imread(os.path.join(image_dir, image_name)) 137 | im_h, im_w, _ = src.shape 138 | total_num += 1 139 | # store images to different subfolders according to mean score 140 | im_scores = score_anno[image_name]['scores'] 141 | mean_score = float(sum(im_scores)) / len(im_scores) 142 | im_level = math.floor(mean_score) - 1 if mean_score < 5 else 3 143 | im_level = score_levels[im_level] 144 | score_stats[im_level].append(image_name) 145 | subfolder = os.path.join(score_dir, im_level) 146 | os.makedirs(subfolder, exist_ok=True) 147 | text = 'score:{:.1f}'.format(mean_score) 148 | dst = src.copy() 149 | cv2.putText(dst, text, (20, 70), cv2.FONT_HERSHEY_COMPLEX, 1.5, (0, 255, 255), 3) 150 | cv2.imwrite(os.path.join(subfolder, image_name), dst) 151 | # readout scene annotation 152 | per_scene = scene_anno[image_name] 153 | assert per_scene in scene_categories, per_scene 154 | os.makedirs(os.path.join(scene_dir, per_scene), exist_ok=True) 155 | cv2.imwrite(os.path.join(scene_dir, per_scene, image_name), src) 156 | scene_stats[per_scene].append(image_name) 157 | # visualize composition elements annotation 158 | for comp, element in anno.items(): 159 | element_stats[comp].append(image_name) 160 | dst = draw_element_on_image(src.copy(), comp, element) 161 | subpath = os.path.join(element_dir, comp) 162 | if not os.path.exists(subpath): 163 | os.makedirs(subpath) 164 | cv2.imwrite(os.path.join(subpath, image_name), dst) 165 | 166 | # show dataset statistical information 167 | element_stats = sorted(element_stats.items(), key=lambda d: len(d[1]), reverse=True) 168 | scene_stats = sorted(scene_stats.items(), key=lambda d: len(d[1]), reverse=True) 169 | 170 | with open(os.path.join(opt.save_folder, 'statistics.txt'), 'w') as f: 171 | f.write('Composition Score\n') 172 | print('Composition score') 173 | for level, img_list in score_stats.items(): 174 | ss = '{}: {} images, {:.1%}'.format(level, len(img_list), len(img_list) / total_num) 175 | f.write(ss + '\n') 176 | print(ss) 177 | 178 | print('\nComposition Element') 179 | f.write('\nComposition Element\n') 180 | for comp, img_list in element_stats: 181 | ss = '{}: {} images, {:.1%}'.format(comp, len(img_list), len(img_list) / total_num) 182 | f.write(ss + '\n') 183 | print(ss) 184 | 185 | print('\nScene Category') 186 | f.write('\nScene Category\n') 187 | for scene, img_list in scene_stats: 188 | ss = '{}: {} images, {:.1%}'.format(scene, len(img_list), len(img_list) / total_num) 189 | f.write(ss + '\n') 190 | print(ss) 191 | 192 | ss = 'Total number of images: {}'.format(total_num) 193 | print(ss + '\n') 194 | f.write(ss) 195 | -------------------------------------------------------------------------------- /examples/annotation_example.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/Image-Composition-Assessment-Dataset-CADB/35c093bafdaaa98923d8ba093a73ddf0079ffbc9/examples/annotation_example.jpg -------------------------------------------------------------------------------- /examples/element_examples.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/Image-Composition-Assessment-Dataset-CADB/35c093bafdaaa98923d8ba093a73ddf0079ffbc9/examples/element_examples.jpg -------------------------------------------------------------------------------- /examples/example.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/Image-Composition-Assessment-Dataset-CADB/35c093bafdaaa98923d8ba093a73ddf0079ffbc9/examples/example.jpg -------------------------------------------------------------------------------- /examples/interpretability.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/Image-Composition-Assessment-Dataset-CADB/35c093bafdaaa98923d8ba093a73ddf0079ffbc9/examples/interpretability.jpg -------------------------------------------------------------------------------- /examples/samp_net.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/Image-Composition-Assessment-Dataset-CADB/35c093bafdaaa98923d8ba093a73ddf0079ffbc9/examples/samp_net.jpg --------------------------------------------------------------------------------