├── LICENSE ├── README.md ├── config ├── GAIC_config.py └── GAIC_params.yaml ├── dataset ├── candidate_generation.py └── cropping_dataset.py ├── evaluate ├── demo.py └── test.py ├── networks ├── GAIC_model.py └── __init__.py ├── requirements.txt ├── result_images ├── 211958.jpg ├── 265813.jpg └── 297406.jpg ├── test_images ├── 211958.jpg ├── 265813.jpg └── 297406.jpg ├── train └── train.py └── untils ├── make_all.sh ├── rod_align ├── __init__.py ├── functions │ ├── __init__.py │ └── rod_align.py ├── make.sh ├── make_python2.sh ├── modules │ ├── __init__.py │ └── rod_align.py ├── setup.py └── src │ ├── rod_align.cpp │ ├── rod_align.h │ ├── rod_align_cuda.cpp │ ├── rod_align_cuda.h │ ├── rod_align_kernel.cu │ └── rod_align_kernel.h └── roi_align ├── __init__.py ├── functions ├── __init__.py └── roi_align.py ├── make.sh ├── make_python2.sh ├── modules ├── __init__.py └── roi_align.py ├── setup.py └── src ├── roi_align.cpp ├── roi_align.h ├── roi_align_cuda.cpp ├── roi_align_cuda.h ├── roi_align_kernel.cu └── roi_align_kernel.h /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 bo-zhang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GAIC-PyTorch 2 | This is an unofficial PyTorch implementation of [Grid Anchor based Image Cropping: A New Benchmark and An Efficient Model](https://arxiv.org/pdf/1909.08989.pdf?ref=https://githubhelp.com), which is the journal version and extension of [Reliable and Efficient Image Cropping: A Grid Anchor based Approach](https://arxiv.org/pdf/1904.04441.pdf). 3 | We provide demo code to produce the best cropping results with different aspect ratios (1:1, 4:3, and 16:9) for arbitrary test images. 4 | Moreover, this code is also able to generate crops with arbitrary specified aspect ratios. 5 | 6 |
7 | 8 |
9 |
10 | 11 |
12 |
13 | 14 |
15 | 16 | ### Requirements 17 | - PyTorch>=1.0 18 | 19 | You can install packages using pip according to [``requirements.txt``](./requirements.txt): 20 | 21 | ```Shell 22 | pip install -r requirements.txt 23 | ``` 24 | 25 | ### Installation 26 | 1. Get the code. We will call the directory that you cloned into `$PRJ_ROOT`. 27 | ```Shell 28 | git clone https://github.com/bo-zhang-cs/GAIC-Pytorch.git 29 | ``` 30 | 31 | 2. Build the RoI&RoDAlign libraries. The source code of RoI&RoDAlign is from [[here]](https://github.com/lld533/Grid-Anchor-based-Image-Cropping-Pytorch) compatible with PyTorch 1.0 or later. If you use Pytorch 0.4.1, please refer to [[official implementation]](https://github.com/HuiZeng/Grid-Anchor-based-Image-Cropping-Pytorch). 32 | ```Shell 33 | cd $PRJ_ROOT/untils 34 | # Change the **CUDA_HOME** and **-arch=sm_86** in ``roi_align/make.sh`` and ``rod_align/make.sh`` according to your enviroment, respectively. 35 | # Make sure these bash files (``make_all.sh, roi_align/make.sh, rod_align/make.sh``) are Unix text file format by runing ``:set ff=unix`` in VIM. 36 | sudo bash make_all.sh 37 | ``` 38 | 39 | ### Running Demo 40 | 1. Download pretrained models(~98MB) from [[Google Drive]](https://drive.google.com/file/d/1U_8C9oWOBT64LHtxZP1_0Uo9Ndw0QLsJ/view?usp=sharing) [[Baidu Cloud]](https://pan.baidu.com/s/1fmy18FD5_0v6vrab6OZKmQ)(access code: *webf*). By default, we assume the models (``*.pth``) is stored in `$PRJ_ROOT/pretrained_models`. 41 | 42 | 2. Predict best crops for the user's images. 43 | ```Shell 44 | cd $PRJ_ROOT 45 | python evaluate/demo.py --gpu 0 --image_dir $IMAGE_PATH/IMAGE_FOLDER --save_dir $RESULT_FOLDER 46 | # or execute python evaluate/demo.py to predict best crops for the images in the test_images folder. 47 | ``` 48 | 49 | ### Train/Evaluate 50 | 1. Download [GAICD dataset](https://github.com/HuiZeng/Grid-Anchor-based-Image-Cropping). And set the ``GAIC_folder`` in ``config.py`` and you can check the paths by running: 51 | ```Shell 52 | cd $PRJ_ROOT 53 | python dataset/cropping_dataset.py 54 | ``` 55 | 56 | 2. Train your model and evaluate the model on the fly. 57 | ```Shell 58 | cd $PRJ_ROOT 59 | python train/train.py --gpu 0 --backbone vgg16 60 | # or just running python train/train.py 61 | ``` 62 | The model performance for each epoch is also recorded in ``*.csv`` file under the produced ``experiments`` folder. 63 | More model parameters and experimental settings can be found in ``config/GAIC_params.yaml``. 64 | 65 | 3. Evaluate the pretrained model and reproduce the below quantitative results. 66 | ```Shell 67 | cd $PRJ_ROOT 68 | python evaluate/test.py 69 | ``` 70 | 71 | ### Performance on GAICD Dataset 72 | | #Metric | Backbone | SRCC↑ | PCC↑ | Acc5↑ |Acc1_5↑ |Acc4_5↑ |Acc10↑ |Acc1_10↑ |Acc4_10↑| 73 | |:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:| 74 | | Paper | VGG16 | 0.777 | 0.800 | - | 60.5 | 50.2 | - | 77.5 | 70.6 | 75 | | This code | VGG16 | 0.778 | 0.808 | 54.8 | 60.0 | 51.6 | 73.3 | 77.5 | 69.9 | 76 | | Paper | MobileNetV2 | 0.783 | 0.806 | - | **62.5** | 52.5 | - | 78.5 | **72.3**| 77 | | This code | MobileNetV2 | 0.783 | 0.810 | **58.0** | **62.5** | **53.0** | **74.2** | **80.0** | 72.0| 78 | | Paper | ShuffleNetV2 | 0.774 | 0.801 | - | 61.5 | 52.0 | - | 78.5 | 71.3 | 79 | | This code | ShuffleNetV2 | **0.787** | **0.811** | 55.4 | 62.0 | 50.0 | 73.4 | 78.0 | 69.6 | 80 | 81 | ### Citation 82 | ``` 83 | @inproceedings{zhang2019deep, 84 | title={Reliable and Efficient Image Cropping: A Grid Anchor based Approach}, 85 | author={Zeng, Hui, Li, Lida, Cao, Zisheng and Zhang, Lei}, 86 | booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, 87 | year={2019} 88 | } 89 | @article{zeng2020cropping, 90 | title={Grid Anchor based Image Cropping: A New Benchmark and An Efficient Model}, 91 | author={Zeng, Hui and Li, Lida and Cao, Zisheng and Zhang, Lei}, 92 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence}, 93 | volume={}, 94 | number={}, 95 | pages={}, 96 | year={2020}, 97 | publisher={IEEE} 98 | } 99 | ``` 100 | -------------------------------------------------------------------------------- /config/GAIC_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | 4 | yaml_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'GAIC_params.yaml') 5 | assert os.path.exists(yaml_path), yaml_path 6 | 7 | def load_yaml_params(): 8 | with open(yaml_path, 'r') as yaml_file: 9 | params = yaml.full_load(yaml_file.read()) 10 | return params 11 | 12 | def refresh_yaml_params(args): 13 | yaml_params = load_yaml_params() 14 | for arg in vars(args): 15 | # print(arg, type(arg), getattr(args, arg)) 16 | assert arg in yaml_params, arg 17 | yaml_params[arg] = getattr(args, arg) 18 | 19 | with open(yaml_path, 'w') as yaml_file: 20 | yaml.dump(yaml_params, yaml_file) 21 | 22 | class Config: 23 | data_root = '../../dataset' 24 | GAIC_folder = os.path.join(data_root, 'GAICD') 25 | 26 | def __init__(self): 27 | self.refresh_params() 28 | 29 | def refresh_params(self): 30 | self.load_params_from_yaml() 31 | self.generate_path() 32 | 33 | def load_params_from_yaml(self): 34 | # add parameters from yaml file 35 | names = self.__dict__ 36 | params = load_yaml_params() 37 | for k, v in params.items(): 38 | # print(v, type(v)) 39 | names[k] = v 40 | 41 | def generate_path(self): 42 | prefix = 'GAIC-{}-re{}dim'.format(self.backbone, self.reddim) 43 | exp_root = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'experiments') 44 | exp_name = prefix 45 | exp_path = os.path.join(exp_root, prefix) 46 | while os.path.exists(exp_path): 47 | index = os.path.basename(exp_path).split(prefix)[-1].split('repeat')[-1] 48 | try: 49 | index = int(index) + 1 50 | except: 51 | index = 1 52 | exp_name = prefix + ('_repeat{}'.format(index)) 53 | exp_path = os.path.join(exp_root, exp_name) 54 | # print('Experiment name {} \n'.format(os.path.basename(exp_path))) 55 | self.exp_name = exp_name 56 | self.exp_path = exp_path 57 | self.checkpoint_dir = os.path.join(exp_path, 'checkpoints') 58 | self.log_dir = os.path.join(exp_path, 'logs') 59 | self.code_dir = os.path.join(exp_path, 'code') 60 | 61 | def create_path(self): 62 | print('Create experiment directory: ', self.exp_path) 63 | os.makedirs(self.exp_path) 64 | os.makedirs(self.checkpoint_dir) 65 | os.makedirs(self.log_dir) 66 | os.makedirs(self.code_dir) 67 | 68 | cfg = Config() -------------------------------------------------------------------------------- /config/GAIC_params.yaml: -------------------------------------------------------------------------------- 1 | alignsize: 9 2 | backbone: vgg16 3 | batch_size: 1 4 | data_augmentation: true 5 | display_freq: 200 6 | eval_freq: 1 7 | gpu_id: 0 8 | image_size: 9 | - 256 10 | - 256 11 | keep_aspect_ratio: true 12 | lr: 1.0e-4 13 | lr_decay: 0.1 14 | lr_decay_epoch: [81] 15 | max_epoch: 80 16 | num_workers: 8 17 | reddim: 32 18 | save_freq: 81 19 | weight_decay: 1.0e-4 20 | -------------------------------------------------------------------------------- /dataset/candidate_generation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | # this code is modified from https://github.com/HuiZeng/Grid-Anchor-based-Image-Cropping-Pytorch 4 | def generate_anchors(im_w, im_h, bins=12): 5 | assert (im_w > 100) and (im_h > 100), (im_w, im_h) 6 | step_h = im_h / bins 7 | step_w = im_w / bins 8 | pdefined_anchors = [] 9 | for x1 in range(0, int(bins / 3)): 10 | for y1 in range(0, int(bins / 3)): 11 | for x2 in range(int(bins / 3 * 2), bins): 12 | for y2 in range(int(bins / 3 * 2), bins): 13 | area = (x2 - x1) * (y2 - y1) / float(bins * bins) 14 | aspect_ratio = (y2 - y1) * step_h / ((x2 - x1) * step_w) 15 | if area > 0.4999 and aspect_ratio > 0.5 and aspect_ratio < 2.0: 16 | crop_x1 = int(step_w * (0.5+x1)) 17 | crop_y1 = int(step_h * (0.5+y1)) 18 | crop_x2 = int(step_w * (0.5 + x2)) 19 | crop_y2 = int(step_h * (0.5+y2)) 20 | pdefined_anchors.append([crop_x1, crop_y1, crop_x2, crop_y2]) 21 | pdefined_anchors = np.array(pdefined_anchors).reshape(-1,4) 22 | # print('image size:({},{}), obtain {} pre-defined anchors.'.format( 23 | # im_w, im_h, pdefined_anchors.shape[0])) 24 | return pdefined_anchors 25 | 26 | 27 | def generate_anchors_aspect_ratio_specific(im_w, im_h, aspect_ratio, bins=20): 28 | assert (im_w > 100) and (im_h > 100), (im_w, im_h) 29 | assert isinstance(aspect_ratio, tuple), \ 30 | 'undefined aspect ratio type: {}'.format(aspect_ratio) 31 | assert aspect_ratio[0] >= 1 and aspect_ratio[1] >= 1, \ 32 | 'undefined aspect ratio type: {}'.format(aspect_ratio) 33 | w_step, h_step = int(aspect_ratio[0]), int(aspect_ratio[1]) 34 | 35 | max_step = int(min(im_w / w_step, im_h / h_step)) 36 | # limit the search space by increasing the step size 37 | if max_step > bins: 38 | scale = int(max(im_w / w_step / bins, im_h / h_step / bins)) 39 | h_step *= scale 40 | w_step *= scale 41 | max_step = int(min(im_w / w_step, im_h / h_step)) 42 | # print('image_size:{}, aspect_ratio: {}, step:{}, max_steps:{}'.format( 43 | # (im_w, im_h), aspect_ratio, (w_step, h_step), max_step)) 44 | min_step = int(max_step / 2. - 1) 45 | pdefined_anchors = [] 46 | for i in range(min_step, max_step): 47 | out_h = h_step * i 48 | out_w = w_step * i 49 | if out_h < im_h and out_w < im_w and (out_w * out_h > 0.4 * im_w * im_h): 50 | for w_start in range(0, im_w - out_w, w_step): 51 | for h_start in range(0, im_h - out_h, h_step): 52 | x1 = int(w_start) 53 | y1 = int(h_start) 54 | x2 = int(w_start + out_w - 1) 55 | y2 = int(h_start + out_h - 1) 56 | pdefined_anchors.append([x1, y1, x2, y2]) 57 | pdefined_anchors = np.array(pdefined_anchors).reshape(-1, 4) 58 | # print('aspect-ratio:{}, image size:({},{}), obtain {} pre-defined anchors.'.format( 59 | # aspect_ratio, im_w, im_h, pdefined_anchors.shape[0])) 60 | return pdefined_anchors 61 | 62 | if __name__ == '__main__': 63 | print(generate_anchors(384, 256)) 64 | # print(generate_anchors_aspect_ratio_specific(128, 128, (1, 1), bins=20)) 65 | # print(generate_anchors_aspect_ratio_specific(512, 512, (3, 4), bins=20)) 66 | # print(generate_anchors_aspect_ratio_specific(512, 512, (16, 9), bins=20)) -------------------------------------------------------------------------------- /dataset/cropping_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image, ImageOps 4 | from torch.utils.data import DataLoader, Dataset 5 | import torchvision.transforms as transforms 6 | import random 7 | import cv2 8 | import json 9 | from config.GAIC_config import cfg 10 | 11 | MOS_MEAN = 2.95 12 | MOS_STD = 0.8 13 | IMAGE_NET_MEAN = [0.485, 0.456, 0.406] 14 | IMAGE_NET_STD = [0.229, 0.224, 0.225] 15 | 16 | def rescale_crops(boxes, ratio_w, ratio_h): 17 | boxes = np.array(boxes).reshape(-1, 4) 18 | boxes[:, 0] = np.floor(boxes[:, 0] * ratio_w) 19 | boxes[:, 1] = np.floor(boxes[:, 1] * ratio_h) 20 | boxes[:, 2] = np.ceil(boxes[:, 2] * ratio_w) 21 | boxes[:, 3] = np.ceil(boxes[:, 3] * ratio_h) 22 | return boxes.astype(np.float32) 23 | 24 | def is_number(s): 25 | if not isinstance(s, str): 26 | return False 27 | if s.isdigit(): 28 | return True 29 | else: 30 | try: 31 | float(s) 32 | return True 33 | except: 34 | return False 35 | 36 | class GAICDataset(Dataset): 37 | def __init__(self, split): 38 | self.split = split 39 | assert self.split in ['train', 'test'], self.split 40 | self.keep_aspect = cfg.keep_aspect_ratio 41 | self.data_dir = cfg.GAIC_folder 42 | assert os.path.exists(self.data_dir), self.data_dir 43 | self.image_dir = os.path.join(self.data_dir, 'images', split) 44 | assert os.path.exists(self.image_dir), self.image_dir 45 | self.image_list = [file for file in os.listdir(self.image_dir) if file.endswith('.jpg')] 46 | # print('GAICD {} set contains {} images'.format(split, len(self.image_list))) 47 | self.anno_dir = os.path.join(self.data_dir, 'annotations') 48 | assert os.path.exists(self.anno_dir), self.anno_dir 49 | self.annos = self.parse_annotations() 50 | 51 | self.image_size = cfg.image_size 52 | self.augmentation = (cfg.data_augmentation and self.split == 'train') 53 | self.PhotometricDistort = transforms.ColorJitter( 54 | brightness=0.125, contrast=0.5, saturation=0.5, hue=0.05) 55 | self.image_transformer = transforms.Compose([ 56 | transforms.ToTensor(), 57 | transforms.Normalize(mean=IMAGE_NET_MEAN, std=IMAGE_NET_STD)]) 58 | 59 | def parse_annotations(self): 60 | image_annos = dict() 61 | for image_name in self.image_list: 62 | anno_file = os.path.join(self.anno_dir, image_name.replace('.jpg', '.txt')) 63 | assert os.path.exists(anno_file), anno_file 64 | with open(anno_file, 'r') as f: 65 | crops,scores = [],[] 66 | for line in f.readlines(): 67 | line = line.strip().split(' ') 68 | values = [s for s in line if is_number(s)] 69 | y1,x1,y2,x2 = [int(s) for s in values[0:4]] 70 | s = float(values[-1]) 71 | if s > -2: 72 | crops.append([x1,y1,x2,y2]) 73 | scores.append(s) 74 | if len(crops) == 0: 75 | print(image_name, anno_file) 76 | else: 77 | # rank all crops 78 | rank = sorted(range(len(scores)), key=lambda k: scores[k], reverse=True) 79 | scores = [scores[i] for i in rank] 80 | crops = [crops[i] for i in rank] 81 | image_annos[image_name] = {'crops':crops, 'scores':scores} 82 | return image_annos 83 | 84 | def __len__(self): 85 | return len(self.image_list) 86 | 87 | def __getitem__(self, index): 88 | image_name = self.image_list[index] 89 | image_file = os.path.join(self.image_dir, image_name) 90 | image = Image.open(image_file).convert('RGB') 91 | im_width, im_height = image.size 92 | if self.keep_aspect: 93 | scale = float(cfg.image_size[0]) / min(im_height, im_width) 94 | h = round(im_height * scale / 32.0) * 32 95 | w = round(im_width * scale / 32.0) * 32 96 | else: 97 | h = cfg.image_size[1] 98 | w = cfg.image_size[0] 99 | resized_image = image.resize((w, h), Image.ANTIALIAS) 100 | crop = self.annos[image_name]['crops'] 101 | rs_width, rs_height = resized_image.size 102 | ratio_w = float(rs_width) / im_width 103 | ratio_h = float(rs_height) / im_height 104 | crop = rescale_crops(crop, ratio_w, ratio_h) 105 | score = np.array(self.annos[image_name]['scores']).reshape((-1)).astype(np.float32) 106 | if self.augmentation: 107 | if random.uniform(0,1) > 0.5: 108 | resized_image = ImageOps.mirror(resized_image) 109 | temp_x1 = crop[:, 0].copy() 110 | crop[:, 0] = rs_width - crop[:, 2] 111 | crop[:, 2] = rs_width - temp_x1 112 | resized_image = self.PhotometricDistort(resized_image) 113 | im = self.image_transformer(resized_image) 114 | return im, crop, score, im_width, im_height, image_file 115 | 116 | if __name__ == '__main__': 117 | GAICD_testset = GAICDataset(split='train') 118 | print('GAICD training set has {} images'.format(len(GAICD_testset))) 119 | dataloader = DataLoader(GAICD_testset, batch_size=1, num_workers=0) 120 | for batch_idx, data in enumerate(dataloader): 121 | im, crops, scores, w, h, file = data 122 | print(im.shape, crops.shape, scores.shape, w.shape, h.shape) -------------------------------------------------------------------------------- /evaluate/demo.py: -------------------------------------------------------------------------------- 1 | import sys,os 2 | PROJECT_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 3 | sys.path.insert(0, PROJECT_PATH) 4 | import numpy as np 5 | import torch 6 | from networks.GAIC_model import build_crop_model 7 | import argparse 8 | import cv2 9 | from PIL import Image 10 | from dataset.candidate_generation import generate_anchors, generate_anchors_aspect_ratio_specific 11 | import torchvision.transforms as transforms 12 | import matplotlib.pyplot as plt 13 | import warnings 14 | warnings.filterwarnings('ignore') 15 | 16 | IMAGE_NET_MEAN = [0.485, 0.456, 0.406] 17 | IMAGE_NET_STD = [0.229, 0.224, 0.225] 18 | image_transformer = transforms.Compose([ 19 | transforms.ToTensor(), 20 | transforms.Normalize(mean=IMAGE_NET_MEAN, std=IMAGE_NET_STD)]) 21 | 22 | 23 | def parse_args(): 24 | parser = argparse.ArgumentParser(description="Run cropping model on images") 25 | parser.add_argument('--gpu', type=int, dest='gpu_id', 26 | help='gpu_id', default=0) 27 | parser.add_argument('--backbone', type=str, choices=['vgg16', 'mobilenetv2', 'shufflenetv2'], 28 | help='the architecture of backbone network', default='vgg16') 29 | parser.add_argument('--image_dir', type=str, default='test_images', 30 | help='the directory of test images') 31 | parser.add_argument('--save_dir', type=str, default='result_images', 32 | help='the directory of saving resulting images') 33 | args = parser.parse_args() 34 | assert os.path.exists(args.image_dir), args.image_dir 35 | os.makedirs(args.save_dir, exist_ok=True) 36 | return args 37 | 38 | def build_network(backbone): 39 | if backbone in ['vgg16','shufflenetv2']: 40 | reddim = 32 41 | elif backbone == 'mobilenetv2': 42 | reddim = 16 43 | else: 44 | raise Exception('undefined backbone architecture', backbone) 45 | net = build_crop_model(scale='multi', alignsize=9, reddim=reddim, 46 | loadweight=False, model=backbone) 47 | weights_path = 'pretrained_models/GAIC-{}-reddim{}.pth'.format(backbone, reddim) 48 | assert os.path.exists(weights_path), weights_path 49 | print('load pretrained weights from ', weights_path) 50 | net.load_state_dict(torch.load(weights_path)) 51 | return net 52 | 53 | def get_image_list(args): 54 | img_list = [] 55 | if os.path.isdir(args.image_dir): 56 | for file in os.listdir(args.image_dir): 57 | if file.endswith(('.jpg', '.png', '.jpeg', '.bmp')): 58 | img_list.append(os.path.join(args.image_dir, file)) 59 | else: 60 | if args.image_dir.endswith(('.jpg', '.png', '.jpeg', '.bmp')): 61 | img_list.append(args.image_dir) 62 | print('find total {} images'.format(len(img_list))) 63 | return img_list 64 | 65 | def image_preprocessing(im): 66 | im_width,im_height = im.size 67 | scale = 256. / min(im_height, im_width) 68 | h = round(im_height * scale / 32.0) * 32 69 | w = round(im_width * scale / 32.0) * 32 70 | resized_image = im.resize((w, h), Image.ANTIALIAS) 71 | im_tensor = image_transformer(resized_image).unsqueeze(0) 72 | return im_tensor 73 | 74 | def predict_best_crop(model, im_tensor, anchors, im): 75 | im_width, im_height = im.size 76 | with torch.no_grad(): 77 | rois = anchors.astype(np.float32) 78 | rois = torch.from_numpy(rois).unsqueeze(0).to(im_tensor.device) 79 | scores = model(im_tensor, rois) 80 | scores = scores.detach().cpu().numpy().reshape(-1) 81 | pr_idx = np.argmax(scores) 82 | # mapping the coordinates of predefined anchors to source image 83 | rescale_anchors = anchors.astype(np.float32) 84 | rescale_anchors[:,0::2] = rescale_anchors[:,0::2] / im_tensor.shape[-1] * im_width 85 | rescale_anchors[:,1::2] = rescale_anchors[:,1::2] / im_tensor.shape[-2] * im_height 86 | rescale_anchors = rescale_anchors.astype(np.int32) 87 | pr_bbox = rescale_anchors[pr_idx].tolist() 88 | x1, y1, x2, y2 = pr_bbox 89 | pr_crop = im.crop((x1, y1, x2, y2)) 90 | # pr_crop = np.asarray(pr_crop)[:,:,::-1] # convert to opencv format 91 | return pr_crop, pr_bbox 92 | 93 | 94 | if __name__ == '__main__': 95 | args = parse_args() 96 | device = torch.device('cuda:{}'.format(args.gpu_id)) 97 | net = build_network(args.backbone) 98 | net = net.eval().to(device) 99 | img_list = get_image_list(args) 100 | 101 | for i,img in enumerate(img_list): 102 | im_name = os.path.basename(img) 103 | src = Image.open(img).convert('RGB') 104 | src_tensor = image_preprocessing(src).to(device) 105 | input_w, input_h = src_tensor.shape[-1], src_tensor.shape[-2] 106 | # generate aspect-ratio-agnostic crops 107 | anchors = generate_anchors(input_w, input_h) 108 | best_crop, bbox = predict_best_crop(net, src_tensor, anchors, src) 109 | print('source image:{} {}, num_candidates:{}, best crop bbox:{}, crop(w,h):{}'.format( 110 | im_name, src.size, anchors.shape[0], bbox, best_crop.size)) 111 | 112 | # generage aspect-ratio-specific crops 113 | anchors_1_1 = generate_anchors_aspect_ratio_specific(input_w, input_h, (1,1), bins=30) 114 | crop_1_1, bbox_1_1 = predict_best_crop(net, src_tensor, anchors_1_1, src) 115 | print('aspect_ratio=1:1, num_candidates:{}, best crop bbox:{}, crop(w,h):{}'.format( 116 | anchors_1_1.shape[0], bbox_1_1, crop_1_1.size)) 117 | 118 | anchors_4_3 = generate_anchors_aspect_ratio_specific(input_w, input_h, (4,3), bins=20) 119 | crop_4_3, bbox_4_3 = predict_best_crop(net, src_tensor, anchors_4_3, src) 120 | print('aspect_ratio=4:3, num_candidates:{}, best crop bbox:{}, crop(w,h):{}'.format( 121 | anchors_4_3.shape[0], bbox_4_3, crop_4_3.size)) 122 | 123 | anchors_16_9 = generate_anchors_aspect_ratio_specific(input_w, input_h, (16,9), bins=15) 124 | crop_16_9, bbox_16_9 = predict_best_crop(net, src_tensor, anchors_16_9, src) 125 | print('aspect_ratio=16:9, num_candidates:{}, best crop bbox:{}, crop(w,h):{}'.format( 126 | anchors_16_9.shape[0], bbox_16_9, crop_16_9.size)) 127 | 128 | # results visualization 129 | crop_list = [src, best_crop, crop_1_1, crop_4_3, crop_16_9] 130 | title_list = ['source image', 'best crop', '1:1', '4:3', '16:9'] 131 | fig_cols = 5 132 | fig_rows = (len(crop_list) + fig_cols - 1) // fig_cols 133 | fig = plt.figure(figsize=(20,5)) 134 | for i in range(len(crop_list)): 135 | ax = fig.add_subplot(fig_rows, fig_cols, i+1) 136 | ax.imshow(crop_list[i]) 137 | ax.set_axis_off() 138 | ax.set_title(title_list[i]) 139 | fig.tight_layout() 140 | result_file = os.path.join(args.save_dir, im_name) 141 | plt.savefig(result_file) 142 | plt.close() 143 | print('Save results to ', result_file) 144 | print() 145 | 146 | -------------------------------------------------------------------------------- /evaluate/test.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 3 | import numpy as np 4 | import torch 5 | from tqdm import tqdm 6 | import pickle 7 | from scipy.stats import spearmanr, pearsonr 8 | import math 9 | from dataset.cropping_dataset import GAICDataset 10 | from config.GAIC_config import cfg 11 | from networks.GAIC_model import build_crop_model 12 | from thop import profile 13 | 14 | def compute_acc(gt_scores, pr_scores): 15 | assert (len(gt_scores) == len(pr_scores)), '{} vs. {}'.format(len(gt_scores), len(pr_scores)) 16 | sample_cnt = 0 17 | acc4_5 = [0 for i in range(4)] 18 | acc4_10 = [0 for i in range(4)] 19 | for i in range(len(gt_scores)): 20 | gts, preds = gt_scores[i], pr_scores[i] 21 | id_gt = sorted(range(len(gts)), key=lambda j : gts[j], reverse=True) 22 | id_pr = sorted(range(len(preds)), key=lambda j : preds[j], reverse=True) 23 | for k in range(4): 24 | temp_acc4_5 = 0. 25 | temp_acc4_10 = 0. 26 | for j in range(k+1): 27 | if gts[id_pr[j]] >= gts[id_gt[4]]: 28 | temp_acc4_5 += 1.0 29 | if gts[id_pr[j]] >= gts[id_gt[9]]: 30 | temp_acc4_10 += 1.0 31 | acc4_5[k] += (temp_acc4_5 / (k+1.0)) 32 | acc4_10[k] += ((temp_acc4_10) / (k+1.0)) 33 | sample_cnt += 1 34 | acc4_5 = [round(i / sample_cnt,3) for i in acc4_5] 35 | acc4_10 = [round(i / sample_cnt,3) for i in acc4_10] 36 | # print('acc4_5', acc4_5) 37 | # print('acc4_10', acc4_10) 38 | return acc4_5, acc4_10 39 | 40 | def evaluate_on_GAICD_official(model): 41 | # https://github.com/HuiZeng/Grid-Anchor-based-Image-Cropping-Pytorch 42 | model.eval() 43 | device = next(model.parameters()).device 44 | print('='*5, 'Evaluating on GAICD dataset', '='*5) 45 | count = 0 46 | test_dataset = GAICDataset(split='test') 47 | test_loader = torch.utils.data.DataLoader( 48 | test_dataset, batch_size=1, 49 | shuffle=False, num_workers=cfg.num_workers, 50 | drop_last=False) 51 | acc4_5 = [] 52 | acc4_10 = [] 53 | wacc4_5 = [] 54 | wacc4_10 = [] 55 | srcc = [] 56 | pcc = [] 57 | for n in range(4): 58 | acc4_5.append(0) 59 | acc4_10.append(0) 60 | wacc4_5.append(0) 61 | wacc4_10.append(0) 62 | 63 | with torch.no_grad(): 64 | for batch_idx, batch_data in enumerate(tqdm(test_loader)): 65 | im = batch_data[0].to(device) 66 | rois = batch_data[1].to(device) 67 | MOS = batch_data[2].reshape(-1,1) 68 | width = batch_data[3] 69 | height = batch_data[4] 70 | count += im.shape[0] 71 | 72 | out = model(im, rois) 73 | id_MOS = sorted(range(len(MOS)), key=lambda k: MOS[k], reverse=True) 74 | id_out = sorted(range(len(out)), key=lambda k: out[k], reverse=True) 75 | 76 | rank_of_returned_crop = [] 77 | for k in range(4): 78 | rank_of_returned_crop.append(id_MOS.index(id_out[k])) 79 | 80 | for k in range(4): 81 | temp_acc_4_5 = 0.0 82 | temp_acc_4_10 = 0.0 83 | for j in range(k + 1): 84 | if MOS[id_out[j]] >= MOS[id_MOS[4]]: 85 | temp_acc_4_5 += 1.0 86 | if MOS[id_out[j]] >= MOS[id_MOS[9]]: 87 | temp_acc_4_10 += 1.0 88 | acc4_5[k] += temp_acc_4_5 / (k + 1.0) 89 | acc4_10[k] += temp_acc_4_10 / (k + 1.0) 90 | 91 | for k in range(4): 92 | temp_wacc_4_5 = 0.0 93 | temp_wacc_4_10 = 0.0 94 | temp_rank_of_returned_crop = rank_of_returned_crop[:(k + 1)] 95 | temp_rank_of_returned_crop.sort() 96 | for j in range(k + 1): 97 | if temp_rank_of_returned_crop[j] <= 4: 98 | temp_wacc_4_5 += 1.0 * math.exp(-0.2 * (temp_rank_of_returned_crop[j] - j)) 99 | if temp_rank_of_returned_crop[j] <= 9: 100 | temp_wacc_4_10 += 1.0 * math.exp(-0.1 * (temp_rank_of_returned_crop[j] - j)) 101 | wacc4_5[k] += temp_wacc_4_5 / (k + 1.0) 102 | wacc4_10[k] += temp_wacc_4_10 / (k + 1.0) 103 | 104 | MOS_arr = [] 105 | out = torch.squeeze(out).cpu().detach().numpy() 106 | for k in range(len(MOS)): 107 | MOS_arr.append(MOS[k].numpy()[0]) 108 | srcc.append(spearmanr(MOS_arr, out)[0]) 109 | pcc.append(pearsonr(MOS_arr, out)[0]) 110 | 111 | for k in range(4): 112 | acc4_5[k] = acc4_5[k] / count 113 | acc4_10[k] = acc4_10[k] / count 114 | wacc4_5[k] = wacc4_5[k] / count 115 | wacc4_10[k] = wacc4_10[k] / count 116 | 117 | avg_srcc = sum(srcc) / count 118 | avg_pcc = sum(pcc) / count 119 | avg_acc5 = sum(acc4_5) / len(acc4_5) 120 | avg_acc10 = sum(acc4_10) / len(acc4_10) 121 | 122 | sys.stdout.write('Acc4_5:[%.3f, %.3f, %.3f, %.3f] Acc4_10:[%.3f, %.3f, %.3f, %.3f]\n' % ( 123 | acc4_5[0], acc4_5[1], acc4_5[2], acc4_5[3], acc4_10[0], acc4_10[1], acc4_10[2], acc4_10[3])) 124 | sys.stdout.write('WAcc4_5:[%.3f, %.3f, %.3f, %.3f] WAcc4_10:[%.3f, %.3f, %.3f, %.3f]\n' % ( 125 | wacc4_5[0], wacc4_5[1], wacc4_5[2], wacc4_5[3], wacc4_10[0], wacc4_10[1], wacc4_10[2], wacc4_10[3])) 126 | sys.stdout.write('[Avg SRCC: %.3f] [Avg PCC: %.3f] [Acc5: %.3f] [Acc10: %.3f]\n' % ( 127 | avg_srcc, avg_pcc, avg_acc5, avg_acc10)) 128 | return avg_srcc, avg_pcc, avg_acc5, avg_acc10, acc4_5, acc4_10 129 | 130 | def evaluate_on_GAICD(model): 131 | device = next(model.parameters()).device 132 | model.eval() 133 | print('='*5, 'Evaluating on GAICD dataset', '='*5) 134 | srcc_list = [] 135 | pcc_list = [] 136 | gt_scores = [] 137 | pr_scores = [] 138 | count = 0 139 | test_dataset = GAICDataset(split='test') 140 | test_loader = torch.utils.data.DataLoader( 141 | test_dataset, batch_size=1, 142 | shuffle=False, num_workers=cfg.num_workers, 143 | drop_last=False) 144 | test_results = dict() 145 | with torch.no_grad(): 146 | for batch_idx, batch_data in enumerate(tqdm(test_loader)): 147 | im = batch_data[0].to(device) 148 | rois = batch_data[1].to(device) 149 | scores = batch_data[2].cpu().numpy().reshape(-1) 150 | width = batch_data[3] 151 | height = batch_data[4] 152 | image_name = batch_data[5][0] 153 | count += im.shape[0] 154 | 155 | pre_scores = model(im, rois) 156 | pre_scores = pre_scores.cpu().detach().numpy().reshape(-1) 157 | srcc_list.append(spearmanr(scores, pre_scores)[0]) 158 | pcc_list.append(pearsonr(scores, pre_scores)[0]) 159 | gt_scores.append(scores) 160 | pr_scores.append(pre_scores) 161 | test_results[image_name] = pre_scores.tolist() 162 | avg_srcc = sum(srcc_list) / len(srcc_list) 163 | avg_pcc = sum(pcc_list) / len(pcc_list) 164 | acc4_5, acc4_10 = compute_acc(gt_scores, pr_scores) 165 | avg_acc5 = sum(acc4_5) / len(acc4_5) 166 | avg_acc10 = sum(acc4_10) / len(acc4_10) 167 | sys.stdout.write('Acc4_5:[%.3f, %.3f, %.3f, %.3f] Acc4_10:[%.3f, %.3f, %.3f, %.3f]\n' % ( 168 | acc4_5[0], acc4_5[1], acc4_5[2], acc4_5[3], acc4_10[0], acc4_10[1], acc4_10[2], acc4_10[3])) 169 | sys.stdout.write('[Avg SRCC: %.3f] [Avg PCC: %.3f] [Acc5: %.3f] [Acc10: %.3f]\n' % ( 170 | avg_srcc, avg_pcc, avg_acc5, avg_acc10)) 171 | return avg_srcc, avg_pcc, avg_acc5, avg_acc10, acc4_5, acc4_10 172 | 173 | 174 | if __name__ == '__main__': 175 | device = torch.device('cuda:{}'.format(cfg.gpu_id)) 176 | torch.cuda.set_device(device) 177 | backbone, reddim = 'vgg16', 32 178 | pretrained_weight = 'pretrained_models/GAIC-{}-reddim{}.pth'.format(backbone, reddim) 179 | model = build_crop_model(scale='multi', alignsize=9, reddim=32, 180 | loadweight=False, model=backbone) 181 | model = model.eval().to(device) 182 | print('load pretrained weights from ', pretrained_weight) 183 | model.load_state_dict(torch.load(pretrained_weight), strict=False) 184 | evaluate_on_GAICD_official(model) 185 | 186 | # roi = torch.tensor([[0, 0, 128, 128], [64, 64, 223, 223]]).float() 187 | # roi = roi.unsqueeze(0).to(device) 188 | # img = torch.randn((1, 3, 256, 256)).to(device) 189 | # flops, params = profile(model, inputs=(img, roi)) 190 | # print("params: %.2fMB flops: %.2fG" % (params / (1000 ** 2), flops / (1000 ** 3))) -------------------------------------------------------------------------------- /networks/GAIC_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.models as models 4 | import torch.nn.functional as F 5 | from untils.roi_align.modules.roi_align import RoIAlignAvg, RoIAlign 6 | from untils.rod_align.modules.rod_align import RoDAlignAvg, RoDAlign 7 | from torchvision.models.mobilenetv2 import mobilenet_v2 as MobileNetV2 8 | from torchvision.models.shufflenetv2 import shufflenet_v2_x1_0 as ShuffleNetV2 9 | import torch.nn.init as init 10 | from thop import profile 11 | import warnings 12 | warnings.filterwarnings('ignore') 13 | 14 | class vgg_base(nn.Module): 15 | def __init__(self, loadweights=True): 16 | super(vgg_base, self).__init__() 17 | 18 | vgg = models.vgg16(pretrained=loadweights) 19 | self.feature3 = nn.Sequential(vgg.features[:23]) 20 | self.feature4 = nn.Sequential(vgg.features[23:30]) 21 | self.feature5 = nn.Sequential(vgg.features[30:]) 22 | 23 | # img = torch.randn((1, 3, 256, 256)) 24 | # flops, params = profile(vgg.features[:-1], inputs=(img,)) 25 | # print("params: %.2fMB flops: %.2fG" % (params / (1000 ** 2), flops / (1000 ** 3))) 26 | # params: 14.71MB flops: 20.06G 27 | 28 | def forward(self, x): 29 | #return self.feature(x) 30 | f3 = self.feature3(x) 31 | f4 = self.feature4(f3) 32 | f5 = self.feature5(f4) 33 | return f3, f4, f5 34 | 35 | class resnet50_base(nn.Module): 36 | def __init__(self, loadweights=True): 37 | super(resnet50_base, self).__init__() 38 | 39 | resnet50 = models.resnet50(pretrained=True) 40 | 41 | self.feature3 = nn.Sequential(resnet50.conv1, resnet50.bn1, 42 | resnet50.relu,resnet50.maxpool, 43 | resnet50.layer1,resnet50.layer2) 44 | self.feature4 = nn.Sequential(resnet50.layer3) 45 | self.feature5 = nn.Sequential(resnet50.layer4) 46 | 47 | #flops, params = profile(self.feature, input_size=(1, 3, 256,256)) 48 | 49 | def forward(self, x): 50 | #return self.feature(x) 51 | f3 = self.feature3(x) 52 | f4 = self.feature4(f3) 53 | f5 = self.feature5(f4) 54 | return f3, f4, f5 55 | 56 | 57 | class mobilenetv2_base(nn.Module): 58 | 59 | def __init__(self, loadweights=True): 60 | super(mobilenetv2_base, self).__init__() 61 | 62 | model = MobileNetV2(pretrained=loadweights) 63 | 64 | self.feature3 = nn.Sequential(model.features[:7]) 65 | self.feature4 = nn.Sequential(model.features[7:14]) 66 | self.feature5 = nn.Sequential(model.features[14:-1]) 67 | #flops, params = profile(self.feature, input_size=(1, 3, 256,256)) 68 | 69 | def forward(self, x): 70 | #return self.feature(x) 71 | f3 = self.feature3(x) 72 | f4 = self.feature4(f3) 73 | f5 = self.feature5(f4) 74 | return f3, f4, f5 75 | 76 | 77 | class shufflenetv2_base(nn.Module): 78 | 79 | def __init__(self, loadweights=True): 80 | super(shufflenetv2_base, self).__init__() 81 | 82 | model = ShuffleNetV2(pretrained=loadweights) 83 | 84 | self.feature3 = nn.Sequential(model.conv1, model.maxpool, model.stage2) 85 | self.feature4 = nn.Sequential(model.stage3) 86 | self.feature5 = nn.Sequential(model.stage4) 87 | #flops, params = profile(self.feature, input_size=(1, 3, 256,256)) 88 | 89 | def forward(self, x): 90 | #return self.feature(x) 91 | f3 = self.feature3(x) 92 | f4 = self.feature4(f3) 93 | f5 = self.feature5(f4) 94 | return f3, f4, f5 95 | 96 | def fc_layers(reddim=32, alignsize=8): 97 | conv1 = nn.Sequential(nn.Conv2d(reddim, 768, kernel_size=alignsize, padding=0), 98 | nn.ReLU(inplace=True)) 99 | conv2 = nn.Sequential(nn.Conv2d(768, 128, kernel_size=1), 100 | nn.ReLU(inplace=True)) 101 | conv3 = nn.Conv2d(128, 1, kernel_size=1) 102 | layers = nn.Sequential(conv1, conv2, conv3) 103 | return layers 104 | 105 | class crop_model_single_scale(nn.Module): 106 | 107 | def __init__(self, alignsize = 8, reddim = 8, loadweight = True, model = None): 108 | super(crop_model_single_scale, self).__init__() 109 | 110 | if model == 'shufflenetv2': 111 | self.Feat_ext = shufflenetv2_base(loadweight) 112 | self.DimRed = nn.Conv2d(232, reddim, kernel_size=1, padding=0) 113 | elif model == 'mobilenetv2': 114 | self.Feat_ext = mobilenetv2_base(loadweight) 115 | self.DimRed = nn.Conv2d(96, reddim, kernel_size=1, padding=0) 116 | elif model == 'vgg16': 117 | self.Feat_ext = vgg_base(loadweight) 118 | self.DimRed = nn.Conv2d(512, reddim, kernel_size=1, padding=0) 119 | elif model == 'resnet50': 120 | self.Feat_ext = resnet50_base(loadweight) 121 | self.DimRed = nn.Conv2d(1024, reddim, kernel_size=1, padding=0) 122 | downsample = 4 123 | self.RoIAlign = RoIAlignAvg(alignsize, alignsize, 1.0/2**downsample) 124 | self.RoDAlign = RoDAlignAvg(alignsize, alignsize, 1.0/2**downsample) 125 | self.FC_layers = fc_layers(reddim*2, alignsize) 126 | 127 | #flops, params = profile(self.FC_layers, input_size=(1,reddim*2,9,9)) 128 | 129 | def forward(self, im_data, boxes): 130 | 131 | f3,base_feat,f5 = self.Feat_ext(im_data) 132 | red_feat = self.DimRed(base_feat) 133 | RoI_feat = self.RoIAlign(red_feat, boxes) 134 | RoD_feat = self.RoDAlign(red_feat, boxes) 135 | final_feat = torch.cat((RoI_feat, RoD_feat), 1) 136 | prediction = self.FC_layers(final_feat) 137 | return prediction 138 | 139 | def _init_weights(self): 140 | print('Initializing weights...') 141 | self.DimRed.apply(weights_init) 142 | self.FC_layers.apply(weights_init) 143 | 144 | class crop_model_multi_scale_shared(nn.Module): 145 | 146 | def __init__(self, alignsize = 8, reddim = 32, loadweight = True, model = None, ): 147 | super(crop_model_multi_scale_shared, self).__init__() 148 | downsample = 4 149 | if model == 'shufflenetv2': 150 | self.Feat_ext = shufflenetv2_base(loadweight) 151 | self.DimRed = nn.Conv2d(812, reddim, kernel_size=1, padding=0) 152 | elif model == 'mobilenetv2': 153 | self.Feat_ext = mobilenetv2_base(loadweight) 154 | self.DimRed = nn.Conv2d(448, reddim, kernel_size=1, padding=0) 155 | elif model == 'vgg16': 156 | self.Feat_ext = vgg_base(loadweight) 157 | self.DimRed = nn.Conv2d(1536, reddim, kernel_size=1, padding=0) 158 | elif model == 'resnet50': 159 | self.Feat_ext = resnet50_base(loadweight) 160 | self.DimRed = nn.Conv2d(3584, reddim, kernel_size=1, padding=0) 161 | 162 | self.downsample2 = nn.UpsamplingBilinear2d(scale_factor=1.0/2.0) 163 | self.upsample2 = nn.UpsamplingBilinear2d(scale_factor=2.0) 164 | self.RoIAlign = RoIAlignAvg(alignsize, alignsize, 1.0/2**downsample) 165 | self.RoDAlign = RoDAlignAvg(alignsize, alignsize, 1.0/2**downsample) 166 | self.FC_layers = fc_layers(reddim*2, alignsize) 167 | 168 | def forward(self, im_data, boxes): 169 | # print(im_data.shape, im_data.dtype, im_data.device, boxes.shape, boxes.dtype, boxes.device) 170 | B, N, _ = boxes.shape 171 | if boxes.shape[-1] == 4: 172 | index = torch.arange(B).view(-1, 1).repeat(1, N).reshape(B, N, 1).to(boxes.device) 173 | boxes = torch.cat((index, boxes),dim=-1).contiguous() 174 | if boxes.dim() == 3: 175 | boxes = boxes.view(-1,5) 176 | 177 | f3,f4,f5 = self.Feat_ext(im_data) 178 | f3 = F.interpolate(f3, size=f4.shape[2:], mode='bilinear', align_corners=True) 179 | f5 = F.interpolate(f5, size=f4.shape[2:], mode='bilinear', align_corners=True) 180 | cat_feat = torch.cat((f3,f4,0.5*f5),1) 181 | 182 | red_feat = self.DimRed(cat_feat) 183 | RoI_feat = self.RoIAlign(red_feat, boxes) 184 | RoD_feat = self.RoDAlign(red_feat, boxes) 185 | 186 | final_feat = torch.cat((RoI_feat, RoD_feat), 1) 187 | prediction = self.FC_layers(final_feat) 188 | return prediction 189 | 190 | def _init_weights(self): 191 | print('Initializing weights...') 192 | self.DimRed.apply(weights_init) 193 | self.FC_layers.apply(weights_init) 194 | 195 | def cropping_rank_loss(pre_score, gt_score): 196 | ''' 197 | :param pre_score: 198 | :param gt_score: 199 | :return: 200 | ''' 201 | if pre_score.dim() > 1: 202 | pre_score = pre_score.reshape(-1) 203 | if gt_score.dim() > 1: 204 | gt_score = gt_score.reshape(-1) 205 | assert pre_score.shape == gt_score.shape, '{} vs. {}'.format(pre_score.shape, gt_score.shape) 206 | N = pre_score.shape[0] 207 | pair_num = N * (N-1) / 2 208 | pre_diff = pre_score[:,None] - pre_score[None,:] 209 | gt_diff = gt_score[:,None] - gt_score[None,:] 210 | indicat = -1 * torch.sign(gt_diff) * (pre_diff - gt_diff) 211 | diff = torch.maximum(indicat, torch.zeros_like(indicat)) 212 | rank_loss= torch.sum(diff) / pair_num 213 | return rank_loss 214 | 215 | def xavier(param): 216 | init.xavier_uniform_(param) 217 | 218 | def weights_init(m): 219 | if isinstance(m, nn.Conv2d): 220 | xavier(m.weight.data) 221 | m.bias.data.zero_() 222 | 223 | 224 | def build_crop_model(scale='single', alignsize=9, reddim=32, loadweight=True, model=None): 225 | if scale=='single': 226 | return crop_model_single_scale(alignsize, reddim, loadweight, model) 227 | elif scale=='multi': 228 | return crop_model_multi_scale_shared(alignsize, reddim, loadweight, model) 229 | 230 | 231 | if __name__ == '__main__': 232 | net = build_crop_model(scale='multi', alignsize=9, 233 | reddim=32, loadweight=True, 234 | model='vgg16') 235 | net = net.eval().cuda() 236 | roi = torch.tensor([[0, 0, 128, 128], [64, 64, 223, 223]]).float() 237 | roi = roi.unsqueeze(0).cuda() 238 | roi = roi.repeat(2,1,1) 239 | img = torch.randn((2, 3, 224, 224)).cuda() 240 | out = net(img, roi) 241 | print(out.shape, out) 242 | 243 | 244 | 245 | -------------------------------------------------------------------------------- /networks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bo-zhang-cs/GAIC-Pytorch/877d6a325943c0a098324dc6d4a975ff4097c3f6/networks/__init__.py -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==3.2.2 2 | numpy==1.19.1 3 | opencv_python==4.5.4.60 4 | Pillow==9.0.1 5 | PyYAML==6.0 6 | scipy==1.5.2 7 | setuptools==52.0.0.post20210125 8 | tensorboardX==2.4.1 9 | thop==0.0.31.post2005241907 10 | torch==1.9.0+cu111 11 | torchvision==0.10.0+cu111 12 | tqdm==4.51.0 13 | -------------------------------------------------------------------------------- /result_images/211958.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bo-zhang-cs/GAIC-Pytorch/877d6a325943c0a098324dc6d4a975ff4097c3f6/result_images/211958.jpg -------------------------------------------------------------------------------- /result_images/265813.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bo-zhang-cs/GAIC-Pytorch/877d6a325943c0a098324dc6d4a975ff4097c3f6/result_images/265813.jpg -------------------------------------------------------------------------------- /result_images/297406.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bo-zhang-cs/GAIC-Pytorch/877d6a325943c0a098324dc6d4a975ff4097c3f6/result_images/297406.jpg -------------------------------------------------------------------------------- /test_images/211958.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bo-zhang-cs/GAIC-Pytorch/877d6a325943c0a098324dc6d4a975ff4097c3f6/test_images/211958.jpg -------------------------------------------------------------------------------- /test_images/265813.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bo-zhang-cs/GAIC-Pytorch/877d6a325943c0a098324dc6d4a975ff4097c3f6/test_images/265813.jpg -------------------------------------------------------------------------------- /test_images/297406.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bo-zhang-cs/GAIC-Pytorch/877d6a325943c0a098324dc6d4a975ff4097c3f6/test_images/297406.jpg -------------------------------------------------------------------------------- /train/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 4 | from tensorboardX import SummaryWriter 5 | import torch 6 | import time 7 | import datetime 8 | import csv 9 | import random 10 | import shutil 11 | from networks.GAIC_model import build_crop_model 12 | from dataset.cropping_dataset import GAICDataset 13 | from config.GAIC_config import cfg, refresh_yaml_params 14 | from evaluate.test import evaluate_on_GAICD_official as evaluate_on_GAICD 15 | import argparse 16 | import warnings 17 | warnings.filterwarnings('ignore') 18 | 19 | def create_dataloader(): 20 | dataset = GAICDataset(split='train') 21 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=cfg.batch_size, 22 | shuffle=True, num_workers=cfg.num_workers, 23 | drop_last=False, pin_memory=False) 24 | print('training set has {} samples, {} batches'.format(len(dataset), len(dataloader))) 25 | return dataloader 26 | 27 | 28 | class Trainer: 29 | def __init__(self, model): 30 | self.model = model 31 | self.epoch = 0 32 | self.iters = 0 33 | self.max_epoch = cfg.max_epoch 34 | self.writer = SummaryWriter(log_dir=cfg.log_dir) 35 | self.optimizer, self.lr_scheduler = self.get_optimizer() 36 | self.train_loader = create_dataloader() 37 | self.eval_results = [] 38 | self.best_results = {'acc1_5':0., 'acc2_5':0., 'acc3_5':0., 'acc4_5':0., 'acc5':0., 39 | 'acc1_10':0., 'acc2_10':0., 'acc3_10':0, 'acc4_10':0, 'acc10':0., 40 | 'srcc':0., 'pcc':0.} 41 | self.criterion = torch.nn.SmoothL1Loss(reduction='mean') 42 | self.contain_BN = False 43 | for name,m in self.model.Feat_ext.named_modules(): 44 | if isinstance(m, torch.nn.BatchNorm2d): 45 | self.contain_BN = True 46 | break 47 | 48 | def get_optimizer(self): 49 | optim = torch.optim.Adam( 50 | self.model.parameters(), 51 | lr=cfg.lr 52 | ) 53 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( 54 | optim, milestones=cfg.lr_decay_epoch, gamma=cfg.lr_decay 55 | ) 56 | return optim, lr_scheduler 57 | 58 | def run(self): 59 | print(("======== Begin Training =========")) 60 | for epoch in range(self.max_epoch): 61 | self.epoch = epoch 62 | self.train() 63 | if epoch % cfg.eval_freq == 0: 64 | self.eval() 65 | self.record_eval_results() 66 | self.lr_scheduler.step() 67 | 68 | def train(self): 69 | self.model.train() 70 | if self.contain_BN: 71 | self.model.Feat_ext.eval() 72 | device = next(self.model.parameters()).device 73 | start = time.time() 74 | batch_loss = 0 75 | total_batch = len(self.train_loader) 76 | total_loss = 0 77 | for batch_idx, batch_data in enumerate(self.train_loader): 78 | self.iters += 1 79 | im = batch_data[0].to(device) 80 | rois = batch_data[1].to(device) 81 | scores = batch_data[2].to(device) 82 | width = batch_data[3].to(device) 83 | height = batch_data[4].to(device) 84 | 85 | random_ID = list(range(0, rois.shape[1])) 86 | random.shuffle(random_ID) 87 | chosen_ID = random_ID[:64] 88 | 89 | rois = rois[:,chosen_ID] 90 | scores = scores[:, chosen_ID] 91 | 92 | pred_scores = self.model(im, rois) 93 | loss = self.criterion(pred_scores.squeeze(), scores.squeeze()) 94 | batch_loss += loss.item() 95 | self.optimizer.zero_grad() 96 | loss.backward() 97 | self.optimizer.step() 98 | 99 | if batch_idx % cfg.display_freq == 0: 100 | avg_loss = batch_loss / (1 + batch_idx) 101 | cur_lr = self.optimizer.param_groups[0]['lr'] 102 | self.writer.add_scalar('train/loss', avg_loss, self.iters) 103 | self.writer.add_scalar('train/lr', cur_lr, self.iters) 104 | 105 | time_per_batch = (time.time() - start) / (batch_idx + 1.) 106 | last_batches = (self.max_epoch - self.epoch - 1) * total_batch + (total_batch - batch_idx - 1) 107 | last_time = int(last_batches * time_per_batch) 108 | time_str = str(datetime.timedelta(seconds=last_time)) 109 | 110 | print('=== epoch:{}/{}, step:{}/{} | Loss:{:.4f} | lr:{:.6f} | estimated remaining time:{} ==='.format( 111 | self.epoch, self.max_epoch, batch_idx, total_batch, avg_loss, cur_lr, time_str 112 | )) 113 | 114 | def eval(self): 115 | self.model.eval() 116 | avg_srcc, avg_pcc, avg_acc5, avg_acc10, acc4_5, acc4_10 = evaluate_on_GAICD(self.model) 117 | self.eval_results.append([self.epoch, avg_srcc, avg_pcc, avg_acc5, avg_acc10, 118 | acc4_5[0], acc4_5[1], acc4_5[2], acc4_5[3], 119 | acc4_10[0], acc4_10[1], acc4_10[2], acc4_10[3]]) 120 | epoch_result = {'srcc': avg_srcc, 'pcc': avg_pcc, 'acc5': avg_acc5, 'acc10': avg_acc10, 121 | 'acc1_5': acc4_5[0], 'acc2_5': acc4_5[1], 'acc3_5': acc4_5[2], 'acc4_5': acc4_5[3], 122 | 'acc1_10': acc4_10[0], 'acc2_10': acc4_10[1], 'acc3_10': acc4_10[2], 'acc4_10': acc4_10[3]} 123 | for m in epoch_result.keys(): 124 | update = False 125 | if (epoch_result[m] > self.best_results[m]): 126 | update = True 127 | if update: 128 | self.best_results[m] = epoch_result[m] 129 | checkpoint_path = os.path.join(cfg.checkpoint_dir, 'best-{}.pth'.format(m)) 130 | torch.save(self.model.state_dict(), checkpoint_path) 131 | print('Update best {} model, best {}={:.4f}'.format(m, m, self.best_results[m])) 132 | if m in ['srcc', 'acc5']: 133 | self.writer.add_scalar('test/{}'.format(m), epoch_result[m], self.epoch) 134 | self.writer.add_scalar('test/best-{}'.format(m), self.best_results[m], self.epoch) 135 | if self.epoch % cfg.save_freq == 0: 136 | checkpoint_path = os.path.join(cfg.checkpoint_dir, 'epoch-{}.pth'.format(self.epoch)) 137 | torch.save(self.model.state_dict(), checkpoint_path) 138 | 139 | def record_eval_results(self): 140 | csv_path = os.path.join(cfg.exp_path, '..', '{}.csv'.format(cfg.exp_name)) 141 | header = ['epoch', 'srcc', 'pcc', 'acc5', 'acc10', 142 | 'acc1_5', 'acc2_5', 'acc3_5', 'acc4_5', 143 | 'acc1_10', 'acc2_10', 'acc3_10', 'acc4_10'] 144 | # Limit the number of decimal places in the result 145 | limit_results = [] 146 | for epoch, result in enumerate(self.eval_results): 147 | limit_results.append([]) 148 | for i,r in enumerate(result): 149 | if i == 0: # epoch 150 | limit_results[epoch].append(r) 151 | else: 152 | limit_results[epoch].append(round(r, 3)) 153 | # find the best results 154 | rows = [header] + limit_results 155 | metrics = [[] for i in header] 156 | for result in limit_results: 157 | for i, r in enumerate(result): 158 | metrics[i].append(r) 159 | for name, m in zip(header, metrics): 160 | if name == 'epoch': 161 | continue 162 | index = m.index(max(m)) 163 | title = 'best {}(epoch-{})'.format(name, index) 164 | row = [l[index] for l in metrics] 165 | row[0] = title 166 | rows.append(row) 167 | with open(csv_path, 'w') as f: 168 | cw = csv.writer(f) 169 | cw.writerows(rows) 170 | print('Save result to ', csv_path) 171 | 172 | def parse_args(): 173 | parser = argparse.ArgumentParser(description="Train a GAIC model") 174 | parser.add_argument('--gpu', type=int, dest='gpu_id', 175 | help='gpu_id', default=0) 176 | parser.add_argument('--backbone', type=str, choices=['vgg16', 'mobilenetv2', 'resnet50', 'shufflenetv2'], 177 | help='the architecture of backbone network', default='vgg16') 178 | parser.add_argument('--reddim', type=int, choices=[64, 32, 16, 8], 179 | help='the reduced channel dimension of the feature map', default=32) 180 | parser.add_argument('--alignsize', type=int, choices=[3, 5, 9], 181 | help='RoIAlign and RoDAlign output size', default=9) 182 | parser.add_argument('--num_workers', type=int, help='number of dataloader workers', default=8) 183 | args = parser.parse_args() 184 | refresh_yaml_params(args) 185 | 186 | 187 | if __name__ == '__main__': 188 | parse_args() 189 | cfg.refresh_params() 190 | cfg.create_path() 191 | device = torch.device('cuda:{}'.format(cfg.gpu_id)) 192 | torch.cuda.set_device(device) 193 | for file in ['config/GAIC_config.py', 'config/GAIC_params.yaml', 'dataset/cropping_dataset.py', 194 | 'evaluate/test.py', 'networks/GAIC_model.py', 'train/train.py']: 195 | if not os.path.exists(file): 196 | file = os.path.join('..', file) 197 | shutil.copy(file, cfg.code_dir) 198 | print('backup', file) 199 | net = build_crop_model(scale='multi', alignsize=cfg.alignsize, reddim=cfg.reddim, 200 | loadweight=True, model=cfg.backbone) 201 | net = net.to(device) 202 | trainer = Trainer(net) 203 | trainer.run() 204 | -------------------------------------------------------------------------------- /untils/make_all.sh: -------------------------------------------------------------------------------- 1 | cd ./roi_align 2 | bash make.sh 3 | 4 | cd ../rod_align 5 | bash make.sh 6 | 7 | cd .. 8 | -------------------------------------------------------------------------------- /untils/rod_align/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bo-zhang-cs/GAIC-Pytorch/877d6a325943c0a098324dc6d4a975ff4097c3f6/untils/rod_align/__init__.py -------------------------------------------------------------------------------- /untils/rod_align/functions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bo-zhang-cs/GAIC-Pytorch/877d6a325943c0a098324dc6d4a975ff4097c3f6/untils/rod_align/functions/__init__.py -------------------------------------------------------------------------------- /untils/rod_align/functions/rod_align.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | import rod_align_api 4 | 5 | class RoDAlignFunction(Function): 6 | @staticmethod 7 | def forward(ctx, features, rois, aligned_width, aligned_height, spatial_scale): 8 | batch_size, num_channels, data_height, data_width = features.size() 9 | ctx.save_for_backward(rois, 10 | torch.IntTensor([int(batch_size), 11 | int(num_channels), 12 | int(data_height), 13 | int(data_width), 14 | int(aligned_width), 15 | int(aligned_height)]), 16 | torch.FloatTensor([float(spatial_scale)])) 17 | 18 | num_rois = rois.size(0) 19 | 20 | output = features.new(num_rois, 21 | num_channels, 22 | int(aligned_height), 23 | int(aligned_width)).zero_() 24 | 25 | rod_align_api.forward(int(aligned_height), 26 | int(aligned_width), 27 | float(spatial_scale), 28 | features, 29 | rois, output) 30 | 31 | return output 32 | 33 | @staticmethod 34 | def backward(ctx, grad_output): 35 | rois, core_size, scale = ctx.saved_tensors 36 | spatial_scale = scale[0] 37 | 38 | batch_size, num_channels, data_height, data_width, aligned_width, aligned_height = core_size 39 | 40 | grad_input = rois.new(batch_size, 41 | num_channels, 42 | data_height, 43 | data_width).zero_() 44 | 45 | rod_align_api.backward(int(aligned_height), 46 | int(aligned_width), 47 | float(spatial_scale), 48 | grad_output, 49 | rois, 50 | grad_input) 51 | 52 | return grad_input, None, None, None, None 53 | -------------------------------------------------------------------------------- /untils/rod_align/make.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | cd src 3 | echo "Compiling rod_align kernels by nvcc..." 4 | 5 | # Specify the architecture of your NV card below. 6 | # -arch=sm_75 is compatible with the following NV GPU cards, 7 | # GeForce RTX 2080 Ti, RTX 2080, RTX 2070 Quadro RTX 8000, Quadro RTX 6000, Quadro RTX 5000 Tesla T4 8 | # See more at https://raw.githubusercontent.com/stereolabs/zed-yolo/master/libdarknet/Makefile 9 | nvcc -c -o rod_align_kernel.cu.o rod_align_kernel.cu -x cu -Xcompiler -fPIC -arch=sm_86 10 | 11 | cd ../ 12 | # Export CUDA_HOME. And build and install the library. 13 | export CUDA_HOME=/usr/local/cuda && python3 setup.py install 14 | 15 | -------------------------------------------------------------------------------- /untils/rod_align/make_python2.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | cd src 3 | echo "Compiling rod_align kernels by nvcc..." 4 | 5 | # Specify the architecture of your NV card below. 6 | # -arch=sm_75 is compatible with the following NV GPU cards, 7 | # GeForce RTX 2080 Ti, RTX 2080, RTX 2070 Quadro RTX 8000, Quadro RTX 6000, Quadro RTX 5000 Tesla T4 8 | # See more at https://raw.githubusercontent.com/stereolabs/zed-yolo/master/libdarknet/Makefile 9 | nvcc -c -o rod_align_kernel.cu.o rod_align_kernel.cu -x cu -Xcompiler -fPIC -arch=sm_75 10 | 11 | cd ../ 12 | # Export CUDA_HOME. And build and install the library. 13 | export CUDA_HOME=/usr/local/cuda-10.0 && python setup.py install 14 | 15 | -------------------------------------------------------------------------------- /untils/rod_align/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bo-zhang-cs/GAIC-Pytorch/877d6a325943c0a098324dc6d4a975ff4097c3f6/untils/rod_align/modules/__init__.py -------------------------------------------------------------------------------- /untils/rod_align/modules/rod_align.py: -------------------------------------------------------------------------------- 1 | from torch.nn.modules.module import Module 2 | from torch.nn.functional import avg_pool2d, max_pool2d 3 | from ..functions.rod_align import RoDAlignFunction 4 | 5 | 6 | class RoDAlign(Module): 7 | def __init__(self, aligned_height, aligned_width, spatial_scale): 8 | super(RoDAlign, self).__init__() 9 | 10 | self.aligned_width = int(aligned_width) 11 | self.aligned_height = int(aligned_height) 12 | self.spatial_scale = float(spatial_scale) 13 | 14 | def forward(self, features, rois): 15 | return RoDAlignFunction.apply(features, 16 | rois, 17 | self.aligned_height, 18 | self.aligned_width, 19 | self.spatial_scale) 20 | 21 | class RoDAlignAvg(Module): 22 | def __init__(self, aligned_height, aligned_width, spatial_scale): 23 | super(RoDAlignAvg, self).__init__() 24 | 25 | self.aligned_width = int(aligned_width) 26 | self.aligned_height = int(aligned_height) 27 | self.spatial_scale = float(spatial_scale) 28 | 29 | def forward(self, features, rois): 30 | x = RoDAlignFunction.apply(features, 31 | rois, 32 | self.aligned_height+1, 33 | self.aligned_width+1, 34 | self.spatial_scale) 35 | return avg_pool2d(x, kernel_size=2, stride=1) 36 | 37 | class RoDAlignMax(Module): 38 | def __init__(self, aligned_height, aligned_width, spatial_scale): 39 | super(RoDAlignMax, self).__init__() 40 | 41 | self.aligned_width = int(aligned_width) 42 | self.aligned_height = int(aligned_height) 43 | self.spatial_scale = float(spatial_scale) 44 | 45 | def forward(self, features, rois): 46 | x = RoDAlignFunction.apply(features, 47 | rois, 48 | self.aligned_height+1, 49 | self.aligned_width+1, 50 | self.spatial_scale) 51 | return max_pool2d(x, kernel_size=2, stride=1) 52 | -------------------------------------------------------------------------------- /untils/rod_align/setup.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import torch 4 | from pkg_resources import parse_version 5 | 6 | min_version = parse_version('1.0.0') 7 | current_version = parse_version(torch.__version__) 8 | 9 | 10 | if current_version < min_version: #PyTorch before 1.0 11 | from torch.utils.ffi import create_extension 12 | 13 | sources = ['src/roi_align.c'] 14 | headers = ['src/roi_align.h'] 15 | extra_objects = [] 16 | 17 | defines = [] 18 | with_cuda = False 19 | 20 | this_file = os.path.dirname(os.path.realpath(__file__)) 21 | print(this_file) 22 | 23 | if torch.cuda.is_available(): 24 | print('Including CUDA code.') 25 | sources += ['src/rod_align_cuda.c'] 26 | headers += ['src/rod_align_cuda.h'] 27 | defines += [('WITH_CUDA', None)] 28 | with_cuda = True 29 | 30 | extra_objects = ['src/rod_align_kernel.cu.o'] 31 | extra_objects = [os.path.join(this_file, fname) for fname in extra_objects] 32 | 33 | ffi = create_extension( 34 | '_ext.rod_align', 35 | headers=headers, 36 | sources=sources, 37 | define_macros=defines, 38 | relative_to=__file__, 39 | with_cuda=with_cuda, 40 | extra_objects=extra_objects 41 | ) 42 | 43 | if __name__ == '__main__': 44 | ffi.build() 45 | else: # PyTorch 1.0 or later 46 | from setuptools import setup 47 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 48 | 49 | print('Including CUDA code.') 50 | 51 | current_dir = os.path.dirname(os.path.realpath(__file__)) 52 | 53 | setup( 54 | name='rod_align_api', 55 | ext_modules=[ 56 | CUDAExtension( 57 | name='rod_align_api', 58 | sources=['src/rod_align_cuda.cpp', 'src/rod_align_kernel.cu'], 59 | include_dirs=[current_dir]+torch.utils.cpp_extension.include_paths(cuda=True) 60 | ) 61 | ], 62 | cmdclass={ 63 | 'build_ext': BuildExtension 64 | }) 65 | 66 | -------------------------------------------------------------------------------- /untils/rod_align/src/rod_align.cpp: -------------------------------------------------------------------------------- 1 | #include "rod_align.h" 2 | 3 | void RODAlignForwardCpu(const float* bottom_data, const float spatial_scale, const int num_rois, 4 | const int height, const int width, const int channels, 5 | const int aligned_height, const int aligned_width, const float * bottom_rois, 6 | float* top_data); 7 | 8 | void RODAlignBackwardCpu(const float* top_diff, const float spatial_scale, const int num_rois, 9 | const int height, const int width, const int channels, 10 | const int aligned_height, const int aligned_width, const float * bottom_rois, 11 | float* top_data); 12 | 13 | int rod_align_forward(int aligned_height, int aligned_width, float spatial_scale, 14 | torch::Tensor features, torch::Tensor rois, torch::Tensor output) 15 | { 16 | //Grab the input tensor 17 | //float * data_flat = THFloatTensor_data(features); 18 | //float * rois_flat = THFloatTensor_data(rois); 19 | auto data_flat = features.data(); 20 | auto rois_flat = rois.data(); 21 | 22 | //float * output_flat = THFloatTensor_data(output); 23 | auto output_flat = output.data(); 24 | 25 | // Number of ROIs 26 | //int num_rois = THFloatTensor_size(rois, 0); 27 | //int size_rois = THFloatTensor_size(rois, 1); 28 | auto rois_sz = rois.sizes(); 29 | int num_rois = rois_sz[0]; 30 | int size_rois = rois_sz[1]; 31 | 32 | if (size_rois != 5) 33 | { 34 | return 0; 35 | } 36 | 37 | // data height 38 | //int data_height = THFloatTensor_size(features, 2); 39 | // data width 40 | //int data_width = THFloatTensor_size(features, 3); 41 | // Number of channels 42 | //int num_channels = THFloatTensor_size(features, 1); 43 | auto feat_sz = features.sizes(); 44 | int data_height = feat_sz[2]; 45 | int data_width = feat_sz[3]; 46 | int num_channels = feat_sz[1]; 47 | 48 | // do ROIAlignForward 49 | RODAlignForwardCpu(data_flat, spatial_scale, num_rois, data_height, data_width, num_channels, 50 | aligned_height, aligned_width, rois_flat, output_flat); 51 | 52 | return 1; 53 | } 54 | 55 | int rod_align_backward(int aligned_height, int aligned_width, float spatial_scale, 56 | torch::Tensor top_grad, torch::Tensor rois, torch::Tensor bottom_grad) 57 | { 58 | //Grab the input tensor 59 | //float * top_grad_flat = THFloatTensor_data(top_grad); 60 | //float * rois_flat = THFloatTensor_data(rois); 61 | 62 | //float * bottom_grad_flat = THFloatTensor_data(bottom_grad); 63 | 64 | auto top_grad_flat = top_grad.data(); 65 | auto rois_flat = rois.data(); 66 | auto bottom_grad_flat = bottom_grad.data(); 67 | 68 | 69 | // Number of ROIs 70 | //int num_rois = THFloatTensor_size(rois, 0); 71 | //int size_rois = THFloatTensor_size(rois, 1); 72 | 73 | auto rois_sz = rois.sizes(); 74 | int num_rois = rois_sz[0]; 75 | int size_rois = rois_sz[1]; 76 | 77 | if (size_rois != 5) 78 | { 79 | return 0; 80 | } 81 | 82 | // batch size 83 | // int batch_size = THFloatTensor_size(bottom_grad, 0); 84 | // data height 85 | //int data_height = THFloatTensor_size(bottom_grad, 2); 86 | // data width 87 | //int data_width = THFloatTensor_size(bottom_grad, 3); 88 | // Number of channels 89 | //int num_channels = THFloatTensor_size(bottom_grad, 1); 90 | auto grad_sz = bottom_grad.sizes(); 91 | int data_height = grad_sz[2]; 92 | int data_width = grad_sz[3]; 93 | int num_channels = grad_sz[1]; 94 | 95 | // do ROIAlignBackward 96 | RODAlignBackwardCpu(top_grad_flat, spatial_scale, num_rois, data_height, 97 | data_width, num_channels, aligned_height, aligned_width, rois_flat, bottom_grad_flat); 98 | 99 | return 1; 100 | } 101 | 102 | void RODAlignForwardCpu(const float* bottom_data, const float spatial_scale, const int num_rois, 103 | const int height, const int width, const int channels, 104 | const int aligned_height, const int aligned_width, const float * bottom_rois, 105 | float* top_data) 106 | { 107 | const int output_size = num_rois * aligned_height * aligned_width * channels; 108 | 109 | int idx = 0; 110 | float bin_size_h = (float)(height - 1.001) / (aligned_height - 1.); 111 | float bin_size_w = (float)(width - 1.001) / (aligned_width - 1.); 112 | for (idx = 0; idx < output_size; ++idx) 113 | { 114 | // (n, c, ph, pw) is an element in the aligned output 115 | int pw = idx % aligned_width; 116 | int ph = (idx / aligned_width) % aligned_height; 117 | int c = (idx / aligned_width / aligned_height) % channels; 118 | int n = idx / aligned_width / aligned_height / channels; 119 | 120 | float roi_batch_ind = bottom_rois[n * 5 + 0]; 121 | float roi_start_w = bottom_rois[n * 5 + 1] * spatial_scale; 122 | float roi_start_h = bottom_rois[n * 5 + 2] * spatial_scale; 123 | float roi_end_w = bottom_rois[n * 5 + 3] * spatial_scale; 124 | float roi_end_h = bottom_rois[n * 5 + 4] * spatial_scale; 125 | 126 | 127 | float h = (float)(ph) * bin_size_h; 128 | float w = (float)(pw) * bin_size_w; 129 | 130 | int hstart = fminf(floor(h), height - 2); 131 | int wstart = fminf(floor(w), width - 2); 132 | 133 | int img_start = roi_batch_ind * channels * height * width; 134 | 135 | // bilinear interpolation 136 | if (h >= roi_start_h && h <= roi_end_h && w >= roi_start_w && w <= roi_end_w){ 137 | top_data[idx] = 0.; 138 | } else { 139 | float h_ratio = h - (float)(hstart); 140 | float w_ratio = w - (float)(wstart); 141 | int upleft = img_start + (c * height + hstart) * width + wstart; 142 | int upright = upleft + 1; 143 | int downleft = upleft + width; 144 | int downright = downleft + 1; 145 | 146 | top_data[idx] = bottom_data[upleft] * (1. - h_ratio) * (1. - w_ratio) 147 | + bottom_data[upright] * (1. - h_ratio) * w_ratio 148 | + bottom_data[downleft] * h_ratio * (1. - w_ratio) 149 | + bottom_data[downright] * h_ratio * w_ratio; 150 | } 151 | } 152 | } 153 | 154 | void RODAlignBackwardCpu(const float* top_diff, const float spatial_scale, const int num_rois, 155 | const int height, const int width, const int channels, 156 | const int aligned_height, const int aligned_width, const float * bottom_rois, 157 | float* bottom_diff) 158 | { 159 | const int output_size = num_rois * aligned_height * aligned_width * channels; 160 | 161 | int idx = 0; 162 | float bin_size_h = (float)(height - 1.001) / (aligned_height - 1.); 163 | float bin_size_w = (float)(width - 1.001) / (aligned_width - 1.); 164 | for (idx = 0; idx < output_size; ++idx) 165 | { 166 | // (n, c, ph, pw) is an element in the aligned output 167 | int pw = idx % aligned_width; 168 | int ph = (idx / aligned_width) % aligned_height; 169 | int c = (idx / aligned_width / aligned_height) % channels; 170 | int n = idx / aligned_width / aligned_height / channels; 171 | 172 | float roi_batch_ind = bottom_rois[n * 5 + 0]; 173 | float roi_start_w = bottom_rois[n * 5 + 1] * spatial_scale; 174 | float roi_start_h = bottom_rois[n * 5 + 2] * spatial_scale; 175 | float roi_end_w = bottom_rois[n * 5 + 3] * spatial_scale; 176 | float roi_end_h = bottom_rois[n * 5 + 4] * spatial_scale; 177 | 178 | float h = (float)(ph) * bin_size_h; 179 | float w = (float)(pw) * bin_size_w; 180 | 181 | int hstart = fminf(floor(h), height - 2); 182 | int wstart = fminf(floor(w), width - 2); 183 | 184 | int img_start = roi_batch_ind * channels * height * width; 185 | 186 | // bilinear interpolation 187 | if (!(h >= roi_start_h && h <= roi_end_h && w >= roi_start_w && w <= roi_end_w)) { 188 | float h_ratio = h - (float)(hstart); 189 | float w_ratio = w - (float)(wstart); 190 | int upleft = img_start + (c * height + hstart) * width + wstart; 191 | int upright = upleft + 1; 192 | int downleft = upleft + width; 193 | int downright = downleft + 1; 194 | 195 | bottom_diff[upleft] += top_diff[idx] * (1. - h_ratio) * (1. - w_ratio); 196 | bottom_diff[upright] += top_diff[idx] * (1. - h_ratio) * w_ratio; 197 | bottom_diff[downleft] += top_diff[idx] * h_ratio * (1. - w_ratio); 198 | bottom_diff[downright] += top_diff[idx] * h_ratio * w_ratio; 199 | } 200 | } 201 | } 202 | 203 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 204 | m.def("forward", &rod_align_forward, "rod_align forward"); 205 | m.def("backward", &rod_align_backward, "rod_align backward"); 206 | } 207 | -------------------------------------------------------------------------------- /untils/rod_align/src/rod_align.h: -------------------------------------------------------------------------------- 1 | #ifndef ROD_ALIGN_H 2 | #define ROD_ALIGN_H 3 | 4 | #include 5 | 6 | int rod_align_forward(int aligned_height, int aligned_width, float spatial_scale, 7 | torch::Tensor features, torch::Tensor rois, torch::Tensor output); 8 | 9 | int rod_align_backward(int aligned_height, int aligned_width, float spatial_scale, 10 | torch::Tensor top_grad, torch::Tensor rois, torch::Tensor bottom_grad); 11 | 12 | #endif 13 | -------------------------------------------------------------------------------- /untils/rod_align/src/rod_align_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "rod_align_kernel.h" 4 | #include "rod_align_cuda.h" 5 | 6 | 7 | int rod_align_forward_cuda(int aligned_height, int aligned_width, float spatial_scale, 8 | torch::Tensor features, torch::Tensor rois, torch::Tensor output) 9 | { 10 | // Grab the input tensor 11 | //float * data_flat = THCudaTensor_data(state, features); 12 | //float * rois_flat = THCudaTensor_data(state, rois); 13 | 14 | //float * output_flat = THCudaTensor_data(state, output); 15 | 16 | auto data_flat = features.data(); 17 | auto rois_flat = rois.data(); 18 | auto output_flat = output.data(); 19 | 20 | // Number of ROIs 21 | //int num_rois = THCudaTensor_size(state, rois, 0); 22 | //int size_rois = THCudaTensor_size(state, rois, 1); 23 | 24 | auto rois_sz = rois.sizes(); 25 | int num_rois = rois_sz[0]; 26 | int size_rois = rois_sz[1]; 27 | 28 | if (size_rois != 5) 29 | { 30 | return 0; 31 | } 32 | 33 | // data height 34 | //int data_height = THCudaTensor_size(state, features, 2); 35 | // data width 36 | //int data_width = THCudaTensor_size(state, features, 3); 37 | // Number of channels 38 | //int num_channels = THCudaTensor_size(state, features, 1); 39 | auto feat_sz = features.sizes(); 40 | int data_height = feat_sz[2]; 41 | int data_width = feat_sz[3]; 42 | int num_channels = feat_sz[1]; 43 | 44 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 45 | 46 | RODAlignForwardLaucher( 47 | data_flat, spatial_scale, num_rois, data_height, 48 | data_width, num_channels, aligned_height, 49 | aligned_width, rois_flat, 50 | output_flat, stream); 51 | 52 | return 1; 53 | } 54 | 55 | int rod_align_backward_cuda(int aligned_height, int aligned_width, float spatial_scale, 56 | torch::Tensor top_grad, torch::Tensor rois, torch::Tensor bottom_grad) 57 | { 58 | // Grab the input tensor 59 | //float * top_grad_flat = THCudaTensor_data(state, top_grad); 60 | //float * rois_flat = THCudaTensor_data(state, rois); 61 | 62 | //float * bottom_grad_flat = THCudaTensor_data(state, bottom_grad); 63 | auto top_grad_flat = top_grad.data(); 64 | auto rois_flat = rois.data(); 65 | auto bottom_grad_flat = bottom_grad.data(); 66 | 67 | // Number of ROIs 68 | //int num_rois = THCudaTensor_size(state, rois, 0); 69 | //int size_rois = THCudaTensor_size(state, rois, 1); 70 | auto rois_sz = rois.sizes(); 71 | int num_rois = rois_sz[0]; 72 | int size_rois = rois_sz[1]; 73 | if (size_rois != 5) 74 | { 75 | return 0; 76 | } 77 | 78 | // batch size 79 | //int batch_size = THCudaTensor_size(state, bottom_grad, 0); 80 | // data height 81 | //int data_height = THCudaTensor_size(state, bottom_grad, 2); 82 | // data width 83 | //int data_width = THCudaTensor_size(state, bottom_grad, 3); 84 | // Number of channels 85 | //int num_channels = THCudaTensor_size(state, bottom_grad, 1); 86 | 87 | auto grad_sz = bottom_grad.sizes(); 88 | int batch_size = grad_sz[0]; 89 | int data_height = grad_sz[2]; 90 | int data_width = grad_sz[3]; 91 | int num_channels = grad_sz[1]; 92 | 93 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 94 | RODAlignBackwardLaucher( 95 | top_grad_flat, spatial_scale, batch_size, num_rois, data_height, 96 | data_width, num_channels, aligned_height, 97 | aligned_width, rois_flat, 98 | bottom_grad_flat, stream); 99 | 100 | return 1; 101 | } 102 | 103 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 104 | m.def("forward", &rod_align_forward_cuda, "rod_align forward"); 105 | m.def("backward", &rod_align_backward_cuda, "rod_align backward"); 106 | } 107 | -------------------------------------------------------------------------------- /untils/rod_align/src/rod_align_cuda.h: -------------------------------------------------------------------------------- 1 | #ifndef ROD_ALIGN_CUDA_H 2 | #define ROD_ALIGN_CUDA_H 3 | 4 | #include 5 | 6 | int rod_align_forward_cuda(int aligned_height, int aligned_width, float spatial_scale, 7 | torch::Tensor features, torch::Tensor rois, torch::Tensor output); 8 | 9 | int rod_align_backward_cuda(int aligned_height, int aligned_width, float spatial_scale, 10 | torch::Tensor top_grad, torch::Tensor rois, torch::Tensor bottom_grad); 11 | 12 | #endif 13 | -------------------------------------------------------------------------------- /untils/rod_align/src/rod_align_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include "rod_align_kernel.h" 3 | 4 | #define CUDA_1D_KERNEL_LOOP(i, n) \ 5 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \ 6 | i += blockDim.x * gridDim.x) 7 | 8 | 9 | __global__ void RODAlignForward(const int nthreads, const float* bottom_data, const float spatial_scale, const int height, const int width, 10 | const int channels, const int aligned_height, const int aligned_width, const float* bottom_rois, float* top_data) { 11 | float bin_size_h = (float)(height - 1.001) / (aligned_height - 1.); 12 | float bin_size_w = (float)(width - 1.001) / (aligned_width - 1.); 13 | CUDA_1D_KERNEL_LOOP(index, nthreads) { 14 | // (n, c, ph, pw) is an element in the aligned output 15 | // int n = index; 16 | // int pw = n % aligned_width; 17 | // n /= aligned_width; 18 | // int ph = n % aligned_height; 19 | // n /= aligned_height; 20 | // int c = n % channels; 21 | // n /= channels; 22 | 23 | int pw = index % aligned_width; 24 | int ph = (index / aligned_width) % aligned_height; 25 | int c = (index / aligned_width / aligned_height) % channels; 26 | int n = index / aligned_width / aligned_height / channels; 27 | 28 | // bottom_rois += n * 5; 29 | float roi_batch_ind = bottom_rois[n * 5 + 0]; 30 | float roi_start_w = bottom_rois[n * 5 + 1] * spatial_scale; 31 | float roi_start_h = bottom_rois[n * 5 + 2] * spatial_scale; 32 | float roi_end_w = bottom_rois[n * 5 + 3] * spatial_scale; 33 | float roi_end_h = bottom_rois[n * 5 + 4] * spatial_scale; 34 | 35 | 36 | float h = (float)(ph) * bin_size_h; 37 | float w = (float)(pw) * bin_size_w; 38 | 39 | int hstart = fminf(floor(h), height - 2); 40 | int wstart = fminf(floor(w), width - 2); 41 | 42 | int img_start = roi_batch_ind * channels * height * width; 43 | 44 | // bilinear interpolation 45 | if (h >= roi_start_h && h <= roi_end_h && w >= roi_start_w && w <= roi_end_w){ 46 | top_data[index] = 0.; 47 | } else { 48 | float h_ratio = h - (float)(hstart); 49 | float w_ratio = w - (float)(wstart); 50 | int upleft = img_start + (c * height + hstart) * width + wstart; 51 | int upright = upleft + 1; 52 | int downleft = upleft + width; 53 | int downright = downleft + 1; 54 | 55 | top_data[index] = bottom_data[upleft] * (1. - h_ratio) * (1. - w_ratio) 56 | + bottom_data[upright] * (1. - h_ratio) * w_ratio 57 | + bottom_data[downleft] * h_ratio * (1. - w_ratio) 58 | + bottom_data[downright] * h_ratio * w_ratio; 59 | } 60 | } 61 | } 62 | 63 | 64 | int RODAlignForwardLaucher(const float* bottom_data, const float spatial_scale, const int num_rois, const int height, const int width, 65 | const int channels, const int aligned_height, const int aligned_width, const float* bottom_rois, float* top_data, cudaStream_t stream) { 66 | const int kThreadsPerBlock = 1024; 67 | const int output_size = num_rois * aligned_height * aligned_width * channels; 68 | cudaError_t err; 69 | 70 | 71 | RODAlignForward<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>( 72 | output_size, bottom_data, spatial_scale, height, width, channels, 73 | aligned_height, aligned_width, bottom_rois, top_data); 74 | 75 | err = cudaGetLastError(); 76 | if(cudaSuccess != err) { 77 | fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) ); 78 | exit( -1 ); 79 | } 80 | 81 | return 1; 82 | } 83 | 84 | 85 | __global__ void RODAlignBackward(const int nthreads, const float* top_diff, const float spatial_scale, const int height, const int width, 86 | const int channels, const int aligned_height, const int aligned_width, float* bottom_diff, const float* bottom_rois) { 87 | float bin_size_h = (float)(height - 1.001) / (aligned_height - 1.); 88 | float bin_size_w = (float)(width - 1.001) / (aligned_width - 1.); 89 | CUDA_1D_KERNEL_LOOP(index, nthreads) { 90 | 91 | // (n, c, ph, pw) is an element in the aligned output 92 | int pw = index % aligned_width; 93 | int ph = (index / aligned_width) % aligned_height; 94 | int c = (index / aligned_width / aligned_height) % channels; 95 | int n = index / aligned_width / aligned_height / channels; 96 | 97 | float roi_batch_ind = bottom_rois[n * 5 + 0]; 98 | float roi_start_w = bottom_rois[n * 5 + 1] * spatial_scale; 99 | float roi_start_h = bottom_rois[n * 5 + 2] * spatial_scale; 100 | float roi_end_w = bottom_rois[n * 5 + 3] * spatial_scale; 101 | float roi_end_h = bottom_rois[n * 5 + 4] * spatial_scale; 102 | 103 | 104 | float h = (float)(ph) * bin_size_h; 105 | float w = (float)(pw) * bin_size_w; 106 | 107 | int hstart = fminf(floor(h), height - 2); 108 | int wstart = fminf(floor(w), width - 2); 109 | 110 | int img_start = roi_batch_ind * channels * height * width; 111 | 112 | // bilinear interpolation 113 | if (!(h >= roi_start_h && h <= roi_end_h && w >= roi_start_w && w <= roi_end_w)) { 114 | float h_ratio = h - (float)(hstart); 115 | float w_ratio = w - (float)(wstart); 116 | int upleft = img_start + (c * height + hstart) * width + wstart; 117 | int upright = upleft + 1; 118 | int downleft = upleft + width; 119 | int downright = downleft + 1; 120 | 121 | atomicAdd(bottom_diff + upleft, top_diff[index] * (1. - h_ratio) * (1 - w_ratio)); 122 | atomicAdd(bottom_diff + upright, top_diff[index] * (1. - h_ratio) * w_ratio); 123 | atomicAdd(bottom_diff + downleft, top_diff[index] * h_ratio * (1 - w_ratio)); 124 | atomicAdd(bottom_diff + downright, top_diff[index] * h_ratio * w_ratio); 125 | } 126 | } 127 | } 128 | 129 | int RODAlignBackwardLaucher(const float* top_diff, const float spatial_scale, const int batch_size, const int num_rois, const int height, const int width, 130 | const int channels, const int aligned_height, const int aligned_width, const float* bottom_rois, float* bottom_diff, cudaStream_t stream) { 131 | const int kThreadsPerBlock = 1024; 132 | const int output_size = num_rois * aligned_height * aligned_width * channels; 133 | cudaError_t err; 134 | 135 | RODAlignBackward<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>( 136 | output_size, top_diff, spatial_scale, height, width, channels, 137 | aligned_height, aligned_width, bottom_diff, bottom_rois); 138 | 139 | err = cudaGetLastError(); 140 | if(cudaSuccess != err) { 141 | fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) ); 142 | exit( -1 ); 143 | } 144 | 145 | return 1; 146 | } 147 | -------------------------------------------------------------------------------- /untils/rod_align/src/rod_align_kernel.h: -------------------------------------------------------------------------------- 1 | #ifndef _ROD_ALIGN_KERNEL 2 | #define _ROD_ALIGN_KERNEL 3 | 4 | #include 5 | 6 | __global__ void RODAlignForward(const int nthreads, const float* bottom_data, 7 | const float spatial_scale, const int height, const int width, 8 | const int channels, const int aligned_height, const int aligned_width, 9 | const float* bottom_rois, float* top_data); 10 | 11 | int RODAlignForwardLaucher( 12 | const float* bottom_data, const float spatial_scale, const int num_rois, const int height, 13 | const int width, const int channels, const int aligned_height, 14 | const int aligned_width, const float* bottom_rois, 15 | float* top_data, cudaStream_t stream); 16 | 17 | __global__ void RODAlignBackward(const int nthreads, const float* top_diff, 18 | const float spatial_scale, const int height, const int width, 19 | const int channels, const int aligned_height, const int aligned_width, 20 | float* bottom_diff, const float* bottom_rois); 21 | 22 | int RODAlignBackwardLaucher(const float* top_diff, const float spatial_scale, const int batch_size, const int num_rois, 23 | const int height, const int width, const int channels, const int aligned_height, 24 | const int aligned_width, const float* bottom_rois, 25 | float* bottom_diff, cudaStream_t stream); 26 | 27 | #endif 28 | 29 | -------------------------------------------------------------------------------- /untils/roi_align/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bo-zhang-cs/GAIC-Pytorch/877d6a325943c0a098324dc6d4a975ff4097c3f6/untils/roi_align/__init__.py -------------------------------------------------------------------------------- /untils/roi_align/functions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bo-zhang-cs/GAIC-Pytorch/877d6a325943c0a098324dc6d4a975ff4097c3f6/untils/roi_align/functions/__init__.py -------------------------------------------------------------------------------- /untils/roi_align/functions/roi_align.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | import roi_align_api 4 | 5 | class RoIAlignFunction(Function): 6 | @staticmethod 7 | def forward(ctx, features, rois, aligned_height, aligned_width, spatial_scale): 8 | batch_size, num_channels, data_height, data_width = features.size() 9 | ctx.save_for_backward(rois, 10 | torch.IntTensor([int(batch_size), 11 | int(num_channels), 12 | int(data_height), 13 | int(data_width), 14 | int(aligned_height), 15 | int(aligned_width)]), 16 | torch.FloatTensor([float(spatial_scale)])) 17 | 18 | num_rois = rois.size(0) 19 | 20 | output = features.new(num_rois, 21 | num_channels, 22 | int(aligned_height), 23 | int(aligned_width)).zero_() 24 | 25 | roi_align_api.forward(int(aligned_height), 26 | int(aligned_width), 27 | float(spatial_scale), 28 | features, 29 | rois, 30 | output) 31 | return output 32 | 33 | @staticmethod 34 | def backward(ctx, grad_output): 35 | rois, core_size, scale = ctx.saved_tensors 36 | 37 | batch_size, num_channels, data_height, data_width, aligned_height, aligned_width = core_size 38 | spatial_scale = scale[0] 39 | 40 | grad_input = rois.new(batch_size, 41 | num_channels, 42 | data_height, 43 | data_width).zero_() 44 | 45 | roi_align_api.backward(int(aligned_height), 46 | int(aligned_width), 47 | float(spatial_scale), 48 | grad_output, 49 | rois, 50 | grad_input) 51 | 52 | return grad_input, None, None, None, None 53 | -------------------------------------------------------------------------------- /untils/roi_align/make.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | cd src 3 | echo "Compiling roi_align kernels by nvcc..." 4 | 5 | # Specify the architecture of your NV card below. 6 | # -arch=sm_75 is compatible with the following NV GPU cards, 7 | # GeForce RTX 2080 Ti, RTX 2080, RTX 2070 Quadro RTX 8000, Quadro RTX 6000, Quadro RTX 5000 Tesla T4 8 | # See more https://raw.githubusercontent.com/stereolabs/zed-yolo/master/libdarknet/Makefile 9 | nvcc -c -o roi_align_kernel.cu.o roi_align_kernel.cu -x cu -Xcompiler -fPIC -arch=sm_86 10 | 11 | cd ../ 12 | # Export CUDA_HOME. Build and install the library. 13 | export CUDA_HOME=/usr/local/cuda && python3 setup.py install 14 | -------------------------------------------------------------------------------- /untils/roi_align/make_python2.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | cd src 3 | echo "Compiling roi_align kernels by nvcc..." 4 | 5 | # Specify the architecture of your NV card below. 6 | # -arch=sm_75 is compatible with the following NV GPU cards, 7 | # GeForce RTX 2080 Ti, RTX 2080, RTX 2070 Quadro RTX 8000, Quadro RTX 6000, Quadro RTX 5000 Tesla T4 8 | # See more https://raw.githubusercontent.com/stereolabs/zed-yolo/master/libdarknet/Makefile 9 | nvcc -c -o roi_align_kernel.cu.o roi_align_kernel.cu -x cu -Xcompiler -fPIC -arch=sm_75 10 | 11 | cd ../ 12 | # Export CUDA_HOME. Build and install the library. 13 | export CUDA_HOME=/usr/local/cuda-10.0 && python setup.py install 14 | -------------------------------------------------------------------------------- /untils/roi_align/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bo-zhang-cs/GAIC-Pytorch/877d6a325943c0a098324dc6d4a975ff4097c3f6/untils/roi_align/modules/__init__.py -------------------------------------------------------------------------------- /untils/roi_align/modules/roi_align.py: -------------------------------------------------------------------------------- 1 | from torch.nn.modules.module import Module 2 | from torch.nn.functional import avg_pool2d, max_pool2d 3 | from ..functions.roi_align import RoIAlignFunction 4 | 5 | 6 | class RoIAlign(Module): 7 | def __init__(self, aligned_height, aligned_width, spatial_scale): 8 | super(RoIAlign, self).__init__() 9 | 10 | self.aligned_width = int(aligned_width) 11 | self.aligned_height = int(aligned_height) 12 | self.spatial_scale = float(spatial_scale) 13 | 14 | def forward(self, features, rois): 15 | return RoIAlignFunction.apply(features, 16 | rois, 17 | self.aligned_height, 18 | self.aligned_width, 19 | self.spatial_scale) 20 | 21 | class RoIAlignAvg(Module): 22 | def __init__(self, aligned_height, aligned_width, spatial_scale): 23 | super(RoIAlignAvg, self).__init__() 24 | 25 | self.aligned_width = int(aligned_width) 26 | self.aligned_height = int(aligned_height) 27 | self.spatial_scale = float(spatial_scale) 28 | 29 | def forward(self, features, rois): 30 | x = RoIAlignFunction.apply(features, 31 | rois, 32 | self.aligned_height+1, 33 | self.aligned_width+1, 34 | self.spatial_scale) 35 | return avg_pool2d(x, kernel_size=2, stride=1) 36 | 37 | class RoIAlignMax(Module): 38 | def __init__(self, aligned_height, aligned_width, spatial_scale): 39 | super(RoIAlignMax, self).__init__() 40 | 41 | self.aligned_width = int(aligned_width) 42 | self.aligned_height = int(aligned_height) 43 | self.spatial_scale = float(spatial_scale) 44 | 45 | def forward(self, features, rois): 46 | x = RoIAlignFunction.apply(features, 47 | rois, 48 | self.aligned_height+1, 49 | self.aligned_width+1, 50 | self.spatial_scale) 51 | return max_pool2d(x, kernel_size=2, stride=1) 52 | -------------------------------------------------------------------------------- /untils/roi_align/setup.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import torch 4 | from pkg_resources import parse_version 5 | 6 | min_version = parse_version('1.0.0') 7 | current_version = parse_version(torch.__version__) 8 | 9 | 10 | if current_version < min_version: #PyTorch before 1.0 11 | from torch.utils.ffi import create_extension 12 | 13 | sources = ['src/roi_align.c'] 14 | headers = ['src/roi_align.h'] 15 | extra_objects = [] 16 | #sources = [] 17 | #headers = [] 18 | defines = [] 19 | with_cuda = False 20 | 21 | this_file = os.path.dirname(os.path.realpath(__file__)) 22 | print(this_file) 23 | 24 | if torch.cuda.is_available(): 25 | print('Including CUDA code.') 26 | sources += ['src/roi_align_cuda.c'] 27 | headers += ['src/roi_align_cuda.h'] 28 | defines += [('WITH_CUDA', None)] 29 | with_cuda = True 30 | 31 | extra_objects = ['src/roi_align_kernel.cu.o'] 32 | extra_objects = [os.path.join(this_file, fname) for fname in extra_objects] 33 | 34 | ffi = create_extension( 35 | '_ext.roi_align', 36 | headers=headers, 37 | sources=sources, 38 | define_macros=defines, 39 | relative_to=__file__, 40 | with_cuda=with_cuda, 41 | extra_objects=extra_objects 42 | ) 43 | 44 | if __name__ == '__main__': 45 | ffi.build() 46 | else: # PyTorch 1.0 or later 47 | from setuptools import setup 48 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 49 | 50 | print('Including CUDA code.') 51 | current_dir = os.path.dirname(os.path.realpath(__file__)) 52 | #cuda_include = '/usr/local/cuda-10.0/include' 53 | 54 | #GPU version 55 | setup( 56 | name='roi_align_api', 57 | ext_modules=[ 58 | CUDAExtension( 59 | name='roi_align_api', 60 | sources=['src/roi_align_cuda.cpp', 'src/roi_align_kernel.cu'], 61 | include_dirs=[current_dir]+torch.utils.cpp_extension.include_paths(cuda=True) 62 | ) 63 | ], 64 | cmdclass={ 65 | 'build_ext': BuildExtension 66 | }) 67 | -------------------------------------------------------------------------------- /untils/roi_align/src/roi_align.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include "roi_align.h" 5 | 6 | void ROIAlignForwardCpu(const float* bottom_data, const float spatial_scale, const int num_rois, 7 | const int height, const int width, const int channels, 8 | const int aligned_height, const int aligned_width, const float * bottom_rois, 9 | float* top_data); 10 | 11 | void ROIAlignBackwardCpu(const float* top_diff, const float spatial_scale, const int num_rois, 12 | const int height, const int width, const int channels, 13 | const int aligned_height, const int aligned_width, const float * bottom_rois, 14 | float* top_data); 15 | 16 | int roi_align_forward(int aligned_height, int aligned_width, float spatial_scale, 17 | torch::Tensor features, torch::Tensor rois, torch::Tensor output) 18 | { 19 | //Grab the input tensor 20 | //float * data_flat = THFloatTensor_data(features); 21 | //float * rois_flat = THFloatTensor_data(rois); 22 | auto data_flat = features.data(); 23 | auto rois_flat = rois.data(); 24 | 25 | //float * output_flat = THFloatTensor_data(output); 26 | auto output_flat = output.data(); 27 | 28 | // Number of ROIs 29 | //int num_rois = THFloatTensor_size(rois, 0); 30 | //int size_rois = THFloatTensor_size(rois, 1); 31 | 32 | auto rois_sz = rois.sizes(); 33 | int num_rois = rois_sz[0]; 34 | int size_rois = rois_sz[1]; 35 | 36 | if (size_rois != 5) 37 | { 38 | return 0; 39 | } 40 | 41 | // data height 42 | //int data_height = THFloatTensor_size(features, 2); 43 | // data width 44 | //int data_width = THFloatTensor_size(features, 3); 45 | // Number of channels 46 | //int num_channels = THFloatTensor_size(features, 1); 47 | auto feat_sz = features.sizes(); 48 | int data_height = feat_sz[2]; 49 | int data_width = feat_sz[3]; 50 | int num_channels = feat_sz[1]; 51 | 52 | // do ROIAlignForward 53 | ROIAlignForwardCpu(data_flat, spatial_scale, num_rois, data_height, data_width, num_channels, 54 | aligned_height, aligned_width, rois_flat, output_flat); 55 | 56 | return 1; 57 | } 58 | 59 | int roi_align_backward(int aligned_height, int aligned_width, float spatial_scale, 60 | torch::Tensor top_grad, torch::Tensor rois, torch::Tensor bottom_grad) 61 | { 62 | //Grab the input tensor 63 | //float * top_grad_flat = THFloatTensor_data(top_grad); 64 | //float * rois_flat = THFloatTensor_data(rois); 65 | 66 | //float * bottom_grad_flat = THFloatTensor_data(bottom_grad); 67 | auto top_grad_flat = top_grad.data(); 68 | auto rois_flat = rois.data(); 69 | auto bottom_grad_flat = bottom_grad.data(); 70 | 71 | // Number of ROIs 72 | //int num_rois = THFloatTensor_size(rois, 0); 73 | //int size_rois = THFloatTensor_size(rois, 1); 74 | auto rois_sz = rois.sizes(); 75 | int num_rois = rois_sz[0]; 76 | int size_rois = rois_sz[1]; 77 | if (size_rois != 5) 78 | { 79 | return 0; 80 | } 81 | 82 | // batch size 83 | // int batch_size = THFloatTensor_size(bottom_grad, 0); 84 | // data height 85 | //int data_height = THFloatTensor_size(bottom_grad, 2); 86 | // data width 87 | //int data_width = THFloatTensor_size(bottom_grad, 3); 88 | // Number of channels 89 | //int num_channels = THFloatTensor_size(bottom_grad, 1); 90 | 91 | auto grad_sz = bottom_grad.sizes(); 92 | int data_height = grad_sz[2]; 93 | int data_width = grad_sz[3]; 94 | int num_channels = grad_sz[1]; 95 | 96 | // do ROIAlignBackward 97 | ROIAlignBackwardCpu(top_grad_flat, spatial_scale, num_rois, data_height, 98 | data_width, num_channels, aligned_height, aligned_width, rois_flat, bottom_grad_flat); 99 | 100 | return 1; 101 | } 102 | 103 | void ROIAlignForwardCpu(const float* bottom_data, const float spatial_scale, const int num_rois, 104 | const int height, const int width, const int channels, 105 | const int aligned_height, const int aligned_width, const float * bottom_rois, 106 | float* top_data) 107 | { 108 | const int output_size = num_rois * aligned_height * aligned_width * channels; 109 | 110 | int idx = 0; 111 | for (idx = 0; idx < output_size; ++idx) 112 | { 113 | // (n, c, ph, pw) is an element in the aligned output 114 | int pw = idx % aligned_width; 115 | int ph = (idx / aligned_width) % aligned_height; 116 | int c = (idx / aligned_width / aligned_height) % channels; 117 | int n = idx / aligned_width / aligned_height / channels; 118 | 119 | float roi_batch_ind = bottom_rois[n * 5 + 0]; 120 | float roi_start_w = bottom_rois[n * 5 + 1] * spatial_scale; 121 | float roi_start_h = bottom_rois[n * 5 + 2] * spatial_scale; 122 | float roi_end_w = bottom_rois[n * 5 + 3] * spatial_scale; 123 | float roi_end_h = bottom_rois[n * 5 + 4] * spatial_scale; 124 | 125 | // Force malformed ROI to be 1x1 126 | float roi_width = fmaxf(roi_end_w - roi_start_w + 1., 0.); 127 | float roi_height = fmaxf(roi_end_h - roi_start_h + 1., 0.); 128 | float bin_size_h = roi_height / (aligned_height - 1.); 129 | float bin_size_w = roi_width / (aligned_width - 1.); 130 | 131 | float h = (float)(ph) * bin_size_h + roi_start_h; 132 | float w = (float)(pw) * bin_size_w + roi_start_w; 133 | 134 | int hstart = fminf(floor(h), height - 2); 135 | int wstart = fminf(floor(w), width - 2); 136 | 137 | int img_start = roi_batch_ind * channels * height * width; 138 | 139 | // bilinear interpolation 140 | if (h < 0 || h >= height || w < 0 || w >= width) 141 | { 142 | top_data[idx] = 0.; 143 | } 144 | else 145 | { 146 | float h_ratio = h - (float)(hstart); 147 | float w_ratio = w - (float)(wstart); 148 | int upleft = img_start + (c * height + hstart) * width + wstart; 149 | int upright = upleft + 1; 150 | int downleft = upleft + width; 151 | int downright = downleft + 1; 152 | 153 | top_data[idx] = bottom_data[upleft] * (1. - h_ratio) * (1. - w_ratio) 154 | + bottom_data[upright] * (1. - h_ratio) * w_ratio 155 | + bottom_data[downleft] * h_ratio * (1. - w_ratio) 156 | + bottom_data[downright] * h_ratio * w_ratio; 157 | } 158 | } 159 | } 160 | 161 | void ROIAlignBackwardCpu(const float* top_diff, const float spatial_scale, const int num_rois, 162 | const int height, const int width, const int channels, 163 | const int aligned_height, const int aligned_width, const float * bottom_rois, 164 | float* bottom_diff) 165 | { 166 | const int output_size = num_rois * aligned_height * aligned_width * channels; 167 | 168 | int idx = 0; 169 | for (idx = 0; idx < output_size; ++idx) 170 | { 171 | // (n, c, ph, pw) is an element in the aligned output 172 | int pw = idx % aligned_width; 173 | int ph = (idx / aligned_width) % aligned_height; 174 | int c = (idx / aligned_width / aligned_height) % channels; 175 | int n = idx / aligned_width / aligned_height / channels; 176 | 177 | float roi_batch_ind = bottom_rois[n * 5 + 0]; 178 | float roi_start_w = bottom_rois[n * 5 + 1] * spatial_scale; 179 | float roi_start_h = bottom_rois[n * 5 + 2] * spatial_scale; 180 | float roi_end_w = bottom_rois[n * 5 + 3] * spatial_scale; 181 | float roi_end_h = bottom_rois[n * 5 + 4] * spatial_scale; 182 | 183 | // Force malformed ROI to be 1x1 184 | float roi_width = fmaxf(roi_end_w - roi_start_w + 1., 0.); 185 | float roi_height = fmaxf(roi_end_h - roi_start_h + 1., 0.); 186 | float bin_size_h = roi_height / (aligned_height - 1.); 187 | float bin_size_w = roi_width / (aligned_width - 1.); 188 | 189 | float h = (float)(ph) * bin_size_h + roi_start_h; 190 | float w = (float)(pw) * bin_size_w + roi_start_w; 191 | 192 | int hstart = fminf(floor(h), height - 2); 193 | int wstart = fminf(floor(w), width - 2); 194 | 195 | int img_start = roi_batch_ind * channels * height * width; 196 | 197 | // bilinear interpolation 198 | if (h < 0 || h >= height || w < 0 || w >= width) 199 | { 200 | float h_ratio = h - (float)(hstart); 201 | float w_ratio = w - (float)(wstart); 202 | int upleft = img_start + (c * height + hstart) * width + wstart; 203 | int upright = upleft + 1; 204 | int downleft = upleft + width; 205 | int downright = downleft + 1; 206 | 207 | bottom_diff[upleft] += top_diff[idx] * (1. - h_ratio) * (1. - w_ratio); 208 | bottom_diff[upright] += top_diff[idx] * (1. - h_ratio) * w_ratio; 209 | bottom_diff[downleft] += top_diff[idx] * h_ratio * (1. - w_ratio); 210 | bottom_diff[downright] += top_diff[idx] * h_ratio * w_ratio; 211 | } 212 | } 213 | } 214 | 215 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 216 | m.def("forward", &roi_align_forward, "roi_align forward"); 217 | m.def("backward", &roi_align_backward, "roi_align backward"); 218 | } 219 | -------------------------------------------------------------------------------- /untils/roi_align/src/roi_align.h: -------------------------------------------------------------------------------- 1 | #ifndef ROI_ALIGN_H 2 | #define ROI_ALIGN_H 3 | 4 | #include 5 | 6 | int roi_align_forward(int aligned_height, int aligned_width, float spatial_scale, 7 | torch::Tensor features, torch::Tensor rois, torch::Tensor output); 8 | 9 | int roi_align_backward(int aligned_height, int aligned_width, float spatial_scale, 10 | torch::Tensor top_grad, torch::Tensor rois, torch::Tensor bottom_grad); 11 | 12 | #endif 13 | -------------------------------------------------------------------------------- /untils/roi_align/src/roi_align_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include "roi_align_kernel.h" 5 | 6 | 7 | int roi_align_forward_cuda(int aligned_height, int aligned_width, float spatial_scale, 8 | torch::Tensor features, torch::Tensor rois, torch::Tensor output) 9 | { 10 | // Grab the input tensor 11 | //float * data_flat = THCudaTensor_data(state, features); 12 | //float * rois_flat = THCudaTensor_data(state, rois); 13 | 14 | //float * output_flat = THCudaTensor_data(state, output); 15 | 16 | auto data_flat = features.data(); 17 | auto rois_flat = rois.data(); 18 | auto output_flat = output.data(); 19 | 20 | // Number of ROIs 21 | //int num_rois = THCudaTensor_size(state, rois, 0); 22 | //int size_rois = THCudaTensor_size(state, rois, 1); 23 | auto rois_sz = rois.sizes(); 24 | int num_rois = rois_sz[0]; 25 | int size_rois = rois_sz[1]; 26 | if (size_rois != 5) 27 | { 28 | return 0; 29 | } 30 | 31 | // data height 32 | //int data_height = THCudaTensor_size(state, features, 2); 33 | // data width 34 | //int data_width = THCudaTensor_size(state, features, 3); 35 | // Number of channels 36 | //int num_channels = THCudaTensor_size(state, features, 1); 37 | auto feat_sz = features.sizes(); 38 | int data_height = feat_sz[2]; 39 | int data_width = feat_sz[3]; 40 | int num_channels = feat_sz[1]; 41 | 42 | 43 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 44 | 45 | ROIAlignForwardLaucher( 46 | data_flat, spatial_scale, num_rois, data_height, 47 | data_width, num_channels, aligned_height, 48 | aligned_width, rois_flat, 49 | output_flat, stream); 50 | 51 | return 1; 52 | } 53 | 54 | int roi_align_backward_cuda(int aligned_height, int aligned_width, float spatial_scale, 55 | torch::Tensor top_grad, torch::Tensor rois, torch::Tensor bottom_grad) 56 | { 57 | // Grab the input tensor 58 | //float * top_grad_flat = THCudaTensor_data(state, top_grad); 59 | //float * rois_flat = THCudaTensor_data(state, rois); 60 | 61 | //float * bottom_grad_flat = THCudaTensor_data(state, bottom_grad); 62 | auto top_grad_flat = top_grad.data(); 63 | auto rois_flat = rois.data(); 64 | auto bottom_grad_flat = bottom_grad.data(); 65 | 66 | // Number of ROIs 67 | //int num_rois = THCudaTensor_size(state, rois, 0); 68 | //int size_rois = THCudaTensor_size(state, rois, 1); 69 | auto rois_sz = rois.sizes(); 70 | int num_rois = rois_sz[0]; 71 | int size_rois = rois_sz[1]; 72 | 73 | if (size_rois != 5) 74 | { 75 | return 0; 76 | } 77 | 78 | // batch size 79 | //int batch_size = THCudaTensor_size(state, bottom_grad, 0); 80 | // data height 81 | //int data_height = THCudaTensor_size(state, bottom_grad, 2); 82 | // data width 83 | //int data_width = THCudaTensor_size(state, bottom_grad, 3); 84 | // Number of channels 85 | //int num_channels = THCudaTensor_size(state, bottom_grad, 1); 86 | auto grad_sz = bottom_grad.sizes(); 87 | int batch_size = grad_sz[0]; 88 | int data_height = grad_sz[2]; 89 | int data_width = grad_sz[3]; 90 | int num_channels = grad_sz[1]; 91 | 92 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 93 | ROIAlignBackwardLaucher( 94 | top_grad_flat, spatial_scale, batch_size, num_rois, data_height, 95 | data_width, num_channels, aligned_height, 96 | aligned_width, rois_flat, 97 | bottom_grad_flat, stream); 98 | 99 | return 1; 100 | } 101 | 102 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 103 | m.def("forward", &roi_align_forward_cuda, "roi_align forward"); 104 | m.def("backward", &roi_align_backward_cuda, "roi_align backward"); 105 | } 106 | -------------------------------------------------------------------------------- /untils/roi_align/src/roi_align_cuda.h: -------------------------------------------------------------------------------- 1 | #ifndef ROI_ALIGN_CUDA_H 2 | #define ROI_ALIGN_CUDA_H 3 | 4 | #include 5 | 6 | int roi_align_forward_cuda(int aligned_height, int aligned_width, float spatial_scale, 7 | torch::Tensor features, torch::Tensor rois, torch::Tensor output); 8 | 9 | int roi_align_backward_cuda(int aligned_height, int aligned_width, float spatial_scale, 10 | torch::Tensor top_grad, torch::Tensor rois, torch::Tensor bottom_grad); 11 | 12 | #endif 13 | -------------------------------------------------------------------------------- /untils/roi_align/src/roi_align_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include "roi_align_kernel.h" 5 | 6 | #define CUDA_1D_KERNEL_LOOP(i, n) \ 7 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \ 8 | i += blockDim.x * gridDim.x) 9 | 10 | 11 | __global__ void ROIAlignForward(const int nthreads, const float* bottom_data, const float spatial_scale, const int height, const int width, 12 | const int channels, const int aligned_height, const int aligned_width, const float* bottom_rois, float* top_data) { 13 | CUDA_1D_KERNEL_LOOP(index, nthreads) { 14 | // (n, c, ph, pw) is an element in the aligned output 15 | // int n = index; 16 | // int pw = n % aligned_width; 17 | // n /= aligned_width; 18 | // int ph = n % aligned_height; 19 | // n /= aligned_height; 20 | // int c = n % channels; 21 | // n /= channels; 22 | 23 | int pw = index % aligned_width; 24 | int ph = (index / aligned_width) % aligned_height; 25 | int c = (index / aligned_width / aligned_height) % channels; 26 | int n = index / aligned_width / aligned_height / channels; 27 | 28 | // bottom_rois += n * 5; 29 | float roi_batch_ind = bottom_rois[n * 5 + 0]; 30 | float roi_start_w = bottom_rois[n * 5 + 1] * spatial_scale; 31 | float roi_start_h = bottom_rois[n * 5 + 2] * spatial_scale; 32 | float roi_end_w = bottom_rois[n * 5 + 3] * spatial_scale; 33 | float roi_end_h = bottom_rois[n * 5 + 4] * spatial_scale; 34 | 35 | // Force malformed ROIs to be 1x1 36 | float roi_width = fmaxf(roi_end_w - roi_start_w + 1., 0.); 37 | float roi_height = fmaxf(roi_end_h - roi_start_h + 1., 0.); 38 | float bin_size_h = roi_height / (aligned_height - 1.); 39 | float bin_size_w = roi_width / (aligned_width - 1.); 40 | 41 | float h = (float)(ph) * bin_size_h + roi_start_h; 42 | float w = (float)(pw) * bin_size_w + roi_start_w; 43 | 44 | int hstart = fminf(floor(h), height - 2); 45 | int wstart = fminf(floor(w), width - 2); 46 | 47 | int img_start = roi_batch_ind * channels * height * width; 48 | 49 | // bilinear interpolation 50 | if (h < 0 || h >= height || w < 0 || w >= width) { 51 | top_data[index] = 0.; 52 | } else { 53 | float h_ratio = h - (float)(hstart); 54 | float w_ratio = w - (float)(wstart); 55 | int upleft = img_start + (c * height + hstart) * width + wstart; 56 | int upright = upleft + 1; 57 | int downleft = upleft + width; 58 | int downright = downleft + 1; 59 | 60 | top_data[index] = bottom_data[upleft] * (1. - h_ratio) * (1. - w_ratio) 61 | + bottom_data[upright] * (1. - h_ratio) * w_ratio 62 | + bottom_data[downleft] * h_ratio * (1. - w_ratio) 63 | + bottom_data[downright] * h_ratio * w_ratio; 64 | } 65 | } 66 | } 67 | 68 | 69 | int ROIAlignForwardLaucher(const float* bottom_data, const float spatial_scale, const int num_rois, const int height, const int width, 70 | const int channels, const int aligned_height, const int aligned_width, const float* bottom_rois, float* top_data, cudaStream_t stream) { 71 | const int kThreadsPerBlock = 1024; 72 | const int output_size = num_rois * aligned_height * aligned_width * channels; 73 | cudaError_t err; 74 | 75 | 76 | ROIAlignForward<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>( 77 | output_size, bottom_data, spatial_scale, height, width, channels, 78 | aligned_height, aligned_width, bottom_rois, top_data); 79 | 80 | err = cudaGetLastError(); 81 | if(cudaSuccess != err) { 82 | fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) ); 83 | exit( -1 ); 84 | } 85 | 86 | return 1; 87 | } 88 | 89 | 90 | __global__ void ROIAlignBackward(const int nthreads, const float* top_diff, const float spatial_scale, const int height, const int width, 91 | const int channels, const int aligned_height, const int aligned_width, float* bottom_diff, const float* bottom_rois) { 92 | CUDA_1D_KERNEL_LOOP(index, nthreads) { 93 | 94 | // (n, c, ph, pw) is an element in the aligned output 95 | int pw = index % aligned_width; 96 | int ph = (index / aligned_width) % aligned_height; 97 | int c = (index / aligned_width / aligned_height) % channels; 98 | int n = index / aligned_width / aligned_height / channels; 99 | 100 | float roi_batch_ind = bottom_rois[n * 5 + 0]; 101 | float roi_start_w = bottom_rois[n * 5 + 1] * spatial_scale; 102 | float roi_start_h = bottom_rois[n * 5 + 2] * spatial_scale; 103 | float roi_end_w = bottom_rois[n * 5 + 3] * spatial_scale; 104 | float roi_end_h = bottom_rois[n * 5 + 4] * spatial_scale; 105 | /* int roi_start_w = round(bottom_rois[1] * spatial_scale); */ 106 | /* int roi_start_h = round(bottom_rois[2] * spatial_scale); */ 107 | /* int roi_end_w = round(bottom_rois[3] * spatial_scale); */ 108 | /* int roi_end_h = round(bottom_rois[4] * spatial_scale); */ 109 | 110 | // Force malformed ROIs to be 1x1 111 | float roi_width = fmaxf(roi_end_w - roi_start_w + 1., 0.); 112 | float roi_height = fmaxf(roi_end_h - roi_start_h + 1., 0.); 113 | float bin_size_h = roi_height / (aligned_height - 1.); 114 | float bin_size_w = roi_width / (aligned_width - 1.); 115 | 116 | float h = (float)(ph) * bin_size_h + roi_start_h; 117 | float w = (float)(pw) * bin_size_w + roi_start_w; 118 | 119 | int hstart = fminf(floor(h), height - 2); 120 | int wstart = fminf(floor(w), width - 2); 121 | 122 | int img_start = roi_batch_ind * channels * height * width; 123 | 124 | // bilinear interpolation 125 | if (!(h < 0 || h >= height || w < 0 || w >= width)) { 126 | float h_ratio = h - (float)(hstart); 127 | float w_ratio = w - (float)(wstart); 128 | int upleft = img_start + (c * height + hstart) * width + wstart; 129 | int upright = upleft + 1; 130 | int downleft = upleft + width; 131 | int downright = downleft + 1; 132 | 133 | atomicAdd(bottom_diff + upleft, top_diff[index] * (1. - h_ratio) * (1 - w_ratio)); 134 | atomicAdd(bottom_diff + upright, top_diff[index] * (1. - h_ratio) * w_ratio); 135 | atomicAdd(bottom_diff + downleft, top_diff[index] * h_ratio * (1 - w_ratio)); 136 | atomicAdd(bottom_diff + downright, top_diff[index] * h_ratio * w_ratio); 137 | } 138 | } 139 | } 140 | 141 | int ROIAlignBackwardLaucher(const float* top_diff, const float spatial_scale, const int batch_size, const int num_rois, const int height, const int width, 142 | const int channels, const int aligned_height, const int aligned_width, const float* bottom_rois, float* bottom_diff, cudaStream_t stream) { 143 | const int kThreadsPerBlock = 1024; 144 | const int output_size = num_rois * aligned_height * aligned_width * channels; 145 | cudaError_t err; 146 | 147 | ROIAlignBackward<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>( 148 | output_size, top_diff, spatial_scale, height, width, channels, 149 | aligned_height, aligned_width, bottom_diff, bottom_rois); 150 | 151 | err = cudaGetLastError(); 152 | if(cudaSuccess != err) { 153 | fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) ); 154 | exit( -1 ); 155 | } 156 | 157 | return 1; 158 | } 159 | -------------------------------------------------------------------------------- /untils/roi_align/src/roi_align_kernel.h: -------------------------------------------------------------------------------- 1 | #ifndef _ROI_ALIGN_KERNEL 2 | #define _ROI_ALIGN_KERNEL 3 | 4 | 5 | __global__ void ROIAlignForward(const int nthreads, const float* bottom_data, 6 | const float spatial_scale, const int height, const int width, 7 | const int channels, const int aligned_height, const int aligned_width, 8 | const float* bottom_rois, float* top_data); 9 | 10 | int ROIAlignForwardLaucher( 11 | const float* bottom_data, const float spatial_scale, const int num_rois, const int height, 12 | const int width, const int channels, const int aligned_height, 13 | const int aligned_width, const float* bottom_rois, 14 | float* top_data, cudaStream_t stream); 15 | 16 | __global__ void ROIAlignBackward(const int nthreads, const float* top_diff, 17 | const float spatial_scale, const int height, const int width, 18 | const int channels, const int aligned_height, const int aligned_width, 19 | float* bottom_diff, const float* bottom_rois); 20 | 21 | int ROIAlignBackwardLaucher(const float* top_diff, const float spatial_scale, const int batch_size, const int num_rois, 22 | const int height, const int width, const int channels, const int aligned_height, 23 | const int aligned_width, const float* bottom_rois, 24 | float* bottom_diff, cudaStream_t stream); 25 | 26 | #endif 27 | 28 | --------------------------------------------------------------------------------