├── LICENSE
├── README.md
├── SAMPNet
├── cadb_dataset.py
├── config.py
├── requirements.txt
├── samp_module.py
├── samp_net.py
├── test.py
└── train.py
├── annotations
├── composition_elements.json
├── composition_scores.json
├── scene_categories.json
└── visualize_cadb_annotation.py
└── examples
├── annotation_example.jpg
├── element_examples.jpg
├── example.jpg
├── interpretability.jpg
└── samp_net.jpg
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 BCMI
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | > # Image Composition Assessment Dataset
4 | > Welcomes to the offical homepage of Image Composition Assessment DataBase (**CADB**).
5 | >
6 | > Image composition assessment aims to assess the overall composition quality of a given image, which is crucial in aesthetic assessment.
7 | > To support the research on this task, we contribute the first image composition assessment dataset. Furthermore, we propose a composition assessment network **SAMP-Net**
8 | > with a novel Saliency-Augmented Multi-pattern Pooling (**SAMP**) module, which can perform more favorably than previous aesthetic assessment approaches.
9 | > This work has been accepted by BMVC 2021 ([**paper**](https://arxiv.org/pdf/2104.03133.pdf)).
10 |
11 | **Update 2022-05-05**: We release the annotations of scene categories, composition classes as well as composition elements for more fine-grained analysis of image composition quality.
12 |
13 | **Table of Contents**
14 |
15 |
16 |
17 | - [Dataset](#dataset)
18 | - [Introduction](#introduction)
19 | - [Download](#download)
20 | - [Visualizing Annotations](#visualizing-annotations)
21 | - [Method Overview](#method-overview)
22 | - [Motivation](#motivation)
23 | - [SAMP-Net](#samp-net)
24 | - [Results](#results)
25 | - [Code Usage](#code-usage)
26 | - [Requirements](#requirements)
27 | - [Training](#training)
28 | - [Testing](#testing)
29 | - [Citation](#citation)
30 |
31 | # Dataset
32 |
33 | ## Introduction
34 |
35 | We built the CADB dataset upon the existing Aesthetics and Attributes DataBase ([AADB](https://github.com/aimerykong/deepImageAestheticsAnalysis)). CADB dataset contains **9,497** images with each image rated by **5 individual raters** who specialize in fine art for the overall composition quality, in which we provide a **composition rating scale from 1 to 5**, where a larger score indicates better composition. Some example images with annotations in CADB dataset are illustrated in the figure below, in which we show five composition scores provided by five raters in blue and the calculated composition mean score in red. We also show the aesthetic scores annotated by AADB dataset on a scale from 1 to 5 in green.
36 |
37 |

38 |
39 |
40 | To facilitate the study of image composition assessment, apart from the composition score, we also annotate scene categories, composition classes as well as elements for each image. Specifically, we carefully select 9 frequently appeared scenes (including ***animal, plant, human, static, architecture, landscape, cityscape, indoor, night***) and 1 *other* class specially refers to images without obvious meaning. As for composition classes, we categorize the common photographic composition rules into 13 classes: ***center, rule of thirds, golden ratio, triangle, horizontal, vertical, diagonal, symmetric, curved, radial, vanishing point, pattern, fill the frame***, and assign *none* class to the images without obvious composition rules, in which each image is annotated with one or more composition classes. Moreover, we annotate the dominant composition elements for each composition class except pattern and fill the frame, as illustrated in the figure below. We mark composition elements in yellow and add white auxiliary gridlines to some composition classes for better viewing.
41 |
42 |

43 |
44 |
45 | ## Download
46 | Download ``CADB_Dataset.zip`` (~2GB) from
47 | [[Google Drive]](https://drive.google.com/file/d/1fpZoo5exRfoarqDvdLDpQVXVOKFW63vz/view?usp=sharing) | [[Baidu Cloud]](https://pan.baidu.com/s/1o3AktNB-kmOIanJtzEx98g)(access code: *rmnb*).
48 | The annotations of scene categories, composition classes as well as elements can be found in ``annotations`` folder of this repository.
49 |
50 | ## Visualizing Annotations
51 | Put the json files in the ``annotations`` folder into the CADB dataset directory ``CADB_Dataset``. Then we obtain the file structure below:
52 | ```
53 | CADB_Dataset
54 | ├── composition_elements.json
55 | ├── composition_scores.json
56 | ├── scene_categories.json
57 | └── images
58 | ├── 10000.jpg
59 | ├── 10001.jpg
60 | ├── 10002.jpg
61 | ├── ……
62 | ```
63 | Visualizing the annotations of composition score, scene category, and composition element:
64 | ```bash
65 | python annotations/visualize_cadb_annotation.py --data_root ./CADB_Dataset
66 | ```
67 | The visualized results will be stored in ``CADB_Dataset/Visualization``.
68 |
69 | # Method Overview
70 |
71 | ## Motivation
72 | As shown in the following Figure, each **composition pattern** divides the holistic image into multiple non-overlapping partitions, which can model
73 | human perception of composition quality. By analyzing the **visual layout** (e.g., positions and sizes of visual elements) according to composition pattern, i.e., comparing the
74 | visual elements in various partitions, we can quantify the aesthetics of visual layout in terms of **visual balance** (e.g., symmetrical balance and radial balance), **composition rules** (e.g., rule of thirds, diagonals and triangles), and so on. Different composition patterns offer different perspectives to evaluate composition quality. For example, the composition pattern in the top (resp., bottom) row can help judge the composition quality in terms of symmetrical (resp., radial) balance.
75 |
76 |

77 |
78 |
79 | ## SAMP-Net
80 | To accomplish the composition assessment task, we propose a novel network SAMP-Net, which is named after **Saliency-Augmented Multi-pattern Pooling (SAMP)** module.
81 | The overall pipeline of our method is illustrated in the following Figure, where we first extract the global feature map from input image by backbone and then yield aggregated
82 | pattern feature through our SAMP module, which is followed by **Attentional Attribute Feature Fusion (AAFF)** module to fuse the composition feature and attribute feature. After that, we predict **composition score distribution** based on the fused feature and predict the attribute score based on the attribute feature, which are supervised by **weighted EMD loss** and Mean Squared Error (MSE) loss respectively.
83 |
84 |

85 |
86 |
87 | # Results
88 | We show the input image, its saliency map, its ground-truth/predicted composition mean score, and its pattern weights in below Figure. Moreover, our method predicts the pattern weights which indicate the importance of different patterns on the overall composition quality. For each image, the composition pattern with the largest weight is referred to as its **dominant pattern** and we overly this pattern on the image. The dominant pattern helps to reveal from which perspective the input image is given a high or low score, which further provide constructive suggestions for improving the composition quality. In the left figure of the third row, per the low score under the center pattern, the dog is suggested to be moved to the center.
89 |
90 |

91 |
92 |
93 | # Code Usage
94 | ```bash
95 | # clone this repository
96 | git clone https://github.com/bcmi/Image-Composition-Assessment-with-SAMP.git
97 | cd Image-Composition-Assessment-with-SAMP/SAMPNet
98 | # download CADB data (~2GB), change the default dataset folder and gpu id in config.py.
99 | ```
100 | ## Requirements
101 | - PyTorch>=1.0
102 | - torchvision
103 | - tensorboardX
104 | - opencv-python
105 | - scipy
106 | - tqdm
107 |
108 | Or you can refer to [``requirement.txt``](./SAMPNet/requirements.txt).
109 |
110 | ## Training
111 | ```bash
112 | python train.py
113 | # track your experiments
114 | tensorboard --logdir=./experiments --bind_all
115 | ```
116 | During training, the evaluation results of each epoch are recorded in a ``csv`` format file under the produced folder ``./experiments``.
117 |
118 | ## Testing
119 | You can download pretrained model (~180MB) from [[Google Drive]](https://drive.google.com/file/d/1sIcYr5cQGbxm--tCGaASmN0xtE_r-QUg/view?usp=sharing) | [[Baidu Cloud]](https://pan.baidu.com/s/17EzhsbHqwA5aR8ty77fTvw)(access code: *5qgg*).
120 | ```bash
121 | # place the pretrianed model in the folder ``pretrained_model`` and check the path in ``test.py``.
122 | # change the default gpu id in config.py
123 | python test.py
124 | ```
125 |
126 | # Citation
127 | ```
128 | @article{zhang2021image,
129 | title={Image Composition Assessment with Saliency-augmented Multi-pattern Pooling},
130 | author={Zhang, Bo and Niu, Li and Zhang, Liqing},
131 | journal={arXiv preprint arXiv:2104.03133},
132 | year={2021}
133 | }
134 | ```
135 |
--------------------------------------------------------------------------------
/SAMPNet/cadb_dataset.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset, DataLoader
2 | import torch
3 | import torch.nn.functional as F
4 | from PIL import Image
5 | import os, json
6 | import torchvision.transforms as transforms
7 | import random
8 | import numpy as np
9 | from config import Config
10 | import cv2
11 |
12 | IMAGE_NET_MEAN = [0.485, 0.456, 0.406]
13 | IMAGE_NET_STD = [0.229, 0.224, 0.225]
14 |
15 | random.seed(1)
16 | torch.manual_seed(1)
17 | cv2.setNumThreads(0)
18 |
19 | # Refer to: Saliency detection: A spectral residual approach
20 | def detect_saliency(img, scale=6, q_value=0.95, target_size=(224,224)):
21 | img_gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
22 | W, H = img_gray.shape
23 | img_resize = cv2.resize(img_gray, (H // scale, W // scale), interpolation=cv2.INTER_AREA)
24 |
25 | myFFT = np.fft.fft2(img_resize)
26 | myPhase = np.angle(myFFT)
27 | myLogAmplitude = np.log(np.abs(myFFT) + 0.000001)
28 | myAvg = cv2.blur(myLogAmplitude, (3, 3))
29 | mySpectralResidual = myLogAmplitude - myAvg
30 |
31 | m = np.exp(mySpectralResidual) * (np.cos(myPhase) + complex(1j) * np.sin(myPhase))
32 | saliencyMap = np.abs(np.fft.ifft2(m)) ** 2
33 | saliencyMap = cv2.GaussianBlur(saliencyMap, (9, 9), 2.5)
34 | saliencyMap = cv2.resize(saliencyMap, target_size, interpolation=cv2.INTER_LINEAR)
35 | threshold = np.quantile(saliencyMap.reshape(-1), q_value)
36 | if threshold > 0:
37 | saliencyMap[saliencyMap > threshold] = threshold
38 | saliencyMap = (saliencyMap - saliencyMap.min()) / threshold
39 | # for debugging
40 | # import matplotlib.pyplot as plt
41 | # plt.subplot(1,2,1)
42 | # plt.imshow(img)
43 | # plt.axis('off')
44 | # plt.subplot(1,2,2)
45 | # plt.imshow(saliencyMap, cmap='gray')
46 | # plt.axis('off')
47 | # plt.show()
48 | return saliencyMap
49 |
50 | class CADBDataset(Dataset):
51 | def __init__(self, split, cfg):
52 | self.data_path = cfg.dataset_path
53 | self.image_path = os.path.join(self.data_path, 'images')
54 | self.score_path = os.path.join(self.data_path, 'composition_scores.json')
55 | self.split_path = os.path.join(self.data_path, 'split.json')
56 | self.attr_path = os.path.join(self.data_path, 'composition_attributes.json')
57 | self.weight_path= os.path.join(self.data_path, 'emdloss_weight.json')
58 | self.split = split
59 | self.attr_types = cfg.attribute_types
60 |
61 | self.image_list = json.load(open(self.split_path, 'r'))[split]
62 | self.comp_scores = json.load(open(self.score_path, 'r'))
63 | self.comp_attrs = json.load(open(self.attr_path, 'r'))
64 | if self.split == 'train':
65 | self.image_weight = json.load(open(self.weight_path, 'r'))
66 | else:
67 | self.image_weight = None
68 |
69 | self.image_size = cfg.image_size
70 | self.transformer = transforms.Compose([
71 | transforms.Resize((self.image_size, self.image_size)),
72 | transforms.ToTensor(),
73 | transforms.Normalize(mean=IMAGE_NET_MEAN, std=IMAGE_NET_STD)
74 | ])
75 |
76 | def __len__(self):
77 | return len(self.image_list)
78 |
79 | def __getitem__(self, index):
80 | image_name = self.image_list[index]
81 | image_file = os.path.join(self.image_path, image_name)
82 | assert os.path.exists(image_file), image_file + ' not found'
83 | src = Image.open(image_file).convert('RGB')
84 | im = self.transformer(src)
85 |
86 | score_mean = self.comp_scores[image_name]['mean']
87 | score_mean = torch.Tensor([score_mean])
88 | score_dist = self.comp_scores[image_name]['dist']
89 | score_dist = torch.Tensor(score_dist)
90 |
91 | attrs = torch.tensor(self.get_attribute(image_name))
92 | src_im = np.asarray(src).copy()
93 | sal_map = detect_saliency(src_im, target_size=(self.image_size, self.image_size))
94 | sal_map = torch.from_numpy(sal_map.astype(np.float32)).unsqueeze(0)
95 |
96 | if self.split == 'train':
97 | emd_weight = torch.tensor(self.image_weight[image_name])
98 | return im, score_mean, score_dist, sal_map, attrs, emd_weight
99 | else:
100 | return im, score_mean, score_dist, sal_map, attrs
101 |
102 | def get_attribute(self, image_name):
103 | all_attrs = self.comp_attrs[image_name]
104 | attrs = [all_attrs[k] for k in self.attr_types]
105 | return attrs
106 |
107 | def normbboxes(self, bboxes, w, h):
108 | bboxes = bboxes.astype(np.float32)
109 | center_x = (bboxes[:,0] + bboxes[:,2]) / 2. / w
110 | center_y = (bboxes[:,1] + bboxes[:,3]) / 2. / h
111 | norm_w = (bboxes[:,2] - bboxes[:,0]) / w
112 | norm_h = (bboxes[:,3] - bboxes[:,1]) / h
113 | norm_bboxes = np.column_stack((center_x, center_y, norm_w, norm_h))
114 | norm_bboxes = np.clip(norm_bboxes, 0, 1)
115 | assert norm_bboxes.shape == bboxes.shape, '{} vs. {}'.format(bboxes.shape, norm_bboxes.shape)
116 | # print(w,h,bboxes[0],norm_bboxes[0])
117 | return norm_bboxes
118 |
119 | def scores2dist(self, scores):
120 | scores = np.array(scores)
121 | count = [(scores == i).sum() for i in range(1,6)]
122 | count = np.array(count)
123 | assert count.sum() == 5, scores
124 | distribution = count.astype(np.float) / count.sum()
125 | distribution = distribution.tolist()
126 | return distribution
127 |
128 | def collate_fn(self, batch):
129 | batch = list(zip(*batch))
130 | if self.split == 'train':
131 | assert (not self.need_image_path) and (not self.need_mask) \
132 | and (not self.need_proposal), 'Multi-scale training not implement'
133 | self.image_size = random.choice(range(self.min_image_size,
134 | self.max_image_size+1,
135 | 16))
136 | batch[0] = [resize(im, self.image_size) for im in batch[0]]
137 | batch = [torch.stack(data, dim=0) for data in batch]
138 | return batch
139 |
140 |
141 | def resize(image, size):
142 | image = F.interpolate(image.unsqueeze(0), size=size, mode="nearest").squeeze(0)
143 | return image
144 |
145 | if __name__ == '__main__':
146 | cfg = Config()
147 | train_dataset = CADBDataset('train', cfg)
148 | train_loader = DataLoader(train_dataset,
149 | batch_size=2,
150 | shuffle=True,
151 | num_workers=8,
152 | drop_last=True,
153 | collate_fn=None)
154 | test_dataset = CADBDataset('test', cfg)
155 | test_loader = DataLoader(test_dataset,
156 | batch_size=2,
157 | shuffle=False,
158 | num_workers=0,
159 | drop_last=False)
160 |
161 | print("training set size {}, test set size {}".format(len(train_dataset), len(test_dataset)))
162 | for batch, data in enumerate(train_loader):
163 | im,score,dist,mask,attrs,weight = data
164 | print('train', im.shape, score.shape, dist.shape, mask.shape, attrs.shape, weight.shape)
165 |
166 |
167 | for batch, data in enumerate(test_loader):
168 | im,score,dist,mask,attrs = data
169 | print('test', im.shape, score.shape, dist.shape, mask.shape, attrs.shape)
170 | break
--------------------------------------------------------------------------------
/SAMPNet/config.py:
--------------------------------------------------------------------------------
1 | import os,time
2 |
3 | class Config:
4 | # setting for dataset and dataloader
5 | dataset_path = '/workspace/composition/CADB_Dataset'
6 | assert os.path.exists(dataset_path), dataset_path + 'not found'
7 |
8 | batch_size = 16
9 | gpu_id = 0
10 | num_workers = 8
11 |
12 | # setting for training and optimization
13 | max_epoch = 50
14 | save_epoch = 1
15 | display_steps = 10
16 | score_level = 5
17 |
18 | use_weighted_loss = True
19 | use_attribute = True
20 | use_channel_attention = True
21 | use_saliency = True
22 | use_multipattern = True
23 | use_pattern_weight = True
24 | # AADB attributes: "Light" "Symmetry" "Object"
25 | # "score" "RuleOfThirds" "Repetition"
26 | # "BalacingElements" "ColorHarmony" "MotionBlur"
27 | # "VividColor" "DoF" "Content"
28 | attribute_types = ['RuleOfThirds', 'BalacingElements','DoF',
29 | 'Object', 'Symmetry', 'Repetition']
30 | num_attributes = len(attribute_types)
31 | attribute_weight = 0.1
32 |
33 | optimizer = 'adam' # or sgd
34 | lr = 1e-4
35 | weight_decay = 5e-5
36 | momentum = 0.9
37 | # setting for cnn
38 | image_size = 224
39 | resnet_layers = 18
40 | dropout = 0.5
41 | pool_dropout = 0.5
42 | pattern_list = [1, 2, 3, 4, 5, 6, 7, 8]
43 | pattern_fuse = 'sum'
44 |
45 | # setting for testing
46 | test_epoch = 1
47 |
48 | # setting for expriments
49 | if len(pattern_list) == 1:
50 | exp_root = os.path.join(os.getcwd(), './experiments/single_pattern')
51 | prefix = 'pattern{}'.format(pattern_list[0])
52 | else:
53 | exp_root = os.path.join(os.getcwd(), './experiments/')
54 | prefix = 'resnet{}'.format(resnet_layers)
55 | if use_multipattern:
56 | if use_pattern_weight and use_saliency:
57 | prefix += '_samp'
58 | elif use_pattern_weight:
59 | prefix += '_weighted_mpp'
60 | elif use_saliency:
61 | prefix += '_saliency_mpp'
62 | else:
63 | prefix += '_mpp'
64 | if use_attribute:
65 | if use_channel_attention:
66 | prefix += '_aaff'
67 | else:
68 | prefix += '_attr'
69 | if use_weighted_loss:
70 | prefix += '_wemd'
71 | exp_name = prefix
72 | exp_path = os.path.join(exp_root, prefix)
73 | while os.path.exists(exp_path):
74 | index = os.path.basename(exp_path).split(prefix)[-1]
75 | try:
76 | index = int(index) + 1
77 | except:
78 | index = 1
79 | exp_name = prefix + str(index)
80 | exp_path = os.path.join (exp_root, exp_name)
81 | print('Experiment name {} \n'.format(os.path.basename(exp_path)))
82 | checkpoint_dir = os.path.join(exp_path, 'checkpoints')
83 | log_dir = os.path.join(exp_path, 'logs')
84 |
85 | def create_path(self):
86 | print('Create experiment directory: ', self.exp_path)
87 | os.makedirs(self.exp_path)
88 | os.makedirs(self.checkpoint_dir)
89 | os.makedirs(self.log_dir)
90 |
91 | if __name__ == '__main__':
92 | cfg = Config()
--------------------------------------------------------------------------------
/SAMPNet/requirements.txt:
--------------------------------------------------------------------------------
1 | einops==0.3.0
2 | numpy==1.19.1
3 | opencv_contrib_python==4.4.0.46
4 | Pillow==8.4.0
5 | scipy==1.5.2
6 | tensorboardX==2.4
7 | torch==1.9.1
8 | torchvision==0.9.0+cu111
9 | tqdm==4.51.0
10 |
--------------------------------------------------------------------------------
/SAMPNet/samp_module.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import math
3 | import os
4 | import torch
5 | import torch.nn.functional as F
6 | import torchvision.models as models
7 | from einops import rearrange, repeat
8 | import torch.optim as optim
9 | from torch.nn.modules.utils import _pair
10 |
11 |
12 | class TriangularPattern(nn.Module):
13 | def __init__(self, flip=False, sal_size=56):
14 | super(TriangularPattern,self).__init__()
15 | self.flip = flip
16 | target_size = int(sal_size * 3 / 4)
17 | self.downsample = nn.AdaptiveMaxPool2d(output_size=(target_size, target_size))
18 |
19 | def forward(self, x, s):
20 | if self.flip:
21 | x = torch.flip(x, dims=(3,))
22 |
23 | up_indices = torch.triu_indices(x.shape[2], x.shape[3],
24 | device=x.device,
25 | offset=1)
26 | up_feat = x[:,:,up_indices[0], up_indices[1]].mean(dim=2)
27 |
28 | lw_indices = torch.tril_indices(x.shape[2], x.shape[3],
29 | device=x.device,
30 | offset=-1)
31 | lw_feat = x[:,:,lw_indices[0], lw_indices[1]].mean(dim=2)
32 | fused = torch.stack([up_feat, lw_feat], dim=2).unsqueeze(3)
33 | if s is None:
34 | return fused
35 |
36 | if self.flip:
37 | s = torch.flip(s, dims=(3,))
38 | s = self.downsample(s)
39 | up_sal_indices = torch.triu_indices(s.shape[2],
40 | s.shape[3],
41 | device=s.device,
42 | offset=1)
43 | up_sal = s[:, :, up_sal_indices[0], up_sal_indices[1]].flatten(1)
44 |
45 | lw_sal_indices = torch.tril_indices(s.shape[2],
46 | s.shape[3],
47 | device=s.device,
48 | offset=-1)
49 | lw_sal = s[:, :, lw_sal_indices[0], lw_sal_indices[1]].flatten(1)
50 |
51 |
52 | fused_sal = torch.stack([up_sal, lw_sal], dim=2).unsqueeze(3)
53 | return fused, fused_sal
54 |
55 | class CrossPattern(nn.Module):
56 | def __init__(self):
57 | super(CrossPattern, self).__init__()
58 |
59 | def forward(self, x, s):
60 | ones_vec = torch.ones(x.size(2), x.size(3), requires_grad=False)
61 | up_flip = torch.triu(ones_vec, diagonal=0).flip(dims=(-1,))
62 | lw_flip = torch.tril(ones_vec, diagonal=0).flip(dims=(-1,))
63 | up_inds = torch.where(torch.triu(up_flip, diagonal=0) > 0)
64 | lf_inds = torch.where(torch.tril(up_flip, diagonal=0) > 0)
65 | rt_inds = torch.where(torch.triu(lw_flip, diagonal=0) > 0)
66 | lw_inds = torch.where(torch.tril(lw_flip, diagonal=0) > 0)
67 |
68 | up_feat = x[:, :, up_inds[0], up_inds[1]].mean(2)
69 | lf_feat = x[:, :, lf_inds[0], lf_inds[1]].mean(2)
70 | rt_feat = x[:, :, rt_inds[0], rt_inds[1]].mean(2)
71 | lw_feat = x[:, :, lw_inds[0], lw_inds[1]].mean(2)
72 |
73 | fused1 = torch.stack([lf_feat, up_feat], dim=2)
74 | fused2 = torch.stack([lw_feat, rt_feat], dim=2)
75 | fused = torch.stack([fused1, fused2], dim=3)
76 | if s is None:
77 | return fused
78 |
79 | assert s.shape[2] == s.shape[3], \
80 | 'saliency map should be square, but get a shape {}'.format(s.shape)
81 | ones_vec = torch.ones(s.size(2), s.size(3), requires_grad=False)
82 | up_flip = torch.triu(ones_vec, diagonal=0).flip(dims=(-1,))
83 | lw_flip = torch.tril(ones_vec, diagonal=0).flip(dims=(-1,))
84 | up_inds = torch.where(torch.triu(up_flip, diagonal=0) > 0)
85 | lf_inds = torch.where(torch.tril(up_flip, diagonal=0) > 0)
86 | rt_inds = torch.where(torch.triu(lw_flip, diagonal=0) > 0)
87 | lw_inds = torch.where(torch.tril(lw_flip, diagonal=0) > 0)
88 |
89 | up_sal = s[:, :, up_inds[0], up_inds[1]].flatten(1)
90 | lf_sal = s[:, :, lf_inds[0], lf_inds[1]].flatten(1)
91 | rt_sal = s[:, :, rt_inds[0], rt_inds[1]].flatten(1)
92 | lw_sal = s[:, :, lw_inds[0], lw_inds[1]].flatten(1)
93 |
94 | sal1 = torch.stack([lf_sal, up_sal], dim=2)
95 | sal2 = torch.stack([lw_sal, rt_sal], dim=2)
96 | fused_sal = torch.stack([sal1, sal2], dim=3)
97 |
98 | return fused, fused_sal
99 |
100 | class SurroundPattern(nn.Module):
101 | def __init__(self, crop_size=1./2):
102 | super(SurroundPattern, self).__init__()
103 | self.crop_size = crop_size
104 |
105 | def forward(self, x, s):
106 | H,W = x.shape[2:]
107 | crop_h = (int(H / 2 - self.crop_size / 2 * H), int(H / 2 + self.crop_size / 2 * H))
108 | crop_w = (int(W / 2 - self.crop_size / 2 * W), int(W / 2 + self.crop_size / 2 * W))
109 | x_mask = torch.zeros(H,W,device=x.device, dtype=torch.bool)
110 | x_mask[crop_h[0] : crop_h[1], crop_w[0] : crop_w[1]] = True
111 |
112 | inside_indices = torch.where(x_mask)
113 | inside_part = x[:, :, inside_indices[0], inside_indices[1]]
114 | inside_feat = inside_part.mean(2)
115 |
116 | outside_indices = torch.where(~x_mask)
117 | outside_part = x[:, :, outside_indices[0], outside_indices[1]]
118 | outside_feat = outside_part.mean(2)
119 | fused = torch.stack([inside_feat, outside_feat], dim=2).unsqueeze(3)
120 | if s is None:
121 | return fused
122 |
123 | SH,SW = s.shape[2:]
124 | crop_sh = (int(SH / 2 - self.crop_size / 2 * SH), int(SH / 2 + self.crop_size / 2 * SH))
125 | crop_sw = (int(SW / 2 - self.crop_size / 2 * SW), int(SW / 2 + self.crop_size / 2 * SW))
126 | s_mask = torch.zeros(SH, SW, device=s.device, dtype=torch.bool)
127 | s_mask[crop_sh[0] : crop_sh[1], crop_sw[0] : crop_sw[1]] = True
128 |
129 | s_inside_indices = torch.where(s_mask)
130 | inside_sal = s[:, :, s_inside_indices[0], s_inside_indices[1]].flatten(1)
131 |
132 | s_outside_indices = torch.where(~s_mask)
133 | outside_sal = s[:, :, s_outside_indices[0], s_outside_indices[1]].flatten(1)
134 | if outside_sal.shape != inside_sal.shape:
135 | outside_sal = F.adaptive_max_pool1d(outside_sal.unsqueeze(1), output_size=inside_sal.shape[1])
136 | outside_sal = outside_sal.squeeze(1)
137 | fused_sal = torch.stack([inside_sal, outside_sal], dim=2).unsqueeze(3)
138 | return fused, fused_sal
139 |
140 | class HorizontalPattern(nn.Module):
141 | def __init__(self):
142 | super(HorizontalPattern, self).__init__()
143 | self.downsample = nn.MaxPool2d(kernel_size=(1,3), stride=(1,2), padding=(0,1))
144 |
145 | def forward(self, x, s):
146 | H = x.shape[2]
147 | up_part = x[:, :, : H // 2, :]
148 | up_feat = up_part.mean(3).mean(2)
149 |
150 | lw_part = x[:, :, H // 2 :, :]
151 | lw_feat = lw_part.mean(3).mean(2)
152 | fused = torch.stack([up_feat, lw_feat], dim=2).unsqueeze(3)
153 | if s is None:
154 | return fused
155 | SH = s.shape[2]
156 | s = self.downsample(s)
157 | up_sal = s[:, :, : SH // 2, :].flatten(1)
158 | if SH % 2 == 0:
159 | lw_sal = s[:, :, SH // 2 :, :].flatten(1)
160 | else:
161 | lw_sal = s[:, :, SH // 2 + 1 :, :].flatten(1)
162 | fused_sal = torch.stack([up_sal, lw_sal], dim=2).unsqueeze(3)
163 | return fused, fused_sal
164 |
165 | class VerticalPattern(nn.Module):
166 | def __init__(self):
167 | super(VerticalPattern, self).__init__()
168 | self.downsample = nn.MaxPool2d(kernel_size=(3, 1), stride=(2, 1), padding=(1, 0))
169 |
170 | def forward(self, x, s):
171 | W = x.shape[3]
172 | left_part = x[:, :, :, : W // 2]
173 | left_feat = left_part.mean(3).mean(2)
174 |
175 | right_part = x[:, :, :, W // 2 :]
176 | right_feat = right_part.mean(3).mean(2)
177 | fused = torch.stack([left_feat, right_feat], dim=2).unsqueeze(2)
178 | if s is None:
179 | return fused
180 |
181 | SW = s.shape[3]
182 | s = self.downsample(s)
183 | left_sal = s[:, :, :, : SW // 2].flatten(1)
184 | if SW % 2 == 0:
185 | right_sal = s[:, :, :, SW // 2 :].flatten(1)
186 | else:
187 | right_sal = s[:, :, :, SW // 2 + 1 :].flatten(1)
188 | fused_sal = torch.stack([left_sal, right_sal], dim=2).unsqueeze(2)
189 | return fused, fused_sal
190 |
191 | class GlobalPattern(nn.Module):
192 | def __init__(self):
193 | super(GlobalPattern, self).__init__()
194 | self.downsample = nn.Sequential(
195 | nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
196 | nn.Flatten(1)
197 | )
198 | self.gap = nn.Sequential(
199 | nn.AdaptiveAvgPool2d(1)
200 | )
201 |
202 | def forward(self, x, s):
203 | gap_feat = self.gap(x)
204 | if s is None:
205 | return gap_feat
206 | sal_feat = self.downsample(s).unsqueeze(2).unsqueeze(3)
207 | return gap_feat, sal_feat
208 |
209 | class QuarterPattern(nn.Module):
210 | def __init__(self):
211 | super(QuarterPattern, self).__init__()
212 |
213 | def forward(self, x, s):
214 | feat = F.adaptive_avg_pool2d(x, output_size=(2,2))
215 | if s is None:
216 | return feat
217 |
218 | s_chunks = torch.chunk(s, 2, dim=2)
219 | up_left, up_right = torch.chunk(s_chunks[0], 2, dim=3)
220 | lw_left, lw_right = torch.chunk(s_chunks[1], 2, dim=3)
221 | up_sal = torch.stack([up_left.flatten(1), up_right.flatten(1)],dim=2)
222 | lw_sal = torch.stack([lw_left.flatten(1), lw_right.flatten(1)],dim=2)
223 | feat_sal = torch.stack([up_sal, lw_sal], dim=2)
224 | return feat, feat_sal
225 |
226 | class ThirdOfRulePattern(nn.Module):
227 | def __init__(self):
228 | super(ThirdOfRulePattern, self).__init__()
229 |
230 | def forward(self, x, s):
231 | feat = F.adaptive_avg_pool2d(x, output_size=(3,3))
232 | if s is None:
233 | return feat
234 | out_size = (s.shape[-1] // 3) * 3
235 | s = F.adaptive_max_pool2d(s, output_size=(out_size,out_size))
236 | hor_chunks = torch.chunk(s, 3, dim=3)
237 | feat_sal = []
238 | for h_chunk in hor_chunks:
239 | tmp = []
240 | for v_chunk in torch.chunk(h_chunk, 3, dim=2):
241 | tmp.append(v_chunk.flatten(1))
242 | tmp = torch.stack(tmp, dim=2)
243 | feat_sal.append(tmp)
244 | feat_sal = torch.stack(feat_sal, dim=3)
245 | return feat, feat_sal
246 |
247 | class VerticalThirdPattern(nn.Module):
248 | def __init__(self):
249 | super(VerticalThirdPattern, self).__init__()
250 |
251 | def forward(self, x, s):
252 | feat = F.adaptive_avg_pool2d(x, output_size=(3,1))
253 | if s is None:
254 | return feat
255 | out_size = (s.shape[-2] // 3) * 3
256 | s = F.adaptive_max_pool2d(s, output_size=(out_size,s.shape[-1]))
257 | ver_chunks = torch.chunk(s, 3, dim=2)
258 | feat_sal = []
259 | for ver in ver_chunks:
260 | feat_sal.append(ver.flatten(1))
261 | feat_sal = torch.stack(feat_sal, dim=2).unsqueeze(3)
262 | return feat, feat_sal
263 |
264 | class HorThirdPattern(nn.Module):
265 | def __init__(self):
266 | super(HorThirdPattern, self).__init__()
267 |
268 | def forward(self, x, s):
269 | feat = F.adaptive_avg_pool2d(x, output_size=(1,3))
270 | if s is None:
271 | return feat
272 | out_size = (s.shape[-1] // 3) * 3
273 | s = F.adaptive_max_pool2d(s, output_size=(s.shape[-2], out_size))
274 | hor_chunks = torch.chunk(s, 3, dim=3)
275 | feat_sal = []
276 | for hor in hor_chunks:
277 | feat_sal.append(hor.flatten(1))
278 | feat_sal = torch.stack(feat_sal, dim=2).unsqueeze(2)
279 | return feat, feat_sal
280 |
281 | class MultiDirectionPattern(nn.Module):
282 | def __init__(self):
283 | super(MultiDirectionPattern, self).__init__()
284 | feat_mask = self.generate_multi_direction_mask(7, 7)
285 | sal_mask = self.generate_multi_direction_mask(56, 56)
286 | self.register_buffer('feat_mask', feat_mask)
287 | self.register_buffer('sal_mask', sal_mask)
288 |
289 | def generate_multi_direction_mask(self, w, h):
290 | mask = torch.zeros(8, h, w)
291 | degree_mask = torch.zeros(h, w)
292 | if h % 2 == 0:
293 | cx, cy = float(w-1)/2, float(h-1)/2
294 | else:
295 | cx, cy = w // 2, h // 2
296 |
297 | for i in range(w):
298 | for j in range(h):
299 | degree = math.degrees(math.atan2(cy - j, cx - i))
300 | degree_mask[j, i] = (degree + 180) % 360
301 | for i in range(8):
302 | if i == 7:
303 | degree_mask[degree_mask == 0] = 360
304 | mask[i, (degree_mask >= i * 45) & (degree_mask <= (i + 1) * 45)] = 1
305 | if h % 2 != 0:
306 | mask[i, cy, cx] = 1
307 | # mask = torch.count_nonzero(mask.flatten(1))
308 | return mask
309 |
310 | def forward(self, x, s):
311 | # B, C, H, W = x.shape
312 | mask = rearrange(self.feat_mask, 'p h w -> 1 1 p h w')
313 | count = torch.count_nonzero(mask.flatten(-2), dim=-1)
314 | x = x.unsqueeze(2)
315 | part = (x * mask).sum(-2).sum(-1)
316 | feat = part / count
317 | feat = rearrange(feat, 'b c (h w) -> b c h w', h=2, w=4)
318 | if s is None:
319 | return feat
320 | sal_feat = []
321 | for i in range(self.sal_mask.shape[0]):
322 | sal_feat.append(s[:, :, self.sal_mask[i] > 0].flatten(1))
323 | sal_feat = torch.stack(sal_feat, dim=2)
324 | sal_feat = rearrange(sal_feat, 'b c (h w) -> b c h w', h=2, w=4)
325 | return feat, sal_feat
326 |
327 | class MultiRectanglePattern(nn.Module):
328 | def __init__(self):
329 | super(MultiRectanglePattern, self).__init__()
330 | feat_mask = self.generate_multi_direction_mask(7, 7)
331 | sal_mask = self.generate_multi_direction_mask(56, 56)
332 | self.register_buffer('feat_mask', feat_mask)
333 | self.register_buffer('sal_mask', sal_mask)
334 |
335 | def generate_multi_direction_mask(self, w, h):
336 | square_part = torch.zeros(4, 4, h, w)
337 | index_y = torch.split(torch.arange(h), (h + 1) // 4)
338 | index_x = torch.split(torch.arange(w), (w + 1) // 4)
339 | for i in range(4):
340 | for j in range(4):
341 | for x in index_x[i]:
342 | for y in index_y[j]:
343 | square_part[i, j, y, x] = 1
344 | mask = torch.zeros(8, h, w)
345 | group_x = [[0, 0, 1], [1], [2, 3, 3], [2], [0, 0, 1], [1], [2, 3, 3], [2]]
346 | group_y = [[1, 0, 0], [1], [0, 0, 1], [1], [2, 3, 3], [2], [3, 3, 2], [2]]
347 | for i in range(len(group_x)):
348 | mask[i] = torch.sum(square_part[group_x[i], group_y[i]], dim=0)
349 | mask = torch.clip(mask, min=0, max=1)
350 | return mask
351 |
352 | def forward(self, x, s):
353 | # B, C, H, W = x.shape
354 | mask = rearrange(self.feat_mask, 'p h w -> 1 1 p h w')
355 | count = torch.count_nonzero(mask.flatten(-2), dim=-1)
356 | x = x.unsqueeze(2)
357 | part = (x * mask).sum(-2).sum(-1)
358 | feat = part / count
359 | feat = rearrange(feat, 'b c (h w) -> b c h w', h=2, w=4)
360 | if s is None:
361 | return feat
362 | sal_count = torch.count_nonzero(self.sal_mask.flatten(1), dim=-1)
363 | target_size = sal_count.min().item()
364 | sal_feat = []
365 | for i in range(self.sal_mask.shape[0]):
366 | sal = s[:, :, self.sal_mask[i] > 0].flatten(1)
367 | if sal.shape[1] > target_size:
368 | sal = F.adaptive_max_pool1d(sal.unsqueeze(1), output_size=target_size).squeeze(1)
369 | sal_feat.append(sal)
370 | sal_feat = torch.stack(sal_feat, dim=2)
371 | sal_feat = rearrange(sal_feat, 'b c (h w) -> b c h w', h=2, w=4)
372 | return feat, sal_feat
373 |
374 | class MPPModule(nn.Module):
375 | def __init__(self, in_dim, out_dim, dropout=0.5,
376 | pattern_list=[1,2,3,4,5,6,7,8],
377 | fusion='sum'):
378 | super(MPPModule, self).__init__()
379 | self.pattern_list = pattern_list
380 | self.in_dim = in_dim
381 | self.out_dim = out_dim
382 | self.dropout = nn.Dropout(dropout)
383 | self.fusion = fusion
384 | pool_list = []
385 | conv_list = []
386 | print('Multi-Pattern Pooling pattern: {}, fusion manner: {}, dropout: {}'.\
387 | format(pattern_list, fusion, dropout))
388 | for pattern in pattern_list:
389 | p_fn = getattr(self, 'pattern{}'.format(int(pattern)))
390 | p, c = p_fn()
391 | pool_list.append(p)
392 | conv_list.append(c)
393 | self.pool_list = nn.ModuleList(pool_list)
394 | self.conv_list = nn.ModuleList(conv_list)
395 |
396 | def forward(self, x, weights):
397 | outputs = []
398 | for pool,conv in zip(self.pool_list, self.conv_list):
399 | feat = self.dropout(pool(x,s=None))
400 | feat = self.dropout(conv(feat))
401 | outputs.append(feat)
402 |
403 | if len(outputs) == 1:
404 | return outputs[0]
405 | if self.fusion == 'sum':
406 | outputs = torch.stack(outputs, dim=2)
407 | if weights is None:
408 | outputs = torch.sum(outputs, dim=2)
409 | else:
410 | weights = F.softmax(weights, dim=1)
411 | outputs = torch.sum(outputs * weights.unsqueeze(1), dim=2)
412 | elif self.fusion == 'mean':
413 | outputs = torch.stack(outputs, dim=2)
414 | outputs = torch.mean(outputs, dim=2)
415 | elif self.fusion == 'concat':
416 | outputs = torch.cat(outputs, dim=1)
417 | else:
418 | raise ValueError('Unkown fusion type {}'.format(self.fusion))
419 | return outputs
420 |
421 | def pattern0(self):
422 | pool = GlobalPattern()
423 | conv = nn.Sequential(
424 | nn.Conv2d(self.in_dim, self.out_dim, (1,1), bias=False),
425 | nn.ReLU(True),
426 | nn.Flatten(1)
427 | )
428 | return pool, conv
429 |
430 | def pattern1(self):
431 | pool = HorizontalPattern()
432 | conv = nn.Sequential(
433 | nn.Conv2d(self.in_dim, self.out_dim, (2,1), bias=False),
434 | nn.ReLU(True),
435 | nn.Flatten(1)
436 | )
437 | return pool, conv
438 |
439 | def pattern2(self):
440 | pool = VerticalPattern()
441 | conv = nn.Sequential(
442 | nn.Conv2d(self.in_dim, self.out_dim, (1,2), bias=False),
443 | nn.ReLU(True),
444 | nn.Flatten(1)
445 | )
446 | return pool, conv
447 |
448 | def pattern3(self):
449 | pool = TriangularPattern(flip=False)
450 | conv = nn.Sequential(
451 | nn.Conv2d(self.in_dim, self.out_dim, (2,1), bias=False),
452 | nn.ReLU(True),
453 | nn.Flatten(1)
454 | )
455 | return pool, conv
456 |
457 | def pattern4(self):
458 | pool = TriangularPattern(flip=True)
459 | conv = nn.Sequential(
460 | nn.Conv2d(self.in_dim, self.out_dim, (2, 1), bias=False),
461 | nn.ReLU(True),
462 | nn.Flatten(1)
463 | )
464 | return pool, conv
465 |
466 | def pattern5(self):
467 | pool = SurroundPattern()
468 | conv = nn.Sequential(
469 | nn.Conv2d(self.in_dim, self.out_dim, (2, 1), bias=False),
470 | nn.ReLU(True),
471 | nn.Flatten(1)
472 | )
473 | return pool, conv
474 |
475 | def pattern6(self):
476 | pool = QuarterPattern()
477 | conv = nn.Sequential(
478 | nn.Conv2d(self.in_dim, self.out_dim, (2, 2), bias=False),
479 | nn.ReLU(True),
480 | nn.Flatten(1)
481 | )
482 | return pool, conv
483 |
484 | def pattern7(self):
485 | pool = CrossPattern()
486 | conv = nn.Sequential(
487 | nn.Conv2d(self.in_dim, self.out_dim, (2,2), bias=False),
488 | nn.ReLU(True),
489 | nn.Flatten(1)
490 | )
491 | return pool, conv
492 |
493 | def pattern8(self):
494 | pool = ThirdOfRulePattern()
495 | conv = nn.Sequential(
496 | nn.Conv2d(self.in_dim, self.out_dim, (3,3), bias=False),
497 | nn.ReLU(True),
498 | nn.Flatten(1)
499 | )
500 | return pool, conv
501 |
502 | def pattern9(self):
503 | pool = HorThirdPattern()
504 | conv = nn.Sequential(
505 | nn.Conv2d(self.in_dim, self.out_dim, (1,3), bias=False),
506 | nn.ReLU(True),
507 | nn.Flatten(1)
508 | )
509 | return pool, conv
510 |
511 | def pattern10(self):
512 | pool = VerticalThirdPattern()
513 | conv = nn.Sequential(
514 | nn.Conv2d(self.in_dim, self.out_dim, (3,1), bias=False),
515 | nn.ReLU(True),
516 | nn.Flatten(1)
517 | )
518 | return pool, conv
519 |
520 | def pattern11(self):
521 | pool = MultiDirectionPattern()
522 | conv = nn.Sequential(
523 | nn.Conv2d(self.in_dim, self.out_dim, (2,4), bias=False),
524 | nn.ReLU(True),
525 | nn.Flatten(1)
526 | )
527 | return pool, conv
528 |
529 | def pattern12(self):
530 | pool = MultiRectanglePattern()
531 | conv = nn.Sequential(
532 | nn.Conv2d(self.in_dim, self.out_dim, (2,4), bias=False),
533 | nn.ReLU(True),
534 | nn.Flatten(1)
535 | )
536 | return pool, conv
537 |
538 |
539 | class SAMPPModule(nn.Module):
540 | def __init__(self, in_dim, out_dim,
541 | saliency_size, dropout=0.5,
542 | pattern_list=[1,2,3,4,5,6,7,8],
543 | fusion='sum'):
544 | super(SAMPPModule, self).__init__()
545 | self.pattern_list = pattern_list
546 | self.in_dim = in_dim
547 | self.out_dim = out_dim
548 | self.dropout = nn.Dropout(dropout)
549 | self.fusion = fusion
550 | self.saliency_size = saliency_size
551 | pool_list = []
552 | conv_list = []
553 | print('Saliency-aware Multi-Pattern Pooling pattern: {}, fusion manner: {}, dropout: {}'.\
554 | format(pattern_list, fusion, dropout))
555 | for pattern in pattern_list:
556 | p_fn = getattr(self, 'pattern{}'.format(int(pattern)))
557 | p,c = p_fn()
558 | pool_list.append(p)
559 | conv_list.append(c)
560 | self.pool_list = nn.ModuleList(pool_list)
561 | self.conv_list = nn.ModuleList(conv_list)
562 |
563 | def forward(self, x, s, weights):
564 | outputs = []
565 | idx = 0
566 | for pool,conv in zip(self.pool_list, self.conv_list):
567 | feat, sal = pool(x,s)
568 | feat = self.dropout(feat)
569 | # print('pattern{}, x {}, s {}, feat_dim {}, sal_dim {}'.format(
570 | # self.pattern_list[idx], x.shape, s.shape, feat.shape, sal.shape))
571 | fused = torch.cat([feat, sal], dim=1)
572 | fused = self.dropout(conv(fused))
573 | outputs.append(fused)
574 | idx += 1
575 |
576 | if len(outputs) == 1:
577 | return outputs[0]
578 | if self.fusion == 'sum':
579 | outputs = torch.stack(outputs, dim=2)
580 | if weights is None:
581 | outputs = torch.sum(outputs, dim=2)
582 | else:
583 | weights = F.softmax(weights, dim=1)
584 | outputs = torch.sum(outputs * weights.unsqueeze(1), dim=2)
585 | elif self.fusion == 'mean':
586 | outputs = torch.stack(outputs, dim=2)
587 | outputs = torch.mean(outputs, dim=2)
588 | elif self.fusion == 'concat':
589 | outputs = torch.cat(outputs, dim=1)
590 | else:
591 | raise ValueError('Unkown fusion type {}'.format(self.fusion))
592 | return outputs
593 |
594 | def pattern0(self):
595 | sal_length = (self.saliency_size // 2) ** 2
596 | pool = GlobalPattern()
597 | conv = nn.Sequential(
598 | nn.Conv2d(self.in_dim + sal_length, self.out_dim, (1, 1), bias=False),
599 | nn.ReLU(True),
600 | nn.Flatten(1)
601 | )
602 | return pool, conv
603 |
604 |
605 | def pattern1(self):
606 | sal_length = (self.saliency_size // 2) ** 2
607 | pool = HorizontalPattern()
608 | conv = nn.Sequential(
609 | nn.Conv2d(self.in_dim + sal_length, self.out_dim, (2,1), bias=False),
610 | nn.ReLU(True),
611 | nn.Flatten(1)
612 | )
613 | return pool, conv
614 |
615 | def pattern2(self):
616 | sal_length = (self.saliency_size // 2) ** 2
617 | pool = VerticalPattern()
618 | conv = nn.Sequential(
619 | nn.Conv2d(self.in_dim + sal_length, self.out_dim, (1,2), bias=False),
620 | nn.ReLU(True),
621 | nn.Flatten(1)
622 | )
623 | return pool, conv
624 |
625 | def pattern3(self):
626 | s_size = int(self.saliency_size * 3 / 4)
627 | sal_length = s_size * (s_size - 1) // 2
628 | pool = TriangularPattern(flip=False, sal_size=self.saliency_size)
629 | conv = nn.Sequential(
630 | nn.Conv2d(self.in_dim + sal_length, self.out_dim, (2,1), bias=False),
631 | nn.ReLU(True),
632 | nn.Flatten(1)
633 | )
634 | return pool, conv
635 |
636 | def pattern4(self):
637 | s_size = int(self.saliency_size * 3 / 4)
638 | sal_length = s_size * (s_size - 1) // 2
639 | pool = TriangularPattern(flip=True, sal_size=self.saliency_size)
640 | conv = nn.Sequential(
641 | nn.Conv2d(self.in_dim + sal_length, self.out_dim, (2, 1), bias=False),
642 | nn.ReLU(True),
643 | nn.Flatten(1)
644 | )
645 | return pool, conv
646 |
647 | def pattern5(self):
648 | crop_size = 1./2
649 | sal_length = int(self.saliency_size * crop_size) ** 2
650 | pool = SurroundPattern(crop_size)
651 | conv = nn.Sequential(
652 | nn.Conv2d(self.in_dim + sal_length, self.out_dim, (2, 1), bias=False),
653 | nn.ReLU(True),
654 | nn.Flatten(1)
655 | )
656 | return pool, conv
657 |
658 | def pattern6(self):
659 | sal_length = (self.saliency_size // 2) ** 2
660 | pool = QuarterPattern()
661 | conv = nn.Sequential(
662 | nn.Conv2d(self.in_dim + sal_length, self.out_dim, (2, 2), bias=False),
663 | nn.ReLU(True),
664 | nn.Flatten(1)
665 | )
666 | return pool, conv
667 |
668 | def pattern7(self):
669 | sal_length = 0
670 | row_len = self.saliency_size
671 | while row_len > 0:
672 | sal_length += row_len
673 | row_len -= 2
674 | pool = CrossPattern()
675 | conv = nn.Sequential(
676 | nn.Conv2d(self.in_dim + sal_length, self.out_dim, (2,2), bias=False),
677 | nn.ReLU(True),
678 | nn.Flatten(1)
679 | )
680 | return pool, conv
681 |
682 | def pattern8(self):
683 | sal_length = (self.saliency_size // 3)**2
684 | pool = ThirdOfRulePattern()
685 | conv = nn.Sequential(
686 | nn.Conv2d(self.in_dim + sal_length, self.out_dim, (3,3), bias=False),
687 | nn.ReLU(True),
688 | nn.Flatten(1)
689 | )
690 | return pool, conv
691 |
692 | def pattern9(self):
693 | sal_length = (self.saliency_size // 3) * self.saliency_size
694 | pool = HorThirdPattern()
695 | conv = nn.Sequential(
696 | nn.Conv2d(self.in_dim + sal_length, self.out_dim, (1,3), bias=False),
697 | nn.ReLU(True),
698 | nn.Flatten(1)
699 | )
700 | return pool, conv
701 |
702 | def pattern10(self):
703 | sal_length = (self.saliency_size // 3) * self.saliency_size
704 | pool = VerticalThirdPattern()
705 | conv = nn.Sequential(
706 | nn.Conv2d(self.in_dim + sal_length, self.out_dim, (3,1), bias=False),
707 | nn.ReLU(True),
708 | nn.Flatten(1)
709 | )
710 | return pool, conv
711 |
712 | def pattern11(self):
713 | sal_length = int((self.saliency_size // 2) * (self.saliency_size // 2 + 1) / 2)
714 | pool = MultiDirectionPattern()
715 | conv = nn.Sequential(
716 | nn.Conv2d(self.in_dim + sal_length, self.out_dim, (2,4), bias=False),
717 | nn.ReLU(True),
718 | nn.Flatten(1)
719 | )
720 | return pool, conv
721 |
722 | def pattern12(self):
723 | sal_length = (self.saliency_size // 4) ** 2
724 | pool = MultiRectanglePattern()
725 | conv = nn.Sequential(
726 | nn.Conv2d(self.in_dim + sal_length, self.out_dim, (2,4), bias=False),
727 | nn.ReLU(True),
728 | nn.Flatten(1)
729 | )
730 | return pool, conv
731 |
732 |
733 |
734 | if __name__ == '__main__':
735 | pattern_list = list(range(1,13))
736 | x = torch.randn(2, 512, 7, 7)
737 | s = torch.randn(2, 1, 56, 56)
738 | w = torch.randn(2, len(pattern_list))
739 | mpp = MPPModule(512, 512, 0.5, pattern_list)
740 | sampp = SAMPPModule(512, 512, 56, 0.5, pattern_list)
741 | mpp_out = mpp(x, weights=w)
742 | sa_out = sampp(x,s, weights=w)
743 | print('mpp_out', mpp_out.shape, 'sampp_out', sa_out.shape)
744 |
745 | # third_pattern = ThirdOfRulePattern()
746 | # print(third_pattern(x,s)[0].shape)
747 |
748 | # h = w = 56
749 | # square_part = torch.zeros(4, 4, h, w)
750 | # index_y = torch.split(torch.arange(h), (h+1)//4)
751 | # index_x = torch.split(torch.arange(w), (w+1)//4)
752 | # for i in range(4):
753 | # for j in range(4):
754 | # for x in index_x[i]:
755 | # for y in index_y[j]:
756 | # square_part[i,j,y,x] = 1
757 | # mask = torch.zeros(8, h, w)
758 | # group_x = [[0, 0, 1], [1], [2, 3, 3], [2], [0, 0, 1], [1], [2, 3, 3], [2]]
759 | # group_y = [[1, 0, 0], [1], [0, 0, 1], [1], [2, 3, 3], [2], [3, 3, 2], [2]]
760 | # for i in range(len(group_x)):
761 | # mask[i] = torch.sum(square_part[group_x[i], group_y[i]], dim=0)
762 | # print(mask[i], torch.count_nonzero(mask[i]))
763 |
764 |
765 |
766 |
767 |
768 |
769 |
--------------------------------------------------------------------------------
/SAMPNet/samp_net.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import torch.nn.functional as F
4 | import torchvision.models as models
5 | import warnings
6 | warnings.filterwarnings('ignore')
7 | from samp_module import MPPModule, SAMPPModule
8 | from config import Config
9 |
10 | class SAPModule(nn.Module):
11 | def __init__(self, input_channel, output_channel, saliency_size, dropout):
12 | super(SAPModule, self).__init__()
13 | self.pooling = nn.Sequential(
14 | nn.AdaptiveAvgPool2d(1),
15 | nn.Flatten(1),
16 | nn.Dropout(dropout))
17 | sal_feat_len = saliency_size * saliency_size
18 | self.feature_layer = nn.Sequential(
19 | nn.Linear(input_channel + sal_feat_len,
20 | output_channel, bias=False),
21 | nn.Dropout(dropout))
22 |
23 | def forward(self, x, s):
24 | x = self.pooling(x)
25 | s = s.flatten(1)
26 | f = self.feature_layer(torch.cat([x,s], dim=1))
27 | return f
28 |
29 | def build_resnet(layers, pretrained=False):
30 | assert layers in [18, 34, 50, 101], f'layers must be one of [18, 34, 50, 101], while layers = {layers}'
31 | if layers == 18:
32 | resnet = models.resnet18(pretrained)
33 | elif layers == 34:
34 | resnet = models.resnet34(pretrained)
35 | elif layers == 50:
36 | resnet = models.resnet50(pretrained)
37 | else:
38 | resnet = models.resnet101(pretrained)
39 | modules = list(resnet.children())[:-2]
40 | resnet = nn.Sequential(*modules)
41 | return resnet
42 |
43 | class SAMPNet(nn.Module):
44 | def __init__(self, cfg, pretrained=True):
45 | super(SAMPNet, self).__init__()
46 | score_level = cfg.score_level
47 | layers = cfg.resnet_layers
48 | dropout = cfg.dropout
49 | num_attributes = cfg.num_attributes
50 | input_channel = 512 if layers in [18,34] else 2048
51 | sal_dim = 512
52 | pool_dropout = cfg.pool_dropout
53 | pattern_list = cfg.pattern_list
54 | pattern_fuse = cfg.pattern_fuse
55 |
56 | self.use_weighted_loss = cfg.use_weighted_loss
57 | self.use_attribute = cfg.use_attribute
58 | self.use_channel_attention = cfg.use_channel_attention
59 | self.use_saliency = cfg.use_saliency
60 | self.use_multipattern = cfg.use_multipattern
61 | self.use_pattern_weight = cfg.use_pattern_weight
62 |
63 | self.backbone = build_resnet(layers, pretrained=pretrained)
64 | self.global_pool = nn.Sequential(
65 | nn.AdaptiveAvgPool2d(1),
66 | nn.Dropout(dropout),
67 | nn.Flatten(1),
68 | )
69 |
70 | output_channel = input_channel
71 | if self.use_multipattern:
72 | if self.use_saliency:
73 | self.saliency_max = nn.Sequential(
74 | nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
75 | nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
76 | # multi-pattern pooling module
77 | self.pattern_module = SAMPPModule(input_channel,
78 | input_channel + sal_dim,
79 | saliency_size=56,
80 | dropout=pool_dropout,
81 | pattern_list=pattern_list,
82 | fusion=pattern_fuse)
83 | output_channel = input_channel + sal_dim
84 | else:
85 | self.pattern_module = MPPModule(input_channel,
86 | input_channel,
87 | dropout=pool_dropout,
88 | pattern_list=pattern_list,
89 | fusion=pattern_fuse)
90 | output_channel = input_channel
91 | if self.use_pattern_weight:
92 | self.pattern_weight_layer = nn.Sequential(
93 | nn.AdaptiveAvgPool2d(1),
94 | nn.Dropout(dropout),
95 | nn.Flatten(1),
96 | nn.Linear(input_channel, len(pattern_list), bias=False)
97 | )
98 | else:
99 | if self.use_saliency:
100 | if self.use_saliency:
101 | self.saliency_max = nn.Sequential(
102 | nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
103 | nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
104 | self.pattern_module = SAPModule(input_channel,
105 | input_channel + sal_dim,
106 | saliency_size=56,
107 | dropout=pool_dropout)
108 | output_channel = input_channel + sal_dim
109 |
110 | if self.use_attribute:
111 | # multi-task structure
112 | concat_dim = output_channel
113 | att_dim = 512 if concat_dim >= 1024 else concat_dim // 2
114 | com_dim = concat_dim - att_dim
115 | self.att_feature_layer = nn.Sequential(
116 | nn.Linear(concat_dim, att_dim, bias=False),
117 | nn.ReLU(True),
118 | nn.Dropout(dropout)
119 | )
120 | self.att_pred_layer = nn.Sequential(
121 | nn.Linear(att_dim, num_attributes, bias=False)
122 | )
123 | self.com_feature_layer = nn.Sequential(
124 | nn.Linear(concat_dim, com_dim, bias=False),
125 | nn.ReLU(True),
126 | nn.Dropout(dropout)
127 | )
128 | if self.use_channel_attention:
129 | self.alpha_predict_layer = nn.Sequential(
130 | nn.Linear(concat_dim, 2, bias=False),
131 | nn.Sigmoid()
132 | )
133 | self.com_pred_layer = nn.Sequential(
134 | nn.Linear(output_channel, output_channel, bias=False),
135 | nn.ReLU(True),
136 | nn.Dropout(dropout),
137 | nn.Linear(output_channel, input_channel, bias=False),
138 | nn.ReLU(True),
139 | nn.Linear(input_channel, score_level, bias=False),
140 | nn.Softmax(dim=1))
141 |
142 | def forward(self, x, s):
143 | feature_map = self.backbone(x)
144 | weight = None
145 | attribute = None
146 |
147 | if self.use_multipattern:
148 | if self.use_pattern_weight:
149 | weight = self.pattern_weight_layer(feature_map)
150 | if self.use_saliency:
151 | sal_map = self.saliency_max(s)
152 | pattern_feat = self.pattern_module(feature_map, sal_map, weight)
153 | else:
154 | pattern_feat = self.pattern_module(feature_map, weight)
155 | else:
156 | if self.use_saliency:
157 | sal_map = self.saliency_max(s)
158 | pattern_feat = self.pattern_module(feature_map, sal_map)
159 | else:
160 | pattern_feat = self.global_pool(feature_map)
161 | if self.use_attribute:
162 | att_feat = self.att_feature_layer(pattern_feat)
163 | com_feat = self.com_feature_layer(pattern_feat)
164 | attribute = self.att_pred_layer(att_feat)
165 | fused_feat = torch.cat([att_feat, com_feat], dim=1)
166 | if self.use_channel_attention:
167 | alpha = self.alpha_predict_layer(fused_feat)
168 | fused_feat = torch.cat([alpha[:,0:1] * att_feat, alpha[:,1:] * com_feat], dim=1)
169 | scores = self.com_pred_layer(fused_feat)
170 | else:
171 | scores = self.com_pred_layer(pattern_feat)
172 | return weight, attribute, scores
173 |
174 | class AttributeLoss(nn.Module):
175 | def __init__(self, scalar=0.1):
176 | super(AttributeLoss, self).__init__()
177 | self.scalar = scalar
178 |
179 | def forward(self, att_target, att_estimate):
180 | assert att_target.shape == att_estimate.shape, \
181 | 'target {} vs. predict {}'.format(att_target.shape, att_estimate.shape)
182 | diff = att_target - att_estimate
183 | diff[att_target == 0] = 0.
184 | diff = torch.sum(torch.pow(diff, 2), dim=0)
185 | num_samples = torch.count_nonzero(att_target, dim=0)
186 | diff_mean = torch.where(num_samples > 0, diff / num_samples, diff)
187 | loss_att = torch.sum(diff_mean) * self.scalar
188 | return loss_att
189 |
190 | class EMDLoss(nn.Module):
191 | def __init__(self, reduction='mean', r=2):
192 | super(EMDLoss, self).__init__()
193 | self.reduction = reduction
194 | self.r = r
195 |
196 | def forward(self, p_target, p_estimate, weight=None):
197 | assert p_target.shape == p_estimate.shape, \
198 | 'target {} vs. predict {}'.format(p_target.shape, p_estimate.shape)
199 | # cdf for values [1, 2, ..., 5]
200 | cdf_target = torch.cumsum(p_target, dim=-1)
201 | # cdf for values [1, 2, ..., 5]
202 | cdf_estimate = torch.cumsum(p_estimate, dim=-1)
203 | cdf_diff = cdf_estimate - cdf_target
204 | if self.r == 1:
205 | samplewise_emd = torch.mean(cdf_diff.abs(), dim=-1)
206 | else:
207 | abs_diff = torch.where(cdf_diff.abs() > 1e-15, cdf_diff.abs(), torch.ones_like(cdf_diff) * 1e-15)
208 | sqaures = torch.pow(abs_diff, self.r)
209 | average = torch.mean(sqaures, dim=-1)
210 | samplewise_emd = torch.pow(average, 1. / self.r)
211 | if weight is not None:
212 | samplewise_emd = samplewise_emd * weight
213 | if self.reduction == 'sum':
214 | return samplewise_emd.sum()
215 | elif self.reduction is None:
216 | return samplewise_emd
217 | else:
218 | return samplewise_emd.mean()
219 |
220 | if __name__ == '__main__':
221 | x = torch.randn(2,3,224,224)
222 | s = torch.randn(2,1,224,224)
223 | cfg = Config()
224 | model = SAMPNet(cfg)
225 | weight, attribute, score = model(x,s)
226 | if weight is not None:
227 | print('weight', weight.shape, F.softmax(weight,dim=1))
228 | if attribute is not None:
229 | print('attribute', attribute.shape, attribute)
230 | print('score', score.shape, score)
--------------------------------------------------------------------------------
/SAMPNet/test.py:
--------------------------------------------------------------------------------
1 | from samp_net import EMDLoss, SAMPNet
2 | from cadb_dataset import CADBDataset
3 | import torch
4 | from torch.utils.data import DataLoader
5 | import scipy.stats as stats
6 | import numpy as np
7 | from tqdm import tqdm
8 | from config import Config
9 |
10 | def calculate_accuracy(predict, target, threhold=2.6):
11 | assert target.shape == predict.shape, '{} vs. {}'.format(target.shape, predict.shape)
12 | bin_tar = target > threhold
13 | bin_pre = predict > threhold
14 | correct = (bin_tar == bin_pre).sum()
15 | acc = correct.float() / target.size(0)
16 | return correct,acc
17 |
18 | def calculate_lcc(target, predict):
19 | if len(target.shape) > 1:
20 | target = target.view(-1)
21 | if len(predict.shape) > 1:
22 | predict = predict.view(-1)
23 | predict = predict.cpu().numpy()
24 | target = target.cpu().numpy()
25 | lcc = np.corrcoef(predict, target)[0,1]
26 | return lcc
27 |
28 | def calculate_spearmanr(target, predict):
29 | if len(target.shape) > 1:
30 | target = target.view(-1)
31 | if len(predict.shape) > 1:
32 | predict = predict.view(-1)
33 | target_list = target.cpu().numpy().tolist()
34 | predict_list = predict.cpu().numpy().tolist()
35 | # sort_target = np.sort(target_list).tolist()
36 | # sort_predict = np.sort(predict_list).tolist()
37 | # pre_rank = []
38 | # for i in predict_list:
39 | # pre_rank.append(sort_predict.index(i))
40 | # tar_rank = []
41 | # for i in target_list:
42 | # tar_rank.append(sort_target.index(i))
43 | # rho,pval = stats.spearmanr(pre_rank, tar_rank)
44 | rho,_ = stats.spearmanr(predict_list, target_list)
45 | return rho
46 |
47 | def dist2ave(pred_dist):
48 | pred_score = torch.sum(pred_dist* torch.Tensor(range(1,6)).to(pred_dist.device), dim=-1, keepdim=True)
49 | return pred_score
50 |
51 | def evaluation_on_cadb(model, cfg):
52 | model.eval()
53 | device = next(model.parameters()).device
54 | testdataset = CADBDataset('test', cfg)
55 | testloader = DataLoader(testdataset,
56 | batch_size=cfg.batch_size,
57 | shuffle=False,
58 | num_workers=cfg.num_workers,
59 | drop_last=False)
60 | emd_r2_fn = EMDLoss(reduction='sum', r=2)
61 | emd_r1_fn = EMDLoss(reduction='sum', r=1)
62 | emd_r2_error = 0.0
63 | emd_r1_error = 0.0
64 | correct = 0.
65 | tar_scores = None
66 | pre_scores = None
67 | print()
68 | print('Evaluation begining...')
69 | with torch.no_grad():
70 | for (im,score,dist,saliency,attributes) in tqdm(testloader):
71 | image = im.to(device)
72 | score = score.to(device)
73 | dist = dist.to(device)
74 | saliency = saliency.to(device)
75 | weight, atts, output = model(image, saliency)
76 |
77 | pred_score = dist2ave(output)
78 | emd_r1_error += emd_r1_fn(dist, output).item()
79 | emd_r2_error += emd_r2_fn(dist, output).item()
80 | correct += calculate_accuracy(pred_score, score)[0].item()
81 | if tar_scores is None:
82 | tar_scores = score
83 | pre_scores = pred_score
84 | else:
85 | tar_scores = torch.cat([tar_scores, score], dim=0)
86 | pre_scores = torch.cat([pre_scores, pred_score], dim=0)
87 | print('Evaluation result...')
88 | # print('Scores shape', pre_scores.shape, tar_scores.shape)
89 | avg_mse = torch.nn.MSELoss()(pre_scores.view(-1), tar_scores.view(-1)).item()
90 | SRCC = calculate_spearmanr(tar_scores, pre_scores)
91 | LCC = calculate_lcc(tar_scores, pre_scores)
92 | avg_r1_emd = emd_r1_error / len(testdataset)
93 | avg_r2_emd = emd_r2_error / len(testdataset)
94 | avg_acc = correct / len(testdataset)
95 | ss = "Test on {} images, Accuracy={:.2%}, EMD(r=1)={:.4f}, EMD(r=2)={:.4f},". \
96 | format(len(testdataset), avg_acc, avg_r1_emd, avg_r2_emd)
97 | ss += " MSE_loss={:.4f}, SRCC={:.4f}, LCC={:.4f}". \
98 | format(avg_mse, SRCC, LCC)
99 | print(ss)
100 | return avg_acc, avg_r1_emd, avg_r2_emd, avg_mse, SRCC, LCC
101 |
102 | if __name__ == '__main__':
103 | cfg = Config()
104 | device = torch.device('cuda:{}'.format(cfg.gpu_id))
105 | model = SAMPNet(cfg,pretrained=False).to(device)
106 | weight_file = './pretrained_model/samp_net.pth'
107 | model.load_state_dict(torch.load(weight_file))
108 | evaluation_on_cadb(model, cfg)
--------------------------------------------------------------------------------
/SAMPNet/train.py:
--------------------------------------------------------------------------------
1 | import sys,os
2 | from torch.autograd import Variable
3 | import torch.optim as optim
4 | from tensorboardX import SummaryWriter
5 | import torch
6 | import time
7 | import shutil
8 | from torch.utils.data import DataLoader
9 | import csv
10 |
11 | from samp_net import EMDLoss, AttributeLoss, SAMPNet
12 | from config import Config
13 | from cadb_dataset import CADBDataset
14 | from test import evaluation_on_cadb
15 |
16 | def calculate_accuracy(predict, target, threhold=2.6):
17 | assert target.shape == predict.shape, '{} vs. {}'.format(target.shape, predict.shape)
18 | bin_tar = target > threhold
19 | bin_pre = predict > threhold
20 | correct = (bin_tar == bin_pre).sum()
21 | acc = correct.float() / target.size(0)
22 | return correct,acc
23 |
24 | def build_dataloader(cfg):
25 | trainset = CADBDataset('train', cfg)
26 | trainloader = DataLoader(trainset,
27 | batch_size=cfg.batch_size,
28 | shuffle=True,
29 | num_workers=cfg.num_workers,
30 | drop_last=False)
31 | return trainloader
32 |
33 | class Trainer(object):
34 | def __init__(self, model, cfg):
35 | self.cfg = cfg
36 | self.model = model
37 | self.device = torch.device('cuda:{}'.format(self.cfg.gpu_id))
38 | self.trainloader = build_dataloader(cfg)
39 | self.optimizer = self.create_optimizer()
40 | self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
41 | self.optimizer, mode='min', patience=5)
42 | self.epoch = 0
43 | self.iters = 0
44 |
45 | self.avg_mse = 0.
46 | self.avg_emd = 0.
47 | self.avg_acc = 0.
48 | self.avg_att = 0.
49 |
50 | self.smooth_coe = 0.4
51 | self.smooth_mse = None
52 | self.smooth_emd = None
53 | self.smooth_acc = None
54 | self.smooth_att = None
55 |
56 | self.mse_loss = torch.nn.MSELoss()
57 | self.emd_loss = EMDLoss()
58 |
59 | self.test_acc = []
60 | self.test_emd1 = []
61 | self.test_emd2 = []
62 | self.test_mse = []
63 | self.test_srcc = []
64 | self.test_lcc = []
65 |
66 | if cfg.use_attribute:
67 | self.att_loss = AttributeLoss(cfg.attribute_weight)
68 |
69 | self.least_metric = 1.
70 | self.writer = self.create_writer()
71 |
72 | def create_optimizer(self):
73 | # for param in self.model.backbone.parameters():
74 | # param.requires_grad = False
75 | bb_params = list(map(id, self.model.backbone.parameters()))
76 | lr_params = filter(lambda p:id(p) not in bb_params, self.model.parameters())
77 | params = [
78 | {'params': lr_params, 'lr': self.cfg.lr},
79 | {'params': self.model.backbone.parameters(), 'lr': self.cfg.lr * 0.01}
80 | ]
81 | if self.cfg.optimizer == 'adam':
82 | optimizer = optim.Adam(params,
83 | weight_decay=self.cfg.weight_decay)
84 | elif self.cfg.optimizer == 'sgd':
85 | optimizer = optim.SGD(params,
86 | momentum=self.cfg.momentum,
87 | weight_decay=self.cfg.weight_decay)
88 | else:
89 | raise ValueError(f"not such optimizer {self.cfg.optimizer}")
90 | return optimizer
91 |
92 | def create_writer(self):
93 | print('Create tensorboardX writer...', self.cfg.log_dir)
94 | writer = SummaryWriter(log_dir=self.cfg.log_dir)
95 | return writer
96 |
97 | def run(self):
98 | for epoch in range(self.cfg.max_epoch):
99 | self.run_epoch()
100 | self.epoch += 1
101 | self.scheduler.step(metrics=self.least_metric)
102 | self.writer.add_scalar('Train/lr', self.optimizer.param_groups[0]['lr'], self.epoch)
103 | if self.epoch % self.cfg.save_epoch == 0:
104 | checkpoint_path = os.path.join(self.cfg.checkpoint_dir, 'model-{epoch}.pth')
105 | torch.save(self.model.state_dict(), checkpoint_path.format(epoch=self.epoch))
106 | print('Save checkpoint...')
107 | if self.epoch % self.cfg.test_epoch == 0:
108 | test_emd = self.eval_training()
109 | if test_emd < self.least_metric:
110 | self.least_metric = test_emd
111 | checkpoint_path = os.path.join(self.cfg.checkpoint_dir, 'model-best.pth')
112 | torch.save(self.model.state_dict(), checkpoint_path)
113 | print('Update best checkpoint...')
114 | self.writer.add_scalar('Test/Least EMD', self.least_metric, self.epoch)
115 |
116 |
117 | def eval_training(self):
118 | avg_acc, avg_r1_emd, avg_r2_emd, avg_mse, SRCC, LCC = \
119 | evaluation_on_cadb(self.model, self.cfg)
120 | self.writer.add_scalar('Test/Average EMD(r=2)', avg_r2_emd, self.epoch)
121 | self.writer.add_scalar('Test/Average EMD(r=1)', avg_r1_emd, self.epoch)
122 | self.writer.add_scalar('Test/Average MSE', avg_mse, self.epoch)
123 | self.writer.add_scalar('Test/Accuracy', avg_acc, self.epoch)
124 | self.writer.add_scalar('Test/SRCC', SRCC, self.epoch)
125 | self.writer.add_scalar('Test/LCC', LCC, self.epoch)
126 | error = avg_r1_emd
127 |
128 | self.test_acc.append(avg_acc)
129 | self.test_emd1.append(avg_r1_emd)
130 | self.test_emd2.append(avg_r2_emd)
131 | self.test_mse.append(avg_mse)
132 | self.test_srcc.append(SRCC)
133 | self.test_lcc.append(LCC)
134 | self.write2csv()
135 | return error
136 |
137 | def write2csv(self):
138 | csv_path = os.path.join(self.cfg.exp_path, '..', '{}.csv'.format(self.cfg.exp_name))
139 | header = ['epoch', 'Accuracy', 'EMD r=1', 'EMD r=2', 'MSE', 'SRCC', 'LCC']
140 | epoches = list(range(len(self.test_acc)))
141 | metrics = [epoches, self.test_acc, self.test_emd1, self.test_emd2,
142 | self.test_mse, self.test_srcc, self.test_lcc]
143 | rows = [header]
144 | for i in range(len(epoches)):
145 | row = [m[i] for m in metrics]
146 | rows.append(row)
147 | for name, m in zip(header, metrics):
148 | if name == 'epoch':
149 | continue
150 | index = m.index(min(m))
151 | if name in ['Accuracy', 'SRCC', 'LCC']:
152 | index = m.index(max(m))
153 | title = 'best {} (epoch-{})'.format(name, index)
154 | row = [l[index] for l in metrics]
155 | row[0] = title
156 | rows.append(row)
157 | with open(csv_path, 'w') as f:
158 | cw = csv.writer(f)
159 | cw.writerows(rows)
160 | print('Save result to ', csv_path)
161 |
162 | def dist2ave(self, pred_dist):
163 | pred_score = torch.sum(pred_dist* torch.Tensor(range(1,6)).to(pred_dist.device), dim=-1, keepdim=True)
164 | return pred_score
165 |
166 | def run_epoch(self):
167 | self.model.train()
168 | for batch, data in enumerate(self.trainloader):
169 | self.iters += 1
170 | image = data[0].to(self.device)
171 | score = data[1].to(self.device)
172 | score_dist = data[2].to(self.device)
173 | saliency = data[3].to(self.device)
174 | attributes = data[4].to(self.device)
175 | weight = data[5].to(self.device)
176 |
177 | pred_weight, pred_atts, pred_dist = self.model(image, saliency)
178 |
179 | if self.cfg.use_weighted_loss:
180 | dist_loss = self.emd_loss(score_dist, pred_dist, weight)
181 | else:
182 | dist_loss = self.emd_loss(score_dist, pred_dist)
183 |
184 | if self.cfg.use_attribute:
185 | att_loss = self.att_loss(attributes, pred_atts)
186 | loss = dist_loss + att_loss
187 | else:
188 | loss = dist_loss
189 | self.optimizer.zero_grad()
190 | loss.backward()
191 | self.optimizer.step()
192 |
193 | self.avg_emd += dist_loss.item()
194 | self.avg_att += att_loss.item()
195 | pred_score = self.dist2ave(pred_dist)
196 | correct, accuracy = calculate_accuracy(pred_score, score)
197 | self.avg_acc += accuracy.item()
198 | if (self.iters+1) % self.cfg.display_steps == 0:
199 | print('ground truth: average={}'.format(score.view(-1)))
200 | print('prediction: average={}'.format(pred_score.view(-1)))
201 |
202 | self.avg_emd = self.avg_emd / self.cfg.display_steps
203 | self.avg_acc = self.avg_acc / self.cfg.display_steps
204 | if self.cfg.use_attribute:
205 | self.avg_att = self.avg_att / self.cfg.display_steps
206 |
207 | if self.smooth_emd != None:
208 | self.avg_emd = (1-self.smooth_coe) * self.avg_emd + self.smooth_coe * self.smooth_emd
209 | self.avg_acc = (1-self.smooth_coe) * self.avg_acc + self.smooth_coe * self.smooth_acc
210 | if self.cfg.use_attribute:
211 | self.avg_att = (1-self.smooth_coe) * self.avg_att + self.smooth_coe * self.smooth_att
212 | self.writer.add_scalar('Train/AttributeLoss', self.avg_att, self.iters)
213 |
214 | self.writer.add_scalar('Train/EMD_Loss', self.avg_emd, self.iters)
215 | self.writer.add_scalar('Train/Accuracy', self.avg_acc, self.iters)
216 |
217 | if self.cfg.use_attribute:
218 | print('Traning Epoch:{}/{} Current Batch: {}/{} EMD_Loss:{:.4f} Attribute_Loss:{:.4f} ACC:{:.2%} lr:{:.6f} '.
219 | format(
220 | self.epoch, self.cfg.max_epoch,
221 | batch, len(self.trainloader),
222 | self.avg_emd, self.avg_att,
223 | self.avg_acc,
224 | self.optimizer.param_groups[0]['lr']))
225 | else:
226 | print(
227 | 'Traning Epoch:{}/{} Current Batch: {}/{} EMD_Loss:{:.4f} ACC:{:.2%} lr:{:.6f} '.
228 | format(
229 | self.epoch, self.cfg.max_epoch,
230 | batch, len(self.trainloader),
231 | self.avg_emd, self.avg_acc,
232 | self.optimizer.param_groups[0]['lr']))
233 |
234 | self.smooth_emd = self.avg_emd
235 | self.smooth_acc = self.avg_acc
236 |
237 | self.avg_mse = 0.
238 | self.avg_emd = 0.
239 | self.avg_acc = 0.
240 | if self.cfg.use_attribute:
241 | self.smooth_att = self.avg_att
242 | self.avg_att = 0.
243 | print()
244 |
245 | if __name__ == '__main__':
246 | cfg = Config()
247 | cfg.create_path()
248 | device = torch.device('cuda:{}'.format(cfg.gpu_id))
249 | # evaluate(cfg)
250 | for file in os.listdir('./'):
251 | if file.endswith('.py'):
252 | shutil.copy(file, cfg.exp_path)
253 | print('Backup ', file)
254 |
255 | model = SAMPNet(cfg)
256 | model = model.train().to(device)
257 | trainer = Trainer(model, cfg)
258 | trainer.run()
--------------------------------------------------------------------------------
/annotations/visualize_cadb_annotation.py:
--------------------------------------------------------------------------------
1 | import json
2 | import cv2
3 | import os
4 | import numpy as np
5 | import shutil
6 | import argparse
7 | from tqdm import tqdm
8 | import math
9 |
10 | score_levels = ['[1,2)', '[2,3)', '[3,4)', '[4,5]']
11 | element_categories = ["center", "rule_of_thirds", "golden_ratio", "horizontal", "vertical", "diagonal",
12 | "curved", "fill_the_frame", "radial", "vanishing_point", "symmetric", "triangle",
13 | "pattern", 'none']
14 | scene_categories = ['animal', 'plant', 'human', 'static', 'architecture',
15 | 'landscape', 'cityscape', 'indoor', 'night', 'other']
16 |
17 | def draw_auxiliary_line(image, com_type):
18 | im_h, im_w, _ = image.shape
19 | if com_type in ['rule_of_thirds','center']:
20 | x1, x2 = int(im_w / 3), int(im_w / 3 * 2)
21 | y1, y2 = int(im_h / 3), int(im_h / 3 * 2)
22 | else:
23 | x1, x2 = int(im_w * 1. / 2.618), int(im_w * 1.618 / 2.618)
24 | y1, y2 = int(im_h * 1. / 2.618), int(im_h * 1.618 / 2.618)
25 | color = (255, 255, 255)
26 | line_width = 3
27 | cv2.line(image, (0, y1), (im_w, y1), color, line_width)
28 | cv2.line(image, (0, y2), (im_w, y2), color, line_width)
29 | cv2.line(image, (x1, 0), (x1, im_h), color, line_width)
30 | cv2.line(image, (x2, 0), (x2, im_h), color, line_width)
31 | return image
32 |
33 | def draw_element_on_image(image, com_type, element):
34 | pd_color = (0, 255, 255)
35 | if len(element) > 0:
36 | if com_type in ['center', 'rule_of_thirds', 'golden_ratio']:
37 | image = draw_auxiliary_line(image.copy(), com_type)
38 | for rect in element:
39 | x1,y1,x2,y2 = map(int, rect)
40 | cv2.rectangle(image, (x1,y1), (x2,y2), pd_color, 5)
41 | elif com_type in ['horizontal', 'diagonal', 'vertical']:
42 | image = draw_line_elements(image.copy(), com_type, element)
43 | else:
44 | element = np.array(element).astype(np.int32).reshape((len(element), -1, 2))
45 | for i in range(element.shape[0]):
46 | for j in range(element[i].shape[0] - 1):
47 | cv2.line(image, (element[i, j, 0], element[i, j, 1]), (element[i, j + 1, 0], element[i, j + 1, 1]),
48 | pd_color, 5)
49 |
50 | text = '{}'.format(com_type)
51 | cv2.putText(image, text, (20, 70), cv2.FONT_HERSHEY_COMPLEX, 1.5, pd_color, 3)
52 | return image
53 |
54 | def compute_angle(lines):
55 | lines = np.array(lines).reshape((-1,4))
56 | reverse_lines = np.concatenate([lines[:,2:], lines[:,0:2]], axis=1)
57 | l2r_points = np.where(lines[:,0:1] <= lines[:,2:3], lines, reverse_lines)
58 | angle = np.rad2deg(np.arctan2(l2r_points[:,3] - l2r_points[:,1], l2r_points[:,2] - l2r_points[:,0]))
59 | return np.abs(angle)
60 |
61 | def draw_line_elements(src, comp, element, vis_angle=False):
62 | im_h, im_w, _ = src.shape
63 | angle = compute_angle(element)
64 | element = np.array(element).astype(np.int32).reshape((-1, 4))
65 | color = (0,255,255)
66 | for i in range(element.shape[0]):
67 | cv2.line(src, (element[i, 0], element[i, 1]),
68 | (element[i, 2], element[i, 3]), color, 5)
69 | if vis_angle:
70 | text = '{:.1f}'.format(angle[i])
71 | tl_point = (element[i,0], element[i,1])
72 | br_point = (element[i,2], element[i,3])
73 | if element[i,0] > element[i,2] or \
74 | (element[i,0] == element[i,2] and element[i,1] > element[i,3]):
75 | tl_point = (element[i,2], element[i,3])
76 | br_point = (element[i,0], element[i,1])
77 | if tl_point[1] <= br_point[1]:
78 | pos_x = max(min(tl_point[0] + 20, im_w - 50), 0)
79 | pos_y = max(tl_point[1] - 10, 30)
80 | else:
81 | pos_x = max(min(tl_point[0] - 100, im_w - 50), 0)
82 | pos_y = max(min(tl_point[1] + 50, im_h - 50), 0)
83 | cv2.putText(src, text, (pos_x, pos_y), cv2.FONT_HERSHEY_COMPLEX, 1.2, color, 2)
84 | return src
85 |
86 | def get_parser():
87 | parser = argparse.ArgumentParser()
88 | parser.add_argument('--data_root',default='./CADB_Dataset/',
89 | help='path to images (should have subfolder images/)')
90 | parser.add_argument('--save_folder', default='./CADB_Dataset/Visualization/',type=str)
91 | return parser
92 |
93 | if __name__ == '__main__':
94 | parser = get_parser()
95 | opt, _ = parser.parse_known_args()
96 | image_dir = os.path.join(opt.data_root, 'images')
97 | assert os.path.exists(image_dir), image_dir
98 | element_file = os.path.join(opt.data_root, 'composition_elements.json')
99 | assert os.path.exists(element_file), element_file
100 | element_anno = json.load(open(element_file, 'r'))
101 |
102 | score_file = os.path.join(opt.data_root, 'composition_scores.json')
103 | assert os.path.exists(score_file), score_file
104 | score_anno = json.load(open(score_file, 'r'))
105 |
106 | scene_file = os.path.join(opt.data_root, 'scene_categories.json')
107 | assert os.path.exists(scene_file), scene_file
108 | scene_anno = json.load(open(scene_file, 'r'))
109 |
110 | score_dir = os.path.join(opt.save_folder, 'composition_scores')
111 | os.makedirs(score_dir, exist_ok=True)
112 |
113 | element_dir = os.path.join(opt.save_folder, 'composition_elements')
114 | os.makedirs(element_dir, exist_ok=True)
115 |
116 | scene_dir = os.path.join(opt.save_folder, 'scene_classification')
117 | os.makedirs(scene_dir, exist_ok=True)
118 |
119 | score_stats = {}
120 | for level in score_levels:
121 | score_stats[level] = []
122 |
123 | element_stats = {}
124 | for comp in element_categories:
125 | element_stats[comp] = []
126 |
127 | scene_stats = {}
128 | for scene in scene_categories:
129 | scene_stats[scene] = []
130 |
131 | total_num = 0
132 | for image_name, anno in tqdm(element_anno.items()):
133 | # read source image
134 | image_file = os.path.join(image_dir, image_name)
135 | assert os.path.exists(image_file), image_file
136 | src = cv2.imread(os.path.join(image_dir, image_name))
137 | im_h, im_w, _ = src.shape
138 | total_num += 1
139 | # store images to different subfolders according to mean score
140 | im_scores = score_anno[image_name]['scores']
141 | mean_score = float(sum(im_scores)) / len(im_scores)
142 | im_level = math.floor(mean_score) - 1 if mean_score < 5 else 3
143 | im_level = score_levels[im_level]
144 | score_stats[im_level].append(image_name)
145 | subfolder = os.path.join(score_dir, im_level)
146 | os.makedirs(subfolder, exist_ok=True)
147 | text = 'score:{:.1f}'.format(mean_score)
148 | dst = src.copy()
149 | cv2.putText(dst, text, (20, 70), cv2.FONT_HERSHEY_COMPLEX, 1.5, (0, 255, 255), 3)
150 | cv2.imwrite(os.path.join(subfolder, image_name), dst)
151 | # readout scene annotation
152 | per_scene = scene_anno[image_name]
153 | assert per_scene in scene_categories, per_scene
154 | os.makedirs(os.path.join(scene_dir, per_scene), exist_ok=True)
155 | cv2.imwrite(os.path.join(scene_dir, per_scene, image_name), src)
156 | scene_stats[per_scene].append(image_name)
157 | # visualize composition elements annotation
158 | for comp, element in anno.items():
159 | element_stats[comp].append(image_name)
160 | dst = draw_element_on_image(src.copy(), comp, element)
161 | subpath = os.path.join(element_dir, comp)
162 | if not os.path.exists(subpath):
163 | os.makedirs(subpath)
164 | cv2.imwrite(os.path.join(subpath, image_name), dst)
165 |
166 | # show dataset statistical information
167 | element_stats = sorted(element_stats.items(), key=lambda d: len(d[1]), reverse=True)
168 | scene_stats = sorted(scene_stats.items(), key=lambda d: len(d[1]), reverse=True)
169 |
170 | with open(os.path.join(opt.save_folder, 'statistics.txt'), 'w') as f:
171 | f.write('Composition Score\n')
172 | print('Composition score')
173 | for level, img_list in score_stats.items():
174 | ss = '{}: {} images, {:.1%}'.format(level, len(img_list), len(img_list) / total_num)
175 | f.write(ss + '\n')
176 | print(ss)
177 |
178 | print('\nComposition Element')
179 | f.write('\nComposition Element\n')
180 | for comp, img_list in element_stats:
181 | ss = '{}: {} images, {:.1%}'.format(comp, len(img_list), len(img_list) / total_num)
182 | f.write(ss + '\n')
183 | print(ss)
184 |
185 | print('\nScene Category')
186 | f.write('\nScene Category\n')
187 | for scene, img_list in scene_stats:
188 | ss = '{}: {} images, {:.1%}'.format(scene, len(img_list), len(img_list) / total_num)
189 | f.write(ss + '\n')
190 | print(ss)
191 |
192 | ss = 'Total number of images: {}'.format(total_num)
193 | print(ss + '\n')
194 | f.write(ss)
195 |
--------------------------------------------------------------------------------
/examples/annotation_example.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/Image-Composition-Assessment-Dataset-CADB/35c093bafdaaa98923d8ba093a73ddf0079ffbc9/examples/annotation_example.jpg
--------------------------------------------------------------------------------
/examples/element_examples.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/Image-Composition-Assessment-Dataset-CADB/35c093bafdaaa98923d8ba093a73ddf0079ffbc9/examples/element_examples.jpg
--------------------------------------------------------------------------------
/examples/example.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/Image-Composition-Assessment-Dataset-CADB/35c093bafdaaa98923d8ba093a73ddf0079ffbc9/examples/example.jpg
--------------------------------------------------------------------------------
/examples/interpretability.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/Image-Composition-Assessment-Dataset-CADB/35c093bafdaaa98923d8ba093a73ddf0079ffbc9/examples/interpretability.jpg
--------------------------------------------------------------------------------
/examples/samp_net.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/Image-Composition-Assessment-Dataset-CADB/35c093bafdaaa98923d8ba093a73ddf0079ffbc9/examples/samp_net.jpg
--------------------------------------------------------------------------------