├── LICENSE
├── README.md
├── config
├── GAIC_config.py
└── GAIC_params.yaml
├── dataset
├── candidate_generation.py
└── cropping_dataset.py
├── evaluate
├── demo.py
└── test.py
├── networks
├── GAIC_model.py
└── __init__.py
├── requirements.txt
├── result_images
├── 211958.jpg
├── 265813.jpg
└── 297406.jpg
├── test_images
├── 211958.jpg
├── 265813.jpg
└── 297406.jpg
├── train
└── train.py
└── untils
├── make_all.sh
├── rod_align
├── __init__.py
├── functions
│ ├── __init__.py
│ └── rod_align.py
├── make.sh
├── make_python2.sh
├── modules
│ ├── __init__.py
│ └── rod_align.py
├── setup.py
└── src
│ ├── rod_align.cpp
│ ├── rod_align.h
│ ├── rod_align_cuda.cpp
│ ├── rod_align_cuda.h
│ ├── rod_align_kernel.cu
│ └── rod_align_kernel.h
└── roi_align
├── __init__.py
├── functions
├── __init__.py
└── roi_align.py
├── make.sh
├── make_python2.sh
├── modules
├── __init__.py
└── roi_align.py
├── setup.py
└── src
├── roi_align.cpp
├── roi_align.h
├── roi_align_cuda.cpp
├── roi_align_cuda.h
├── roi_align_kernel.cu
└── roi_align_kernel.h
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 bo-zhang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # GAIC-PyTorch
2 | This is an unofficial PyTorch implementation of [Grid Anchor based Image Cropping: A New Benchmark and An Efficient Model](https://arxiv.org/pdf/1909.08989.pdf?ref=https://githubhelp.com), which is the journal version and extension of [Reliable and Efficient Image Cropping: A Grid Anchor based Approach](https://arxiv.org/pdf/1904.04441.pdf).
3 | We provide demo code to produce the best cropping results with different aspect ratios (1:1, 4:3, and 16:9) for arbitrary test images.
4 | Moreover, this code is also able to generate crops with arbitrary specified aspect ratios.
5 |
6 |
7 |

8 |
9 |
10 |

11 |
12 |
13 |

14 |
15 |
16 | ### Requirements
17 | - PyTorch>=1.0
18 |
19 | You can install packages using pip according to [``requirements.txt``](./requirements.txt):
20 |
21 | ```Shell
22 | pip install -r requirements.txt
23 | ```
24 |
25 | ### Installation
26 | 1. Get the code. We will call the directory that you cloned into `$PRJ_ROOT`.
27 | ```Shell
28 | git clone https://github.com/bo-zhang-cs/GAIC-Pytorch.git
29 | ```
30 |
31 | 2. Build the RoI&RoDAlign libraries. The source code of RoI&RoDAlign is from [[here]](https://github.com/lld533/Grid-Anchor-based-Image-Cropping-Pytorch) compatible with PyTorch 1.0 or later. If you use Pytorch 0.4.1, please refer to [[official implementation]](https://github.com/HuiZeng/Grid-Anchor-based-Image-Cropping-Pytorch).
32 | ```Shell
33 | cd $PRJ_ROOT/untils
34 | # Change the **CUDA_HOME** and **-arch=sm_86** in ``roi_align/make.sh`` and ``rod_align/make.sh`` according to your enviroment, respectively.
35 | # Make sure these bash files (``make_all.sh, roi_align/make.sh, rod_align/make.sh``) are Unix text file format by runing ``:set ff=unix`` in VIM.
36 | sudo bash make_all.sh
37 | ```
38 |
39 | ### Running Demo
40 | 1. Download pretrained models(~98MB) from [[Google Drive]](https://drive.google.com/file/d/1U_8C9oWOBT64LHtxZP1_0Uo9Ndw0QLsJ/view?usp=sharing) [[Baidu Cloud]](https://pan.baidu.com/s/1fmy18FD5_0v6vrab6OZKmQ)(access code: *webf*). By default, we assume the models (``*.pth``) is stored in `$PRJ_ROOT/pretrained_models`.
41 |
42 | 2. Predict best crops for the user's images.
43 | ```Shell
44 | cd $PRJ_ROOT
45 | python evaluate/demo.py --gpu 0 --image_dir $IMAGE_PATH/IMAGE_FOLDER --save_dir $RESULT_FOLDER
46 | # or execute python evaluate/demo.py to predict best crops for the images in the test_images folder.
47 | ```
48 |
49 | ### Train/Evaluate
50 | 1. Download [GAICD dataset](https://github.com/HuiZeng/Grid-Anchor-based-Image-Cropping). And set the ``GAIC_folder`` in ``config.py`` and you can check the paths by running:
51 | ```Shell
52 | cd $PRJ_ROOT
53 | python dataset/cropping_dataset.py
54 | ```
55 |
56 | 2. Train your model and evaluate the model on the fly.
57 | ```Shell
58 | cd $PRJ_ROOT
59 | python train/train.py --gpu 0 --backbone vgg16
60 | # or just running python train/train.py
61 | ```
62 | The model performance for each epoch is also recorded in ``*.csv`` file under the produced ``experiments`` folder.
63 | More model parameters and experimental settings can be found in ``config/GAIC_params.yaml``.
64 |
65 | 3. Evaluate the pretrained model and reproduce the below quantitative results.
66 | ```Shell
67 | cd $PRJ_ROOT
68 | python evaluate/test.py
69 | ```
70 |
71 | ### Performance on GAICD Dataset
72 | | #Metric | Backbone | SRCC↑ | PCC↑ | Acc5↑ |Acc1_5↑ |Acc4_5↑ |Acc10↑ |Acc1_10↑ |Acc4_10↑|
73 | |:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:|
74 | | Paper | VGG16 | 0.777 | 0.800 | - | 60.5 | 50.2 | - | 77.5 | 70.6 |
75 | | This code | VGG16 | 0.778 | 0.808 | 54.8 | 60.0 | 51.6 | 73.3 | 77.5 | 69.9 |
76 | | Paper | MobileNetV2 | 0.783 | 0.806 | - | **62.5** | 52.5 | - | 78.5 | **72.3**|
77 | | This code | MobileNetV2 | 0.783 | 0.810 | **58.0** | **62.5** | **53.0** | **74.2** | **80.0** | 72.0|
78 | | Paper | ShuffleNetV2 | 0.774 | 0.801 | - | 61.5 | 52.0 | - | 78.5 | 71.3 |
79 | | This code | ShuffleNetV2 | **0.787** | **0.811** | 55.4 | 62.0 | 50.0 | 73.4 | 78.0 | 69.6 |
80 |
81 | ### Citation
82 | ```
83 | @inproceedings{zhang2019deep,
84 | title={Reliable and Efficient Image Cropping: A Grid Anchor based Approach},
85 | author={Zeng, Hui, Li, Lida, Cao, Zisheng and Zhang, Lei},
86 | booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
87 | year={2019}
88 | }
89 | @article{zeng2020cropping,
90 | title={Grid Anchor based Image Cropping: A New Benchmark and An Efficient Model},
91 | author={Zeng, Hui and Li, Lida and Cao, Zisheng and Zhang, Lei},
92 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
93 | volume={},
94 | number={},
95 | pages={},
96 | year={2020},
97 | publisher={IEEE}
98 | }
99 | ```
100 |
--------------------------------------------------------------------------------
/config/GAIC_config.py:
--------------------------------------------------------------------------------
1 | import os
2 | import yaml
3 |
4 | yaml_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'GAIC_params.yaml')
5 | assert os.path.exists(yaml_path), yaml_path
6 |
7 | def load_yaml_params():
8 | with open(yaml_path, 'r') as yaml_file:
9 | params = yaml.full_load(yaml_file.read())
10 | return params
11 |
12 | def refresh_yaml_params(args):
13 | yaml_params = load_yaml_params()
14 | for arg in vars(args):
15 | # print(arg, type(arg), getattr(args, arg))
16 | assert arg in yaml_params, arg
17 | yaml_params[arg] = getattr(args, arg)
18 |
19 | with open(yaml_path, 'w') as yaml_file:
20 | yaml.dump(yaml_params, yaml_file)
21 |
22 | class Config:
23 | data_root = '../../dataset'
24 | GAIC_folder = os.path.join(data_root, 'GAICD')
25 |
26 | def __init__(self):
27 | self.refresh_params()
28 |
29 | def refresh_params(self):
30 | self.load_params_from_yaml()
31 | self.generate_path()
32 |
33 | def load_params_from_yaml(self):
34 | # add parameters from yaml file
35 | names = self.__dict__
36 | params = load_yaml_params()
37 | for k, v in params.items():
38 | # print(v, type(v))
39 | names[k] = v
40 |
41 | def generate_path(self):
42 | prefix = 'GAIC-{}-re{}dim'.format(self.backbone, self.reddim)
43 | exp_root = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'experiments')
44 | exp_name = prefix
45 | exp_path = os.path.join(exp_root, prefix)
46 | while os.path.exists(exp_path):
47 | index = os.path.basename(exp_path).split(prefix)[-1].split('repeat')[-1]
48 | try:
49 | index = int(index) + 1
50 | except:
51 | index = 1
52 | exp_name = prefix + ('_repeat{}'.format(index))
53 | exp_path = os.path.join(exp_root, exp_name)
54 | # print('Experiment name {} \n'.format(os.path.basename(exp_path)))
55 | self.exp_name = exp_name
56 | self.exp_path = exp_path
57 | self.checkpoint_dir = os.path.join(exp_path, 'checkpoints')
58 | self.log_dir = os.path.join(exp_path, 'logs')
59 | self.code_dir = os.path.join(exp_path, 'code')
60 |
61 | def create_path(self):
62 | print('Create experiment directory: ', self.exp_path)
63 | os.makedirs(self.exp_path)
64 | os.makedirs(self.checkpoint_dir)
65 | os.makedirs(self.log_dir)
66 | os.makedirs(self.code_dir)
67 |
68 | cfg = Config()
--------------------------------------------------------------------------------
/config/GAIC_params.yaml:
--------------------------------------------------------------------------------
1 | alignsize: 9
2 | backbone: vgg16
3 | batch_size: 1
4 | data_augmentation: true
5 | display_freq: 200
6 | eval_freq: 1
7 | gpu_id: 0
8 | image_size:
9 | - 256
10 | - 256
11 | keep_aspect_ratio: true
12 | lr: 1.0e-4
13 | lr_decay: 0.1
14 | lr_decay_epoch: [81]
15 | max_epoch: 80
16 | num_workers: 8
17 | reddim: 32
18 | save_freq: 81
19 | weight_decay: 1.0e-4
20 |
--------------------------------------------------------------------------------
/dataset/candidate_generation.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | # this code is modified from https://github.com/HuiZeng/Grid-Anchor-based-Image-Cropping-Pytorch
4 | def generate_anchors(im_w, im_h, bins=12):
5 | assert (im_w > 100) and (im_h > 100), (im_w, im_h)
6 | step_h = im_h / bins
7 | step_w = im_w / bins
8 | pdefined_anchors = []
9 | for x1 in range(0, int(bins / 3)):
10 | for y1 in range(0, int(bins / 3)):
11 | for x2 in range(int(bins / 3 * 2), bins):
12 | for y2 in range(int(bins / 3 * 2), bins):
13 | area = (x2 - x1) * (y2 - y1) / float(bins * bins)
14 | aspect_ratio = (y2 - y1) * step_h / ((x2 - x1) * step_w)
15 | if area > 0.4999 and aspect_ratio > 0.5 and aspect_ratio < 2.0:
16 | crop_x1 = int(step_w * (0.5+x1))
17 | crop_y1 = int(step_h * (0.5+y1))
18 | crop_x2 = int(step_w * (0.5 + x2))
19 | crop_y2 = int(step_h * (0.5+y2))
20 | pdefined_anchors.append([crop_x1, crop_y1, crop_x2, crop_y2])
21 | pdefined_anchors = np.array(pdefined_anchors).reshape(-1,4)
22 | # print('image size:({},{}), obtain {} pre-defined anchors.'.format(
23 | # im_w, im_h, pdefined_anchors.shape[0]))
24 | return pdefined_anchors
25 |
26 |
27 | def generate_anchors_aspect_ratio_specific(im_w, im_h, aspect_ratio, bins=20):
28 | assert (im_w > 100) and (im_h > 100), (im_w, im_h)
29 | assert isinstance(aspect_ratio, tuple), \
30 | 'undefined aspect ratio type: {}'.format(aspect_ratio)
31 | assert aspect_ratio[0] >= 1 and aspect_ratio[1] >= 1, \
32 | 'undefined aspect ratio type: {}'.format(aspect_ratio)
33 | w_step, h_step = int(aspect_ratio[0]), int(aspect_ratio[1])
34 |
35 | max_step = int(min(im_w / w_step, im_h / h_step))
36 | # limit the search space by increasing the step size
37 | if max_step > bins:
38 | scale = int(max(im_w / w_step / bins, im_h / h_step / bins))
39 | h_step *= scale
40 | w_step *= scale
41 | max_step = int(min(im_w / w_step, im_h / h_step))
42 | # print('image_size:{}, aspect_ratio: {}, step:{}, max_steps:{}'.format(
43 | # (im_w, im_h), aspect_ratio, (w_step, h_step), max_step))
44 | min_step = int(max_step / 2. - 1)
45 | pdefined_anchors = []
46 | for i in range(min_step, max_step):
47 | out_h = h_step * i
48 | out_w = w_step * i
49 | if out_h < im_h and out_w < im_w and (out_w * out_h > 0.4 * im_w * im_h):
50 | for w_start in range(0, im_w - out_w, w_step):
51 | for h_start in range(0, im_h - out_h, h_step):
52 | x1 = int(w_start)
53 | y1 = int(h_start)
54 | x2 = int(w_start + out_w - 1)
55 | y2 = int(h_start + out_h - 1)
56 | pdefined_anchors.append([x1, y1, x2, y2])
57 | pdefined_anchors = np.array(pdefined_anchors).reshape(-1, 4)
58 | # print('aspect-ratio:{}, image size:({},{}), obtain {} pre-defined anchors.'.format(
59 | # aspect_ratio, im_w, im_h, pdefined_anchors.shape[0]))
60 | return pdefined_anchors
61 |
62 | if __name__ == '__main__':
63 | print(generate_anchors(384, 256))
64 | # print(generate_anchors_aspect_ratio_specific(128, 128, (1, 1), bins=20))
65 | # print(generate_anchors_aspect_ratio_specific(512, 512, (3, 4), bins=20))
66 | # print(generate_anchors_aspect_ratio_specific(512, 512, (16, 9), bins=20))
--------------------------------------------------------------------------------
/dataset/cropping_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from PIL import Image, ImageOps
4 | from torch.utils.data import DataLoader, Dataset
5 | import torchvision.transforms as transforms
6 | import random
7 | import cv2
8 | import json
9 | from config.GAIC_config import cfg
10 |
11 | MOS_MEAN = 2.95
12 | MOS_STD = 0.8
13 | IMAGE_NET_MEAN = [0.485, 0.456, 0.406]
14 | IMAGE_NET_STD = [0.229, 0.224, 0.225]
15 |
16 | def rescale_crops(boxes, ratio_w, ratio_h):
17 | boxes = np.array(boxes).reshape(-1, 4)
18 | boxes[:, 0] = np.floor(boxes[:, 0] * ratio_w)
19 | boxes[:, 1] = np.floor(boxes[:, 1] * ratio_h)
20 | boxes[:, 2] = np.ceil(boxes[:, 2] * ratio_w)
21 | boxes[:, 3] = np.ceil(boxes[:, 3] * ratio_h)
22 | return boxes.astype(np.float32)
23 |
24 | def is_number(s):
25 | if not isinstance(s, str):
26 | return False
27 | if s.isdigit():
28 | return True
29 | else:
30 | try:
31 | float(s)
32 | return True
33 | except:
34 | return False
35 |
36 | class GAICDataset(Dataset):
37 | def __init__(self, split):
38 | self.split = split
39 | assert self.split in ['train', 'test'], self.split
40 | self.keep_aspect = cfg.keep_aspect_ratio
41 | self.data_dir = cfg.GAIC_folder
42 | assert os.path.exists(self.data_dir), self.data_dir
43 | self.image_dir = os.path.join(self.data_dir, 'images', split)
44 | assert os.path.exists(self.image_dir), self.image_dir
45 | self.image_list = [file for file in os.listdir(self.image_dir) if file.endswith('.jpg')]
46 | # print('GAICD {} set contains {} images'.format(split, len(self.image_list)))
47 | self.anno_dir = os.path.join(self.data_dir, 'annotations')
48 | assert os.path.exists(self.anno_dir), self.anno_dir
49 | self.annos = self.parse_annotations()
50 |
51 | self.image_size = cfg.image_size
52 | self.augmentation = (cfg.data_augmentation and self.split == 'train')
53 | self.PhotometricDistort = transforms.ColorJitter(
54 | brightness=0.125, contrast=0.5, saturation=0.5, hue=0.05)
55 | self.image_transformer = transforms.Compose([
56 | transforms.ToTensor(),
57 | transforms.Normalize(mean=IMAGE_NET_MEAN, std=IMAGE_NET_STD)])
58 |
59 | def parse_annotations(self):
60 | image_annos = dict()
61 | for image_name in self.image_list:
62 | anno_file = os.path.join(self.anno_dir, image_name.replace('.jpg', '.txt'))
63 | assert os.path.exists(anno_file), anno_file
64 | with open(anno_file, 'r') as f:
65 | crops,scores = [],[]
66 | for line in f.readlines():
67 | line = line.strip().split(' ')
68 | values = [s for s in line if is_number(s)]
69 | y1,x1,y2,x2 = [int(s) for s in values[0:4]]
70 | s = float(values[-1])
71 | if s > -2:
72 | crops.append([x1,y1,x2,y2])
73 | scores.append(s)
74 | if len(crops) == 0:
75 | print(image_name, anno_file)
76 | else:
77 | # rank all crops
78 | rank = sorted(range(len(scores)), key=lambda k: scores[k], reverse=True)
79 | scores = [scores[i] for i in rank]
80 | crops = [crops[i] for i in rank]
81 | image_annos[image_name] = {'crops':crops, 'scores':scores}
82 | return image_annos
83 |
84 | def __len__(self):
85 | return len(self.image_list)
86 |
87 | def __getitem__(self, index):
88 | image_name = self.image_list[index]
89 | image_file = os.path.join(self.image_dir, image_name)
90 | image = Image.open(image_file).convert('RGB')
91 | im_width, im_height = image.size
92 | if self.keep_aspect:
93 | scale = float(cfg.image_size[0]) / min(im_height, im_width)
94 | h = round(im_height * scale / 32.0) * 32
95 | w = round(im_width * scale / 32.0) * 32
96 | else:
97 | h = cfg.image_size[1]
98 | w = cfg.image_size[0]
99 | resized_image = image.resize((w, h), Image.ANTIALIAS)
100 | crop = self.annos[image_name]['crops']
101 | rs_width, rs_height = resized_image.size
102 | ratio_w = float(rs_width) / im_width
103 | ratio_h = float(rs_height) / im_height
104 | crop = rescale_crops(crop, ratio_w, ratio_h)
105 | score = np.array(self.annos[image_name]['scores']).reshape((-1)).astype(np.float32)
106 | if self.augmentation:
107 | if random.uniform(0,1) > 0.5:
108 | resized_image = ImageOps.mirror(resized_image)
109 | temp_x1 = crop[:, 0].copy()
110 | crop[:, 0] = rs_width - crop[:, 2]
111 | crop[:, 2] = rs_width - temp_x1
112 | resized_image = self.PhotometricDistort(resized_image)
113 | im = self.image_transformer(resized_image)
114 | return im, crop, score, im_width, im_height, image_file
115 |
116 | if __name__ == '__main__':
117 | GAICD_testset = GAICDataset(split='train')
118 | print('GAICD training set has {} images'.format(len(GAICD_testset)))
119 | dataloader = DataLoader(GAICD_testset, batch_size=1, num_workers=0)
120 | for batch_idx, data in enumerate(dataloader):
121 | im, crops, scores, w, h, file = data
122 | print(im.shape, crops.shape, scores.shape, w.shape, h.shape)
--------------------------------------------------------------------------------
/evaluate/demo.py:
--------------------------------------------------------------------------------
1 | import sys,os
2 | PROJECT_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
3 | sys.path.insert(0, PROJECT_PATH)
4 | import numpy as np
5 | import torch
6 | from networks.GAIC_model import build_crop_model
7 | import argparse
8 | import cv2
9 | from PIL import Image
10 | from dataset.candidate_generation import generate_anchors, generate_anchors_aspect_ratio_specific
11 | import torchvision.transforms as transforms
12 | import matplotlib.pyplot as plt
13 | import warnings
14 | warnings.filterwarnings('ignore')
15 |
16 | IMAGE_NET_MEAN = [0.485, 0.456, 0.406]
17 | IMAGE_NET_STD = [0.229, 0.224, 0.225]
18 | image_transformer = transforms.Compose([
19 | transforms.ToTensor(),
20 | transforms.Normalize(mean=IMAGE_NET_MEAN, std=IMAGE_NET_STD)])
21 |
22 |
23 | def parse_args():
24 | parser = argparse.ArgumentParser(description="Run cropping model on images")
25 | parser.add_argument('--gpu', type=int, dest='gpu_id',
26 | help='gpu_id', default=0)
27 | parser.add_argument('--backbone', type=str, choices=['vgg16', 'mobilenetv2', 'shufflenetv2'],
28 | help='the architecture of backbone network', default='vgg16')
29 | parser.add_argument('--image_dir', type=str, default='test_images',
30 | help='the directory of test images')
31 | parser.add_argument('--save_dir', type=str, default='result_images',
32 | help='the directory of saving resulting images')
33 | args = parser.parse_args()
34 | assert os.path.exists(args.image_dir), args.image_dir
35 | os.makedirs(args.save_dir, exist_ok=True)
36 | return args
37 |
38 | def build_network(backbone):
39 | if backbone in ['vgg16','shufflenetv2']:
40 | reddim = 32
41 | elif backbone == 'mobilenetv2':
42 | reddim = 16
43 | else:
44 | raise Exception('undefined backbone architecture', backbone)
45 | net = build_crop_model(scale='multi', alignsize=9, reddim=reddim,
46 | loadweight=False, model=backbone)
47 | weights_path = 'pretrained_models/GAIC-{}-reddim{}.pth'.format(backbone, reddim)
48 | assert os.path.exists(weights_path), weights_path
49 | print('load pretrained weights from ', weights_path)
50 | net.load_state_dict(torch.load(weights_path))
51 | return net
52 |
53 | def get_image_list(args):
54 | img_list = []
55 | if os.path.isdir(args.image_dir):
56 | for file in os.listdir(args.image_dir):
57 | if file.endswith(('.jpg', '.png', '.jpeg', '.bmp')):
58 | img_list.append(os.path.join(args.image_dir, file))
59 | else:
60 | if args.image_dir.endswith(('.jpg', '.png', '.jpeg', '.bmp')):
61 | img_list.append(args.image_dir)
62 | print('find total {} images'.format(len(img_list)))
63 | return img_list
64 |
65 | def image_preprocessing(im):
66 | im_width,im_height = im.size
67 | scale = 256. / min(im_height, im_width)
68 | h = round(im_height * scale / 32.0) * 32
69 | w = round(im_width * scale / 32.0) * 32
70 | resized_image = im.resize((w, h), Image.ANTIALIAS)
71 | im_tensor = image_transformer(resized_image).unsqueeze(0)
72 | return im_tensor
73 |
74 | def predict_best_crop(model, im_tensor, anchors, im):
75 | im_width, im_height = im.size
76 | with torch.no_grad():
77 | rois = anchors.astype(np.float32)
78 | rois = torch.from_numpy(rois).unsqueeze(0).to(im_tensor.device)
79 | scores = model(im_tensor, rois)
80 | scores = scores.detach().cpu().numpy().reshape(-1)
81 | pr_idx = np.argmax(scores)
82 | # mapping the coordinates of predefined anchors to source image
83 | rescale_anchors = anchors.astype(np.float32)
84 | rescale_anchors[:,0::2] = rescale_anchors[:,0::2] / im_tensor.shape[-1] * im_width
85 | rescale_anchors[:,1::2] = rescale_anchors[:,1::2] / im_tensor.shape[-2] * im_height
86 | rescale_anchors = rescale_anchors.astype(np.int32)
87 | pr_bbox = rescale_anchors[pr_idx].tolist()
88 | x1, y1, x2, y2 = pr_bbox
89 | pr_crop = im.crop((x1, y1, x2, y2))
90 | # pr_crop = np.asarray(pr_crop)[:,:,::-1] # convert to opencv format
91 | return pr_crop, pr_bbox
92 |
93 |
94 | if __name__ == '__main__':
95 | args = parse_args()
96 | device = torch.device('cuda:{}'.format(args.gpu_id))
97 | net = build_network(args.backbone)
98 | net = net.eval().to(device)
99 | img_list = get_image_list(args)
100 |
101 | for i,img in enumerate(img_list):
102 | im_name = os.path.basename(img)
103 | src = Image.open(img).convert('RGB')
104 | src_tensor = image_preprocessing(src).to(device)
105 | input_w, input_h = src_tensor.shape[-1], src_tensor.shape[-2]
106 | # generate aspect-ratio-agnostic crops
107 | anchors = generate_anchors(input_w, input_h)
108 | best_crop, bbox = predict_best_crop(net, src_tensor, anchors, src)
109 | print('source image:{} {}, num_candidates:{}, best crop bbox:{}, crop(w,h):{}'.format(
110 | im_name, src.size, anchors.shape[0], bbox, best_crop.size))
111 |
112 | # generage aspect-ratio-specific crops
113 | anchors_1_1 = generate_anchors_aspect_ratio_specific(input_w, input_h, (1,1), bins=30)
114 | crop_1_1, bbox_1_1 = predict_best_crop(net, src_tensor, anchors_1_1, src)
115 | print('aspect_ratio=1:1, num_candidates:{}, best crop bbox:{}, crop(w,h):{}'.format(
116 | anchors_1_1.shape[0], bbox_1_1, crop_1_1.size))
117 |
118 | anchors_4_3 = generate_anchors_aspect_ratio_specific(input_w, input_h, (4,3), bins=20)
119 | crop_4_3, bbox_4_3 = predict_best_crop(net, src_tensor, anchors_4_3, src)
120 | print('aspect_ratio=4:3, num_candidates:{}, best crop bbox:{}, crop(w,h):{}'.format(
121 | anchors_4_3.shape[0], bbox_4_3, crop_4_3.size))
122 |
123 | anchors_16_9 = generate_anchors_aspect_ratio_specific(input_w, input_h, (16,9), bins=15)
124 | crop_16_9, bbox_16_9 = predict_best_crop(net, src_tensor, anchors_16_9, src)
125 | print('aspect_ratio=16:9, num_candidates:{}, best crop bbox:{}, crop(w,h):{}'.format(
126 | anchors_16_9.shape[0], bbox_16_9, crop_16_9.size))
127 |
128 | # results visualization
129 | crop_list = [src, best_crop, crop_1_1, crop_4_3, crop_16_9]
130 | title_list = ['source image', 'best crop', '1:1', '4:3', '16:9']
131 | fig_cols = 5
132 | fig_rows = (len(crop_list) + fig_cols - 1) // fig_cols
133 | fig = plt.figure(figsize=(20,5))
134 | for i in range(len(crop_list)):
135 | ax = fig.add_subplot(fig_rows, fig_cols, i+1)
136 | ax.imshow(crop_list[i])
137 | ax.set_axis_off()
138 | ax.set_title(title_list[i])
139 | fig.tight_layout()
140 | result_file = os.path.join(args.save_dir, im_name)
141 | plt.savefig(result_file)
142 | plt.close()
143 | print('Save results to ', result_file)
144 | print()
145 |
146 |
--------------------------------------------------------------------------------
/evaluate/test.py:
--------------------------------------------------------------------------------
1 | import sys, os
2 | sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
3 | import numpy as np
4 | import torch
5 | from tqdm import tqdm
6 | import pickle
7 | from scipy.stats import spearmanr, pearsonr
8 | import math
9 | from dataset.cropping_dataset import GAICDataset
10 | from config.GAIC_config import cfg
11 | from networks.GAIC_model import build_crop_model
12 | from thop import profile
13 |
14 | def compute_acc(gt_scores, pr_scores):
15 | assert (len(gt_scores) == len(pr_scores)), '{} vs. {}'.format(len(gt_scores), len(pr_scores))
16 | sample_cnt = 0
17 | acc4_5 = [0 for i in range(4)]
18 | acc4_10 = [0 for i in range(4)]
19 | for i in range(len(gt_scores)):
20 | gts, preds = gt_scores[i], pr_scores[i]
21 | id_gt = sorted(range(len(gts)), key=lambda j : gts[j], reverse=True)
22 | id_pr = sorted(range(len(preds)), key=lambda j : preds[j], reverse=True)
23 | for k in range(4):
24 | temp_acc4_5 = 0.
25 | temp_acc4_10 = 0.
26 | for j in range(k+1):
27 | if gts[id_pr[j]] >= gts[id_gt[4]]:
28 | temp_acc4_5 += 1.0
29 | if gts[id_pr[j]] >= gts[id_gt[9]]:
30 | temp_acc4_10 += 1.0
31 | acc4_5[k] += (temp_acc4_5 / (k+1.0))
32 | acc4_10[k] += ((temp_acc4_10) / (k+1.0))
33 | sample_cnt += 1
34 | acc4_5 = [round(i / sample_cnt,3) for i in acc4_5]
35 | acc4_10 = [round(i / sample_cnt,3) for i in acc4_10]
36 | # print('acc4_5', acc4_5)
37 | # print('acc4_10', acc4_10)
38 | return acc4_5, acc4_10
39 |
40 | def evaluate_on_GAICD_official(model):
41 | # https://github.com/HuiZeng/Grid-Anchor-based-Image-Cropping-Pytorch
42 | model.eval()
43 | device = next(model.parameters()).device
44 | print('='*5, 'Evaluating on GAICD dataset', '='*5)
45 | count = 0
46 | test_dataset = GAICDataset(split='test')
47 | test_loader = torch.utils.data.DataLoader(
48 | test_dataset, batch_size=1,
49 | shuffle=False, num_workers=cfg.num_workers,
50 | drop_last=False)
51 | acc4_5 = []
52 | acc4_10 = []
53 | wacc4_5 = []
54 | wacc4_10 = []
55 | srcc = []
56 | pcc = []
57 | for n in range(4):
58 | acc4_5.append(0)
59 | acc4_10.append(0)
60 | wacc4_5.append(0)
61 | wacc4_10.append(0)
62 |
63 | with torch.no_grad():
64 | for batch_idx, batch_data in enumerate(tqdm(test_loader)):
65 | im = batch_data[0].to(device)
66 | rois = batch_data[1].to(device)
67 | MOS = batch_data[2].reshape(-1,1)
68 | width = batch_data[3]
69 | height = batch_data[4]
70 | count += im.shape[0]
71 |
72 | out = model(im, rois)
73 | id_MOS = sorted(range(len(MOS)), key=lambda k: MOS[k], reverse=True)
74 | id_out = sorted(range(len(out)), key=lambda k: out[k], reverse=True)
75 |
76 | rank_of_returned_crop = []
77 | for k in range(4):
78 | rank_of_returned_crop.append(id_MOS.index(id_out[k]))
79 |
80 | for k in range(4):
81 | temp_acc_4_5 = 0.0
82 | temp_acc_4_10 = 0.0
83 | for j in range(k + 1):
84 | if MOS[id_out[j]] >= MOS[id_MOS[4]]:
85 | temp_acc_4_5 += 1.0
86 | if MOS[id_out[j]] >= MOS[id_MOS[9]]:
87 | temp_acc_4_10 += 1.0
88 | acc4_5[k] += temp_acc_4_5 / (k + 1.0)
89 | acc4_10[k] += temp_acc_4_10 / (k + 1.0)
90 |
91 | for k in range(4):
92 | temp_wacc_4_5 = 0.0
93 | temp_wacc_4_10 = 0.0
94 | temp_rank_of_returned_crop = rank_of_returned_crop[:(k + 1)]
95 | temp_rank_of_returned_crop.sort()
96 | for j in range(k + 1):
97 | if temp_rank_of_returned_crop[j] <= 4:
98 | temp_wacc_4_5 += 1.0 * math.exp(-0.2 * (temp_rank_of_returned_crop[j] - j))
99 | if temp_rank_of_returned_crop[j] <= 9:
100 | temp_wacc_4_10 += 1.0 * math.exp(-0.1 * (temp_rank_of_returned_crop[j] - j))
101 | wacc4_5[k] += temp_wacc_4_5 / (k + 1.0)
102 | wacc4_10[k] += temp_wacc_4_10 / (k + 1.0)
103 |
104 | MOS_arr = []
105 | out = torch.squeeze(out).cpu().detach().numpy()
106 | for k in range(len(MOS)):
107 | MOS_arr.append(MOS[k].numpy()[0])
108 | srcc.append(spearmanr(MOS_arr, out)[0])
109 | pcc.append(pearsonr(MOS_arr, out)[0])
110 |
111 | for k in range(4):
112 | acc4_5[k] = acc4_5[k] / count
113 | acc4_10[k] = acc4_10[k] / count
114 | wacc4_5[k] = wacc4_5[k] / count
115 | wacc4_10[k] = wacc4_10[k] / count
116 |
117 | avg_srcc = sum(srcc) / count
118 | avg_pcc = sum(pcc) / count
119 | avg_acc5 = sum(acc4_5) / len(acc4_5)
120 | avg_acc10 = sum(acc4_10) / len(acc4_10)
121 |
122 | sys.stdout.write('Acc4_5:[%.3f, %.3f, %.3f, %.3f] Acc4_10:[%.3f, %.3f, %.3f, %.3f]\n' % (
123 | acc4_5[0], acc4_5[1], acc4_5[2], acc4_5[3], acc4_10[0], acc4_10[1], acc4_10[2], acc4_10[3]))
124 | sys.stdout.write('WAcc4_5:[%.3f, %.3f, %.3f, %.3f] WAcc4_10:[%.3f, %.3f, %.3f, %.3f]\n' % (
125 | wacc4_5[0], wacc4_5[1], wacc4_5[2], wacc4_5[3], wacc4_10[0], wacc4_10[1], wacc4_10[2], wacc4_10[3]))
126 | sys.stdout.write('[Avg SRCC: %.3f] [Avg PCC: %.3f] [Acc5: %.3f] [Acc10: %.3f]\n' % (
127 | avg_srcc, avg_pcc, avg_acc5, avg_acc10))
128 | return avg_srcc, avg_pcc, avg_acc5, avg_acc10, acc4_5, acc4_10
129 |
130 | def evaluate_on_GAICD(model):
131 | device = next(model.parameters()).device
132 | model.eval()
133 | print('='*5, 'Evaluating on GAICD dataset', '='*5)
134 | srcc_list = []
135 | pcc_list = []
136 | gt_scores = []
137 | pr_scores = []
138 | count = 0
139 | test_dataset = GAICDataset(split='test')
140 | test_loader = torch.utils.data.DataLoader(
141 | test_dataset, batch_size=1,
142 | shuffle=False, num_workers=cfg.num_workers,
143 | drop_last=False)
144 | test_results = dict()
145 | with torch.no_grad():
146 | for batch_idx, batch_data in enumerate(tqdm(test_loader)):
147 | im = batch_data[0].to(device)
148 | rois = batch_data[1].to(device)
149 | scores = batch_data[2].cpu().numpy().reshape(-1)
150 | width = batch_data[3]
151 | height = batch_data[4]
152 | image_name = batch_data[5][0]
153 | count += im.shape[0]
154 |
155 | pre_scores = model(im, rois)
156 | pre_scores = pre_scores.cpu().detach().numpy().reshape(-1)
157 | srcc_list.append(spearmanr(scores, pre_scores)[0])
158 | pcc_list.append(pearsonr(scores, pre_scores)[0])
159 | gt_scores.append(scores)
160 | pr_scores.append(pre_scores)
161 | test_results[image_name] = pre_scores.tolist()
162 | avg_srcc = sum(srcc_list) / len(srcc_list)
163 | avg_pcc = sum(pcc_list) / len(pcc_list)
164 | acc4_5, acc4_10 = compute_acc(gt_scores, pr_scores)
165 | avg_acc5 = sum(acc4_5) / len(acc4_5)
166 | avg_acc10 = sum(acc4_10) / len(acc4_10)
167 | sys.stdout.write('Acc4_5:[%.3f, %.3f, %.3f, %.3f] Acc4_10:[%.3f, %.3f, %.3f, %.3f]\n' % (
168 | acc4_5[0], acc4_5[1], acc4_5[2], acc4_5[3], acc4_10[0], acc4_10[1], acc4_10[2], acc4_10[3]))
169 | sys.stdout.write('[Avg SRCC: %.3f] [Avg PCC: %.3f] [Acc5: %.3f] [Acc10: %.3f]\n' % (
170 | avg_srcc, avg_pcc, avg_acc5, avg_acc10))
171 | return avg_srcc, avg_pcc, avg_acc5, avg_acc10, acc4_5, acc4_10
172 |
173 |
174 | if __name__ == '__main__':
175 | device = torch.device('cuda:{}'.format(cfg.gpu_id))
176 | torch.cuda.set_device(device)
177 | backbone, reddim = 'vgg16', 32
178 | pretrained_weight = 'pretrained_models/GAIC-{}-reddim{}.pth'.format(backbone, reddim)
179 | model = build_crop_model(scale='multi', alignsize=9, reddim=32,
180 | loadweight=False, model=backbone)
181 | model = model.eval().to(device)
182 | print('load pretrained weights from ', pretrained_weight)
183 | model.load_state_dict(torch.load(pretrained_weight), strict=False)
184 | evaluate_on_GAICD_official(model)
185 |
186 | # roi = torch.tensor([[0, 0, 128, 128], [64, 64, 223, 223]]).float()
187 | # roi = roi.unsqueeze(0).to(device)
188 | # img = torch.randn((1, 3, 256, 256)).to(device)
189 | # flops, params = profile(model, inputs=(img, roi))
190 | # print("params: %.2fMB flops: %.2fG" % (params / (1000 ** 2), flops / (1000 ** 3)))
--------------------------------------------------------------------------------
/networks/GAIC_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torchvision.models as models
4 | import torch.nn.functional as F
5 | from untils.roi_align.modules.roi_align import RoIAlignAvg, RoIAlign
6 | from untils.rod_align.modules.rod_align import RoDAlignAvg, RoDAlign
7 | from torchvision.models.mobilenetv2 import mobilenet_v2 as MobileNetV2
8 | from torchvision.models.shufflenetv2 import shufflenet_v2_x1_0 as ShuffleNetV2
9 | import torch.nn.init as init
10 | from thop import profile
11 | import warnings
12 | warnings.filterwarnings('ignore')
13 |
14 | class vgg_base(nn.Module):
15 | def __init__(self, loadweights=True):
16 | super(vgg_base, self).__init__()
17 |
18 | vgg = models.vgg16(pretrained=loadweights)
19 | self.feature3 = nn.Sequential(vgg.features[:23])
20 | self.feature4 = nn.Sequential(vgg.features[23:30])
21 | self.feature5 = nn.Sequential(vgg.features[30:])
22 |
23 | # img = torch.randn((1, 3, 256, 256))
24 | # flops, params = profile(vgg.features[:-1], inputs=(img,))
25 | # print("params: %.2fMB flops: %.2fG" % (params / (1000 ** 2), flops / (1000 ** 3)))
26 | # params: 14.71MB flops: 20.06G
27 |
28 | def forward(self, x):
29 | #return self.feature(x)
30 | f3 = self.feature3(x)
31 | f4 = self.feature4(f3)
32 | f5 = self.feature5(f4)
33 | return f3, f4, f5
34 |
35 | class resnet50_base(nn.Module):
36 | def __init__(self, loadweights=True):
37 | super(resnet50_base, self).__init__()
38 |
39 | resnet50 = models.resnet50(pretrained=True)
40 |
41 | self.feature3 = nn.Sequential(resnet50.conv1, resnet50.bn1,
42 | resnet50.relu,resnet50.maxpool,
43 | resnet50.layer1,resnet50.layer2)
44 | self.feature4 = nn.Sequential(resnet50.layer3)
45 | self.feature5 = nn.Sequential(resnet50.layer4)
46 |
47 | #flops, params = profile(self.feature, input_size=(1, 3, 256,256))
48 |
49 | def forward(self, x):
50 | #return self.feature(x)
51 | f3 = self.feature3(x)
52 | f4 = self.feature4(f3)
53 | f5 = self.feature5(f4)
54 | return f3, f4, f5
55 |
56 |
57 | class mobilenetv2_base(nn.Module):
58 |
59 | def __init__(self, loadweights=True):
60 | super(mobilenetv2_base, self).__init__()
61 |
62 | model = MobileNetV2(pretrained=loadweights)
63 |
64 | self.feature3 = nn.Sequential(model.features[:7])
65 | self.feature4 = nn.Sequential(model.features[7:14])
66 | self.feature5 = nn.Sequential(model.features[14:-1])
67 | #flops, params = profile(self.feature, input_size=(1, 3, 256,256))
68 |
69 | def forward(self, x):
70 | #return self.feature(x)
71 | f3 = self.feature3(x)
72 | f4 = self.feature4(f3)
73 | f5 = self.feature5(f4)
74 | return f3, f4, f5
75 |
76 |
77 | class shufflenetv2_base(nn.Module):
78 |
79 | def __init__(self, loadweights=True):
80 | super(shufflenetv2_base, self).__init__()
81 |
82 | model = ShuffleNetV2(pretrained=loadweights)
83 |
84 | self.feature3 = nn.Sequential(model.conv1, model.maxpool, model.stage2)
85 | self.feature4 = nn.Sequential(model.stage3)
86 | self.feature5 = nn.Sequential(model.stage4)
87 | #flops, params = profile(self.feature, input_size=(1, 3, 256,256))
88 |
89 | def forward(self, x):
90 | #return self.feature(x)
91 | f3 = self.feature3(x)
92 | f4 = self.feature4(f3)
93 | f5 = self.feature5(f4)
94 | return f3, f4, f5
95 |
96 | def fc_layers(reddim=32, alignsize=8):
97 | conv1 = nn.Sequential(nn.Conv2d(reddim, 768, kernel_size=alignsize, padding=0),
98 | nn.ReLU(inplace=True))
99 | conv2 = nn.Sequential(nn.Conv2d(768, 128, kernel_size=1),
100 | nn.ReLU(inplace=True))
101 | conv3 = nn.Conv2d(128, 1, kernel_size=1)
102 | layers = nn.Sequential(conv1, conv2, conv3)
103 | return layers
104 |
105 | class crop_model_single_scale(nn.Module):
106 |
107 | def __init__(self, alignsize = 8, reddim = 8, loadweight = True, model = None):
108 | super(crop_model_single_scale, self).__init__()
109 |
110 | if model == 'shufflenetv2':
111 | self.Feat_ext = shufflenetv2_base(loadweight)
112 | self.DimRed = nn.Conv2d(232, reddim, kernel_size=1, padding=0)
113 | elif model == 'mobilenetv2':
114 | self.Feat_ext = mobilenetv2_base(loadweight)
115 | self.DimRed = nn.Conv2d(96, reddim, kernel_size=1, padding=0)
116 | elif model == 'vgg16':
117 | self.Feat_ext = vgg_base(loadweight)
118 | self.DimRed = nn.Conv2d(512, reddim, kernel_size=1, padding=0)
119 | elif model == 'resnet50':
120 | self.Feat_ext = resnet50_base(loadweight)
121 | self.DimRed = nn.Conv2d(1024, reddim, kernel_size=1, padding=0)
122 | downsample = 4
123 | self.RoIAlign = RoIAlignAvg(alignsize, alignsize, 1.0/2**downsample)
124 | self.RoDAlign = RoDAlignAvg(alignsize, alignsize, 1.0/2**downsample)
125 | self.FC_layers = fc_layers(reddim*2, alignsize)
126 |
127 | #flops, params = profile(self.FC_layers, input_size=(1,reddim*2,9,9))
128 |
129 | def forward(self, im_data, boxes):
130 |
131 | f3,base_feat,f5 = self.Feat_ext(im_data)
132 | red_feat = self.DimRed(base_feat)
133 | RoI_feat = self.RoIAlign(red_feat, boxes)
134 | RoD_feat = self.RoDAlign(red_feat, boxes)
135 | final_feat = torch.cat((RoI_feat, RoD_feat), 1)
136 | prediction = self.FC_layers(final_feat)
137 | return prediction
138 |
139 | def _init_weights(self):
140 | print('Initializing weights...')
141 | self.DimRed.apply(weights_init)
142 | self.FC_layers.apply(weights_init)
143 |
144 | class crop_model_multi_scale_shared(nn.Module):
145 |
146 | def __init__(self, alignsize = 8, reddim = 32, loadweight = True, model = None, ):
147 | super(crop_model_multi_scale_shared, self).__init__()
148 | downsample = 4
149 | if model == 'shufflenetv2':
150 | self.Feat_ext = shufflenetv2_base(loadweight)
151 | self.DimRed = nn.Conv2d(812, reddim, kernel_size=1, padding=0)
152 | elif model == 'mobilenetv2':
153 | self.Feat_ext = mobilenetv2_base(loadweight)
154 | self.DimRed = nn.Conv2d(448, reddim, kernel_size=1, padding=0)
155 | elif model == 'vgg16':
156 | self.Feat_ext = vgg_base(loadweight)
157 | self.DimRed = nn.Conv2d(1536, reddim, kernel_size=1, padding=0)
158 | elif model == 'resnet50':
159 | self.Feat_ext = resnet50_base(loadweight)
160 | self.DimRed = nn.Conv2d(3584, reddim, kernel_size=1, padding=0)
161 |
162 | self.downsample2 = nn.UpsamplingBilinear2d(scale_factor=1.0/2.0)
163 | self.upsample2 = nn.UpsamplingBilinear2d(scale_factor=2.0)
164 | self.RoIAlign = RoIAlignAvg(alignsize, alignsize, 1.0/2**downsample)
165 | self.RoDAlign = RoDAlignAvg(alignsize, alignsize, 1.0/2**downsample)
166 | self.FC_layers = fc_layers(reddim*2, alignsize)
167 |
168 | def forward(self, im_data, boxes):
169 | # print(im_data.shape, im_data.dtype, im_data.device, boxes.shape, boxes.dtype, boxes.device)
170 | B, N, _ = boxes.shape
171 | if boxes.shape[-1] == 4:
172 | index = torch.arange(B).view(-1, 1).repeat(1, N).reshape(B, N, 1).to(boxes.device)
173 | boxes = torch.cat((index, boxes),dim=-1).contiguous()
174 | if boxes.dim() == 3:
175 | boxes = boxes.view(-1,5)
176 |
177 | f3,f4,f5 = self.Feat_ext(im_data)
178 | f3 = F.interpolate(f3, size=f4.shape[2:], mode='bilinear', align_corners=True)
179 | f5 = F.interpolate(f5, size=f4.shape[2:], mode='bilinear', align_corners=True)
180 | cat_feat = torch.cat((f3,f4,0.5*f5),1)
181 |
182 | red_feat = self.DimRed(cat_feat)
183 | RoI_feat = self.RoIAlign(red_feat, boxes)
184 | RoD_feat = self.RoDAlign(red_feat, boxes)
185 |
186 | final_feat = torch.cat((RoI_feat, RoD_feat), 1)
187 | prediction = self.FC_layers(final_feat)
188 | return prediction
189 |
190 | def _init_weights(self):
191 | print('Initializing weights...')
192 | self.DimRed.apply(weights_init)
193 | self.FC_layers.apply(weights_init)
194 |
195 | def cropping_rank_loss(pre_score, gt_score):
196 | '''
197 | :param pre_score:
198 | :param gt_score:
199 | :return:
200 | '''
201 | if pre_score.dim() > 1:
202 | pre_score = pre_score.reshape(-1)
203 | if gt_score.dim() > 1:
204 | gt_score = gt_score.reshape(-1)
205 | assert pre_score.shape == gt_score.shape, '{} vs. {}'.format(pre_score.shape, gt_score.shape)
206 | N = pre_score.shape[0]
207 | pair_num = N * (N-1) / 2
208 | pre_diff = pre_score[:,None] - pre_score[None,:]
209 | gt_diff = gt_score[:,None] - gt_score[None,:]
210 | indicat = -1 * torch.sign(gt_diff) * (pre_diff - gt_diff)
211 | diff = torch.maximum(indicat, torch.zeros_like(indicat))
212 | rank_loss= torch.sum(diff) / pair_num
213 | return rank_loss
214 |
215 | def xavier(param):
216 | init.xavier_uniform_(param)
217 |
218 | def weights_init(m):
219 | if isinstance(m, nn.Conv2d):
220 | xavier(m.weight.data)
221 | m.bias.data.zero_()
222 |
223 |
224 | def build_crop_model(scale='single', alignsize=9, reddim=32, loadweight=True, model=None):
225 | if scale=='single':
226 | return crop_model_single_scale(alignsize, reddim, loadweight, model)
227 | elif scale=='multi':
228 | return crop_model_multi_scale_shared(alignsize, reddim, loadweight, model)
229 |
230 |
231 | if __name__ == '__main__':
232 | net = build_crop_model(scale='multi', alignsize=9,
233 | reddim=32, loadweight=True,
234 | model='vgg16')
235 | net = net.eval().cuda()
236 | roi = torch.tensor([[0, 0, 128, 128], [64, 64, 223, 223]]).float()
237 | roi = roi.unsqueeze(0).cuda()
238 | roi = roi.repeat(2,1,1)
239 | img = torch.randn((2, 3, 224, 224)).cuda()
240 | out = net(img, roi)
241 | print(out.shape, out)
242 |
243 |
244 |
245 |
--------------------------------------------------------------------------------
/networks/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bo-zhang-cs/GAIC-Pytorch/877d6a325943c0a098324dc6d4a975ff4097c3f6/networks/__init__.py
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | matplotlib==3.2.2
2 | numpy==1.19.1
3 | opencv_python==4.5.4.60
4 | Pillow==9.0.1
5 | PyYAML==6.0
6 | scipy==1.5.2
7 | setuptools==52.0.0.post20210125
8 | tensorboardX==2.4.1
9 | thop==0.0.31.post2005241907
10 | torch==1.9.0+cu111
11 | torchvision==0.10.0+cu111
12 | tqdm==4.51.0
13 |
--------------------------------------------------------------------------------
/result_images/211958.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bo-zhang-cs/GAIC-Pytorch/877d6a325943c0a098324dc6d4a975ff4097c3f6/result_images/211958.jpg
--------------------------------------------------------------------------------
/result_images/265813.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bo-zhang-cs/GAIC-Pytorch/877d6a325943c0a098324dc6d4a975ff4097c3f6/result_images/265813.jpg
--------------------------------------------------------------------------------
/result_images/297406.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bo-zhang-cs/GAIC-Pytorch/877d6a325943c0a098324dc6d4a975ff4097c3f6/result_images/297406.jpg
--------------------------------------------------------------------------------
/test_images/211958.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bo-zhang-cs/GAIC-Pytorch/877d6a325943c0a098324dc6d4a975ff4097c3f6/test_images/211958.jpg
--------------------------------------------------------------------------------
/test_images/265813.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bo-zhang-cs/GAIC-Pytorch/877d6a325943c0a098324dc6d4a975ff4097c3f6/test_images/265813.jpg
--------------------------------------------------------------------------------
/test_images/297406.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bo-zhang-cs/GAIC-Pytorch/877d6a325943c0a098324dc6d4a975ff4097c3f6/test_images/297406.jpg
--------------------------------------------------------------------------------
/train/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
4 | from tensorboardX import SummaryWriter
5 | import torch
6 | import time
7 | import datetime
8 | import csv
9 | import random
10 | import shutil
11 | from networks.GAIC_model import build_crop_model
12 | from dataset.cropping_dataset import GAICDataset
13 | from config.GAIC_config import cfg, refresh_yaml_params
14 | from evaluate.test import evaluate_on_GAICD_official as evaluate_on_GAICD
15 | import argparse
16 | import warnings
17 | warnings.filterwarnings('ignore')
18 |
19 | def create_dataloader():
20 | dataset = GAICDataset(split='train')
21 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=cfg.batch_size,
22 | shuffle=True, num_workers=cfg.num_workers,
23 | drop_last=False, pin_memory=False)
24 | print('training set has {} samples, {} batches'.format(len(dataset), len(dataloader)))
25 | return dataloader
26 |
27 |
28 | class Trainer:
29 | def __init__(self, model):
30 | self.model = model
31 | self.epoch = 0
32 | self.iters = 0
33 | self.max_epoch = cfg.max_epoch
34 | self.writer = SummaryWriter(log_dir=cfg.log_dir)
35 | self.optimizer, self.lr_scheduler = self.get_optimizer()
36 | self.train_loader = create_dataloader()
37 | self.eval_results = []
38 | self.best_results = {'acc1_5':0., 'acc2_5':0., 'acc3_5':0., 'acc4_5':0., 'acc5':0.,
39 | 'acc1_10':0., 'acc2_10':0., 'acc3_10':0, 'acc4_10':0, 'acc10':0.,
40 | 'srcc':0., 'pcc':0.}
41 | self.criterion = torch.nn.SmoothL1Loss(reduction='mean')
42 | self.contain_BN = False
43 | for name,m in self.model.Feat_ext.named_modules():
44 | if isinstance(m, torch.nn.BatchNorm2d):
45 | self.contain_BN = True
46 | break
47 |
48 | def get_optimizer(self):
49 | optim = torch.optim.Adam(
50 | self.model.parameters(),
51 | lr=cfg.lr
52 | )
53 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
54 | optim, milestones=cfg.lr_decay_epoch, gamma=cfg.lr_decay
55 | )
56 | return optim, lr_scheduler
57 |
58 | def run(self):
59 | print(("======== Begin Training ========="))
60 | for epoch in range(self.max_epoch):
61 | self.epoch = epoch
62 | self.train()
63 | if epoch % cfg.eval_freq == 0:
64 | self.eval()
65 | self.record_eval_results()
66 | self.lr_scheduler.step()
67 |
68 | def train(self):
69 | self.model.train()
70 | if self.contain_BN:
71 | self.model.Feat_ext.eval()
72 | device = next(self.model.parameters()).device
73 | start = time.time()
74 | batch_loss = 0
75 | total_batch = len(self.train_loader)
76 | total_loss = 0
77 | for batch_idx, batch_data in enumerate(self.train_loader):
78 | self.iters += 1
79 | im = batch_data[0].to(device)
80 | rois = batch_data[1].to(device)
81 | scores = batch_data[2].to(device)
82 | width = batch_data[3].to(device)
83 | height = batch_data[4].to(device)
84 |
85 | random_ID = list(range(0, rois.shape[1]))
86 | random.shuffle(random_ID)
87 | chosen_ID = random_ID[:64]
88 |
89 | rois = rois[:,chosen_ID]
90 | scores = scores[:, chosen_ID]
91 |
92 | pred_scores = self.model(im, rois)
93 | loss = self.criterion(pred_scores.squeeze(), scores.squeeze())
94 | batch_loss += loss.item()
95 | self.optimizer.zero_grad()
96 | loss.backward()
97 | self.optimizer.step()
98 |
99 | if batch_idx % cfg.display_freq == 0:
100 | avg_loss = batch_loss / (1 + batch_idx)
101 | cur_lr = self.optimizer.param_groups[0]['lr']
102 | self.writer.add_scalar('train/loss', avg_loss, self.iters)
103 | self.writer.add_scalar('train/lr', cur_lr, self.iters)
104 |
105 | time_per_batch = (time.time() - start) / (batch_idx + 1.)
106 | last_batches = (self.max_epoch - self.epoch - 1) * total_batch + (total_batch - batch_idx - 1)
107 | last_time = int(last_batches * time_per_batch)
108 | time_str = str(datetime.timedelta(seconds=last_time))
109 |
110 | print('=== epoch:{}/{}, step:{}/{} | Loss:{:.4f} | lr:{:.6f} | estimated remaining time:{} ==='.format(
111 | self.epoch, self.max_epoch, batch_idx, total_batch, avg_loss, cur_lr, time_str
112 | ))
113 |
114 | def eval(self):
115 | self.model.eval()
116 | avg_srcc, avg_pcc, avg_acc5, avg_acc10, acc4_5, acc4_10 = evaluate_on_GAICD(self.model)
117 | self.eval_results.append([self.epoch, avg_srcc, avg_pcc, avg_acc5, avg_acc10,
118 | acc4_5[0], acc4_5[1], acc4_5[2], acc4_5[3],
119 | acc4_10[0], acc4_10[1], acc4_10[2], acc4_10[3]])
120 | epoch_result = {'srcc': avg_srcc, 'pcc': avg_pcc, 'acc5': avg_acc5, 'acc10': avg_acc10,
121 | 'acc1_5': acc4_5[0], 'acc2_5': acc4_5[1], 'acc3_5': acc4_5[2], 'acc4_5': acc4_5[3],
122 | 'acc1_10': acc4_10[0], 'acc2_10': acc4_10[1], 'acc3_10': acc4_10[2], 'acc4_10': acc4_10[3]}
123 | for m in epoch_result.keys():
124 | update = False
125 | if (epoch_result[m] > self.best_results[m]):
126 | update = True
127 | if update:
128 | self.best_results[m] = epoch_result[m]
129 | checkpoint_path = os.path.join(cfg.checkpoint_dir, 'best-{}.pth'.format(m))
130 | torch.save(self.model.state_dict(), checkpoint_path)
131 | print('Update best {} model, best {}={:.4f}'.format(m, m, self.best_results[m]))
132 | if m in ['srcc', 'acc5']:
133 | self.writer.add_scalar('test/{}'.format(m), epoch_result[m], self.epoch)
134 | self.writer.add_scalar('test/best-{}'.format(m), self.best_results[m], self.epoch)
135 | if self.epoch % cfg.save_freq == 0:
136 | checkpoint_path = os.path.join(cfg.checkpoint_dir, 'epoch-{}.pth'.format(self.epoch))
137 | torch.save(self.model.state_dict(), checkpoint_path)
138 |
139 | def record_eval_results(self):
140 | csv_path = os.path.join(cfg.exp_path, '..', '{}.csv'.format(cfg.exp_name))
141 | header = ['epoch', 'srcc', 'pcc', 'acc5', 'acc10',
142 | 'acc1_5', 'acc2_5', 'acc3_5', 'acc4_5',
143 | 'acc1_10', 'acc2_10', 'acc3_10', 'acc4_10']
144 | # Limit the number of decimal places in the result
145 | limit_results = []
146 | for epoch, result in enumerate(self.eval_results):
147 | limit_results.append([])
148 | for i,r in enumerate(result):
149 | if i == 0: # epoch
150 | limit_results[epoch].append(r)
151 | else:
152 | limit_results[epoch].append(round(r, 3))
153 | # find the best results
154 | rows = [header] + limit_results
155 | metrics = [[] for i in header]
156 | for result in limit_results:
157 | for i, r in enumerate(result):
158 | metrics[i].append(r)
159 | for name, m in zip(header, metrics):
160 | if name == 'epoch':
161 | continue
162 | index = m.index(max(m))
163 | title = 'best {}(epoch-{})'.format(name, index)
164 | row = [l[index] for l in metrics]
165 | row[0] = title
166 | rows.append(row)
167 | with open(csv_path, 'w') as f:
168 | cw = csv.writer(f)
169 | cw.writerows(rows)
170 | print('Save result to ', csv_path)
171 |
172 | def parse_args():
173 | parser = argparse.ArgumentParser(description="Train a GAIC model")
174 | parser.add_argument('--gpu', type=int, dest='gpu_id',
175 | help='gpu_id', default=0)
176 | parser.add_argument('--backbone', type=str, choices=['vgg16', 'mobilenetv2', 'resnet50', 'shufflenetv2'],
177 | help='the architecture of backbone network', default='vgg16')
178 | parser.add_argument('--reddim', type=int, choices=[64, 32, 16, 8],
179 | help='the reduced channel dimension of the feature map', default=32)
180 | parser.add_argument('--alignsize', type=int, choices=[3, 5, 9],
181 | help='RoIAlign and RoDAlign output size', default=9)
182 | parser.add_argument('--num_workers', type=int, help='number of dataloader workers', default=8)
183 | args = parser.parse_args()
184 | refresh_yaml_params(args)
185 |
186 |
187 | if __name__ == '__main__':
188 | parse_args()
189 | cfg.refresh_params()
190 | cfg.create_path()
191 | device = torch.device('cuda:{}'.format(cfg.gpu_id))
192 | torch.cuda.set_device(device)
193 | for file in ['config/GAIC_config.py', 'config/GAIC_params.yaml', 'dataset/cropping_dataset.py',
194 | 'evaluate/test.py', 'networks/GAIC_model.py', 'train/train.py']:
195 | if not os.path.exists(file):
196 | file = os.path.join('..', file)
197 | shutil.copy(file, cfg.code_dir)
198 | print('backup', file)
199 | net = build_crop_model(scale='multi', alignsize=cfg.alignsize, reddim=cfg.reddim,
200 | loadweight=True, model=cfg.backbone)
201 | net = net.to(device)
202 | trainer = Trainer(net)
203 | trainer.run()
204 |
--------------------------------------------------------------------------------
/untils/make_all.sh:
--------------------------------------------------------------------------------
1 | cd ./roi_align
2 | bash make.sh
3 |
4 | cd ../rod_align
5 | bash make.sh
6 |
7 | cd ..
8 |
--------------------------------------------------------------------------------
/untils/rod_align/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bo-zhang-cs/GAIC-Pytorch/877d6a325943c0a098324dc6d4a975ff4097c3f6/untils/rod_align/__init__.py
--------------------------------------------------------------------------------
/untils/rod_align/functions/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bo-zhang-cs/GAIC-Pytorch/877d6a325943c0a098324dc6d4a975ff4097c3f6/untils/rod_align/functions/__init__.py
--------------------------------------------------------------------------------
/untils/rod_align/functions/rod_align.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Function
3 | import rod_align_api
4 |
5 | class RoDAlignFunction(Function):
6 | @staticmethod
7 | def forward(ctx, features, rois, aligned_width, aligned_height, spatial_scale):
8 | batch_size, num_channels, data_height, data_width = features.size()
9 | ctx.save_for_backward(rois,
10 | torch.IntTensor([int(batch_size),
11 | int(num_channels),
12 | int(data_height),
13 | int(data_width),
14 | int(aligned_width),
15 | int(aligned_height)]),
16 | torch.FloatTensor([float(spatial_scale)]))
17 |
18 | num_rois = rois.size(0)
19 |
20 | output = features.new(num_rois,
21 | num_channels,
22 | int(aligned_height),
23 | int(aligned_width)).zero_()
24 |
25 | rod_align_api.forward(int(aligned_height),
26 | int(aligned_width),
27 | float(spatial_scale),
28 | features,
29 | rois, output)
30 |
31 | return output
32 |
33 | @staticmethod
34 | def backward(ctx, grad_output):
35 | rois, core_size, scale = ctx.saved_tensors
36 | spatial_scale = scale[0]
37 |
38 | batch_size, num_channels, data_height, data_width, aligned_width, aligned_height = core_size
39 |
40 | grad_input = rois.new(batch_size,
41 | num_channels,
42 | data_height,
43 | data_width).zero_()
44 |
45 | rod_align_api.backward(int(aligned_height),
46 | int(aligned_width),
47 | float(spatial_scale),
48 | grad_output,
49 | rois,
50 | grad_input)
51 |
52 | return grad_input, None, None, None, None
53 |
--------------------------------------------------------------------------------
/untils/rod_align/make.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | cd src
3 | echo "Compiling rod_align kernels by nvcc..."
4 |
5 | # Specify the architecture of your NV card below.
6 | # -arch=sm_75 is compatible with the following NV GPU cards,
7 | # GeForce RTX 2080 Ti, RTX 2080, RTX 2070 Quadro RTX 8000, Quadro RTX 6000, Quadro RTX 5000 Tesla T4
8 | # See more at https://raw.githubusercontent.com/stereolabs/zed-yolo/master/libdarknet/Makefile
9 | nvcc -c -o rod_align_kernel.cu.o rod_align_kernel.cu -x cu -Xcompiler -fPIC -arch=sm_86
10 |
11 | cd ../
12 | # Export CUDA_HOME. And build and install the library.
13 | export CUDA_HOME=/usr/local/cuda && python3 setup.py install
14 |
15 |
--------------------------------------------------------------------------------
/untils/rod_align/make_python2.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | cd src
3 | echo "Compiling rod_align kernels by nvcc..."
4 |
5 | # Specify the architecture of your NV card below.
6 | # -arch=sm_75 is compatible with the following NV GPU cards,
7 | # GeForce RTX 2080 Ti, RTX 2080, RTX 2070 Quadro RTX 8000, Quadro RTX 6000, Quadro RTX 5000 Tesla T4
8 | # See more at https://raw.githubusercontent.com/stereolabs/zed-yolo/master/libdarknet/Makefile
9 | nvcc -c -o rod_align_kernel.cu.o rod_align_kernel.cu -x cu -Xcompiler -fPIC -arch=sm_75
10 |
11 | cd ../
12 | # Export CUDA_HOME. And build and install the library.
13 | export CUDA_HOME=/usr/local/cuda-10.0 && python setup.py install
14 |
15 |
--------------------------------------------------------------------------------
/untils/rod_align/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bo-zhang-cs/GAIC-Pytorch/877d6a325943c0a098324dc6d4a975ff4097c3f6/untils/rod_align/modules/__init__.py
--------------------------------------------------------------------------------
/untils/rod_align/modules/rod_align.py:
--------------------------------------------------------------------------------
1 | from torch.nn.modules.module import Module
2 | from torch.nn.functional import avg_pool2d, max_pool2d
3 | from ..functions.rod_align import RoDAlignFunction
4 |
5 |
6 | class RoDAlign(Module):
7 | def __init__(self, aligned_height, aligned_width, spatial_scale):
8 | super(RoDAlign, self).__init__()
9 |
10 | self.aligned_width = int(aligned_width)
11 | self.aligned_height = int(aligned_height)
12 | self.spatial_scale = float(spatial_scale)
13 |
14 | def forward(self, features, rois):
15 | return RoDAlignFunction.apply(features,
16 | rois,
17 | self.aligned_height,
18 | self.aligned_width,
19 | self.spatial_scale)
20 |
21 | class RoDAlignAvg(Module):
22 | def __init__(self, aligned_height, aligned_width, spatial_scale):
23 | super(RoDAlignAvg, self).__init__()
24 |
25 | self.aligned_width = int(aligned_width)
26 | self.aligned_height = int(aligned_height)
27 | self.spatial_scale = float(spatial_scale)
28 |
29 | def forward(self, features, rois):
30 | x = RoDAlignFunction.apply(features,
31 | rois,
32 | self.aligned_height+1,
33 | self.aligned_width+1,
34 | self.spatial_scale)
35 | return avg_pool2d(x, kernel_size=2, stride=1)
36 |
37 | class RoDAlignMax(Module):
38 | def __init__(self, aligned_height, aligned_width, spatial_scale):
39 | super(RoDAlignMax, self).__init__()
40 |
41 | self.aligned_width = int(aligned_width)
42 | self.aligned_height = int(aligned_height)
43 | self.spatial_scale = float(spatial_scale)
44 |
45 | def forward(self, features, rois):
46 | x = RoDAlignFunction.apply(features,
47 | rois,
48 | self.aligned_height+1,
49 | self.aligned_width+1,
50 | self.spatial_scale)
51 | return max_pool2d(x, kernel_size=2, stride=1)
52 |
--------------------------------------------------------------------------------
/untils/rod_align/setup.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import os
3 | import torch
4 | from pkg_resources import parse_version
5 |
6 | min_version = parse_version('1.0.0')
7 | current_version = parse_version(torch.__version__)
8 |
9 |
10 | if current_version < min_version: #PyTorch before 1.0
11 | from torch.utils.ffi import create_extension
12 |
13 | sources = ['src/roi_align.c']
14 | headers = ['src/roi_align.h']
15 | extra_objects = []
16 |
17 | defines = []
18 | with_cuda = False
19 |
20 | this_file = os.path.dirname(os.path.realpath(__file__))
21 | print(this_file)
22 |
23 | if torch.cuda.is_available():
24 | print('Including CUDA code.')
25 | sources += ['src/rod_align_cuda.c']
26 | headers += ['src/rod_align_cuda.h']
27 | defines += [('WITH_CUDA', None)]
28 | with_cuda = True
29 |
30 | extra_objects = ['src/rod_align_kernel.cu.o']
31 | extra_objects = [os.path.join(this_file, fname) for fname in extra_objects]
32 |
33 | ffi = create_extension(
34 | '_ext.rod_align',
35 | headers=headers,
36 | sources=sources,
37 | define_macros=defines,
38 | relative_to=__file__,
39 | with_cuda=with_cuda,
40 | extra_objects=extra_objects
41 | )
42 |
43 | if __name__ == '__main__':
44 | ffi.build()
45 | else: # PyTorch 1.0 or later
46 | from setuptools import setup
47 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
48 |
49 | print('Including CUDA code.')
50 |
51 | current_dir = os.path.dirname(os.path.realpath(__file__))
52 |
53 | setup(
54 | name='rod_align_api',
55 | ext_modules=[
56 | CUDAExtension(
57 | name='rod_align_api',
58 | sources=['src/rod_align_cuda.cpp', 'src/rod_align_kernel.cu'],
59 | include_dirs=[current_dir]+torch.utils.cpp_extension.include_paths(cuda=True)
60 | )
61 | ],
62 | cmdclass={
63 | 'build_ext': BuildExtension
64 | })
65 |
66 |
--------------------------------------------------------------------------------
/untils/rod_align/src/rod_align.cpp:
--------------------------------------------------------------------------------
1 | #include "rod_align.h"
2 |
3 | void RODAlignForwardCpu(const float* bottom_data, const float spatial_scale, const int num_rois,
4 | const int height, const int width, const int channels,
5 | const int aligned_height, const int aligned_width, const float * bottom_rois,
6 | float* top_data);
7 |
8 | void RODAlignBackwardCpu(const float* top_diff, const float spatial_scale, const int num_rois,
9 | const int height, const int width, const int channels,
10 | const int aligned_height, const int aligned_width, const float * bottom_rois,
11 | float* top_data);
12 |
13 | int rod_align_forward(int aligned_height, int aligned_width, float spatial_scale,
14 | torch::Tensor features, torch::Tensor rois, torch::Tensor output)
15 | {
16 | //Grab the input tensor
17 | //float * data_flat = THFloatTensor_data(features);
18 | //float * rois_flat = THFloatTensor_data(rois);
19 | auto data_flat = features.data();
20 | auto rois_flat = rois.data();
21 |
22 | //float * output_flat = THFloatTensor_data(output);
23 | auto output_flat = output.data();
24 |
25 | // Number of ROIs
26 | //int num_rois = THFloatTensor_size(rois, 0);
27 | //int size_rois = THFloatTensor_size(rois, 1);
28 | auto rois_sz = rois.sizes();
29 | int num_rois = rois_sz[0];
30 | int size_rois = rois_sz[1];
31 |
32 | if (size_rois != 5)
33 | {
34 | return 0;
35 | }
36 |
37 | // data height
38 | //int data_height = THFloatTensor_size(features, 2);
39 | // data width
40 | //int data_width = THFloatTensor_size(features, 3);
41 | // Number of channels
42 | //int num_channels = THFloatTensor_size(features, 1);
43 | auto feat_sz = features.sizes();
44 | int data_height = feat_sz[2];
45 | int data_width = feat_sz[3];
46 | int num_channels = feat_sz[1];
47 |
48 | // do ROIAlignForward
49 | RODAlignForwardCpu(data_flat, spatial_scale, num_rois, data_height, data_width, num_channels,
50 | aligned_height, aligned_width, rois_flat, output_flat);
51 |
52 | return 1;
53 | }
54 |
55 | int rod_align_backward(int aligned_height, int aligned_width, float spatial_scale,
56 | torch::Tensor top_grad, torch::Tensor rois, torch::Tensor bottom_grad)
57 | {
58 | //Grab the input tensor
59 | //float * top_grad_flat = THFloatTensor_data(top_grad);
60 | //float * rois_flat = THFloatTensor_data(rois);
61 |
62 | //float * bottom_grad_flat = THFloatTensor_data(bottom_grad);
63 |
64 | auto top_grad_flat = top_grad.data();
65 | auto rois_flat = rois.data();
66 | auto bottom_grad_flat = bottom_grad.data();
67 |
68 |
69 | // Number of ROIs
70 | //int num_rois = THFloatTensor_size(rois, 0);
71 | //int size_rois = THFloatTensor_size(rois, 1);
72 |
73 | auto rois_sz = rois.sizes();
74 | int num_rois = rois_sz[0];
75 | int size_rois = rois_sz[1];
76 |
77 | if (size_rois != 5)
78 | {
79 | return 0;
80 | }
81 |
82 | // batch size
83 | // int batch_size = THFloatTensor_size(bottom_grad, 0);
84 | // data height
85 | //int data_height = THFloatTensor_size(bottom_grad, 2);
86 | // data width
87 | //int data_width = THFloatTensor_size(bottom_grad, 3);
88 | // Number of channels
89 | //int num_channels = THFloatTensor_size(bottom_grad, 1);
90 | auto grad_sz = bottom_grad.sizes();
91 | int data_height = grad_sz[2];
92 | int data_width = grad_sz[3];
93 | int num_channels = grad_sz[1];
94 |
95 | // do ROIAlignBackward
96 | RODAlignBackwardCpu(top_grad_flat, spatial_scale, num_rois, data_height,
97 | data_width, num_channels, aligned_height, aligned_width, rois_flat, bottom_grad_flat);
98 |
99 | return 1;
100 | }
101 |
102 | void RODAlignForwardCpu(const float* bottom_data, const float spatial_scale, const int num_rois,
103 | const int height, const int width, const int channels,
104 | const int aligned_height, const int aligned_width, const float * bottom_rois,
105 | float* top_data)
106 | {
107 | const int output_size = num_rois * aligned_height * aligned_width * channels;
108 |
109 | int idx = 0;
110 | float bin_size_h = (float)(height - 1.001) / (aligned_height - 1.);
111 | float bin_size_w = (float)(width - 1.001) / (aligned_width - 1.);
112 | for (idx = 0; idx < output_size; ++idx)
113 | {
114 | // (n, c, ph, pw) is an element in the aligned output
115 | int pw = idx % aligned_width;
116 | int ph = (idx / aligned_width) % aligned_height;
117 | int c = (idx / aligned_width / aligned_height) % channels;
118 | int n = idx / aligned_width / aligned_height / channels;
119 |
120 | float roi_batch_ind = bottom_rois[n * 5 + 0];
121 | float roi_start_w = bottom_rois[n * 5 + 1] * spatial_scale;
122 | float roi_start_h = bottom_rois[n * 5 + 2] * spatial_scale;
123 | float roi_end_w = bottom_rois[n * 5 + 3] * spatial_scale;
124 | float roi_end_h = bottom_rois[n * 5 + 4] * spatial_scale;
125 |
126 |
127 | float h = (float)(ph) * bin_size_h;
128 | float w = (float)(pw) * bin_size_w;
129 |
130 | int hstart = fminf(floor(h), height - 2);
131 | int wstart = fminf(floor(w), width - 2);
132 |
133 | int img_start = roi_batch_ind * channels * height * width;
134 |
135 | // bilinear interpolation
136 | if (h >= roi_start_h && h <= roi_end_h && w >= roi_start_w && w <= roi_end_w){
137 | top_data[idx] = 0.;
138 | } else {
139 | float h_ratio = h - (float)(hstart);
140 | float w_ratio = w - (float)(wstart);
141 | int upleft = img_start + (c * height + hstart) * width + wstart;
142 | int upright = upleft + 1;
143 | int downleft = upleft + width;
144 | int downright = downleft + 1;
145 |
146 | top_data[idx] = bottom_data[upleft] * (1. - h_ratio) * (1. - w_ratio)
147 | + bottom_data[upright] * (1. - h_ratio) * w_ratio
148 | + bottom_data[downleft] * h_ratio * (1. - w_ratio)
149 | + bottom_data[downright] * h_ratio * w_ratio;
150 | }
151 | }
152 | }
153 |
154 | void RODAlignBackwardCpu(const float* top_diff, const float spatial_scale, const int num_rois,
155 | const int height, const int width, const int channels,
156 | const int aligned_height, const int aligned_width, const float * bottom_rois,
157 | float* bottom_diff)
158 | {
159 | const int output_size = num_rois * aligned_height * aligned_width * channels;
160 |
161 | int idx = 0;
162 | float bin_size_h = (float)(height - 1.001) / (aligned_height - 1.);
163 | float bin_size_w = (float)(width - 1.001) / (aligned_width - 1.);
164 | for (idx = 0; idx < output_size; ++idx)
165 | {
166 | // (n, c, ph, pw) is an element in the aligned output
167 | int pw = idx % aligned_width;
168 | int ph = (idx / aligned_width) % aligned_height;
169 | int c = (idx / aligned_width / aligned_height) % channels;
170 | int n = idx / aligned_width / aligned_height / channels;
171 |
172 | float roi_batch_ind = bottom_rois[n * 5 + 0];
173 | float roi_start_w = bottom_rois[n * 5 + 1] * spatial_scale;
174 | float roi_start_h = bottom_rois[n * 5 + 2] * spatial_scale;
175 | float roi_end_w = bottom_rois[n * 5 + 3] * spatial_scale;
176 | float roi_end_h = bottom_rois[n * 5 + 4] * spatial_scale;
177 |
178 | float h = (float)(ph) * bin_size_h;
179 | float w = (float)(pw) * bin_size_w;
180 |
181 | int hstart = fminf(floor(h), height - 2);
182 | int wstart = fminf(floor(w), width - 2);
183 |
184 | int img_start = roi_batch_ind * channels * height * width;
185 |
186 | // bilinear interpolation
187 | if (!(h >= roi_start_h && h <= roi_end_h && w >= roi_start_w && w <= roi_end_w)) {
188 | float h_ratio = h - (float)(hstart);
189 | float w_ratio = w - (float)(wstart);
190 | int upleft = img_start + (c * height + hstart) * width + wstart;
191 | int upright = upleft + 1;
192 | int downleft = upleft + width;
193 | int downright = downleft + 1;
194 |
195 | bottom_diff[upleft] += top_diff[idx] * (1. - h_ratio) * (1. - w_ratio);
196 | bottom_diff[upright] += top_diff[idx] * (1. - h_ratio) * w_ratio;
197 | bottom_diff[downleft] += top_diff[idx] * h_ratio * (1. - w_ratio);
198 | bottom_diff[downright] += top_diff[idx] * h_ratio * w_ratio;
199 | }
200 | }
201 | }
202 |
203 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
204 | m.def("forward", &rod_align_forward, "rod_align forward");
205 | m.def("backward", &rod_align_backward, "rod_align backward");
206 | }
207 |
--------------------------------------------------------------------------------
/untils/rod_align/src/rod_align.h:
--------------------------------------------------------------------------------
1 | #ifndef ROD_ALIGN_H
2 | #define ROD_ALIGN_H
3 |
4 | #include
5 |
6 | int rod_align_forward(int aligned_height, int aligned_width, float spatial_scale,
7 | torch::Tensor features, torch::Tensor rois, torch::Tensor output);
8 |
9 | int rod_align_backward(int aligned_height, int aligned_width, float spatial_scale,
10 | torch::Tensor top_grad, torch::Tensor rois, torch::Tensor bottom_grad);
11 |
12 | #endif
13 |
--------------------------------------------------------------------------------
/untils/rod_align/src/rod_align_cuda.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include "rod_align_kernel.h"
4 | #include "rod_align_cuda.h"
5 |
6 |
7 | int rod_align_forward_cuda(int aligned_height, int aligned_width, float spatial_scale,
8 | torch::Tensor features, torch::Tensor rois, torch::Tensor output)
9 | {
10 | // Grab the input tensor
11 | //float * data_flat = THCudaTensor_data(state, features);
12 | //float * rois_flat = THCudaTensor_data(state, rois);
13 |
14 | //float * output_flat = THCudaTensor_data(state, output);
15 |
16 | auto data_flat = features.data();
17 | auto rois_flat = rois.data();
18 | auto output_flat = output.data();
19 |
20 | // Number of ROIs
21 | //int num_rois = THCudaTensor_size(state, rois, 0);
22 | //int size_rois = THCudaTensor_size(state, rois, 1);
23 |
24 | auto rois_sz = rois.sizes();
25 | int num_rois = rois_sz[0];
26 | int size_rois = rois_sz[1];
27 |
28 | if (size_rois != 5)
29 | {
30 | return 0;
31 | }
32 |
33 | // data height
34 | //int data_height = THCudaTensor_size(state, features, 2);
35 | // data width
36 | //int data_width = THCudaTensor_size(state, features, 3);
37 | // Number of channels
38 | //int num_channels = THCudaTensor_size(state, features, 1);
39 | auto feat_sz = features.sizes();
40 | int data_height = feat_sz[2];
41 | int data_width = feat_sz[3];
42 | int num_channels = feat_sz[1];
43 |
44 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
45 |
46 | RODAlignForwardLaucher(
47 | data_flat, spatial_scale, num_rois, data_height,
48 | data_width, num_channels, aligned_height,
49 | aligned_width, rois_flat,
50 | output_flat, stream);
51 |
52 | return 1;
53 | }
54 |
55 | int rod_align_backward_cuda(int aligned_height, int aligned_width, float spatial_scale,
56 | torch::Tensor top_grad, torch::Tensor rois, torch::Tensor bottom_grad)
57 | {
58 | // Grab the input tensor
59 | //float * top_grad_flat = THCudaTensor_data(state, top_grad);
60 | //float * rois_flat = THCudaTensor_data(state, rois);
61 |
62 | //float * bottom_grad_flat = THCudaTensor_data(state, bottom_grad);
63 | auto top_grad_flat = top_grad.data();
64 | auto rois_flat = rois.data();
65 | auto bottom_grad_flat = bottom_grad.data();
66 |
67 | // Number of ROIs
68 | //int num_rois = THCudaTensor_size(state, rois, 0);
69 | //int size_rois = THCudaTensor_size(state, rois, 1);
70 | auto rois_sz = rois.sizes();
71 | int num_rois = rois_sz[0];
72 | int size_rois = rois_sz[1];
73 | if (size_rois != 5)
74 | {
75 | return 0;
76 | }
77 |
78 | // batch size
79 | //int batch_size = THCudaTensor_size(state, bottom_grad, 0);
80 | // data height
81 | //int data_height = THCudaTensor_size(state, bottom_grad, 2);
82 | // data width
83 | //int data_width = THCudaTensor_size(state, bottom_grad, 3);
84 | // Number of channels
85 | //int num_channels = THCudaTensor_size(state, bottom_grad, 1);
86 |
87 | auto grad_sz = bottom_grad.sizes();
88 | int batch_size = grad_sz[0];
89 | int data_height = grad_sz[2];
90 | int data_width = grad_sz[3];
91 | int num_channels = grad_sz[1];
92 |
93 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
94 | RODAlignBackwardLaucher(
95 | top_grad_flat, spatial_scale, batch_size, num_rois, data_height,
96 | data_width, num_channels, aligned_height,
97 | aligned_width, rois_flat,
98 | bottom_grad_flat, stream);
99 |
100 | return 1;
101 | }
102 |
103 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
104 | m.def("forward", &rod_align_forward_cuda, "rod_align forward");
105 | m.def("backward", &rod_align_backward_cuda, "rod_align backward");
106 | }
107 |
--------------------------------------------------------------------------------
/untils/rod_align/src/rod_align_cuda.h:
--------------------------------------------------------------------------------
1 | #ifndef ROD_ALIGN_CUDA_H
2 | #define ROD_ALIGN_CUDA_H
3 |
4 | #include
5 |
6 | int rod_align_forward_cuda(int aligned_height, int aligned_width, float spatial_scale,
7 | torch::Tensor features, torch::Tensor rois, torch::Tensor output);
8 |
9 | int rod_align_backward_cuda(int aligned_height, int aligned_width, float spatial_scale,
10 | torch::Tensor top_grad, torch::Tensor rois, torch::Tensor bottom_grad);
11 |
12 | #endif
13 |
--------------------------------------------------------------------------------
/untils/rod_align/src/rod_align_kernel.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include "rod_align_kernel.h"
3 |
4 | #define CUDA_1D_KERNEL_LOOP(i, n) \
5 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
6 | i += blockDim.x * gridDim.x)
7 |
8 |
9 | __global__ void RODAlignForward(const int nthreads, const float* bottom_data, const float spatial_scale, const int height, const int width,
10 | const int channels, const int aligned_height, const int aligned_width, const float* bottom_rois, float* top_data) {
11 | float bin_size_h = (float)(height - 1.001) / (aligned_height - 1.);
12 | float bin_size_w = (float)(width - 1.001) / (aligned_width - 1.);
13 | CUDA_1D_KERNEL_LOOP(index, nthreads) {
14 | // (n, c, ph, pw) is an element in the aligned output
15 | // int n = index;
16 | // int pw = n % aligned_width;
17 | // n /= aligned_width;
18 | // int ph = n % aligned_height;
19 | // n /= aligned_height;
20 | // int c = n % channels;
21 | // n /= channels;
22 |
23 | int pw = index % aligned_width;
24 | int ph = (index / aligned_width) % aligned_height;
25 | int c = (index / aligned_width / aligned_height) % channels;
26 | int n = index / aligned_width / aligned_height / channels;
27 |
28 | // bottom_rois += n * 5;
29 | float roi_batch_ind = bottom_rois[n * 5 + 0];
30 | float roi_start_w = bottom_rois[n * 5 + 1] * spatial_scale;
31 | float roi_start_h = bottom_rois[n * 5 + 2] * spatial_scale;
32 | float roi_end_w = bottom_rois[n * 5 + 3] * spatial_scale;
33 | float roi_end_h = bottom_rois[n * 5 + 4] * spatial_scale;
34 |
35 |
36 | float h = (float)(ph) * bin_size_h;
37 | float w = (float)(pw) * bin_size_w;
38 |
39 | int hstart = fminf(floor(h), height - 2);
40 | int wstart = fminf(floor(w), width - 2);
41 |
42 | int img_start = roi_batch_ind * channels * height * width;
43 |
44 | // bilinear interpolation
45 | if (h >= roi_start_h && h <= roi_end_h && w >= roi_start_w && w <= roi_end_w){
46 | top_data[index] = 0.;
47 | } else {
48 | float h_ratio = h - (float)(hstart);
49 | float w_ratio = w - (float)(wstart);
50 | int upleft = img_start + (c * height + hstart) * width + wstart;
51 | int upright = upleft + 1;
52 | int downleft = upleft + width;
53 | int downright = downleft + 1;
54 |
55 | top_data[index] = bottom_data[upleft] * (1. - h_ratio) * (1. - w_ratio)
56 | + bottom_data[upright] * (1. - h_ratio) * w_ratio
57 | + bottom_data[downleft] * h_ratio * (1. - w_ratio)
58 | + bottom_data[downright] * h_ratio * w_ratio;
59 | }
60 | }
61 | }
62 |
63 |
64 | int RODAlignForwardLaucher(const float* bottom_data, const float spatial_scale, const int num_rois, const int height, const int width,
65 | const int channels, const int aligned_height, const int aligned_width, const float* bottom_rois, float* top_data, cudaStream_t stream) {
66 | const int kThreadsPerBlock = 1024;
67 | const int output_size = num_rois * aligned_height * aligned_width * channels;
68 | cudaError_t err;
69 |
70 |
71 | RODAlignForward<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>(
72 | output_size, bottom_data, spatial_scale, height, width, channels,
73 | aligned_height, aligned_width, bottom_rois, top_data);
74 |
75 | err = cudaGetLastError();
76 | if(cudaSuccess != err) {
77 | fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) );
78 | exit( -1 );
79 | }
80 |
81 | return 1;
82 | }
83 |
84 |
85 | __global__ void RODAlignBackward(const int nthreads, const float* top_diff, const float spatial_scale, const int height, const int width,
86 | const int channels, const int aligned_height, const int aligned_width, float* bottom_diff, const float* bottom_rois) {
87 | float bin_size_h = (float)(height - 1.001) / (aligned_height - 1.);
88 | float bin_size_w = (float)(width - 1.001) / (aligned_width - 1.);
89 | CUDA_1D_KERNEL_LOOP(index, nthreads) {
90 |
91 | // (n, c, ph, pw) is an element in the aligned output
92 | int pw = index % aligned_width;
93 | int ph = (index / aligned_width) % aligned_height;
94 | int c = (index / aligned_width / aligned_height) % channels;
95 | int n = index / aligned_width / aligned_height / channels;
96 |
97 | float roi_batch_ind = bottom_rois[n * 5 + 0];
98 | float roi_start_w = bottom_rois[n * 5 + 1] * spatial_scale;
99 | float roi_start_h = bottom_rois[n * 5 + 2] * spatial_scale;
100 | float roi_end_w = bottom_rois[n * 5 + 3] * spatial_scale;
101 | float roi_end_h = bottom_rois[n * 5 + 4] * spatial_scale;
102 |
103 |
104 | float h = (float)(ph) * bin_size_h;
105 | float w = (float)(pw) * bin_size_w;
106 |
107 | int hstart = fminf(floor(h), height - 2);
108 | int wstart = fminf(floor(w), width - 2);
109 |
110 | int img_start = roi_batch_ind * channels * height * width;
111 |
112 | // bilinear interpolation
113 | if (!(h >= roi_start_h && h <= roi_end_h && w >= roi_start_w && w <= roi_end_w)) {
114 | float h_ratio = h - (float)(hstart);
115 | float w_ratio = w - (float)(wstart);
116 | int upleft = img_start + (c * height + hstart) * width + wstart;
117 | int upright = upleft + 1;
118 | int downleft = upleft + width;
119 | int downright = downleft + 1;
120 |
121 | atomicAdd(bottom_diff + upleft, top_diff[index] * (1. - h_ratio) * (1 - w_ratio));
122 | atomicAdd(bottom_diff + upright, top_diff[index] * (1. - h_ratio) * w_ratio);
123 | atomicAdd(bottom_diff + downleft, top_diff[index] * h_ratio * (1 - w_ratio));
124 | atomicAdd(bottom_diff + downright, top_diff[index] * h_ratio * w_ratio);
125 | }
126 | }
127 | }
128 |
129 | int RODAlignBackwardLaucher(const float* top_diff, const float spatial_scale, const int batch_size, const int num_rois, const int height, const int width,
130 | const int channels, const int aligned_height, const int aligned_width, const float* bottom_rois, float* bottom_diff, cudaStream_t stream) {
131 | const int kThreadsPerBlock = 1024;
132 | const int output_size = num_rois * aligned_height * aligned_width * channels;
133 | cudaError_t err;
134 |
135 | RODAlignBackward<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>(
136 | output_size, top_diff, spatial_scale, height, width, channels,
137 | aligned_height, aligned_width, bottom_diff, bottom_rois);
138 |
139 | err = cudaGetLastError();
140 | if(cudaSuccess != err) {
141 | fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) );
142 | exit( -1 );
143 | }
144 |
145 | return 1;
146 | }
147 |
--------------------------------------------------------------------------------
/untils/rod_align/src/rod_align_kernel.h:
--------------------------------------------------------------------------------
1 | #ifndef _ROD_ALIGN_KERNEL
2 | #define _ROD_ALIGN_KERNEL
3 |
4 | #include
5 |
6 | __global__ void RODAlignForward(const int nthreads, const float* bottom_data,
7 | const float spatial_scale, const int height, const int width,
8 | const int channels, const int aligned_height, const int aligned_width,
9 | const float* bottom_rois, float* top_data);
10 |
11 | int RODAlignForwardLaucher(
12 | const float* bottom_data, const float spatial_scale, const int num_rois, const int height,
13 | const int width, const int channels, const int aligned_height,
14 | const int aligned_width, const float* bottom_rois,
15 | float* top_data, cudaStream_t stream);
16 |
17 | __global__ void RODAlignBackward(const int nthreads, const float* top_diff,
18 | const float spatial_scale, const int height, const int width,
19 | const int channels, const int aligned_height, const int aligned_width,
20 | float* bottom_diff, const float* bottom_rois);
21 |
22 | int RODAlignBackwardLaucher(const float* top_diff, const float spatial_scale, const int batch_size, const int num_rois,
23 | const int height, const int width, const int channels, const int aligned_height,
24 | const int aligned_width, const float* bottom_rois,
25 | float* bottom_diff, cudaStream_t stream);
26 |
27 | #endif
28 |
29 |
--------------------------------------------------------------------------------
/untils/roi_align/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bo-zhang-cs/GAIC-Pytorch/877d6a325943c0a098324dc6d4a975ff4097c3f6/untils/roi_align/__init__.py
--------------------------------------------------------------------------------
/untils/roi_align/functions/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bo-zhang-cs/GAIC-Pytorch/877d6a325943c0a098324dc6d4a975ff4097c3f6/untils/roi_align/functions/__init__.py
--------------------------------------------------------------------------------
/untils/roi_align/functions/roi_align.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Function
3 | import roi_align_api
4 |
5 | class RoIAlignFunction(Function):
6 | @staticmethod
7 | def forward(ctx, features, rois, aligned_height, aligned_width, spatial_scale):
8 | batch_size, num_channels, data_height, data_width = features.size()
9 | ctx.save_for_backward(rois,
10 | torch.IntTensor([int(batch_size),
11 | int(num_channels),
12 | int(data_height),
13 | int(data_width),
14 | int(aligned_height),
15 | int(aligned_width)]),
16 | torch.FloatTensor([float(spatial_scale)]))
17 |
18 | num_rois = rois.size(0)
19 |
20 | output = features.new(num_rois,
21 | num_channels,
22 | int(aligned_height),
23 | int(aligned_width)).zero_()
24 |
25 | roi_align_api.forward(int(aligned_height),
26 | int(aligned_width),
27 | float(spatial_scale),
28 | features,
29 | rois,
30 | output)
31 | return output
32 |
33 | @staticmethod
34 | def backward(ctx, grad_output):
35 | rois, core_size, scale = ctx.saved_tensors
36 |
37 | batch_size, num_channels, data_height, data_width, aligned_height, aligned_width = core_size
38 | spatial_scale = scale[0]
39 |
40 | grad_input = rois.new(batch_size,
41 | num_channels,
42 | data_height,
43 | data_width).zero_()
44 |
45 | roi_align_api.backward(int(aligned_height),
46 | int(aligned_width),
47 | float(spatial_scale),
48 | grad_output,
49 | rois,
50 | grad_input)
51 |
52 | return grad_input, None, None, None, None
53 |
--------------------------------------------------------------------------------
/untils/roi_align/make.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | cd src
3 | echo "Compiling roi_align kernels by nvcc..."
4 |
5 | # Specify the architecture of your NV card below.
6 | # -arch=sm_75 is compatible with the following NV GPU cards,
7 | # GeForce RTX 2080 Ti, RTX 2080, RTX 2070 Quadro RTX 8000, Quadro RTX 6000, Quadro RTX 5000 Tesla T4
8 | # See more https://raw.githubusercontent.com/stereolabs/zed-yolo/master/libdarknet/Makefile
9 | nvcc -c -o roi_align_kernel.cu.o roi_align_kernel.cu -x cu -Xcompiler -fPIC -arch=sm_86
10 |
11 | cd ../
12 | # Export CUDA_HOME. Build and install the library.
13 | export CUDA_HOME=/usr/local/cuda && python3 setup.py install
14 |
--------------------------------------------------------------------------------
/untils/roi_align/make_python2.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | cd src
3 | echo "Compiling roi_align kernels by nvcc..."
4 |
5 | # Specify the architecture of your NV card below.
6 | # -arch=sm_75 is compatible with the following NV GPU cards,
7 | # GeForce RTX 2080 Ti, RTX 2080, RTX 2070 Quadro RTX 8000, Quadro RTX 6000, Quadro RTX 5000 Tesla T4
8 | # See more https://raw.githubusercontent.com/stereolabs/zed-yolo/master/libdarknet/Makefile
9 | nvcc -c -o roi_align_kernel.cu.o roi_align_kernel.cu -x cu -Xcompiler -fPIC -arch=sm_75
10 |
11 | cd ../
12 | # Export CUDA_HOME. Build and install the library.
13 | export CUDA_HOME=/usr/local/cuda-10.0 && python setup.py install
14 |
--------------------------------------------------------------------------------
/untils/roi_align/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bo-zhang-cs/GAIC-Pytorch/877d6a325943c0a098324dc6d4a975ff4097c3f6/untils/roi_align/modules/__init__.py
--------------------------------------------------------------------------------
/untils/roi_align/modules/roi_align.py:
--------------------------------------------------------------------------------
1 | from torch.nn.modules.module import Module
2 | from torch.nn.functional import avg_pool2d, max_pool2d
3 | from ..functions.roi_align import RoIAlignFunction
4 |
5 |
6 | class RoIAlign(Module):
7 | def __init__(self, aligned_height, aligned_width, spatial_scale):
8 | super(RoIAlign, self).__init__()
9 |
10 | self.aligned_width = int(aligned_width)
11 | self.aligned_height = int(aligned_height)
12 | self.spatial_scale = float(spatial_scale)
13 |
14 | def forward(self, features, rois):
15 | return RoIAlignFunction.apply(features,
16 | rois,
17 | self.aligned_height,
18 | self.aligned_width,
19 | self.spatial_scale)
20 |
21 | class RoIAlignAvg(Module):
22 | def __init__(self, aligned_height, aligned_width, spatial_scale):
23 | super(RoIAlignAvg, self).__init__()
24 |
25 | self.aligned_width = int(aligned_width)
26 | self.aligned_height = int(aligned_height)
27 | self.spatial_scale = float(spatial_scale)
28 |
29 | def forward(self, features, rois):
30 | x = RoIAlignFunction.apply(features,
31 | rois,
32 | self.aligned_height+1,
33 | self.aligned_width+1,
34 | self.spatial_scale)
35 | return avg_pool2d(x, kernel_size=2, stride=1)
36 |
37 | class RoIAlignMax(Module):
38 | def __init__(self, aligned_height, aligned_width, spatial_scale):
39 | super(RoIAlignMax, self).__init__()
40 |
41 | self.aligned_width = int(aligned_width)
42 | self.aligned_height = int(aligned_height)
43 | self.spatial_scale = float(spatial_scale)
44 |
45 | def forward(self, features, rois):
46 | x = RoIAlignFunction.apply(features,
47 | rois,
48 | self.aligned_height+1,
49 | self.aligned_width+1,
50 | self.spatial_scale)
51 | return max_pool2d(x, kernel_size=2, stride=1)
52 |
--------------------------------------------------------------------------------
/untils/roi_align/setup.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import os
3 | import torch
4 | from pkg_resources import parse_version
5 |
6 | min_version = parse_version('1.0.0')
7 | current_version = parse_version(torch.__version__)
8 |
9 |
10 | if current_version < min_version: #PyTorch before 1.0
11 | from torch.utils.ffi import create_extension
12 |
13 | sources = ['src/roi_align.c']
14 | headers = ['src/roi_align.h']
15 | extra_objects = []
16 | #sources = []
17 | #headers = []
18 | defines = []
19 | with_cuda = False
20 |
21 | this_file = os.path.dirname(os.path.realpath(__file__))
22 | print(this_file)
23 |
24 | if torch.cuda.is_available():
25 | print('Including CUDA code.')
26 | sources += ['src/roi_align_cuda.c']
27 | headers += ['src/roi_align_cuda.h']
28 | defines += [('WITH_CUDA', None)]
29 | with_cuda = True
30 |
31 | extra_objects = ['src/roi_align_kernel.cu.o']
32 | extra_objects = [os.path.join(this_file, fname) for fname in extra_objects]
33 |
34 | ffi = create_extension(
35 | '_ext.roi_align',
36 | headers=headers,
37 | sources=sources,
38 | define_macros=defines,
39 | relative_to=__file__,
40 | with_cuda=with_cuda,
41 | extra_objects=extra_objects
42 | )
43 |
44 | if __name__ == '__main__':
45 | ffi.build()
46 | else: # PyTorch 1.0 or later
47 | from setuptools import setup
48 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
49 |
50 | print('Including CUDA code.')
51 | current_dir = os.path.dirname(os.path.realpath(__file__))
52 | #cuda_include = '/usr/local/cuda-10.0/include'
53 |
54 | #GPU version
55 | setup(
56 | name='roi_align_api',
57 | ext_modules=[
58 | CUDAExtension(
59 | name='roi_align_api',
60 | sources=['src/roi_align_cuda.cpp', 'src/roi_align_kernel.cu'],
61 | include_dirs=[current_dir]+torch.utils.cpp_extension.include_paths(cuda=True)
62 | )
63 | ],
64 | cmdclass={
65 | 'build_ext': BuildExtension
66 | })
67 |
--------------------------------------------------------------------------------
/untils/roi_align/src/roi_align.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include "roi_align.h"
5 |
6 | void ROIAlignForwardCpu(const float* bottom_data, const float spatial_scale, const int num_rois,
7 | const int height, const int width, const int channels,
8 | const int aligned_height, const int aligned_width, const float * bottom_rois,
9 | float* top_data);
10 |
11 | void ROIAlignBackwardCpu(const float* top_diff, const float spatial_scale, const int num_rois,
12 | const int height, const int width, const int channels,
13 | const int aligned_height, const int aligned_width, const float * bottom_rois,
14 | float* top_data);
15 |
16 | int roi_align_forward(int aligned_height, int aligned_width, float spatial_scale,
17 | torch::Tensor features, torch::Tensor rois, torch::Tensor output)
18 | {
19 | //Grab the input tensor
20 | //float * data_flat = THFloatTensor_data(features);
21 | //float * rois_flat = THFloatTensor_data(rois);
22 | auto data_flat = features.data();
23 | auto rois_flat = rois.data();
24 |
25 | //float * output_flat = THFloatTensor_data(output);
26 | auto output_flat = output.data();
27 |
28 | // Number of ROIs
29 | //int num_rois = THFloatTensor_size(rois, 0);
30 | //int size_rois = THFloatTensor_size(rois, 1);
31 |
32 | auto rois_sz = rois.sizes();
33 | int num_rois = rois_sz[0];
34 | int size_rois = rois_sz[1];
35 |
36 | if (size_rois != 5)
37 | {
38 | return 0;
39 | }
40 |
41 | // data height
42 | //int data_height = THFloatTensor_size(features, 2);
43 | // data width
44 | //int data_width = THFloatTensor_size(features, 3);
45 | // Number of channels
46 | //int num_channels = THFloatTensor_size(features, 1);
47 | auto feat_sz = features.sizes();
48 | int data_height = feat_sz[2];
49 | int data_width = feat_sz[3];
50 | int num_channels = feat_sz[1];
51 |
52 | // do ROIAlignForward
53 | ROIAlignForwardCpu(data_flat, spatial_scale, num_rois, data_height, data_width, num_channels,
54 | aligned_height, aligned_width, rois_flat, output_flat);
55 |
56 | return 1;
57 | }
58 |
59 | int roi_align_backward(int aligned_height, int aligned_width, float spatial_scale,
60 | torch::Tensor top_grad, torch::Tensor rois, torch::Tensor bottom_grad)
61 | {
62 | //Grab the input tensor
63 | //float * top_grad_flat = THFloatTensor_data(top_grad);
64 | //float * rois_flat = THFloatTensor_data(rois);
65 |
66 | //float * bottom_grad_flat = THFloatTensor_data(bottom_grad);
67 | auto top_grad_flat = top_grad.data();
68 | auto rois_flat = rois.data();
69 | auto bottom_grad_flat = bottom_grad.data();
70 |
71 | // Number of ROIs
72 | //int num_rois = THFloatTensor_size(rois, 0);
73 | //int size_rois = THFloatTensor_size(rois, 1);
74 | auto rois_sz = rois.sizes();
75 | int num_rois = rois_sz[0];
76 | int size_rois = rois_sz[1];
77 | if (size_rois != 5)
78 | {
79 | return 0;
80 | }
81 |
82 | // batch size
83 | // int batch_size = THFloatTensor_size(bottom_grad, 0);
84 | // data height
85 | //int data_height = THFloatTensor_size(bottom_grad, 2);
86 | // data width
87 | //int data_width = THFloatTensor_size(bottom_grad, 3);
88 | // Number of channels
89 | //int num_channels = THFloatTensor_size(bottom_grad, 1);
90 |
91 | auto grad_sz = bottom_grad.sizes();
92 | int data_height = grad_sz[2];
93 | int data_width = grad_sz[3];
94 | int num_channels = grad_sz[1];
95 |
96 | // do ROIAlignBackward
97 | ROIAlignBackwardCpu(top_grad_flat, spatial_scale, num_rois, data_height,
98 | data_width, num_channels, aligned_height, aligned_width, rois_flat, bottom_grad_flat);
99 |
100 | return 1;
101 | }
102 |
103 | void ROIAlignForwardCpu(const float* bottom_data, const float spatial_scale, const int num_rois,
104 | const int height, const int width, const int channels,
105 | const int aligned_height, const int aligned_width, const float * bottom_rois,
106 | float* top_data)
107 | {
108 | const int output_size = num_rois * aligned_height * aligned_width * channels;
109 |
110 | int idx = 0;
111 | for (idx = 0; idx < output_size; ++idx)
112 | {
113 | // (n, c, ph, pw) is an element in the aligned output
114 | int pw = idx % aligned_width;
115 | int ph = (idx / aligned_width) % aligned_height;
116 | int c = (idx / aligned_width / aligned_height) % channels;
117 | int n = idx / aligned_width / aligned_height / channels;
118 |
119 | float roi_batch_ind = bottom_rois[n * 5 + 0];
120 | float roi_start_w = bottom_rois[n * 5 + 1] * spatial_scale;
121 | float roi_start_h = bottom_rois[n * 5 + 2] * spatial_scale;
122 | float roi_end_w = bottom_rois[n * 5 + 3] * spatial_scale;
123 | float roi_end_h = bottom_rois[n * 5 + 4] * spatial_scale;
124 |
125 | // Force malformed ROI to be 1x1
126 | float roi_width = fmaxf(roi_end_w - roi_start_w + 1., 0.);
127 | float roi_height = fmaxf(roi_end_h - roi_start_h + 1., 0.);
128 | float bin_size_h = roi_height / (aligned_height - 1.);
129 | float bin_size_w = roi_width / (aligned_width - 1.);
130 |
131 | float h = (float)(ph) * bin_size_h + roi_start_h;
132 | float w = (float)(pw) * bin_size_w + roi_start_w;
133 |
134 | int hstart = fminf(floor(h), height - 2);
135 | int wstart = fminf(floor(w), width - 2);
136 |
137 | int img_start = roi_batch_ind * channels * height * width;
138 |
139 | // bilinear interpolation
140 | if (h < 0 || h >= height || w < 0 || w >= width)
141 | {
142 | top_data[idx] = 0.;
143 | }
144 | else
145 | {
146 | float h_ratio = h - (float)(hstart);
147 | float w_ratio = w - (float)(wstart);
148 | int upleft = img_start + (c * height + hstart) * width + wstart;
149 | int upright = upleft + 1;
150 | int downleft = upleft + width;
151 | int downright = downleft + 1;
152 |
153 | top_data[idx] = bottom_data[upleft] * (1. - h_ratio) * (1. - w_ratio)
154 | + bottom_data[upright] * (1. - h_ratio) * w_ratio
155 | + bottom_data[downleft] * h_ratio * (1. - w_ratio)
156 | + bottom_data[downright] * h_ratio * w_ratio;
157 | }
158 | }
159 | }
160 |
161 | void ROIAlignBackwardCpu(const float* top_diff, const float spatial_scale, const int num_rois,
162 | const int height, const int width, const int channels,
163 | const int aligned_height, const int aligned_width, const float * bottom_rois,
164 | float* bottom_diff)
165 | {
166 | const int output_size = num_rois * aligned_height * aligned_width * channels;
167 |
168 | int idx = 0;
169 | for (idx = 0; idx < output_size; ++idx)
170 | {
171 | // (n, c, ph, pw) is an element in the aligned output
172 | int pw = idx % aligned_width;
173 | int ph = (idx / aligned_width) % aligned_height;
174 | int c = (idx / aligned_width / aligned_height) % channels;
175 | int n = idx / aligned_width / aligned_height / channels;
176 |
177 | float roi_batch_ind = bottom_rois[n * 5 + 0];
178 | float roi_start_w = bottom_rois[n * 5 + 1] * spatial_scale;
179 | float roi_start_h = bottom_rois[n * 5 + 2] * spatial_scale;
180 | float roi_end_w = bottom_rois[n * 5 + 3] * spatial_scale;
181 | float roi_end_h = bottom_rois[n * 5 + 4] * spatial_scale;
182 |
183 | // Force malformed ROI to be 1x1
184 | float roi_width = fmaxf(roi_end_w - roi_start_w + 1., 0.);
185 | float roi_height = fmaxf(roi_end_h - roi_start_h + 1., 0.);
186 | float bin_size_h = roi_height / (aligned_height - 1.);
187 | float bin_size_w = roi_width / (aligned_width - 1.);
188 |
189 | float h = (float)(ph) * bin_size_h + roi_start_h;
190 | float w = (float)(pw) * bin_size_w + roi_start_w;
191 |
192 | int hstart = fminf(floor(h), height - 2);
193 | int wstart = fminf(floor(w), width - 2);
194 |
195 | int img_start = roi_batch_ind * channels * height * width;
196 |
197 | // bilinear interpolation
198 | if (h < 0 || h >= height || w < 0 || w >= width)
199 | {
200 | float h_ratio = h - (float)(hstart);
201 | float w_ratio = w - (float)(wstart);
202 | int upleft = img_start + (c * height + hstart) * width + wstart;
203 | int upright = upleft + 1;
204 | int downleft = upleft + width;
205 | int downright = downleft + 1;
206 |
207 | bottom_diff[upleft] += top_diff[idx] * (1. - h_ratio) * (1. - w_ratio);
208 | bottom_diff[upright] += top_diff[idx] * (1. - h_ratio) * w_ratio;
209 | bottom_diff[downleft] += top_diff[idx] * h_ratio * (1. - w_ratio);
210 | bottom_diff[downright] += top_diff[idx] * h_ratio * w_ratio;
211 | }
212 | }
213 | }
214 |
215 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
216 | m.def("forward", &roi_align_forward, "roi_align forward");
217 | m.def("backward", &roi_align_backward, "roi_align backward");
218 | }
219 |
--------------------------------------------------------------------------------
/untils/roi_align/src/roi_align.h:
--------------------------------------------------------------------------------
1 | #ifndef ROI_ALIGN_H
2 | #define ROI_ALIGN_H
3 |
4 | #include
5 |
6 | int roi_align_forward(int aligned_height, int aligned_width, float spatial_scale,
7 | torch::Tensor features, torch::Tensor rois, torch::Tensor output);
8 |
9 | int roi_align_backward(int aligned_height, int aligned_width, float spatial_scale,
10 | torch::Tensor top_grad, torch::Tensor rois, torch::Tensor bottom_grad);
11 |
12 | #endif
13 |
--------------------------------------------------------------------------------
/untils/roi_align/src/roi_align_cuda.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include "roi_align_kernel.h"
5 |
6 |
7 | int roi_align_forward_cuda(int aligned_height, int aligned_width, float spatial_scale,
8 | torch::Tensor features, torch::Tensor rois, torch::Tensor output)
9 | {
10 | // Grab the input tensor
11 | //float * data_flat = THCudaTensor_data(state, features);
12 | //float * rois_flat = THCudaTensor_data(state, rois);
13 |
14 | //float * output_flat = THCudaTensor_data(state, output);
15 |
16 | auto data_flat = features.data();
17 | auto rois_flat = rois.data();
18 | auto output_flat = output.data();
19 |
20 | // Number of ROIs
21 | //int num_rois = THCudaTensor_size(state, rois, 0);
22 | //int size_rois = THCudaTensor_size(state, rois, 1);
23 | auto rois_sz = rois.sizes();
24 | int num_rois = rois_sz[0];
25 | int size_rois = rois_sz[1];
26 | if (size_rois != 5)
27 | {
28 | return 0;
29 | }
30 |
31 | // data height
32 | //int data_height = THCudaTensor_size(state, features, 2);
33 | // data width
34 | //int data_width = THCudaTensor_size(state, features, 3);
35 | // Number of channels
36 | //int num_channels = THCudaTensor_size(state, features, 1);
37 | auto feat_sz = features.sizes();
38 | int data_height = feat_sz[2];
39 | int data_width = feat_sz[3];
40 | int num_channels = feat_sz[1];
41 |
42 |
43 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
44 |
45 | ROIAlignForwardLaucher(
46 | data_flat, spatial_scale, num_rois, data_height,
47 | data_width, num_channels, aligned_height,
48 | aligned_width, rois_flat,
49 | output_flat, stream);
50 |
51 | return 1;
52 | }
53 |
54 | int roi_align_backward_cuda(int aligned_height, int aligned_width, float spatial_scale,
55 | torch::Tensor top_grad, torch::Tensor rois, torch::Tensor bottom_grad)
56 | {
57 | // Grab the input tensor
58 | //float * top_grad_flat = THCudaTensor_data(state, top_grad);
59 | //float * rois_flat = THCudaTensor_data(state, rois);
60 |
61 | //float * bottom_grad_flat = THCudaTensor_data(state, bottom_grad);
62 | auto top_grad_flat = top_grad.data();
63 | auto rois_flat = rois.data();
64 | auto bottom_grad_flat = bottom_grad.data();
65 |
66 | // Number of ROIs
67 | //int num_rois = THCudaTensor_size(state, rois, 0);
68 | //int size_rois = THCudaTensor_size(state, rois, 1);
69 | auto rois_sz = rois.sizes();
70 | int num_rois = rois_sz[0];
71 | int size_rois = rois_sz[1];
72 |
73 | if (size_rois != 5)
74 | {
75 | return 0;
76 | }
77 |
78 | // batch size
79 | //int batch_size = THCudaTensor_size(state, bottom_grad, 0);
80 | // data height
81 | //int data_height = THCudaTensor_size(state, bottom_grad, 2);
82 | // data width
83 | //int data_width = THCudaTensor_size(state, bottom_grad, 3);
84 | // Number of channels
85 | //int num_channels = THCudaTensor_size(state, bottom_grad, 1);
86 | auto grad_sz = bottom_grad.sizes();
87 | int batch_size = grad_sz[0];
88 | int data_height = grad_sz[2];
89 | int data_width = grad_sz[3];
90 | int num_channels = grad_sz[1];
91 |
92 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
93 | ROIAlignBackwardLaucher(
94 | top_grad_flat, spatial_scale, batch_size, num_rois, data_height,
95 | data_width, num_channels, aligned_height,
96 | aligned_width, rois_flat,
97 | bottom_grad_flat, stream);
98 |
99 | return 1;
100 | }
101 |
102 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
103 | m.def("forward", &roi_align_forward_cuda, "roi_align forward");
104 | m.def("backward", &roi_align_backward_cuda, "roi_align backward");
105 | }
106 |
--------------------------------------------------------------------------------
/untils/roi_align/src/roi_align_cuda.h:
--------------------------------------------------------------------------------
1 | #ifndef ROI_ALIGN_CUDA_H
2 | #define ROI_ALIGN_CUDA_H
3 |
4 | #include
5 |
6 | int roi_align_forward_cuda(int aligned_height, int aligned_width, float spatial_scale,
7 | torch::Tensor features, torch::Tensor rois, torch::Tensor output);
8 |
9 | int roi_align_backward_cuda(int aligned_height, int aligned_width, float spatial_scale,
10 | torch::Tensor top_grad, torch::Tensor rois, torch::Tensor bottom_grad);
11 |
12 | #endif
13 |
--------------------------------------------------------------------------------
/untils/roi_align/src/roi_align_kernel.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include "roi_align_kernel.h"
5 |
6 | #define CUDA_1D_KERNEL_LOOP(i, n) \
7 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
8 | i += blockDim.x * gridDim.x)
9 |
10 |
11 | __global__ void ROIAlignForward(const int nthreads, const float* bottom_data, const float spatial_scale, const int height, const int width,
12 | const int channels, const int aligned_height, const int aligned_width, const float* bottom_rois, float* top_data) {
13 | CUDA_1D_KERNEL_LOOP(index, nthreads) {
14 | // (n, c, ph, pw) is an element in the aligned output
15 | // int n = index;
16 | // int pw = n % aligned_width;
17 | // n /= aligned_width;
18 | // int ph = n % aligned_height;
19 | // n /= aligned_height;
20 | // int c = n % channels;
21 | // n /= channels;
22 |
23 | int pw = index % aligned_width;
24 | int ph = (index / aligned_width) % aligned_height;
25 | int c = (index / aligned_width / aligned_height) % channels;
26 | int n = index / aligned_width / aligned_height / channels;
27 |
28 | // bottom_rois += n * 5;
29 | float roi_batch_ind = bottom_rois[n * 5 + 0];
30 | float roi_start_w = bottom_rois[n * 5 + 1] * spatial_scale;
31 | float roi_start_h = bottom_rois[n * 5 + 2] * spatial_scale;
32 | float roi_end_w = bottom_rois[n * 5 + 3] * spatial_scale;
33 | float roi_end_h = bottom_rois[n * 5 + 4] * spatial_scale;
34 |
35 | // Force malformed ROIs to be 1x1
36 | float roi_width = fmaxf(roi_end_w - roi_start_w + 1., 0.);
37 | float roi_height = fmaxf(roi_end_h - roi_start_h + 1., 0.);
38 | float bin_size_h = roi_height / (aligned_height - 1.);
39 | float bin_size_w = roi_width / (aligned_width - 1.);
40 |
41 | float h = (float)(ph) * bin_size_h + roi_start_h;
42 | float w = (float)(pw) * bin_size_w + roi_start_w;
43 |
44 | int hstart = fminf(floor(h), height - 2);
45 | int wstart = fminf(floor(w), width - 2);
46 |
47 | int img_start = roi_batch_ind * channels * height * width;
48 |
49 | // bilinear interpolation
50 | if (h < 0 || h >= height || w < 0 || w >= width) {
51 | top_data[index] = 0.;
52 | } else {
53 | float h_ratio = h - (float)(hstart);
54 | float w_ratio = w - (float)(wstart);
55 | int upleft = img_start + (c * height + hstart) * width + wstart;
56 | int upright = upleft + 1;
57 | int downleft = upleft + width;
58 | int downright = downleft + 1;
59 |
60 | top_data[index] = bottom_data[upleft] * (1. - h_ratio) * (1. - w_ratio)
61 | + bottom_data[upright] * (1. - h_ratio) * w_ratio
62 | + bottom_data[downleft] * h_ratio * (1. - w_ratio)
63 | + bottom_data[downright] * h_ratio * w_ratio;
64 | }
65 | }
66 | }
67 |
68 |
69 | int ROIAlignForwardLaucher(const float* bottom_data, const float spatial_scale, const int num_rois, const int height, const int width,
70 | const int channels, const int aligned_height, const int aligned_width, const float* bottom_rois, float* top_data, cudaStream_t stream) {
71 | const int kThreadsPerBlock = 1024;
72 | const int output_size = num_rois * aligned_height * aligned_width * channels;
73 | cudaError_t err;
74 |
75 |
76 | ROIAlignForward<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>(
77 | output_size, bottom_data, spatial_scale, height, width, channels,
78 | aligned_height, aligned_width, bottom_rois, top_data);
79 |
80 | err = cudaGetLastError();
81 | if(cudaSuccess != err) {
82 | fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) );
83 | exit( -1 );
84 | }
85 |
86 | return 1;
87 | }
88 |
89 |
90 | __global__ void ROIAlignBackward(const int nthreads, const float* top_diff, const float spatial_scale, const int height, const int width,
91 | const int channels, const int aligned_height, const int aligned_width, float* bottom_diff, const float* bottom_rois) {
92 | CUDA_1D_KERNEL_LOOP(index, nthreads) {
93 |
94 | // (n, c, ph, pw) is an element in the aligned output
95 | int pw = index % aligned_width;
96 | int ph = (index / aligned_width) % aligned_height;
97 | int c = (index / aligned_width / aligned_height) % channels;
98 | int n = index / aligned_width / aligned_height / channels;
99 |
100 | float roi_batch_ind = bottom_rois[n * 5 + 0];
101 | float roi_start_w = bottom_rois[n * 5 + 1] * spatial_scale;
102 | float roi_start_h = bottom_rois[n * 5 + 2] * spatial_scale;
103 | float roi_end_w = bottom_rois[n * 5 + 3] * spatial_scale;
104 | float roi_end_h = bottom_rois[n * 5 + 4] * spatial_scale;
105 | /* int roi_start_w = round(bottom_rois[1] * spatial_scale); */
106 | /* int roi_start_h = round(bottom_rois[2] * spatial_scale); */
107 | /* int roi_end_w = round(bottom_rois[3] * spatial_scale); */
108 | /* int roi_end_h = round(bottom_rois[4] * spatial_scale); */
109 |
110 | // Force malformed ROIs to be 1x1
111 | float roi_width = fmaxf(roi_end_w - roi_start_w + 1., 0.);
112 | float roi_height = fmaxf(roi_end_h - roi_start_h + 1., 0.);
113 | float bin_size_h = roi_height / (aligned_height - 1.);
114 | float bin_size_w = roi_width / (aligned_width - 1.);
115 |
116 | float h = (float)(ph) * bin_size_h + roi_start_h;
117 | float w = (float)(pw) * bin_size_w + roi_start_w;
118 |
119 | int hstart = fminf(floor(h), height - 2);
120 | int wstart = fminf(floor(w), width - 2);
121 |
122 | int img_start = roi_batch_ind * channels * height * width;
123 |
124 | // bilinear interpolation
125 | if (!(h < 0 || h >= height || w < 0 || w >= width)) {
126 | float h_ratio = h - (float)(hstart);
127 | float w_ratio = w - (float)(wstart);
128 | int upleft = img_start + (c * height + hstart) * width + wstart;
129 | int upright = upleft + 1;
130 | int downleft = upleft + width;
131 | int downright = downleft + 1;
132 |
133 | atomicAdd(bottom_diff + upleft, top_diff[index] * (1. - h_ratio) * (1 - w_ratio));
134 | atomicAdd(bottom_diff + upright, top_diff[index] * (1. - h_ratio) * w_ratio);
135 | atomicAdd(bottom_diff + downleft, top_diff[index] * h_ratio * (1 - w_ratio));
136 | atomicAdd(bottom_diff + downright, top_diff[index] * h_ratio * w_ratio);
137 | }
138 | }
139 | }
140 |
141 | int ROIAlignBackwardLaucher(const float* top_diff, const float spatial_scale, const int batch_size, const int num_rois, const int height, const int width,
142 | const int channels, const int aligned_height, const int aligned_width, const float* bottom_rois, float* bottom_diff, cudaStream_t stream) {
143 | const int kThreadsPerBlock = 1024;
144 | const int output_size = num_rois * aligned_height * aligned_width * channels;
145 | cudaError_t err;
146 |
147 | ROIAlignBackward<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>(
148 | output_size, top_diff, spatial_scale, height, width, channels,
149 | aligned_height, aligned_width, bottom_diff, bottom_rois);
150 |
151 | err = cudaGetLastError();
152 | if(cudaSuccess != err) {
153 | fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) );
154 | exit( -1 );
155 | }
156 |
157 | return 1;
158 | }
159 |
--------------------------------------------------------------------------------
/untils/roi_align/src/roi_align_kernel.h:
--------------------------------------------------------------------------------
1 | #ifndef _ROI_ALIGN_KERNEL
2 | #define _ROI_ALIGN_KERNEL
3 |
4 |
5 | __global__ void ROIAlignForward(const int nthreads, const float* bottom_data,
6 | const float spatial_scale, const int height, const int width,
7 | const int channels, const int aligned_height, const int aligned_width,
8 | const float* bottom_rois, float* top_data);
9 |
10 | int ROIAlignForwardLaucher(
11 | const float* bottom_data, const float spatial_scale, const int num_rois, const int height,
12 | const int width, const int channels, const int aligned_height,
13 | const int aligned_width, const float* bottom_rois,
14 | float* top_data, cudaStream_t stream);
15 |
16 | __global__ void ROIAlignBackward(const int nthreads, const float* top_diff,
17 | const float spatial_scale, const int height, const int width,
18 | const int channels, const int aligned_height, const int aligned_width,
19 | float* bottom_diff, const float* bottom_rois);
20 |
21 | int ROIAlignBackwardLaucher(const float* top_diff, const float spatial_scale, const int batch_size, const int num_rois,
22 | const int height, const int width, const int channels, const int aligned_height,
23 | const int aligned_width, const float* bottom_rois,
24 | float* bottom_diff, cudaStream_t stream);
25 |
26 | #endif
27 |
28 |
--------------------------------------------------------------------------------
|