├── LICENSE ├── README.md ├── config.py ├── croppingModel.py ├── cropping_dataset.py ├── make_all.sh ├── requirements.txt ├── rod_align ├── __init__.py ├── functions │ ├── __init__.py │ └── rod_align.py ├── make.sh ├── make_python2.sh ├── modules │ ├── __init__.py │ └── rod_align.py ├── setup.py └── src │ ├── rod_align.cpp │ ├── rod_align.h │ ├── rod_align_cuda.cpp │ ├── rod_align_cuda.h │ ├── rod_align_kernel.cu │ └── rod_align_kernel.h ├── roi_align ├── __init__.py ├── functions │ ├── __init__.py │ └── roi_align.py ├── make.sh ├── make_python2.sh ├── modules │ ├── __init__.py │ └── roi_align.py ├── setup.py └── src │ ├── roi_align.cpp │ ├── roi_align.h │ ├── roi_align_cuda.cpp │ ├── roi_align_cuda.h │ ├── roi_align_kernel.cu │ └── roi_align_kernel.h ├── test.py └── train.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 bo-zhang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CGS-Pytorch 2 | This is an unofficial PyTorch implementation of [Composing Good Shots by Exploiting Mutual Relations](https://openaccess.thecvf.com/content_CVPR_2020/html/Li_Composing_Good_Shots_by_Exploiting_Mutual_Relations_CVPR_2020_paper.html). 3 | 4 | # Results 5 | 6 | ## GAICD 7 | | #Metric | SRCC↑ | Acc5↑ | Acc10↑ | 8 | |:--:|:--:|:--:|:--:| 9 | | Paper | 0.795 | 59.7 | 77.8 | 10 | | This code (best SRCC) | 0.790 | 57.8 | 74.6 | 11 | | This code (best Acc) | 0.779 | 59.5 | 77.3 | 12 | 13 | I set the probability of mixing graph as 0.3 druing training, and scale the elements of adjacency matrix by the number of crops to produce more stable score prediction. 14 | 15 | ## HCDB 16 | | #Metric | IoU↑ | BDE↓ | 17 | |:--:|:--:|:--:| 18 | | Paper | 0.836 | 0.039 | 19 | | This code | 0.811 | 0.044 | 20 | 21 | # Datasets Preparation 22 | + GAICD [[link]](https://github.com/HuiZeng/Grid-Anchor-based-Image-Cropping) 23 | + HCDB (FLMS) [[Download Images]](http://fangchen.org/proj_page/FLMS_mm14/data/radomir500_image/image.tar) [[Download Annotation]](http://fangchen.org/proj_page/FLMS_mm14/data/radomir500_gt/release_data.tar) 24 | 25 | Download&Unzip these datasets, palce them like this: 26 | 27 | DATASET_FOLDER 28 | ├── GAICD 29 | │ └── images 30 | │ │ ├── image1.jpg 31 | │ │ └── image2.jpg 32 | │ └── annotations 33 | │ ├── image1.txt 34 | │ └── image2.txt 35 | └── FLMS 36 | └── image 37 | │ ├── image1.jpg 38 | │ └── image2.jpg 39 | └── 500_image_dataset.mat 40 | 41 | # Requirements 42 | - PyTorch>=1.0 43 | 44 | You can install packages using pip according to [``requirements.txt``](./requirements.txt): 45 | 46 | ```bash 47 | pip install -r requirements.txt 48 | ``` 49 | 50 | # Usage 51 | ```bash 52 | # clone this repository 53 | git clone https://github.com/bo-zhang-cs/CGS-Pytorch.git 54 | cd CGS-Pytorch 55 | ``` 56 | Change the default dataset folder in ``config.py`` and you can check the paths by running ``cropping_dataset.py``. 57 | 58 | ## Install RoIAlign and RoDAlign 59 | 60 | The source code of RoI&RoDAlign is from [[here]](https://github.com/lld533/Grid-Anchor-based-Image-Cropping-Pytorch) compatible with PyTorch 1.0 or later. 61 | If you use Pytorch 0.4.1, please refer to [[official implementation]](https://github.com/HuiZeng/Grid-Anchor-based-Image-Cropping-Pytorch). 62 | 63 | 1. Change the **CUDA_HOME** and **-arch=sm_86** in ``roi_align/make.sh`` and ``rod_align/make.sh`` according to your enviroment, respectively. 64 | 2. If you run this code in linux envoriment, make sure these bash files (``make_all.sh, roi_align/make.sh, rod_align/make.sh``) are Unix text file format by runing ``:set ff=unix`` in VIM. 65 | 3. ``cd CGS-Pytorch && sudo bash make_all.sh`` to build and install the packages. 66 | 67 | ## Test 68 | 69 | Download pretrained models (~75MB, ZIP format file) from [[Google Drive]](https://drive.google.com/file/d/1CmMBcQdOc22Qnyle0xJlbO6BC8tdViWN/view?usp=sharing) and unzip to the folder ``CGS-Pytorch/pretrained_model``. 70 | ``` 71 | python test.py 72 | ``` 73 | This will produce a folder ``results`` where you can find the predicted best crops. 74 | 75 | ## Train 76 | ``` 77 | python train.py 78 | ``` 79 | Track training process: 80 | ``` 81 | tensorboard --logdir=./experiments --bind_all 82 | ``` 83 | The model performance for each epoch is also recorded in *.csv* file under the produced folder *./experiments*. 84 | 85 | # Citation 86 | ``` 87 | @inproceedings{li2020composing, 88 | title={Composing good shots by exploiting mutual relations}, 89 | author={Li, Debang and Zhang, Junge and Huang, Kaiqi and Yang, Ming-Hsuan}, 90 | booktitle={CVPR}, 91 | year={2020} 92 | } 93 | @inproceedings{zeng2019reliable, 94 | title={Reliable and efficient image cropping: A grid anchor based approach}, 95 | author={Zeng, Hui and Li, Lida and Cao, Zisheng and Zhang, Lei}, 96 | booktitle={CVPR}, 97 | year={2019} 98 | } 99 | ``` 100 | 101 | ## More references about image cropping 102 | [Awesome Image Aesthetic Assessment and Cropping](https://github.com/bcmi/Awesome-Aesthetic-Evaluation-and-Cropping) 103 | 104 | ## Acknowledgments 105 | Thanks to [[GAIC]](https://github.com/HuiZeng/Grid-Anchor-based-Image-Cropping-Pytorch) and [[GAIC-Pytorch1.0+]](https://github.com/lld533/Grid-Anchor-based-Image-Cropping-Pytorch). 106 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | class Config: 4 | data_root = '/workspace/aesthetic_cropping/dataset/' 5 | predefined_pkl = os.path.join(data_root, 'pdefined_anchors.pkl') # download from https://github.com/luwr1022/listwise-view-ranking/blob/master/pdefined_anchors.pkl 6 | FLMS_folder = os.path.join(data_root, 'FLMS') 7 | GAIC_folder = os.path.join(data_root, 'GAICD') 8 | 9 | image_size = (256,256) 10 | backbone = 'vgg16' 11 | 12 | # training 13 | gpu_id = 0 14 | num_workers = 4 15 | batch_size = 1 16 | keep_aspect_ratio = True 17 | data_augmentation = True 18 | 19 | max_epoch = 50 20 | lr = 1e-4 21 | lr_decay = 0.1 22 | lr_decay_epoch = [max_epoch + 1] 23 | weight_decay = 1e-4 24 | eval_freq = 1 25 | save_freq = max_epoch+1 26 | display_freq = 100 27 | 28 | prefix = 'CGS' 29 | exp_root = os.path.join(os.getcwd(), './experiments/') 30 | exp_name = prefix 31 | exp_path = os.path.join(exp_root, prefix) 32 | while os.path.exists(exp_path): 33 | index = os.path.basename(exp_path).split(prefix)[-1].split('repeat')[-1] 34 | try: 35 | index = int(index) + 1 36 | except: 37 | index = 1 38 | exp_name = prefix + ('_repeat{}'.format(index)) 39 | exp_path = os.path.join(exp_root, exp_name) 40 | # print('Experiment name {} \n'.format(os.path.basename(exp_path))) 41 | checkpoint_dir = os.path.join(exp_path, 'checkpoints') 42 | log_dir = os.path.join(exp_path, 'logs') 43 | 44 | def create_path(self): 45 | print('Create experiment directory: ', self.exp_path) 46 | os.makedirs(self.exp_path) 47 | os.makedirs(self.checkpoint_dir) 48 | os.makedirs(self.log_dir) 49 | 50 | cfg = Config() 51 | 52 | if __name__ == '__main__': 53 | cfg = Config() -------------------------------------------------------------------------------- /croppingModel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.models as models 4 | import torch.nn.functional as F 5 | from roi_align.modules.roi_align import RoIAlignAvg 6 | from rod_align.modules.rod_align import RoDAlignAvg 7 | import warnings 8 | warnings.filterwarnings("ignore") 9 | 10 | class vgg_base(nn.Module): 11 | def __init__(self, loadweights=True): 12 | super(vgg_base, self).__init__() 13 | vgg = models.vgg16(pretrained=loadweights) 14 | self.feature3 = nn.Sequential(vgg.features[:23]) 15 | self.feature4 = nn.Sequential(vgg.features[23:30]) 16 | self.feature5 = nn.Sequential(vgg.features[30:]) 17 | #flops, params = profile(self.feature, input_size=(1, 3, 256,256)) 18 | 19 | def forward(self, x): 20 | f3 = self.feature3(x) 21 | f4 = self.feature4(f3) 22 | f5 = self.feature5(f4) 23 | return f3, f4, f5 24 | 25 | class RegionFeatureExtractor(nn.Module): 26 | def __init__(self, loadweight = True): 27 | super(RegionFeatureExtractor, self).__init__() 28 | alignsize = 9 29 | reddim = 32 30 | downsample = 4 31 | dim_in = 512 32 | 33 | self.Feat_ext = vgg_base(loadweight) 34 | self.DimRed = nn.Conv2d(1536, reddim, kernel_size=1, padding=0) 35 | self.downsample2 = nn.UpsamplingBilinear2d(scale_factor=1.0/2.0) 36 | self.upsample2 = nn.UpsamplingBilinear2d(scale_factor=2.0) 37 | self.RoIAlign = RoIAlignAvg(alignsize, alignsize, 1.0/2**downsample) 38 | self.RoDAlign = RoDAlignAvg(alignsize, alignsize, 1.0/2**downsample) 39 | self.FC_region = nn.Sequential( 40 | nn.Conv2d(reddim*2, 1024, kernel_size=alignsize, padding=0), 41 | nn.ReLU(True), 42 | nn.Conv2d(1024, dim_in, kernel_size=1), 43 | nn.ReLU(True), 44 | nn.Flatten(1)) 45 | self.FC_region.apply(weights_init) 46 | 47 | def forward(self, im_data, crops): 48 | # print(im_data.shape, im_data.dtype, im_data.device, crops.shape, crops.dtype, crops.device) 49 | B, N, _ = crops.shape 50 | if crops.shape[-1] == 4: 51 | index = torch.arange(B).view(-1, 1).repeat(1, N).reshape(B, N, 1).to(crops.device) 52 | crops = torch.cat((index, crops),dim=-1).contiguous() 53 | if crops.dim() == 3: 54 | crops = crops.flatten(0,1) 55 | 56 | f3,f4,f5 = self.Feat_ext(im_data) 57 | f3 = F.interpolate(f3, size=f4.shape[2:], mode='bilinear', align_corners=True) 58 | f5 = F.interpolate(f5, size=f4.shape[2:], mode='bilinear', align_corners=True) 59 | cat_feat = torch.cat((f3,f4,0.5*f5),1) 60 | red_feat = self.DimRed(cat_feat) 61 | 62 | RoI_feat = self.RoIAlign(red_feat, crops) 63 | RoD_feat = self.RoDAlign(red_feat, crops) 64 | fuse_feat = torch.cat((RoI_feat, RoD_feat), 1) 65 | region_feature = self.FC_region(fuse_feat) 66 | return region_feature 67 | 68 | 69 | class CroppingGraph(nn.Module): 70 | def __init__(self): 71 | super(CroppingGraph, self).__init__() 72 | dim_in = 512 73 | dim_out = 256 74 | self.Wm = nn.Linear(dim_in, dim_out, bias=False) 75 | self.Wn = nn.Linear(dim_in, dim_out, bias=False) 76 | self.Wr = nn.Linear(dim_in, dim_out, bias=False) 77 | self.feature_trans = nn.Linear(dim_in, dim_out, bias=False) 78 | self.feature_rg = nn.Linear(dim_out, dim_out) 79 | self.feature_lg = nn.Linear(dim_out, dim_out) 80 | self.prediction = nn.Linear(dim_out, 1) 81 | 82 | def forward(self, x): 83 | if x.dim() > 2: 84 | x = x.squeeze() 85 | assert x.dim() == 2, x.dim() 86 | xm = self.Wm(x) 87 | xn = self.Wn(x) 88 | # n,n,d 89 | diff = xm[:,None,:] - xn[None,:,:] 90 | diff = torch.pow(diff, 2) 91 | # n,n 92 | dist = torch.sqrt(torch.sum(diff, dim=-1)) / 2 93 | exps = torch.exp(-dist) 94 | eye_t = torch.eye(dist.shape[0]).to(dist.device) 95 | one_t = torch.ones_like(dist) 96 | exps = exps / (x.shape[0] / 64.) 97 | adj = exps * (one_t - eye_t) + eye_t 98 | 99 | # n,d 100 | xr = self.Wr(x) 101 | xr = torch.mm(adj, xr) 102 | xl = self.feature_trans(x) 103 | # fuse relation feature and local feature 104 | weight = torch.sigmoid(self.feature_rg(xr) + self.feature_lg(xl)) 105 | feat = (1 - weight) * xr + weight * xl 106 | score = self.prediction(feat) 107 | return adj,score 108 | 109 | def xavier(param): 110 | torch.nn.init.xavier_uniform_(param) 111 | 112 | def weights_init(m): 113 | if isinstance(m, nn.Conv2d): 114 | xavier(m.weight.data) 115 | m.bias.data.zero_() 116 | 117 | def cropping_regression_loss(pre_score, gt_score, score_mean): 118 | if pre_score.dim() > 1: 119 | pre_score = pre_score.reshape(-1) 120 | if gt_score.dim() > 1: 121 | gt_score = gt_score.reshape(-1) 122 | assert pre_score.shape == gt_score.shape, '{} vs. {}'.format(pre_score.shape, gt_score.shape) 123 | l1_loss = F.smooth_l1_loss(pre_score, gt_score, reduction='none') 124 | weight = torch.exp((gt_score - score_mean).clip(min=0,max=100)) 125 | reg_loss= torch.mean(weight * l1_loss) 126 | # reg_loss = F.smooth_l1_loss(pre_score, gt_score, reduction='mean') 127 | return reg_loss 128 | 129 | def cropping_rank_loss(pre_score, gt_score): 130 | ''' 131 | :param pre_score: 132 | :param gt_score: 133 | :return: 134 | ''' 135 | if pre_score.dim() > 1: 136 | pre_score = pre_score.reshape(-1) 137 | if gt_score.dim() > 1: 138 | gt_score = gt_score.reshape(-1) 139 | assert pre_score.shape == gt_score.shape, '{} vs. {}'.format(pre_score.shape, gt_score.shape) 140 | N = pre_score.shape[0] 141 | pair_num = N * (N-1) / 2 142 | pre_diff = pre_score[:,None] - pre_score[None,:] 143 | gt_diff = gt_score[:,None] - gt_score[None,:] 144 | indicat = -1 * torch.sign(gt_diff) * (pre_diff - gt_diff) 145 | diff = torch.maximum(indicat, torch.zeros_like(indicat)) 146 | rank_loss= torch.sum(diff) / pair_num 147 | return rank_loss 148 | 149 | def score_feature_correlation(gt_score, feat_adj): 150 | ''' 151 | :param gt_score: n 152 | :param feat_adj: n,n 153 | :return: 154 | ''' 155 | if gt_score.dim() > 1: 156 | gt_score = gt_score.reshape(-1) 157 | 158 | score_diff = torch.pow(gt_score[:,None] - gt_score[None,:],2) 159 | # n,n 160 | score_adj = torch.exp(-score_diff / 2) 161 | score_adj = score_adj - score_adj.mean() 162 | feat_adj = feat_adj - feat_adj.mean() 163 | corr_numer = torch.sum(score_adj * feat_adj) 164 | corr_demon = torch.pow(score_adj, 2).sum() * torch.pow(feat_adj, 2).sum() 165 | corr_demon = torch.sqrt(corr_demon + 1e-12) 166 | corr = corr_numer / corr_demon 167 | return corr 168 | 169 | if __name__ == '__main__': 170 | net = RegionFeatureExtractor(loadweight=False) 171 | net = net.eval().cuda() 172 | roi = torch.randint(0, 224, (1,64,4)).float().cuda() 173 | img = torch.randn((1, 3, 256, 256)).cuda() 174 | print(roi.shape, img.shape) 175 | out = net(img, roi) 176 | print(out.shape, out) 177 | # print(out.shape) 178 | # gnn = CroppingGraph().cuda() 179 | # adj,score = gnn(out) 180 | # print(adj.shape,adj) 181 | # print(score.shape, score) 182 | 183 | # gt_score = torch.tensor([1.,2.]).cuda() 184 | # pr_score = torch.randn(2,1).cuda() 185 | # print('rank loss', cropping_rank_loss(pr_score, gt_score)) 186 | # print('reg loss', cropping_regression_loss(pr_score, gt_score, 3)) 187 | # print('corr', score_feature_correlation(pr_score, adj)) 188 | 189 | -------------------------------------------------------------------------------- /cropping_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image, ImageOps 4 | from torch.utils.data import DataLoader, Dataset 5 | import torchvision.transforms as transforms 6 | import random 7 | from config import cfg 8 | 9 | MOS_MEAN = 2.95 10 | MOS_STD = 0.8 11 | IMAGE_NET_MEAN = [0.485, 0.456, 0.406] 12 | IMAGE_NET_STD = [0.229, 0.224, 0.225] 13 | 14 | def rescale_crops(boxes, ratio_w, ratio_h): 15 | boxes = np.array(boxes).reshape(-1, 4) 16 | boxes[:, 0] = np.floor(boxes[:, 0] * ratio_w) 17 | boxes[:, 1] = np.floor(boxes[:, 1] * ratio_h) 18 | boxes[:, 2] = np.ceil(boxes[:, 2] * ratio_w) 19 | boxes[:, 3] = np.ceil(boxes[:, 3] * ratio_h) 20 | return boxes.astype(np.float32) 21 | 22 | def is_number(s): 23 | if not isinstance(s, str): 24 | return False 25 | if s.isdigit(): 26 | return True 27 | else: 28 | try: 29 | float(s) 30 | return True 31 | except: 32 | return False 33 | 34 | class FLMSDataset(Dataset): 35 | def __init__(self): 36 | self.keep_aspect = cfg.keep_aspect_ratio 37 | self.data_dir = cfg.FLMS_folder 38 | assert os.path.exists(self.data_dir), self.data_dir 39 | self.image_dir = os.path.join(self.data_dir, 'image') 40 | assert os.path.exists(self.image_dir), self.image_dir 41 | self.annos = self.parse_annotations() 42 | self.image_list = list(self.annos.keys()) 43 | self.image_transformer = transforms.Compose([ 44 | transforms.ToTensor(), 45 | transforms.Normalize(mean=IMAGE_NET_MEAN, std=IMAGE_NET_STD)]) 46 | 47 | def parse_annotations(self): 48 | image_crops_file = os.path.join(self.data_dir, '500_image_dataset.mat') 49 | assert os.path.exists(image_crops_file), image_crops_file 50 | import scipy.io as scio 51 | image_crops = dict() 52 | anno = scio.loadmat(image_crops_file) 53 | for i in range(anno['img_gt'].shape[0]): 54 | image_name = anno['img_gt'][i, 0][0][0] 55 | gt_crops = anno['img_gt'][i, 0][1] 56 | gt_crops = gt_crops[:, [1, 0, 3, 2]] 57 | keep_index = np.where((gt_crops < 0).sum(1) == 0) 58 | gt_crops = gt_crops[keep_index].tolist() 59 | image_crops[image_name] = gt_crops 60 | print('{} images'.format(len(image_crops))) 61 | return image_crops 62 | 63 | def __len__(self): 64 | return len(self.image_list) 65 | 66 | def __getitem__(self, index): 67 | image_name = self.image_list[index] 68 | image_file = os.path.join(self.image_dir, image_name) 69 | image = Image.open(image_file).convert('RGB') 70 | im_width, im_height = image.size 71 | if self.keep_aspect: 72 | scale = float(cfg.image_size[0]) / min(im_height, im_width) 73 | h = round(im_height * scale / 32.0) * 32 74 | w = round(im_width * scale / 32.0) * 32 75 | else: 76 | h = cfg.image_size[1] 77 | w = cfg.image_size[0] 78 | resized_image = image.resize((w, h), Image.ANTIALIAS) 79 | im = self.image_transformer(resized_image) 80 | crop = self.annos[image_name] 81 | crop = np.array(crop).reshape(-1, 4).astype(np.float32) 82 | return im, crop, im_width, im_height, image_file 83 | 84 | class GAICDataset(Dataset): 85 | def __init__(self, split): 86 | self.split = split 87 | assert self.split in ['train', 'test'], self.split 88 | self.keep_aspect = cfg.keep_aspect_ratio 89 | self.data_dir = cfg.GAIC_folder 90 | assert os.path.exists(self.data_dir), self.data_dir 91 | self.image_dir = os.path.join(self.data_dir, 'images', split) 92 | assert os.path.exists(self.image_dir), self.image_dir 93 | self.image_list = [file for file in os.listdir(self.image_dir) if file.endswith('.jpg')] 94 | # print('GAICD {} set contains {} images'.format(split, len(self.image_list))) 95 | self.anno_dir = os.path.join(self.data_dir, 'annotations') 96 | assert os.path.exists(self.anno_dir), self.anno_dir 97 | self.annos = self.parse_annotations() 98 | 99 | self.image_size = cfg.image_size 100 | self.augmentation = (cfg.data_augmentation and self.split == 'train') 101 | self.PhotometricDistort = transforms.ColorJitter( 102 | brightness=0.125, contrast=0.5, saturation=0.5, hue=0.05) 103 | self.image_transformer = transforms.Compose([ 104 | transforms.ToTensor(), 105 | transforms.Normalize(mean=IMAGE_NET_MEAN, std=IMAGE_NET_STD)]) 106 | 107 | def parse_annotations(self): 108 | image_annos = dict() 109 | for image_name in self.image_list: 110 | anno_file = os.path.join(self.anno_dir, image_name.replace('.jpg', '.txt')) 111 | assert os.path.exists(anno_file), anno_file 112 | with open(anno_file, 'r') as f: 113 | crops,scores = [],[] 114 | for line in f.readlines(): 115 | line = line.strip().split(' ') 116 | values = [s for s in line if is_number(s)] 117 | y1,x1,y2,x2 = [int(s) for s in values[0:4]] 118 | s = float(values[-1]) 119 | if s > -2: 120 | crops.append([x1,y1,x2,y2]) 121 | scores.append(s) 122 | if len(crops) == 0: 123 | print(image_name, anno_file) 124 | else: 125 | # rank all crops 126 | rank = sorted(range(len(scores)), key=lambda k: scores[k], reverse=True) 127 | scores = [scores[i] for i in rank] 128 | crops = [crops[i] for i in rank] 129 | image_annos[image_name] = {'crops':crops, 'scores':scores} 130 | return image_annos 131 | 132 | def __len__(self): 133 | return len(self.image_list) 134 | 135 | def __getitem__(self, index): 136 | image_name = self.image_list[index] 137 | image_file = os.path.join(self.image_dir, image_name) 138 | image = Image.open(image_file).convert('RGB') 139 | im_width, im_height = image.size 140 | if self.keep_aspect: 141 | scale = float(cfg.image_size[0]) / min(im_height, im_width) 142 | h = round(im_height * scale / 32.0) * 32 143 | w = round(im_width * scale / 32.0) * 32 144 | else: 145 | h = cfg.image_size[1] 146 | w = cfg.image_size[0] 147 | resized_image = image.resize((w, h), Image.ANTIALIAS) 148 | crop = self.annos[image_name]['crops'] 149 | rs_width, rs_height = resized_image.size 150 | ratio_w = float(rs_width) / im_width 151 | ratio_h = float(rs_height) / im_height 152 | crop = rescale_crops(crop, ratio_w, ratio_h) 153 | score = np.array(self.annos[image_name]['scores']).reshape((-1)).astype(np.float32) 154 | if self.augmentation: 155 | if random.uniform(0,1) > 0.5: 156 | resized_image = ImageOps.mirror(resized_image) 157 | temp_x1 = crop[:, 0].copy() 158 | crop[:, 0] = rs_width - crop[:, 2] 159 | crop[:, 2] = rs_width - temp_x1 160 | resized_image = self.PhotometricDistort(resized_image) 161 | im = self.image_transformer(resized_image) 162 | return im, crop, score, im_width, im_height, image_file 163 | 164 | if __name__ == '__main__': 165 | # FLMS_testset = FLMSDataset() 166 | # print('FLMS testset has {} images'.format(len(FLMS_testset))) 167 | # dataloader = DataLoader(FLMS_testset, batch_size=1, num_workers=4) 168 | # for batch_idx, data in enumerate(dataloader): 169 | # im, crop, w, h, file = data 170 | # print(im.shape, crop.shape, w.shape, h.shape) 171 | # print(crop[0,:,2].max(), crop[0,:,3].max(), w[0], h[0]) 172 | # break 173 | 174 | GAICD_testset = GAICDataset(split='train') 175 | print('GAICD training set has {} images'.format(len(GAICD_testset))) 176 | dataloader = DataLoader(GAICD_testset, batch_size=1, num_workers=0) 177 | for batch_idx, data in enumerate(dataloader): 178 | im, crops, scores, w, h, file = data 179 | print(im.shape, crops.shape, scores.shape, w.shape, h.shape) -------------------------------------------------------------------------------- /make_all.sh: -------------------------------------------------------------------------------- 1 | cd ./roi_align 2 | bash make.sh 3 | 4 | cd ../rod_align 5 | bash make.sh 6 | 7 | cd .. 8 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.19.1 2 | opencv_contrib_python==4.4.0.46 3 | Pillow==8.4.0 4 | scipy==1.5.2 5 | setuptools==52.0.0.post20210125 6 | tensorboardX==2.4 7 | torch==1.9.1 8 | torchvision==0.9.0+cu111 9 | tqdm==4.51.0 10 | -------------------------------------------------------------------------------- /rod_align/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bo-zhang-cs/CGS-Pytorch/96a7db46ef7e671619a0dcc4746a36dbb4e897c6/rod_align/__init__.py -------------------------------------------------------------------------------- /rod_align/functions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bo-zhang-cs/CGS-Pytorch/96a7db46ef7e671619a0dcc4746a36dbb4e897c6/rod_align/functions/__init__.py -------------------------------------------------------------------------------- /rod_align/functions/rod_align.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | import rod_align_api 4 | 5 | class RoDAlignFunction(Function): 6 | @staticmethod 7 | def forward(ctx, features, rois, aligned_width, aligned_height, spatial_scale): 8 | batch_size, num_channels, data_height, data_width = features.size() 9 | ctx.save_for_backward(rois, 10 | torch.IntTensor([int(batch_size), 11 | int(num_channels), 12 | int(data_height), 13 | int(data_width), 14 | int(aligned_width), 15 | int(aligned_height)]), 16 | torch.FloatTensor([float(spatial_scale)])) 17 | 18 | num_rois = rois.size(0) 19 | 20 | output = features.new(num_rois, 21 | num_channels, 22 | int(aligned_height), 23 | int(aligned_width)).zero_() 24 | 25 | rod_align_api.forward(int(aligned_height), 26 | int(aligned_width), 27 | float(spatial_scale), 28 | features, 29 | rois, output) 30 | 31 | return output 32 | 33 | @staticmethod 34 | def backward(ctx, grad_output): 35 | rois, core_size, scale = ctx.saved_tensors 36 | spatial_scale = scale[0] 37 | 38 | batch_size, num_channels, data_height, data_width, aligned_width, aligned_height = core_size 39 | 40 | grad_input = rois.new(batch_size, 41 | num_channels, 42 | data_height, 43 | data_width).zero_() 44 | 45 | rod_align_api.backward(int(aligned_height), 46 | int(aligned_width), 47 | float(spatial_scale), 48 | grad_output, 49 | rois, 50 | grad_input) 51 | 52 | return grad_input, None, None, None, None 53 | -------------------------------------------------------------------------------- /rod_align/make.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | cd src 3 | echo "Compiling rod_align kernels by nvcc..." 4 | 5 | # Specify the architecture of your NV card below. 6 | # -arch=sm_75 is compatible with the following NV GPU cards, 7 | # GeForce RTX 2080 Ti, RTX 2080, RTX 2070 Quadro RTX 8000, Quadro RTX 6000, Quadro RTX 5000 Tesla T4 8 | # See more at https://raw.githubusercontent.com/stereolabs/zed-yolo/master/libdarknet/Makefile 9 | nvcc -c -o rod_align_kernel.cu.o rod_align_kernel.cu -x cu -Xcompiler -fPIC -arch=sm_86 10 | 11 | cd ../ 12 | # Export CUDA_HOME. And build and install the library. 13 | export CUDA_HOME=/usr/local/cuda && python3 setup.py install 14 | 15 | -------------------------------------------------------------------------------- /rod_align/make_python2.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | cd src 3 | echo "Compiling rod_align kernels by nvcc..." 4 | 5 | # Specify the architecture of your NV card below. 6 | # -arch=sm_75 is compatible with the following NV GPU cards, 7 | # GeForce RTX 2080 Ti, RTX 2080, RTX 2070 Quadro RTX 8000, Quadro RTX 6000, Quadro RTX 5000 Tesla T4 8 | # See more at https://raw.githubusercontent.com/stereolabs/zed-yolo/master/libdarknet/Makefile 9 | nvcc -c -o rod_align_kernel.cu.o rod_align_kernel.cu -x cu -Xcompiler -fPIC -arch=sm_75 10 | 11 | cd ../ 12 | # Export CUDA_HOME. And build and install the library. 13 | export CUDA_HOME=/usr/local/cuda-10.0 && python setup.py install 14 | 15 | -------------------------------------------------------------------------------- /rod_align/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bo-zhang-cs/CGS-Pytorch/96a7db46ef7e671619a0dcc4746a36dbb4e897c6/rod_align/modules/__init__.py -------------------------------------------------------------------------------- /rod_align/modules/rod_align.py: -------------------------------------------------------------------------------- 1 | from torch.nn.modules.module import Module 2 | from torch.nn.functional import avg_pool2d, max_pool2d 3 | from ..functions.rod_align import RoDAlignFunction 4 | 5 | 6 | class RoDAlign(Module): 7 | def __init__(self, aligned_height, aligned_width, spatial_scale): 8 | super(RoDAlign, self).__init__() 9 | 10 | self.aligned_width = int(aligned_width) 11 | self.aligned_height = int(aligned_height) 12 | self.spatial_scale = float(spatial_scale) 13 | 14 | def forward(self, features, rois): 15 | return RoDAlignFunction.apply(features, 16 | rois, 17 | self.aligned_height, 18 | self.aligned_width, 19 | self.spatial_scale) 20 | 21 | class RoDAlignAvg(Module): 22 | def __init__(self, aligned_height, aligned_width, spatial_scale): 23 | super(RoDAlignAvg, self).__init__() 24 | 25 | self.aligned_width = int(aligned_width) 26 | self.aligned_height = int(aligned_height) 27 | self.spatial_scale = float(spatial_scale) 28 | 29 | def forward(self, features, rois): 30 | x = RoDAlignFunction.apply(features, 31 | rois, 32 | self.aligned_height+1, 33 | self.aligned_width+1, 34 | self.spatial_scale) 35 | return avg_pool2d(x, kernel_size=2, stride=1) 36 | 37 | class RoDAlignMax(Module): 38 | def __init__(self, aligned_height, aligned_width, spatial_scale): 39 | super(RoDAlignMax, self).__init__() 40 | 41 | self.aligned_width = int(aligned_width) 42 | self.aligned_height = int(aligned_height) 43 | self.spatial_scale = float(spatial_scale) 44 | 45 | def forward(self, features, rois): 46 | x = RoDAlignFunction.apply(features, 47 | rois, 48 | self.aligned_height+1, 49 | self.aligned_width+1, 50 | self.spatial_scale) 51 | return max_pool2d(x, kernel_size=2, stride=1) 52 | -------------------------------------------------------------------------------- /rod_align/setup.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import torch 4 | from pkg_resources import parse_version 5 | 6 | min_version = parse_version('1.0.0') 7 | current_version = parse_version(torch.__version__) 8 | 9 | 10 | if current_version < min_version: #PyTorch before 1.0 11 | from torch.utils.ffi import create_extension 12 | 13 | sources = ['src/roi_align.c'] 14 | headers = ['src/roi_align.h'] 15 | extra_objects = [] 16 | 17 | defines = [] 18 | with_cuda = False 19 | 20 | this_file = os.path.dirname(os.path.realpath(__file__)) 21 | print(this_file) 22 | 23 | if torch.cuda.is_available(): 24 | print('Including CUDA code.') 25 | sources += ['src/rod_align_cuda.c'] 26 | headers += ['src/rod_align_cuda.h'] 27 | defines += [('WITH_CUDA', None)] 28 | with_cuda = True 29 | 30 | extra_objects = ['src/rod_align_kernel.cu.o'] 31 | extra_objects = [os.path.join(this_file, fname) for fname in extra_objects] 32 | 33 | ffi = create_extension( 34 | '_ext.rod_align', 35 | headers=headers, 36 | sources=sources, 37 | define_macros=defines, 38 | relative_to=__file__, 39 | with_cuda=with_cuda, 40 | extra_objects=extra_objects 41 | ) 42 | 43 | if __name__ == '__main__': 44 | ffi.build() 45 | else: # PyTorch 1.0 or later 46 | from setuptools import setup 47 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 48 | 49 | print('Including CUDA code.') 50 | 51 | current_dir = os.path.dirname(os.path.realpath(__file__)) 52 | 53 | setup( 54 | name='rod_align_api', 55 | ext_modules=[ 56 | CUDAExtension( 57 | name='rod_align_api', 58 | sources=['src/rod_align_cuda.cpp', 'src/rod_align_kernel.cu'], 59 | include_dirs=[current_dir]+torch.utils.cpp_extension.include_paths(cuda=True) 60 | ) 61 | ], 62 | cmdclass={ 63 | 'build_ext': BuildExtension 64 | }) 65 | 66 | -------------------------------------------------------------------------------- /rod_align/src/rod_align.cpp: -------------------------------------------------------------------------------- 1 | #include "rod_align.h" 2 | 3 | void RODAlignForwardCpu(const float* bottom_data, const float spatial_scale, const int num_rois, 4 | const int height, const int width, const int channels, 5 | const int aligned_height, const int aligned_width, const float * bottom_rois, 6 | float* top_data); 7 | 8 | void RODAlignBackwardCpu(const float* top_diff, const float spatial_scale, const int num_rois, 9 | const int height, const int width, const int channels, 10 | const int aligned_height, const int aligned_width, const float * bottom_rois, 11 | float* top_data); 12 | 13 | int rod_align_forward(int aligned_height, int aligned_width, float spatial_scale, 14 | torch::Tensor features, torch::Tensor rois, torch::Tensor output) 15 | { 16 | //Grab the input tensor 17 | //float * data_flat = THFloatTensor_data(features); 18 | //float * rois_flat = THFloatTensor_data(rois); 19 | auto data_flat = features.data(); 20 | auto rois_flat = rois.data(); 21 | 22 | //float * output_flat = THFloatTensor_data(output); 23 | auto output_flat = output.data(); 24 | 25 | // Number of ROIs 26 | //int num_rois = THFloatTensor_size(rois, 0); 27 | //int size_rois = THFloatTensor_size(rois, 1); 28 | auto rois_sz = rois.sizes(); 29 | int num_rois = rois_sz[0]; 30 | int size_rois = rois_sz[1]; 31 | 32 | if (size_rois != 5) 33 | { 34 | return 0; 35 | } 36 | 37 | // data height 38 | //int data_height = THFloatTensor_size(features, 2); 39 | // data width 40 | //int data_width = THFloatTensor_size(features, 3); 41 | // Number of channels 42 | //int num_channels = THFloatTensor_size(features, 1); 43 | auto feat_sz = features.sizes(); 44 | int data_height = feat_sz[2]; 45 | int data_width = feat_sz[3]; 46 | int num_channels = feat_sz[1]; 47 | 48 | // do ROIAlignForward 49 | RODAlignForwardCpu(data_flat, spatial_scale, num_rois, data_height, data_width, num_channels, 50 | aligned_height, aligned_width, rois_flat, output_flat); 51 | 52 | return 1; 53 | } 54 | 55 | int rod_align_backward(int aligned_height, int aligned_width, float spatial_scale, 56 | torch::Tensor top_grad, torch::Tensor rois, torch::Tensor bottom_grad) 57 | { 58 | //Grab the input tensor 59 | //float * top_grad_flat = THFloatTensor_data(top_grad); 60 | //float * rois_flat = THFloatTensor_data(rois); 61 | 62 | //float * bottom_grad_flat = THFloatTensor_data(bottom_grad); 63 | 64 | auto top_grad_flat = top_grad.data(); 65 | auto rois_flat = rois.data(); 66 | auto bottom_grad_flat = bottom_grad.data(); 67 | 68 | 69 | // Number of ROIs 70 | //int num_rois = THFloatTensor_size(rois, 0); 71 | //int size_rois = THFloatTensor_size(rois, 1); 72 | 73 | auto rois_sz = rois.sizes(); 74 | int num_rois = rois_sz[0]; 75 | int size_rois = rois_sz[1]; 76 | 77 | if (size_rois != 5) 78 | { 79 | return 0; 80 | } 81 | 82 | // batch size 83 | // int batch_size = THFloatTensor_size(bottom_grad, 0); 84 | // data height 85 | //int data_height = THFloatTensor_size(bottom_grad, 2); 86 | // data width 87 | //int data_width = THFloatTensor_size(bottom_grad, 3); 88 | // Number of channels 89 | //int num_channels = THFloatTensor_size(bottom_grad, 1); 90 | auto grad_sz = bottom_grad.sizes(); 91 | int data_height = grad_sz[2]; 92 | int data_width = grad_sz[3]; 93 | int num_channels = grad_sz[1]; 94 | 95 | // do ROIAlignBackward 96 | RODAlignBackwardCpu(top_grad_flat, spatial_scale, num_rois, data_height, 97 | data_width, num_channels, aligned_height, aligned_width, rois_flat, bottom_grad_flat); 98 | 99 | return 1; 100 | } 101 | 102 | void RODAlignForwardCpu(const float* bottom_data, const float spatial_scale, const int num_rois, 103 | const int height, const int width, const int channels, 104 | const int aligned_height, const int aligned_width, const float * bottom_rois, 105 | float* top_data) 106 | { 107 | const int output_size = num_rois * aligned_height * aligned_width * channels; 108 | 109 | int idx = 0; 110 | float bin_size_h = (float)(height - 1.001) / (aligned_height - 1.); 111 | float bin_size_w = (float)(width - 1.001) / (aligned_width - 1.); 112 | for (idx = 0; idx < output_size; ++idx) 113 | { 114 | // (n, c, ph, pw) is an element in the aligned output 115 | int pw = idx % aligned_width; 116 | int ph = (idx / aligned_width) % aligned_height; 117 | int c = (idx / aligned_width / aligned_height) % channels; 118 | int n = idx / aligned_width / aligned_height / channels; 119 | 120 | float roi_batch_ind = bottom_rois[n * 5 + 0]; 121 | float roi_start_w = bottom_rois[n * 5 + 1] * spatial_scale; 122 | float roi_start_h = bottom_rois[n * 5 + 2] * spatial_scale; 123 | float roi_end_w = bottom_rois[n * 5 + 3] * spatial_scale; 124 | float roi_end_h = bottom_rois[n * 5 + 4] * spatial_scale; 125 | 126 | 127 | float h = (float)(ph) * bin_size_h; 128 | float w = (float)(pw) * bin_size_w; 129 | 130 | int hstart = fminf(floor(h), height - 2); 131 | int wstart = fminf(floor(w), width - 2); 132 | 133 | int img_start = roi_batch_ind * channels * height * width; 134 | 135 | // bilinear interpolation 136 | if (h >= roi_start_h && h <= roi_end_h && w >= roi_start_w && w <= roi_end_w){ 137 | top_data[idx] = 0.; 138 | } else { 139 | float h_ratio = h - (float)(hstart); 140 | float w_ratio = w - (float)(wstart); 141 | int upleft = img_start + (c * height + hstart) * width + wstart; 142 | int upright = upleft + 1; 143 | int downleft = upleft + width; 144 | int downright = downleft + 1; 145 | 146 | top_data[idx] = bottom_data[upleft] * (1. - h_ratio) * (1. - w_ratio) 147 | + bottom_data[upright] * (1. - h_ratio) * w_ratio 148 | + bottom_data[downleft] * h_ratio * (1. - w_ratio) 149 | + bottom_data[downright] * h_ratio * w_ratio; 150 | } 151 | } 152 | } 153 | 154 | void RODAlignBackwardCpu(const float* top_diff, const float spatial_scale, const int num_rois, 155 | const int height, const int width, const int channels, 156 | const int aligned_height, const int aligned_width, const float * bottom_rois, 157 | float* bottom_diff) 158 | { 159 | const int output_size = num_rois * aligned_height * aligned_width * channels; 160 | 161 | int idx = 0; 162 | float bin_size_h = (float)(height - 1.001) / (aligned_height - 1.); 163 | float bin_size_w = (float)(width - 1.001) / (aligned_width - 1.); 164 | for (idx = 0; idx < output_size; ++idx) 165 | { 166 | // (n, c, ph, pw) is an element in the aligned output 167 | int pw = idx % aligned_width; 168 | int ph = (idx / aligned_width) % aligned_height; 169 | int c = (idx / aligned_width / aligned_height) % channels; 170 | int n = idx / aligned_width / aligned_height / channels; 171 | 172 | float roi_batch_ind = bottom_rois[n * 5 + 0]; 173 | float roi_start_w = bottom_rois[n * 5 + 1] * spatial_scale; 174 | float roi_start_h = bottom_rois[n * 5 + 2] * spatial_scale; 175 | float roi_end_w = bottom_rois[n * 5 + 3] * spatial_scale; 176 | float roi_end_h = bottom_rois[n * 5 + 4] * spatial_scale; 177 | 178 | float h = (float)(ph) * bin_size_h; 179 | float w = (float)(pw) * bin_size_w; 180 | 181 | int hstart = fminf(floor(h), height - 2); 182 | int wstart = fminf(floor(w), width - 2); 183 | 184 | int img_start = roi_batch_ind * channels * height * width; 185 | 186 | // bilinear interpolation 187 | if (!(h >= roi_start_h && h <= roi_end_h && w >= roi_start_w && w <= roi_end_w)) { 188 | float h_ratio = h - (float)(hstart); 189 | float w_ratio = w - (float)(wstart); 190 | int upleft = img_start + (c * height + hstart) * width + wstart; 191 | int upright = upleft + 1; 192 | int downleft = upleft + width; 193 | int downright = downleft + 1; 194 | 195 | bottom_diff[upleft] += top_diff[idx] * (1. - h_ratio) * (1. - w_ratio); 196 | bottom_diff[upright] += top_diff[idx] * (1. - h_ratio) * w_ratio; 197 | bottom_diff[downleft] += top_diff[idx] * h_ratio * (1. - w_ratio); 198 | bottom_diff[downright] += top_diff[idx] * h_ratio * w_ratio; 199 | } 200 | } 201 | } 202 | 203 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 204 | m.def("forward", &rod_align_forward, "rod_align forward"); 205 | m.def("backward", &rod_align_backward, "rod_align backward"); 206 | } 207 | -------------------------------------------------------------------------------- /rod_align/src/rod_align.h: -------------------------------------------------------------------------------- 1 | #ifndef ROD_ALIGN_H 2 | #define ROD_ALIGN_H 3 | 4 | #include 5 | 6 | int rod_align_forward(int aligned_height, int aligned_width, float spatial_scale, 7 | torch::Tensor features, torch::Tensor rois, torch::Tensor output); 8 | 9 | int rod_align_backward(int aligned_height, int aligned_width, float spatial_scale, 10 | torch::Tensor top_grad, torch::Tensor rois, torch::Tensor bottom_grad); 11 | 12 | #endif 13 | -------------------------------------------------------------------------------- /rod_align/src/rod_align_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "rod_align_kernel.h" 4 | #include "rod_align_cuda.h" 5 | 6 | 7 | int rod_align_forward_cuda(int aligned_height, int aligned_width, float spatial_scale, 8 | torch::Tensor features, torch::Tensor rois, torch::Tensor output) 9 | { 10 | // Grab the input tensor 11 | //float * data_flat = THCudaTensor_data(state, features); 12 | //float * rois_flat = THCudaTensor_data(state, rois); 13 | 14 | //float * output_flat = THCudaTensor_data(state, output); 15 | 16 | auto data_flat = features.data(); 17 | auto rois_flat = rois.data(); 18 | auto output_flat = output.data(); 19 | 20 | // Number of ROIs 21 | //int num_rois = THCudaTensor_size(state, rois, 0); 22 | //int size_rois = THCudaTensor_size(state, rois, 1); 23 | 24 | auto rois_sz = rois.sizes(); 25 | int num_rois = rois_sz[0]; 26 | int size_rois = rois_sz[1]; 27 | 28 | if (size_rois != 5) 29 | { 30 | return 0; 31 | } 32 | 33 | // data height 34 | //int data_height = THCudaTensor_size(state, features, 2); 35 | // data width 36 | //int data_width = THCudaTensor_size(state, features, 3); 37 | // Number of channels 38 | //int num_channels = THCudaTensor_size(state, features, 1); 39 | auto feat_sz = features.sizes(); 40 | int data_height = feat_sz[2]; 41 | int data_width = feat_sz[3]; 42 | int num_channels = feat_sz[1]; 43 | 44 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 45 | 46 | RODAlignForwardLaucher( 47 | data_flat, spatial_scale, num_rois, data_height, 48 | data_width, num_channels, aligned_height, 49 | aligned_width, rois_flat, 50 | output_flat, stream); 51 | 52 | return 1; 53 | } 54 | 55 | int rod_align_backward_cuda(int aligned_height, int aligned_width, float spatial_scale, 56 | torch::Tensor top_grad, torch::Tensor rois, torch::Tensor bottom_grad) 57 | { 58 | // Grab the input tensor 59 | //float * top_grad_flat = THCudaTensor_data(state, top_grad); 60 | //float * rois_flat = THCudaTensor_data(state, rois); 61 | 62 | //float * bottom_grad_flat = THCudaTensor_data(state, bottom_grad); 63 | auto top_grad_flat = top_grad.data(); 64 | auto rois_flat = rois.data(); 65 | auto bottom_grad_flat = bottom_grad.data(); 66 | 67 | // Number of ROIs 68 | //int num_rois = THCudaTensor_size(state, rois, 0); 69 | //int size_rois = THCudaTensor_size(state, rois, 1); 70 | auto rois_sz = rois.sizes(); 71 | int num_rois = rois_sz[0]; 72 | int size_rois = rois_sz[1]; 73 | if (size_rois != 5) 74 | { 75 | return 0; 76 | } 77 | 78 | // batch size 79 | //int batch_size = THCudaTensor_size(state, bottom_grad, 0); 80 | // data height 81 | //int data_height = THCudaTensor_size(state, bottom_grad, 2); 82 | // data width 83 | //int data_width = THCudaTensor_size(state, bottom_grad, 3); 84 | // Number of channels 85 | //int num_channels = THCudaTensor_size(state, bottom_grad, 1); 86 | 87 | auto grad_sz = bottom_grad.sizes(); 88 | int batch_size = grad_sz[0]; 89 | int data_height = grad_sz[2]; 90 | int data_width = grad_sz[3]; 91 | int num_channels = grad_sz[1]; 92 | 93 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 94 | RODAlignBackwardLaucher( 95 | top_grad_flat, spatial_scale, batch_size, num_rois, data_height, 96 | data_width, num_channels, aligned_height, 97 | aligned_width, rois_flat, 98 | bottom_grad_flat, stream); 99 | 100 | return 1; 101 | } 102 | 103 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 104 | m.def("forward", &rod_align_forward_cuda, "rod_align forward"); 105 | m.def("backward", &rod_align_backward_cuda, "rod_align backward"); 106 | } 107 | -------------------------------------------------------------------------------- /rod_align/src/rod_align_cuda.h: -------------------------------------------------------------------------------- 1 | #ifndef ROD_ALIGN_CUDA_H 2 | #define ROD_ALIGN_CUDA_H 3 | 4 | #include 5 | 6 | int rod_align_forward_cuda(int aligned_height, int aligned_width, float spatial_scale, 7 | torch::Tensor features, torch::Tensor rois, torch::Tensor output); 8 | 9 | int rod_align_backward_cuda(int aligned_height, int aligned_width, float spatial_scale, 10 | torch::Tensor top_grad, torch::Tensor rois, torch::Tensor bottom_grad); 11 | 12 | #endif 13 | -------------------------------------------------------------------------------- /rod_align/src/rod_align_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include "rod_align_kernel.h" 3 | 4 | #define CUDA_1D_KERNEL_LOOP(i, n) \ 5 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \ 6 | i += blockDim.x * gridDim.x) 7 | 8 | 9 | __global__ void RODAlignForward(const int nthreads, const float* bottom_data, const float spatial_scale, const int height, const int width, 10 | const int channels, const int aligned_height, const int aligned_width, const float* bottom_rois, float* top_data) { 11 | float bin_size_h = (float)(height - 1.001) / (aligned_height - 1.); 12 | float bin_size_w = (float)(width - 1.001) / (aligned_width - 1.); 13 | CUDA_1D_KERNEL_LOOP(index, nthreads) { 14 | // (n, c, ph, pw) is an element in the aligned output 15 | // int n = index; 16 | // int pw = n % aligned_width; 17 | // n /= aligned_width; 18 | // int ph = n % aligned_height; 19 | // n /= aligned_height; 20 | // int c = n % channels; 21 | // n /= channels; 22 | 23 | int pw = index % aligned_width; 24 | int ph = (index / aligned_width) % aligned_height; 25 | int c = (index / aligned_width / aligned_height) % channels; 26 | int n = index / aligned_width / aligned_height / channels; 27 | 28 | // bottom_rois += n * 5; 29 | float roi_batch_ind = bottom_rois[n * 5 + 0]; 30 | float roi_start_w = bottom_rois[n * 5 + 1] * spatial_scale; 31 | float roi_start_h = bottom_rois[n * 5 + 2] * spatial_scale; 32 | float roi_end_w = bottom_rois[n * 5 + 3] * spatial_scale; 33 | float roi_end_h = bottom_rois[n * 5 + 4] * spatial_scale; 34 | 35 | 36 | float h = (float)(ph) * bin_size_h; 37 | float w = (float)(pw) * bin_size_w; 38 | 39 | int hstart = fminf(floor(h), height - 2); 40 | int wstart = fminf(floor(w), width - 2); 41 | 42 | int img_start = roi_batch_ind * channels * height * width; 43 | 44 | // bilinear interpolation 45 | if (h >= roi_start_h && h <= roi_end_h && w >= roi_start_w && w <= roi_end_w){ 46 | top_data[index] = 0.; 47 | } else { 48 | float h_ratio = h - (float)(hstart); 49 | float w_ratio = w - (float)(wstart); 50 | int upleft = img_start + (c * height + hstart) * width + wstart; 51 | int upright = upleft + 1; 52 | int downleft = upleft + width; 53 | int downright = downleft + 1; 54 | 55 | top_data[index] = bottom_data[upleft] * (1. - h_ratio) * (1. - w_ratio) 56 | + bottom_data[upright] * (1. - h_ratio) * w_ratio 57 | + bottom_data[downleft] * h_ratio * (1. - w_ratio) 58 | + bottom_data[downright] * h_ratio * w_ratio; 59 | } 60 | } 61 | } 62 | 63 | 64 | int RODAlignForwardLaucher(const float* bottom_data, const float spatial_scale, const int num_rois, const int height, const int width, 65 | const int channels, const int aligned_height, const int aligned_width, const float* bottom_rois, float* top_data, cudaStream_t stream) { 66 | const int kThreadsPerBlock = 1024; 67 | const int output_size = num_rois * aligned_height * aligned_width * channels; 68 | cudaError_t err; 69 | 70 | 71 | RODAlignForward<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>( 72 | output_size, bottom_data, spatial_scale, height, width, channels, 73 | aligned_height, aligned_width, bottom_rois, top_data); 74 | 75 | err = cudaGetLastError(); 76 | if(cudaSuccess != err) { 77 | fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) ); 78 | exit( -1 ); 79 | } 80 | 81 | return 1; 82 | } 83 | 84 | 85 | __global__ void RODAlignBackward(const int nthreads, const float* top_diff, const float spatial_scale, const int height, const int width, 86 | const int channels, const int aligned_height, const int aligned_width, float* bottom_diff, const float* bottom_rois) { 87 | float bin_size_h = (float)(height - 1.001) / (aligned_height - 1.); 88 | float bin_size_w = (float)(width - 1.001) / (aligned_width - 1.); 89 | CUDA_1D_KERNEL_LOOP(index, nthreads) { 90 | 91 | // (n, c, ph, pw) is an element in the aligned output 92 | int pw = index % aligned_width; 93 | int ph = (index / aligned_width) % aligned_height; 94 | int c = (index / aligned_width / aligned_height) % channels; 95 | int n = index / aligned_width / aligned_height / channels; 96 | 97 | float roi_batch_ind = bottom_rois[n * 5 + 0]; 98 | float roi_start_w = bottom_rois[n * 5 + 1] * spatial_scale; 99 | float roi_start_h = bottom_rois[n * 5 + 2] * spatial_scale; 100 | float roi_end_w = bottom_rois[n * 5 + 3] * spatial_scale; 101 | float roi_end_h = bottom_rois[n * 5 + 4] * spatial_scale; 102 | 103 | 104 | float h = (float)(ph) * bin_size_h; 105 | float w = (float)(pw) * bin_size_w; 106 | 107 | int hstart = fminf(floor(h), height - 2); 108 | int wstart = fminf(floor(w), width - 2); 109 | 110 | int img_start = roi_batch_ind * channels * height * width; 111 | 112 | // bilinear interpolation 113 | if (!(h >= roi_start_h && h <= roi_end_h && w >= roi_start_w && w <= roi_end_w)) { 114 | float h_ratio = h - (float)(hstart); 115 | float w_ratio = w - (float)(wstart); 116 | int upleft = img_start + (c * height + hstart) * width + wstart; 117 | int upright = upleft + 1; 118 | int downleft = upleft + width; 119 | int downright = downleft + 1; 120 | 121 | atomicAdd(bottom_diff + upleft, top_diff[index] * (1. - h_ratio) * (1 - w_ratio)); 122 | atomicAdd(bottom_diff + upright, top_diff[index] * (1. - h_ratio) * w_ratio); 123 | atomicAdd(bottom_diff + downleft, top_diff[index] * h_ratio * (1 - w_ratio)); 124 | atomicAdd(bottom_diff + downright, top_diff[index] * h_ratio * w_ratio); 125 | } 126 | } 127 | } 128 | 129 | int RODAlignBackwardLaucher(const float* top_diff, const float spatial_scale, const int batch_size, const int num_rois, const int height, const int width, 130 | const int channels, const int aligned_height, const int aligned_width, const float* bottom_rois, float* bottom_diff, cudaStream_t stream) { 131 | const int kThreadsPerBlock = 1024; 132 | const int output_size = num_rois * aligned_height * aligned_width * channels; 133 | cudaError_t err; 134 | 135 | RODAlignBackward<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>( 136 | output_size, top_diff, spatial_scale, height, width, channels, 137 | aligned_height, aligned_width, bottom_diff, bottom_rois); 138 | 139 | err = cudaGetLastError(); 140 | if(cudaSuccess != err) { 141 | fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) ); 142 | exit( -1 ); 143 | } 144 | 145 | return 1; 146 | } 147 | -------------------------------------------------------------------------------- /rod_align/src/rod_align_kernel.h: -------------------------------------------------------------------------------- 1 | #ifndef _ROD_ALIGN_KERNEL 2 | #define _ROD_ALIGN_KERNEL 3 | 4 | #include 5 | 6 | __global__ void RODAlignForward(const int nthreads, const float* bottom_data, 7 | const float spatial_scale, const int height, const int width, 8 | const int channels, const int aligned_height, const int aligned_width, 9 | const float* bottom_rois, float* top_data); 10 | 11 | int RODAlignForwardLaucher( 12 | const float* bottom_data, const float spatial_scale, const int num_rois, const int height, 13 | const int width, const int channels, const int aligned_height, 14 | const int aligned_width, const float* bottom_rois, 15 | float* top_data, cudaStream_t stream); 16 | 17 | __global__ void RODAlignBackward(const int nthreads, const float* top_diff, 18 | const float spatial_scale, const int height, const int width, 19 | const int channels, const int aligned_height, const int aligned_width, 20 | float* bottom_diff, const float* bottom_rois); 21 | 22 | int RODAlignBackwardLaucher(const float* top_diff, const float spatial_scale, const int batch_size, const int num_rois, 23 | const int height, const int width, const int channels, const int aligned_height, 24 | const int aligned_width, const float* bottom_rois, 25 | float* bottom_diff, cudaStream_t stream); 26 | 27 | #endif 28 | 29 | -------------------------------------------------------------------------------- /roi_align/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bo-zhang-cs/CGS-Pytorch/96a7db46ef7e671619a0dcc4746a36dbb4e897c6/roi_align/__init__.py -------------------------------------------------------------------------------- /roi_align/functions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bo-zhang-cs/CGS-Pytorch/96a7db46ef7e671619a0dcc4746a36dbb4e897c6/roi_align/functions/__init__.py -------------------------------------------------------------------------------- /roi_align/functions/roi_align.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | import roi_align_api 4 | 5 | class RoIAlignFunction(Function): 6 | @staticmethod 7 | def forward(ctx, features, rois, aligned_height, aligned_width, spatial_scale): 8 | batch_size, num_channels, data_height, data_width = features.size() 9 | ctx.save_for_backward(rois, 10 | torch.IntTensor([int(batch_size), 11 | int(num_channels), 12 | int(data_height), 13 | int(data_width), 14 | int(aligned_height), 15 | int(aligned_width)]), 16 | torch.FloatTensor([float(spatial_scale)])) 17 | 18 | num_rois = rois.size(0) 19 | 20 | output = features.new(num_rois, 21 | num_channels, 22 | int(aligned_height), 23 | int(aligned_width)).zero_() 24 | 25 | roi_align_api.forward(int(aligned_height), 26 | int(aligned_width), 27 | float(spatial_scale), 28 | features, 29 | rois, 30 | output) 31 | return output 32 | 33 | @staticmethod 34 | def backward(ctx, grad_output): 35 | rois, core_size, scale = ctx.saved_tensors 36 | 37 | batch_size, num_channels, data_height, data_width, aligned_height, aligned_width = core_size 38 | spatial_scale = scale[0] 39 | 40 | grad_input = rois.new(batch_size, 41 | num_channels, 42 | data_height, 43 | data_width).zero_() 44 | 45 | roi_align_api.backward(int(aligned_height), 46 | int(aligned_width), 47 | float(spatial_scale), 48 | grad_output, 49 | rois, 50 | grad_input) 51 | 52 | return grad_input, None, None, None, None 53 | -------------------------------------------------------------------------------- /roi_align/make.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | cd src 3 | echo "Compiling roi_align kernels by nvcc..." 4 | 5 | # Specify the architecture of your NV card below. 6 | # -arch=sm_75 is compatible with the following NV GPU cards, 7 | # GeForce RTX 2080 Ti, RTX 2080, RTX 2070 Quadro RTX 8000, Quadro RTX 6000, Quadro RTX 5000 Tesla T4 8 | # See more https://raw.githubusercontent.com/stereolabs/zed-yolo/master/libdarknet/Makefile 9 | nvcc -c -o roi_align_kernel.cu.o roi_align_kernel.cu -x cu -Xcompiler -fPIC -arch=sm_86 10 | 11 | cd ../ 12 | # Export CUDA_HOME. Build and install the library. 13 | export CUDA_HOME=/usr/local/cuda && python3 setup.py install 14 | -------------------------------------------------------------------------------- /roi_align/make_python2.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | cd src 3 | echo "Compiling roi_align kernels by nvcc..." 4 | 5 | # Specify the architecture of your NV card below. 6 | # -arch=sm_75 is compatible with the following NV GPU cards, 7 | # GeForce RTX 2080 Ti, RTX 2080, RTX 2070 Quadro RTX 8000, Quadro RTX 6000, Quadro RTX 5000 Tesla T4 8 | # See more https://raw.githubusercontent.com/stereolabs/zed-yolo/master/libdarknet/Makefile 9 | nvcc -c -o roi_align_kernel.cu.o roi_align_kernel.cu -x cu -Xcompiler -fPIC -arch=sm_75 10 | 11 | cd ../ 12 | # Export CUDA_HOME. Build and install the library. 13 | export CUDA_HOME=/usr/local/cuda-10.0 && python setup.py install 14 | -------------------------------------------------------------------------------- /roi_align/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bo-zhang-cs/CGS-Pytorch/96a7db46ef7e671619a0dcc4746a36dbb4e897c6/roi_align/modules/__init__.py -------------------------------------------------------------------------------- /roi_align/modules/roi_align.py: -------------------------------------------------------------------------------- 1 | from torch.nn.modules.module import Module 2 | from torch.nn.functional import avg_pool2d, max_pool2d 3 | from ..functions.roi_align import RoIAlignFunction 4 | 5 | 6 | class RoIAlign(Module): 7 | def __init__(self, aligned_height, aligned_width, spatial_scale): 8 | super(RoIAlign, self).__init__() 9 | 10 | self.aligned_width = int(aligned_width) 11 | self.aligned_height = int(aligned_height) 12 | self.spatial_scale = float(spatial_scale) 13 | 14 | def forward(self, features, rois): 15 | return RoIAlignFunction.apply(features, 16 | rois, 17 | self.aligned_height, 18 | self.aligned_width, 19 | self.spatial_scale) 20 | 21 | class RoIAlignAvg(Module): 22 | def __init__(self, aligned_height, aligned_width, spatial_scale): 23 | super(RoIAlignAvg, self).__init__() 24 | 25 | self.aligned_width = int(aligned_width) 26 | self.aligned_height = int(aligned_height) 27 | self.spatial_scale = float(spatial_scale) 28 | 29 | def forward(self, features, rois): 30 | x = RoIAlignFunction.apply(features, 31 | rois, 32 | self.aligned_height+1, 33 | self.aligned_width+1, 34 | self.spatial_scale) 35 | return avg_pool2d(x, kernel_size=2, stride=1) 36 | 37 | class RoIAlignMax(Module): 38 | def __init__(self, aligned_height, aligned_width, spatial_scale): 39 | super(RoIAlignMax, self).__init__() 40 | 41 | self.aligned_width = int(aligned_width) 42 | self.aligned_height = int(aligned_height) 43 | self.spatial_scale = float(spatial_scale) 44 | 45 | def forward(self, features, rois): 46 | x = RoIAlignFunction.apply(features, 47 | rois, 48 | self.aligned_height+1, 49 | self.aligned_width+1, 50 | self.spatial_scale) 51 | return max_pool2d(x, kernel_size=2, stride=1) 52 | -------------------------------------------------------------------------------- /roi_align/setup.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import torch 4 | from pkg_resources import parse_version 5 | 6 | min_version = parse_version('1.0.0') 7 | current_version = parse_version(torch.__version__) 8 | 9 | 10 | if current_version < min_version: #PyTorch before 1.0 11 | from torch.utils.ffi import create_extension 12 | 13 | sources = ['src/roi_align.c'] 14 | headers = ['src/roi_align.h'] 15 | extra_objects = [] 16 | #sources = [] 17 | #headers = [] 18 | defines = [] 19 | with_cuda = False 20 | 21 | this_file = os.path.dirname(os.path.realpath(__file__)) 22 | print(this_file) 23 | 24 | if torch.cuda.is_available(): 25 | print('Including CUDA code.') 26 | sources += ['src/roi_align_cuda.c'] 27 | headers += ['src/roi_align_cuda.h'] 28 | defines += [('WITH_CUDA', None)] 29 | with_cuda = True 30 | 31 | extra_objects = ['src/roi_align_kernel.cu.o'] 32 | extra_objects = [os.path.join(this_file, fname) for fname in extra_objects] 33 | 34 | ffi = create_extension( 35 | '_ext.roi_align', 36 | headers=headers, 37 | sources=sources, 38 | define_macros=defines, 39 | relative_to=__file__, 40 | with_cuda=with_cuda, 41 | extra_objects=extra_objects 42 | ) 43 | 44 | if __name__ == '__main__': 45 | ffi.build() 46 | else: # PyTorch 1.0 or later 47 | from setuptools import setup 48 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 49 | 50 | print('Including CUDA code.') 51 | current_dir = os.path.dirname(os.path.realpath(__file__)) 52 | #cuda_include = '/usr/local/cuda-10.0/include' 53 | 54 | #GPU version 55 | setup( 56 | name='roi_align_api', 57 | ext_modules=[ 58 | CUDAExtension( 59 | name='roi_align_api', 60 | sources=['src/roi_align_cuda.cpp', 'src/roi_align_kernel.cu'], 61 | include_dirs=[current_dir]+torch.utils.cpp_extension.include_paths(cuda=True) 62 | ) 63 | ], 64 | cmdclass={ 65 | 'build_ext': BuildExtension 66 | }) 67 | -------------------------------------------------------------------------------- /roi_align/src/roi_align.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include "roi_align.h" 5 | 6 | void ROIAlignForwardCpu(const float* bottom_data, const float spatial_scale, const int num_rois, 7 | const int height, const int width, const int channels, 8 | const int aligned_height, const int aligned_width, const float * bottom_rois, 9 | float* top_data); 10 | 11 | void ROIAlignBackwardCpu(const float* top_diff, const float spatial_scale, const int num_rois, 12 | const int height, const int width, const int channels, 13 | const int aligned_height, const int aligned_width, const float * bottom_rois, 14 | float* top_data); 15 | 16 | int roi_align_forward(int aligned_height, int aligned_width, float spatial_scale, 17 | torch::Tensor features, torch::Tensor rois, torch::Tensor output) 18 | { 19 | //Grab the input tensor 20 | //float * data_flat = THFloatTensor_data(features); 21 | //float * rois_flat = THFloatTensor_data(rois); 22 | auto data_flat = features.data(); 23 | auto rois_flat = rois.data(); 24 | 25 | //float * output_flat = THFloatTensor_data(output); 26 | auto output_flat = output.data(); 27 | 28 | // Number of ROIs 29 | //int num_rois = THFloatTensor_size(rois, 0); 30 | //int size_rois = THFloatTensor_size(rois, 1); 31 | 32 | auto rois_sz = rois.sizes(); 33 | int num_rois = rois_sz[0]; 34 | int size_rois = rois_sz[1]; 35 | 36 | if (size_rois != 5) 37 | { 38 | return 0; 39 | } 40 | 41 | // data height 42 | //int data_height = THFloatTensor_size(features, 2); 43 | // data width 44 | //int data_width = THFloatTensor_size(features, 3); 45 | // Number of channels 46 | //int num_channels = THFloatTensor_size(features, 1); 47 | auto feat_sz = features.sizes(); 48 | int data_height = feat_sz[2]; 49 | int data_width = feat_sz[3]; 50 | int num_channels = feat_sz[1]; 51 | 52 | // do ROIAlignForward 53 | ROIAlignForwardCpu(data_flat, spatial_scale, num_rois, data_height, data_width, num_channels, 54 | aligned_height, aligned_width, rois_flat, output_flat); 55 | 56 | return 1; 57 | } 58 | 59 | int roi_align_backward(int aligned_height, int aligned_width, float spatial_scale, 60 | torch::Tensor top_grad, torch::Tensor rois, torch::Tensor bottom_grad) 61 | { 62 | //Grab the input tensor 63 | //float * top_grad_flat = THFloatTensor_data(top_grad); 64 | //float * rois_flat = THFloatTensor_data(rois); 65 | 66 | //float * bottom_grad_flat = THFloatTensor_data(bottom_grad); 67 | auto top_grad_flat = top_grad.data(); 68 | auto rois_flat = rois.data(); 69 | auto bottom_grad_flat = bottom_grad.data(); 70 | 71 | // Number of ROIs 72 | //int num_rois = THFloatTensor_size(rois, 0); 73 | //int size_rois = THFloatTensor_size(rois, 1); 74 | auto rois_sz = rois.sizes(); 75 | int num_rois = rois_sz[0]; 76 | int size_rois = rois_sz[1]; 77 | if (size_rois != 5) 78 | { 79 | return 0; 80 | } 81 | 82 | // batch size 83 | // int batch_size = THFloatTensor_size(bottom_grad, 0); 84 | // data height 85 | //int data_height = THFloatTensor_size(bottom_grad, 2); 86 | // data width 87 | //int data_width = THFloatTensor_size(bottom_grad, 3); 88 | // Number of channels 89 | //int num_channels = THFloatTensor_size(bottom_grad, 1); 90 | 91 | auto grad_sz = bottom_grad.sizes(); 92 | int data_height = grad_sz[2]; 93 | int data_width = grad_sz[3]; 94 | int num_channels = grad_sz[1]; 95 | 96 | // do ROIAlignBackward 97 | ROIAlignBackwardCpu(top_grad_flat, spatial_scale, num_rois, data_height, 98 | data_width, num_channels, aligned_height, aligned_width, rois_flat, bottom_grad_flat); 99 | 100 | return 1; 101 | } 102 | 103 | void ROIAlignForwardCpu(const float* bottom_data, const float spatial_scale, const int num_rois, 104 | const int height, const int width, const int channels, 105 | const int aligned_height, const int aligned_width, const float * bottom_rois, 106 | float* top_data) 107 | { 108 | const int output_size = num_rois * aligned_height * aligned_width * channels; 109 | 110 | int idx = 0; 111 | for (idx = 0; idx < output_size; ++idx) 112 | { 113 | // (n, c, ph, pw) is an element in the aligned output 114 | int pw = idx % aligned_width; 115 | int ph = (idx / aligned_width) % aligned_height; 116 | int c = (idx / aligned_width / aligned_height) % channels; 117 | int n = idx / aligned_width / aligned_height / channels; 118 | 119 | float roi_batch_ind = bottom_rois[n * 5 + 0]; 120 | float roi_start_w = bottom_rois[n * 5 + 1] * spatial_scale; 121 | float roi_start_h = bottom_rois[n * 5 + 2] * spatial_scale; 122 | float roi_end_w = bottom_rois[n * 5 + 3] * spatial_scale; 123 | float roi_end_h = bottom_rois[n * 5 + 4] * spatial_scale; 124 | 125 | // Force malformed ROI to be 1x1 126 | float roi_width = fmaxf(roi_end_w - roi_start_w + 1., 0.); 127 | float roi_height = fmaxf(roi_end_h - roi_start_h + 1., 0.); 128 | float bin_size_h = roi_height / (aligned_height - 1.); 129 | float bin_size_w = roi_width / (aligned_width - 1.); 130 | 131 | float h = (float)(ph) * bin_size_h + roi_start_h; 132 | float w = (float)(pw) * bin_size_w + roi_start_w; 133 | 134 | int hstart = fminf(floor(h), height - 2); 135 | int wstart = fminf(floor(w), width - 2); 136 | 137 | int img_start = roi_batch_ind * channels * height * width; 138 | 139 | // bilinear interpolation 140 | if (h < 0 || h >= height || w < 0 || w >= width) 141 | { 142 | top_data[idx] = 0.; 143 | } 144 | else 145 | { 146 | float h_ratio = h - (float)(hstart); 147 | float w_ratio = w - (float)(wstart); 148 | int upleft = img_start + (c * height + hstart) * width + wstart; 149 | int upright = upleft + 1; 150 | int downleft = upleft + width; 151 | int downright = downleft + 1; 152 | 153 | top_data[idx] = bottom_data[upleft] * (1. - h_ratio) * (1. - w_ratio) 154 | + bottom_data[upright] * (1. - h_ratio) * w_ratio 155 | + bottom_data[downleft] * h_ratio * (1. - w_ratio) 156 | + bottom_data[downright] * h_ratio * w_ratio; 157 | } 158 | } 159 | } 160 | 161 | void ROIAlignBackwardCpu(const float* top_diff, const float spatial_scale, const int num_rois, 162 | const int height, const int width, const int channels, 163 | const int aligned_height, const int aligned_width, const float * bottom_rois, 164 | float* bottom_diff) 165 | { 166 | const int output_size = num_rois * aligned_height * aligned_width * channels; 167 | 168 | int idx = 0; 169 | for (idx = 0; idx < output_size; ++idx) 170 | { 171 | // (n, c, ph, pw) is an element in the aligned output 172 | int pw = idx % aligned_width; 173 | int ph = (idx / aligned_width) % aligned_height; 174 | int c = (idx / aligned_width / aligned_height) % channels; 175 | int n = idx / aligned_width / aligned_height / channels; 176 | 177 | float roi_batch_ind = bottom_rois[n * 5 + 0]; 178 | float roi_start_w = bottom_rois[n * 5 + 1] * spatial_scale; 179 | float roi_start_h = bottom_rois[n * 5 + 2] * spatial_scale; 180 | float roi_end_w = bottom_rois[n * 5 + 3] * spatial_scale; 181 | float roi_end_h = bottom_rois[n * 5 + 4] * spatial_scale; 182 | 183 | // Force malformed ROI to be 1x1 184 | float roi_width = fmaxf(roi_end_w - roi_start_w + 1., 0.); 185 | float roi_height = fmaxf(roi_end_h - roi_start_h + 1., 0.); 186 | float bin_size_h = roi_height / (aligned_height - 1.); 187 | float bin_size_w = roi_width / (aligned_width - 1.); 188 | 189 | float h = (float)(ph) * bin_size_h + roi_start_h; 190 | float w = (float)(pw) * bin_size_w + roi_start_w; 191 | 192 | int hstart = fminf(floor(h), height - 2); 193 | int wstart = fminf(floor(w), width - 2); 194 | 195 | int img_start = roi_batch_ind * channels * height * width; 196 | 197 | // bilinear interpolation 198 | if (h < 0 || h >= height || w < 0 || w >= width) 199 | { 200 | float h_ratio = h - (float)(hstart); 201 | float w_ratio = w - (float)(wstart); 202 | int upleft = img_start + (c * height + hstart) * width + wstart; 203 | int upright = upleft + 1; 204 | int downleft = upleft + width; 205 | int downright = downleft + 1; 206 | 207 | bottom_diff[upleft] += top_diff[idx] * (1. - h_ratio) * (1. - w_ratio); 208 | bottom_diff[upright] += top_diff[idx] * (1. - h_ratio) * w_ratio; 209 | bottom_diff[downleft] += top_diff[idx] * h_ratio * (1. - w_ratio); 210 | bottom_diff[downright] += top_diff[idx] * h_ratio * w_ratio; 211 | } 212 | } 213 | } 214 | 215 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 216 | m.def("forward", &roi_align_forward, "roi_align forward"); 217 | m.def("backward", &roi_align_backward, "roi_align backward"); 218 | } 219 | -------------------------------------------------------------------------------- /roi_align/src/roi_align.h: -------------------------------------------------------------------------------- 1 | #ifndef ROI_ALIGN_H 2 | #define ROI_ALIGN_H 3 | 4 | #include 5 | 6 | int roi_align_forward(int aligned_height, int aligned_width, float spatial_scale, 7 | torch::Tensor features, torch::Tensor rois, torch::Tensor output); 8 | 9 | int roi_align_backward(int aligned_height, int aligned_width, float spatial_scale, 10 | torch::Tensor top_grad, torch::Tensor rois, torch::Tensor bottom_grad); 11 | 12 | #endif 13 | -------------------------------------------------------------------------------- /roi_align/src/roi_align_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include "roi_align_kernel.h" 5 | 6 | 7 | int roi_align_forward_cuda(int aligned_height, int aligned_width, float spatial_scale, 8 | torch::Tensor features, torch::Tensor rois, torch::Tensor output) 9 | { 10 | // Grab the input tensor 11 | //float * data_flat = THCudaTensor_data(state, features); 12 | //float * rois_flat = THCudaTensor_data(state, rois); 13 | 14 | //float * output_flat = THCudaTensor_data(state, output); 15 | 16 | auto data_flat = features.data(); 17 | auto rois_flat = rois.data(); 18 | auto output_flat = output.data(); 19 | 20 | // Number of ROIs 21 | //int num_rois = THCudaTensor_size(state, rois, 0); 22 | //int size_rois = THCudaTensor_size(state, rois, 1); 23 | auto rois_sz = rois.sizes(); 24 | int num_rois = rois_sz[0]; 25 | int size_rois = rois_sz[1]; 26 | if (size_rois != 5) 27 | { 28 | return 0; 29 | } 30 | 31 | // data height 32 | //int data_height = THCudaTensor_size(state, features, 2); 33 | // data width 34 | //int data_width = THCudaTensor_size(state, features, 3); 35 | // Number of channels 36 | //int num_channels = THCudaTensor_size(state, features, 1); 37 | auto feat_sz = features.sizes(); 38 | int data_height = feat_sz[2]; 39 | int data_width = feat_sz[3]; 40 | int num_channels = feat_sz[1]; 41 | 42 | 43 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 44 | 45 | ROIAlignForwardLaucher( 46 | data_flat, spatial_scale, num_rois, data_height, 47 | data_width, num_channels, aligned_height, 48 | aligned_width, rois_flat, 49 | output_flat, stream); 50 | 51 | return 1; 52 | } 53 | 54 | int roi_align_backward_cuda(int aligned_height, int aligned_width, float spatial_scale, 55 | torch::Tensor top_grad, torch::Tensor rois, torch::Tensor bottom_grad) 56 | { 57 | // Grab the input tensor 58 | //float * top_grad_flat = THCudaTensor_data(state, top_grad); 59 | //float * rois_flat = THCudaTensor_data(state, rois); 60 | 61 | //float * bottom_grad_flat = THCudaTensor_data(state, bottom_grad); 62 | auto top_grad_flat = top_grad.data(); 63 | auto rois_flat = rois.data(); 64 | auto bottom_grad_flat = bottom_grad.data(); 65 | 66 | // Number of ROIs 67 | //int num_rois = THCudaTensor_size(state, rois, 0); 68 | //int size_rois = THCudaTensor_size(state, rois, 1); 69 | auto rois_sz = rois.sizes(); 70 | int num_rois = rois_sz[0]; 71 | int size_rois = rois_sz[1]; 72 | 73 | if (size_rois != 5) 74 | { 75 | return 0; 76 | } 77 | 78 | // batch size 79 | //int batch_size = THCudaTensor_size(state, bottom_grad, 0); 80 | // data height 81 | //int data_height = THCudaTensor_size(state, bottom_grad, 2); 82 | // data width 83 | //int data_width = THCudaTensor_size(state, bottom_grad, 3); 84 | // Number of channels 85 | //int num_channels = THCudaTensor_size(state, bottom_grad, 1); 86 | auto grad_sz = bottom_grad.sizes(); 87 | int batch_size = grad_sz[0]; 88 | int data_height = grad_sz[2]; 89 | int data_width = grad_sz[3]; 90 | int num_channels = grad_sz[1]; 91 | 92 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 93 | ROIAlignBackwardLaucher( 94 | top_grad_flat, spatial_scale, batch_size, num_rois, data_height, 95 | data_width, num_channels, aligned_height, 96 | aligned_width, rois_flat, 97 | bottom_grad_flat, stream); 98 | 99 | return 1; 100 | } 101 | 102 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 103 | m.def("forward", &roi_align_forward_cuda, "roi_align forward"); 104 | m.def("backward", &roi_align_backward_cuda, "roi_align backward"); 105 | } 106 | -------------------------------------------------------------------------------- /roi_align/src/roi_align_cuda.h: -------------------------------------------------------------------------------- 1 | #ifndef ROI_ALIGN_CUDA_H 2 | #define ROI_ALIGN_CUDA_H 3 | 4 | #include 5 | 6 | int roi_align_forward_cuda(int aligned_height, int aligned_width, float spatial_scale, 7 | torch::Tensor features, torch::Tensor rois, torch::Tensor output); 8 | 9 | int roi_align_backward_cuda(int aligned_height, int aligned_width, float spatial_scale, 10 | torch::Tensor top_grad, torch::Tensor rois, torch::Tensor bottom_grad); 11 | 12 | #endif 13 | -------------------------------------------------------------------------------- /roi_align/src/roi_align_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include "roi_align_kernel.h" 5 | 6 | #define CUDA_1D_KERNEL_LOOP(i, n) \ 7 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \ 8 | i += blockDim.x * gridDim.x) 9 | 10 | 11 | __global__ void ROIAlignForward(const int nthreads, const float* bottom_data, const float spatial_scale, const int height, const int width, 12 | const int channels, const int aligned_height, const int aligned_width, const float* bottom_rois, float* top_data) { 13 | CUDA_1D_KERNEL_LOOP(index, nthreads) { 14 | // (n, c, ph, pw) is an element in the aligned output 15 | // int n = index; 16 | // int pw = n % aligned_width; 17 | // n /= aligned_width; 18 | // int ph = n % aligned_height; 19 | // n /= aligned_height; 20 | // int c = n % channels; 21 | // n /= channels; 22 | 23 | int pw = index % aligned_width; 24 | int ph = (index / aligned_width) % aligned_height; 25 | int c = (index / aligned_width / aligned_height) % channels; 26 | int n = index / aligned_width / aligned_height / channels; 27 | 28 | // bottom_rois += n * 5; 29 | float roi_batch_ind = bottom_rois[n * 5 + 0]; 30 | float roi_start_w = bottom_rois[n * 5 + 1] * spatial_scale; 31 | float roi_start_h = bottom_rois[n * 5 + 2] * spatial_scale; 32 | float roi_end_w = bottom_rois[n * 5 + 3] * spatial_scale; 33 | float roi_end_h = bottom_rois[n * 5 + 4] * spatial_scale; 34 | 35 | // Force malformed ROIs to be 1x1 36 | float roi_width = fmaxf(roi_end_w - roi_start_w + 1., 0.); 37 | float roi_height = fmaxf(roi_end_h - roi_start_h + 1., 0.); 38 | float bin_size_h = roi_height / (aligned_height - 1.); 39 | float bin_size_w = roi_width / (aligned_width - 1.); 40 | 41 | float h = (float)(ph) * bin_size_h + roi_start_h; 42 | float w = (float)(pw) * bin_size_w + roi_start_w; 43 | 44 | int hstart = fminf(floor(h), height - 2); 45 | int wstart = fminf(floor(w), width - 2); 46 | 47 | int img_start = roi_batch_ind * channels * height * width; 48 | 49 | // bilinear interpolation 50 | if (h < 0 || h >= height || w < 0 || w >= width) { 51 | top_data[index] = 0.; 52 | } else { 53 | float h_ratio = h - (float)(hstart); 54 | float w_ratio = w - (float)(wstart); 55 | int upleft = img_start + (c * height + hstart) * width + wstart; 56 | int upright = upleft + 1; 57 | int downleft = upleft + width; 58 | int downright = downleft + 1; 59 | 60 | top_data[index] = bottom_data[upleft] * (1. - h_ratio) * (1. - w_ratio) 61 | + bottom_data[upright] * (1. - h_ratio) * w_ratio 62 | + bottom_data[downleft] * h_ratio * (1. - w_ratio) 63 | + bottom_data[downright] * h_ratio * w_ratio; 64 | } 65 | } 66 | } 67 | 68 | 69 | int ROIAlignForwardLaucher(const float* bottom_data, const float spatial_scale, const int num_rois, const int height, const int width, 70 | const int channels, const int aligned_height, const int aligned_width, const float* bottom_rois, float* top_data, cudaStream_t stream) { 71 | const int kThreadsPerBlock = 1024; 72 | const int output_size = num_rois * aligned_height * aligned_width * channels; 73 | cudaError_t err; 74 | 75 | 76 | ROIAlignForward<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>( 77 | output_size, bottom_data, spatial_scale, height, width, channels, 78 | aligned_height, aligned_width, bottom_rois, top_data); 79 | 80 | err = cudaGetLastError(); 81 | if(cudaSuccess != err) { 82 | fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) ); 83 | exit( -1 ); 84 | } 85 | 86 | return 1; 87 | } 88 | 89 | 90 | __global__ void ROIAlignBackward(const int nthreads, const float* top_diff, const float spatial_scale, const int height, const int width, 91 | const int channels, const int aligned_height, const int aligned_width, float* bottom_diff, const float* bottom_rois) { 92 | CUDA_1D_KERNEL_LOOP(index, nthreads) { 93 | 94 | // (n, c, ph, pw) is an element in the aligned output 95 | int pw = index % aligned_width; 96 | int ph = (index / aligned_width) % aligned_height; 97 | int c = (index / aligned_width / aligned_height) % channels; 98 | int n = index / aligned_width / aligned_height / channels; 99 | 100 | float roi_batch_ind = bottom_rois[n * 5 + 0]; 101 | float roi_start_w = bottom_rois[n * 5 + 1] * spatial_scale; 102 | float roi_start_h = bottom_rois[n * 5 + 2] * spatial_scale; 103 | float roi_end_w = bottom_rois[n * 5 + 3] * spatial_scale; 104 | float roi_end_h = bottom_rois[n * 5 + 4] * spatial_scale; 105 | /* int roi_start_w = round(bottom_rois[1] * spatial_scale); */ 106 | /* int roi_start_h = round(bottom_rois[2] * spatial_scale); */ 107 | /* int roi_end_w = round(bottom_rois[3] * spatial_scale); */ 108 | /* int roi_end_h = round(bottom_rois[4] * spatial_scale); */ 109 | 110 | // Force malformed ROIs to be 1x1 111 | float roi_width = fmaxf(roi_end_w - roi_start_w + 1., 0.); 112 | float roi_height = fmaxf(roi_end_h - roi_start_h + 1., 0.); 113 | float bin_size_h = roi_height / (aligned_height - 1.); 114 | float bin_size_w = roi_width / (aligned_width - 1.); 115 | 116 | float h = (float)(ph) * bin_size_h + roi_start_h; 117 | float w = (float)(pw) * bin_size_w + roi_start_w; 118 | 119 | int hstart = fminf(floor(h), height - 2); 120 | int wstart = fminf(floor(w), width - 2); 121 | 122 | int img_start = roi_batch_ind * channels * height * width; 123 | 124 | // bilinear interpolation 125 | if (!(h < 0 || h >= height || w < 0 || w >= width)) { 126 | float h_ratio = h - (float)(hstart); 127 | float w_ratio = w - (float)(wstart); 128 | int upleft = img_start + (c * height + hstart) * width + wstart; 129 | int upright = upleft + 1; 130 | int downleft = upleft + width; 131 | int downright = downleft + 1; 132 | 133 | atomicAdd(bottom_diff + upleft, top_diff[index] * (1. - h_ratio) * (1 - w_ratio)); 134 | atomicAdd(bottom_diff + upright, top_diff[index] * (1. - h_ratio) * w_ratio); 135 | atomicAdd(bottom_diff + downleft, top_diff[index] * h_ratio * (1 - w_ratio)); 136 | atomicAdd(bottom_diff + downright, top_diff[index] * h_ratio * w_ratio); 137 | } 138 | } 139 | } 140 | 141 | int ROIAlignBackwardLaucher(const float* top_diff, const float spatial_scale, const int batch_size, const int num_rois, const int height, const int width, 142 | const int channels, const int aligned_height, const int aligned_width, const float* bottom_rois, float* bottom_diff, cudaStream_t stream) { 143 | const int kThreadsPerBlock = 1024; 144 | const int output_size = num_rois * aligned_height * aligned_width * channels; 145 | cudaError_t err; 146 | 147 | ROIAlignBackward<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>( 148 | output_size, top_diff, spatial_scale, height, width, channels, 149 | aligned_height, aligned_width, bottom_diff, bottom_rois); 150 | 151 | err = cudaGetLastError(); 152 | if(cudaSuccess != err) { 153 | fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) ); 154 | exit( -1 ); 155 | } 156 | 157 | return 1; 158 | } 159 | -------------------------------------------------------------------------------- /roi_align/src/roi_align_kernel.h: -------------------------------------------------------------------------------- 1 | #ifndef _ROI_ALIGN_KERNEL 2 | #define _ROI_ALIGN_KERNEL 3 | 4 | 5 | __global__ void ROIAlignForward(const int nthreads, const float* bottom_data, 6 | const float spatial_scale, const int height, const int width, 7 | const int channels, const int aligned_height, const int aligned_width, 8 | const float* bottom_rois, float* top_data); 9 | 10 | int ROIAlignForwardLaucher( 11 | const float* bottom_data, const float spatial_scale, const int num_rois, const int height, 12 | const int width, const int channels, const int aligned_height, 13 | const int aligned_width, const float* bottom_rois, 14 | float* top_data, cudaStream_t stream); 15 | 16 | __global__ void ROIAlignBackward(const int nthreads, const float* top_diff, 17 | const float spatial_scale, const int height, const int width, 18 | const int channels, const int aligned_height, const int aligned_width, 19 | float* bottom_diff, const float* bottom_rois); 20 | 21 | int ROIAlignBackwardLaucher(const float* top_diff, const float spatial_scale, const int batch_size, const int num_rois, 22 | const int height, const int width, const int channels, const int aligned_height, 23 | const int aligned_width, const float* bottom_rois, 24 | float* bottom_diff, cudaStream_t stream); 25 | 26 | #endif 27 | 28 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | from tqdm import tqdm 5 | import pickle 6 | from scipy.stats import spearmanr 7 | import random 8 | import cv2 9 | import json 10 | from cropping_dataset import FLMSDataset, GAICDataset 11 | from config import cfg 12 | from croppingModel import RegionFeatureExtractor,CroppingGraph 13 | 14 | device = torch.device('cuda:{}'.format(cfg.gpu_id)) 15 | torch.cuda.set_device(cfg.gpu_id) 16 | SEED = 0 17 | random.seed(SEED) 18 | 19 | save_dir = './results' 20 | os.makedirs(save_dir, exist_ok=True) 21 | 22 | def compute_acc(gt_scores, pr_scores): 23 | assert (len(gt_scores) == len(pr_scores)), '{} vs. {}'.format(len(gt_scores), len(pr_scores)) 24 | sample_cnt = 0 25 | acc4_5 = [0 for i in range(4)] 26 | acc4_10 = [0 for i in range(4)] 27 | for i in range(len(gt_scores)): 28 | gts, preds = gt_scores[i], pr_scores[i] 29 | id_gt = sorted(range(len(gts)), key=lambda j : gts[j], reverse=True) 30 | id_pr = sorted(range(len(preds)), key=lambda j : preds[j], reverse=True) 31 | for k in range(4): 32 | temp_acc4_5 = 0. 33 | temp_acc4_10 = 0. 34 | for j in range(k+1): 35 | if gts[id_pr[j]] >= gts[id_gt[4]]: 36 | temp_acc4_5 += 1.0 37 | if gts[id_pr[j]] >= gts[id_gt[9]]: 38 | temp_acc4_10 += 1.0 39 | acc4_5[k] += (temp_acc4_5 / (k+1.0)) 40 | acc4_10[k] += ((temp_acc4_10) / (k+1.0)) 41 | sample_cnt += 1 42 | acc4_5 = [i / sample_cnt for i in acc4_5] 43 | acc4_10 = [i / sample_cnt for i in acc4_10] 44 | # print('acc4_5', acc4_5) 45 | # print('acc4_10', acc4_10) 46 | avg_acc4_5 = sum(acc4_5) / len(acc4_5) 47 | avg_acc4_10 = sum(acc4_10) / len(acc4_10) 48 | return avg_acc4_5, avg_acc4_10 49 | 50 | def compute_iou_and_disp(gt_crop, pre_crop, im_w, im_h): 51 | '''' 52 | :param gt_crop: [[x1,y1,x2,y2]] 53 | :param pre_crop: [[x1,y1,x2,y2]] 54 | :return: 55 | ''' 56 | gt_crop = gt_crop[gt_crop[:,0] >= 0] 57 | zero_t = torch.zeros(gt_crop.shape[0]) 58 | over_x1 = torch.maximum(gt_crop[:,0], pre_crop[:,0]) 59 | over_y1 = torch.maximum(gt_crop[:,1], pre_crop[:,1]) 60 | over_x2 = torch.minimum(gt_crop[:,2], pre_crop[:,2]) 61 | over_y2 = torch.minimum(gt_crop[:,3], pre_crop[:,3]) 62 | over_w = torch.maximum(zero_t, over_x2 - over_x1) 63 | over_h = torch.maximum(zero_t, over_y2 - over_y1) 64 | inter = over_w * over_h 65 | area1 = (gt_crop[:,2] - gt_crop[:,0]) * (gt_crop[:,3] - gt_crop[:,1]) 66 | area2 = (pre_crop[:,2] - pre_crop[:,0]) * (pre_crop[:,3] - pre_crop[:,1]) 67 | union = area1 + area2 - inter 68 | iou = inter / union 69 | disp = (torch.abs(gt_crop[:, 0] - pre_crop[:, 0]) + torch.abs(gt_crop[:, 2] - pre_crop[:, 2])) / im_w + \ 70 | (torch.abs(gt_crop[:, 1] - pre_crop[:, 1]) + torch.abs(gt_crop[:, 3] - pre_crop[:, 3])) / im_h 71 | iou_idx = torch.argmax(iou, dim=-1) 72 | dis_idx = torch.argmin(disp, dim=-1) 73 | index = dis_idx if (iou[iou_idx] == iou[dis_idx]) else iou_idx 74 | return iou[index].item(), disp[index].item() 75 | 76 | def evaluate_on_GAICD(extracor, gnn, save_results=False): 77 | extracor.eval() 78 | gnn.eval() 79 | print('='*5, 'Evaluating on GAICD dataset', '='*5) 80 | srcc_list = [] 81 | gt_scores = [] 82 | pr_scores = [] 83 | count = 0 84 | test_dataset = GAICDataset(split='test') 85 | test_loader = torch.utils.data.DataLoader( 86 | test_dataset, batch_size=1, 87 | shuffle=False, num_workers=cfg.num_workers, 88 | drop_last=False) 89 | if save_results: 90 | image_results = dict() 91 | result_dir = os.path.join(save_dir, 'GAICD') 92 | os.makedirs(result_dir, exist_ok=True) 93 | 94 | with torch.no_grad(): 95 | for batch_idx, batch_data in enumerate(tqdm(test_loader)): 96 | im = batch_data[0].to(device) 97 | rois = batch_data[1].to(device) 98 | scores = batch_data[2].cpu().numpy().reshape(-1) 99 | width = batch_data[3].item() 100 | height = batch_data[4].item() 101 | image_file = batch_data[5][0] 102 | image_name = os.path.basename(image_file) 103 | count += im.shape[0] 104 | 105 | region_feat = extracor(im, rois) 106 | _,pre_scores = gnn(region_feat) 107 | 108 | pre_scores = pre_scores.cpu().detach().numpy().reshape(-1) 109 | srcc_list.append(spearmanr(scores, pre_scores)[0]) 110 | gt_scores.append(scores) 111 | pr_scores.append(pre_scores) 112 | 113 | if save_results: 114 | pre_index = np.argmax(pre_scores) 115 | cand_crop = rois.squeeze().cpu().detach().numpy().reshape(-1,4) 116 | cand_crop[:, 0::2] *= (float(width) / im.shape[-1]) 117 | cand_crop[:, 1::2] *= (float(height) / im.shape[-2]) 118 | cand_crop = cand_crop.astype(np.int32) 119 | pred_crop = cand_crop[pre_index] # x1,y1,x2,y2 120 | image_results[image_name] = pred_crop.tolist() 121 | # save predicted best crop 122 | src_img = cv2.imread(image_file) 123 | crop_img = src_img[pred_crop[1] : pred_crop[3], pred_crop[0] : pred_crop[2]] 124 | result_file = os.path.join(result_dir, image_name) 125 | cv2.imwrite(result_file, crop_img) 126 | if save_results: 127 | with open(os.path.join(save_dir, 'GAICD.json'), 'w') as f: 128 | json.dump(image_results, f) 129 | 130 | srcc = sum(srcc_list) / len(srcc_list) 131 | acc5, acc10 = compute_acc(gt_scores, pr_scores) 132 | print('Test on GAICD {} images, SRCC={:.3f}, acc5={:.3f}, acc10={:.3f}'.format( 133 | count, srcc, acc5, acc10 134 | )) 135 | return srcc, acc5, acc10 136 | 137 | def get_pdefined_anchor(): 138 | # get predefined boxes(x1, y1, x2, y2) 139 | pdefined_anchors = np.array(pickle.load(open(cfg.predefined_pkl, 'rb'), encoding='iso-8859-1')).astype(np.float32) 140 | print('num of pre-defined anchors: ', pdefined_anchors.shape) 141 | return pdefined_anchors 142 | 143 | def evaluate_on_FLMS(extractor, gnn, save_results=False): 144 | print('=' * 5, f'Evaluating on FLMS', '=' * 5) 145 | extractor.eval() 146 | gnn.eval() 147 | pdefined_anchors = get_pdefined_anchor() # n,4, (x1,y1,x2,y2) 148 | 149 | accum_disp = 0 150 | accum_iou = 0 151 | crop_cnt = 0 152 | alpha = 0.75 153 | alpha_cnt = 0 154 | cnt = 0 155 | 156 | if save_results: 157 | image_results = dict() 158 | result_dir = os.path.join(save_dir, 'FLMS') 159 | os.makedirs(result_dir, exist_ok=True) 160 | 161 | with torch.no_grad(): 162 | test_dataset= FLMSDataset() 163 | test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, 164 | shuffle=False, num_workers=cfg.num_workers, 165 | drop_last=False) 166 | for batch_idx, batch_data in enumerate(tqdm(test_loader)): 167 | im = batch_data[0].to(device) 168 | gt_crop = batch_data[1] # x1,y1,w,h 169 | width = batch_data[2].item() 170 | height = batch_data[3].item() 171 | image_file = batch_data[4][0] 172 | image_name = os.path.basename(image_file) 173 | 174 | rois = np.zeros((len(pdefined_anchors), 4), dtype=np.float32) 175 | rois[:, 0::2] = pdefined_anchors[:, 0::2] * im.shape[-1] 176 | rois[:, 1::2] = pdefined_anchors[:, 1::2] * im.shape[-2] 177 | rois = torch.from_numpy(rois).unsqueeze(0).to(device) # 1,n,4 178 | 179 | region_feat = extractor(im, rois) 180 | adj, scores = gnn(region_feat) 181 | scores = scores.reshape(-1) 182 | scores = scores.cpu().detach().numpy() 183 | idx = np.argmax(scores) 184 | 185 | pred_x1 = int(pdefined_anchors[idx][0] * width) 186 | pred_y1 = int(pdefined_anchors[idx][1] * height) 187 | pred_x2 = int(pdefined_anchors[idx][2] * width) 188 | pred_y2 = int(pdefined_anchors[idx][3] * height) 189 | pred_crop = torch.tensor([[pred_x1, pred_y1, pred_x2, pred_y2]]) 190 | gt_crop = gt_crop.reshape(-1, 4) 191 | iou, disp = compute_iou_and_disp(gt_crop, pred_crop, width, height) 192 | if iou >= alpha: 193 | alpha_cnt += 1 194 | accum_iou += iou 195 | accum_disp += disp 196 | cnt += 1 197 | 198 | if save_results: 199 | image_results[image_name] = [pred_x1, pred_y1, pred_x2, pred_y2] 200 | src_img = cv2.imread(image_file) 201 | pred_crop = src_img[pred_y1: pred_y2, pred_x1 : pred_x2] 202 | result_file = os.path.join(result_dir, image_name) 203 | cv2.imwrite(result_file, pred_crop) 204 | if save_results: 205 | with open(os.path.join(save_dir, 'FLMS.json'), 'w') as f: 206 | json.dump(image_results, f) 207 | avg_iou = accum_iou / cnt 208 | avg_disp = accum_disp / (cnt * 4.0) 209 | avg_recall = float(alpha_cnt) / cnt 210 | print('Test on {} images, IoU={:.4f}, Disp={:.4f}, recall={:.4f}(iou>={:.2f})'.format( 211 | cnt, avg_iou, avg_disp, avg_recall, alpha 212 | )) 213 | return avg_iou, avg_disp 214 | 215 | 216 | if __name__ == '__main__': 217 | extractor = RegionFeatureExtractor(loadweight=False) 218 | extractor_weight = './pretrained_model/extractor-best-srcc.pth' 219 | extractor.load_state_dict(torch.load(extractor_weight)) 220 | extractor = extractor.to(device).eval() 221 | 222 | gnn = CroppingGraph() 223 | gnn_weight = './pretrained_model/gnn-best-srcc.pth' 224 | gnn.load_state_dict(torch.load(gnn_weight)) 225 | gnn = gnn.eval().to(device) 226 | evaluate_on_GAICD(extractor, gnn, save_results=True) 227 | evaluate_on_FLMS(extractor, gnn, save_results=True) 228 | 229 | 230 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from tensorboardX import SummaryWriter 4 | import torch 5 | import time 6 | import datetime 7 | import csv 8 | import shutil 9 | import random 10 | import torch.utils.data as data 11 | import math 12 | 13 | from croppingModel import RegionFeatureExtractor,CroppingGraph 14 | from croppingModel import cropping_rank_loss, cropping_regression_loss, score_feature_correlation 15 | from cropping_dataset import GAICDataset 16 | from config import cfg 17 | from test import evaluate_on_FLMS, evaluate_on_GAICD 18 | 19 | # os.environ['CUDA_LAUNCH_BLOCKING'] = "1" 20 | 21 | device = torch.device('cuda:{}'.format(cfg.gpu_id)) 22 | torch.cuda.set_device(cfg.gpu_id) 23 | MOS_MEAN = 2.95 24 | SEED = 0 25 | torch.manual_seed(SEED) 26 | np.random.seed(SEED) 27 | random.seed(SEED) 28 | 29 | def create_dataloader(): 30 | dataset = GAICDataset(split='train') 31 | if cfg.keep_aspect_ratio: 32 | assert cfg.batch_size == 1, 'batch size must be 1 when keeping image aspect ratio' 33 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=cfg.batch_size, 34 | shuffle=True, num_workers=cfg.num_workers, 35 | drop_last=False, worker_init_fn=random.seed(SEED)) 36 | print('training set has {} samples, {} batches'.format(len(dataset), len(dataloader))) 37 | return dataloader 38 | 39 | class Trainer: 40 | def __init__(self, feature_extractor, cropping_gnn): 41 | self.extractor = feature_extractor 42 | self.gnn = cropping_gnn 43 | self.epoch = 0 44 | self.iters = 0 45 | self.max_epoch = cfg.max_epoch 46 | self.writer = SummaryWriter(log_dir=cfg.log_dir) 47 | self.optimizer, self.lr_scheduler = self.get_optimizer() 48 | self.train_loader = create_dataloader() 49 | self.eval_results = [] 50 | self.best_results = {'srcc': 0, 'acc5': 0., 'acc10': 0., 51 | 'FLMS_iou':0., 'FLMS_disp':1.} 52 | 53 | def get_optimizer(self): 54 | params = [ 55 | {'params': self.extractor.parameters(), 'lr': cfg.lr}, 56 | {'params': self.gnn.parameters(), 'lr': cfg.lr} 57 | ] 58 | optimizer = torch.optim.Adam( 59 | params, weight_decay=cfg.weight_decay 60 | ) 61 | # warm_up_with_cosine_lr 62 | warm_up_epochs = 5 63 | warm_up_with_cosine_lr = lambda epoch: epoch / warm_up_epochs if epoch <= warm_up_epochs else 0.5 * ( 64 | math.cos((epoch - warm_up_epochs) / (self.max_epoch - warm_up_epochs) * math.pi) + 1) 65 | lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warm_up_with_cosine_lr) 66 | return optimizer, lr_scheduler 67 | 68 | def run(self): 69 | print(("======== Begin Training =========")) 70 | self.lr_scheduler.step() 71 | for epoch in range(self.max_epoch): 72 | self.epoch = epoch 73 | self.train() 74 | if (epoch+1) % cfg.eval_freq == 0 or epoch == (self.max_epoch-1): 75 | self.eval() 76 | self.record_eval_results() 77 | self.lr_scheduler.step() 78 | 79 | def train(self): 80 | self.extractor.train() 81 | self.gnn.train() 82 | start = time.time() 83 | batch_idx = 0 84 | running_reg_loss = 0. 85 | running_rank_loss = 0. 86 | running_adj_corr = 0. 87 | running_total_loss = 0. 88 | total_batch = len(self.train_loader) 89 | data_iter = iter(self.train_loader) 90 | view_per_image = 64 91 | 92 | while batch_idx < total_batch: 93 | try: 94 | # torch.autograd.set_detect_anomaly(True) 95 | batch_idx += 1 96 | self.iters += 1 97 | batch_data = next(data_iter) 98 | im = batch_data[0].to(device) 99 | rois = batch_data[1].to(device) 100 | gt_scores = batch_data[2].to(device) 101 | 102 | random_ID = list(range(0, rois.shape[1])) 103 | random.shuffle(random_ID) 104 | chosen_ID = random_ID[:view_per_image] 105 | rois = rois[:,chosen_ID] 106 | gt_scores = gt_scores[:,chosen_ID] 107 | region_feat = self.extractor(im, rois) 108 | 109 | if random.uniform(0,1) <= 0.3: 110 | batch_data2 = next(data_iter) 111 | im2 = batch_data2[0].to(device) 112 | rois2 = batch_data2[1].to(device) 113 | gt_scores2 = batch_data2[2].to(device) 114 | 115 | random_ID = list(range(0, rois2.shape[1])) 116 | random.shuffle(random_ID) 117 | chosen_ID = random_ID[:view_per_image] 118 | 119 | rois2 = rois2[:,chosen_ID] 120 | gt_scores2 = gt_scores2[:, chosen_ID] 121 | region_feat2 = self.extractor(im2, rois2) 122 | region_feat = torch.cat([region_feat, region_feat2], dim=0) 123 | gt_scores = torch.cat([gt_scores, gt_scores2], dim=-1) 124 | random_ID = list(range(0, region_feat.shape[0])) 125 | random.shuffle(random_ID) 126 | chosen_ID = random_ID[:view_per_image] 127 | 128 | region_feat = region_feat[chosen_ID] 129 | gt_scores = gt_scores[:,chosen_ID] 130 | 131 | adj, pre_scores = self.gnn(region_feat) 132 | loss_reg = cropping_regression_loss(pre_scores, gt_scores, MOS_MEAN) 133 | loss_rank = cropping_rank_loss(pre_scores, gt_scores) 134 | adj_corr = score_feature_correlation(gt_scores, adj) 135 | total_loss= loss_reg + loss_rank - adj_corr 136 | 137 | self.optimizer.zero_grad() 138 | total_loss.backward() 139 | self.optimizer.step() 140 | 141 | running_reg_loss += loss_reg.item() 142 | running_rank_loss += loss_rank.item() 143 | running_adj_corr += adj_corr.item() 144 | running_total_loss+= total_loss.item() 145 | except StopIteration: 146 | data_iter = iter(self.train_loader) 147 | 148 | if batch_idx % cfg.display_freq == 0: 149 | avg_reg_loss = running_reg_loss / batch_idx 150 | avg_rank_loss = running_rank_loss / batch_idx 151 | avg_adj_corr = running_adj_corr / batch_idx 152 | avg_total_loss= running_total_loss/ batch_idx 153 | 154 | cur_lr = self.optimizer.param_groups[0]['lr'] 155 | self.writer.add_scalar('train/reg_loss', avg_reg_loss, self.iters) 156 | self.writer.add_scalar('train/rank_loss', avg_rank_loss, self.iters) 157 | self.writer.add_scalar('train/adj_corr', avg_adj_corr, self.iters) 158 | self.writer.add_scalar('train/total_loss', avg_total_loss, self.iters) 159 | self.writer.add_scalar('train/lr', cur_lr, self.iters) 160 | 161 | time_per_batch = (time.time() - start) / (batch_idx + 1.) 162 | last_batches = (self.max_epoch - self.epoch - 1) * total_batch + (total_batch - batch_idx - 1) 163 | last_time = int(last_batches * time_per_batch) 164 | time_str = str(datetime.timedelta(seconds=last_time)) 165 | 166 | print('=== epoch:{}/{}, step:{}/{} | Total_Loss:{:.4f} | Adj_Corr: {:.4f} | lr:{:.6f} | estimated last time:{} ==='.format( 167 | self.epoch, self.max_epoch, batch_idx, total_batch, avg_total_loss, avg_adj_corr, cur_lr, time_str 168 | )) 169 | 170 | def eval(self): 171 | srcc, acc5, acc10 = evaluate_on_GAICD(self.extractor, self.gnn) 172 | iou, disp = evaluate_on_FLMS(self.extractor, self.gnn) 173 | self.eval_results.append([self.epoch, srcc, acc5, acc10, iou, disp]) 174 | epoch_result = {'srcc': srcc, 'acc5': acc5, 'acc10': acc10, 175 | 'FLMS_iou': iou, 'FLMS_disp': disp} 176 | for m in self.best_results.keys(): 177 | update = False 178 | if ('disp' not in m) and (epoch_result[m] > self.best_results[m]): 179 | update = True 180 | elif ('disp' in m) and (epoch_result[m] < self.best_results[m]): 181 | update = True 182 | if update: 183 | self.best_results[m] = epoch_result[m] 184 | checkpoint_path = os.path.join(cfg.checkpoint_dir, 'extractor-best-{}.pth'.format(m)) 185 | torch.save(self.extractor.state_dict(), checkpoint_path) 186 | 187 | checkpoint_path = os.path.join(cfg.checkpoint_dir, 'gnn-best-{}.pth'.format(m)) 188 | torch.save(self.gnn.state_dict(), checkpoint_path) 189 | print('Update best {} model, best {}={:.4f}'.format(m, m, self.best_results[m])) 190 | if 'FLMS' in m: 191 | self.writer.add_scalar('eval_FLMS/{}'.format(m), epoch_result[m], self.epoch) 192 | else: 193 | self.writer.add_scalar('eval_GAICD/{}'.format(m), epoch_result[m], self.epoch) 194 | if m == 'srcc': 195 | self.writer.add_scalar('eval_GAICD/best-srcc', self.best_results[m], self.epoch) 196 | 197 | # if self.epoch % cfg.save_freq == 0: 198 | # checkpoint_path = os.path.join(cfg.checkpoint_dir, 'epoch-{}.pth'.format(self.epoch)) 199 | # torch.save(self.model.state_dict(), checkpoint_path) 200 | 201 | def record_eval_results(self): 202 | csv_path = os.path.join(cfg.exp_path, '..', '{}.csv'.format(cfg.exp_name)) 203 | header = ['epoch', 'srcc', 'acc5', 'acc10', 204 | 'FLMS_iou', 'FLMS_disp'] 205 | rows = [header] 206 | for i in range(len(self.eval_results)): 207 | new_results = [] 208 | for j in range(len(self.eval_results[i])): 209 | new_results.append(round(self.eval_results[i][j], 3)) 210 | self.eval_results[i] = new_results 211 | rows += self.eval_results 212 | metrics = [[] for i in header] 213 | for result in self.eval_results: 214 | for i, r in enumerate(result): 215 | metrics[i].append(r) 216 | for name, m in zip(header, metrics): 217 | if name == 'epoch': 218 | continue 219 | index = m.index(max(m)) 220 | if 'disp' in name: 221 | index = m.index(min(m)) 222 | title = 'best {}(epoch-{})'.format(name, index) 223 | row = [l[index] for l in metrics] 224 | row[0] = title 225 | rows.append(row) 226 | with open(csv_path, 'w') as f: 227 | cw = csv.writer(f) 228 | cw.writerows(rows) 229 | print('Save result to ', csv_path) 230 | 231 | if __name__ == '__main__': 232 | cfg.create_path() 233 | for file in os.listdir('./'): 234 | if file.endswith('.py'): 235 | shutil.copy(file, cfg.exp_path) 236 | print('backup', file) 237 | FeatureExtractor = RegionFeatureExtractor(loadweight=True).to(device) 238 | GNN = CroppingGraph().to(device) 239 | trainer = Trainer(FeatureExtractor, GNN) 240 | trainer.run() --------------------------------------------------------------------------------