├── README.md
├── bib
├── detection_filter.py
├── eval.py
├── inplace_abn
├── __init__.py
├── bn.py
├── functions.py
└── src
│ ├── common.h
│ ├── inplace_abn.cpp
│ ├── inplace_abn.h
│ ├── inplace_abn_cpu.cpp
│ └── inplace_abn_cuda.cu
├── metrics
└── iou.py
├── networks
├── __init__.py
├── deeplabv3.py
├── inits.py
└── nets.py
├── requirements.txt
└── utils
├── __init__.py
└── parallel.py
/README.md:
--------------------------------------------------------------------------------
1 | # Anchor diffusion VOS
2 |
3 |
4 | This repository contains code for the paper
5 |
6 | **Anchor Diffusion for Unsupervised Video Object Segmentation**
7 | Zhao Yang\*, [Qiang Wang](http://www.robots.ox.ac.uk/~qwang/)\*, [Luca Bertinetto](http://www.robots.ox.ac.uk/~luca), [Weiming Hu](https://scholar.google.com/citations?user=Wl4tl4QAAAAJ&hl=en), [Song Bai](http://songbai.site/), [Philip H.S. Torr](http://www.robots.ox.ac.uk/~tvg/)
8 | **ICCV 2019** | **[PDF](https://arxiv.org/abs/1910.10895)** | **[BibTex](bib)**
9 |
10 | ## Setup
11 | Code tested for Ubuntu 16.04, Python 3.7, PyTorch 0.4.1, and CUDA 9.2.
12 |
13 | * Clone the repository and change to the new directory.
14 | ```
15 | git clone https://github.com/yz93/anchor-diff-VOS-internal.git && cd anchor-diff-VOS
16 | ```
17 | * Save the working directory to an environment variable for reference.
18 | ```shell
19 | export AnchorDiff=$PWD
20 | ```
21 | * Set up a new conda environment.
22 | * For installing PyTorch 0.4.1 with different versions of CUDA, see [here](https://pytorch.org/get-started/previous-versions/#via-conda).
23 | ```
24 | conda create -n anchordiff python=3.7 pytorch=0.4.1 cuda92 -c pytorch
25 | source activate anchordiff
26 | pip install -r requirements.txt
27 | ```
28 |
29 |
30 | ## Data preparation
31 | - Download the data set
32 | ```shell
33 | cd $AnchorDiff
34 | wget https://data.vision.ee.ethz.ch/csergi/share/davis/DAVIS-2017-trainval-480p.zip
35 | unzip DAVIS-2017-trainval-480p.zip -d data
36 | ```
37 | * Download [pre-trained weights](https://drive.google.com/file/d/1A6ozn2FT2Ef1dn7HGNgUb1WboWFN2NON/view?usp=sharing) (1.5G) to $AnchorDiff.
38 | ```shell
39 | cd $AnchorDiff
40 | unzip snapshots.zip -d snapshots
41 | ```
42 | * (If you do not intend to apply instance pruning described in the paper, feel free to skip this.) Download the detection results that we have computed using [ExtremeNet](https://github.com/xingyizhou/ExtremeNet),
43 | and generate the pruning masks.
44 | ```shell
45 | cd $AnchorDiff
46 | wget www.robots.ox.ac.uk/~yz/detection.zip
47 | unzip detection.zip
48 | python detection_filter.py
49 | ```
50 |
51 | ## Evaluation on [DAVIS 2016](https://davischallenge.org/davis2016/code.html)
52 | * Examples for evaluating mean IoU on the validation set with options,
53 | * *save-mask* (default 'True') for saving the predicted masks,
54 | * *ms-mirror* (default 'False') for multiple-scale and mirrored input (slow),
55 | * *inst-prune* (default 'False') for instance pruning,
56 | * *model* (default 'ad') specifying models in Table 1 of the paper,
57 | * *eval-sal* (default 'False') for computing saliency measures, MAE and F-score.
58 | ```shell
59 | cd $AnchorDiff
60 | python eval.py
61 | python eval.py --ms-mirror True --inst-prune True --eval-sal True
62 | ```
63 | * Use the [benchmark tool](https://github.com/davisvideochallenge/davis-matlab) to evaluate the saved masks under more metrics.
64 | * [Pre-computed results](https://www.robots.ox.ac.uk/~yz/val_results.zip)
65 |
66 | ## License
67 | The [MIT License](https://choosealicense.com/licenses/mit/).
68 |
--------------------------------------------------------------------------------
/bib:
--------------------------------------------------------------------------------
1 | @inproceedings{yang2019anchor,
2 | title={Anchor Diffusion for Unsupervised Video Object Segmentation},
3 | author={Yang, Zhao and Wang, Qiang and Bertinetto, Luca and Bai, Song and Hu, Weiming and Torr, Philip H.S.},
4 | booktitle={ICCV},
5 | year={2019}
6 | }
7 |
--------------------------------------------------------------------------------
/detection_filter.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import cv2
4 | import numpy as np
5 | import pickle
6 | import shutil
7 |
8 |
9 | def _IoU(rect1, rect2):
10 | def inter(rect1, rect2):
11 | x1 = max(rect1[0], rect2[0])
12 | y1 = max(rect1[1], rect2[1])
13 | x2 = min(rect1[2], rect2[2])
14 | y2 = min(rect1[3], rect2[3])
15 | return max(x2 - x1 + 1, 0) * max(y2 - y1 + 1, 0) * 1.
16 |
17 | def area(rect):
18 | x1, y1, x2, y2 = rect
19 | return (x2 - x1 + 1) * (y2 - y1 + 1)
20 |
21 | ii = inter(rect1, rect2)
22 | iou = ii / (area(rect1) + area(rect2) - ii)
23 | return iou
24 |
25 |
26 | def vis_mask(img, mask, col, alpha=0.4, show_border=True, border_thick=2):
27 | """Visualizes a single binary mask."""
28 |
29 | img = img.astype(np.float32)
30 | idx = np.nonzero(mask)
31 |
32 | img[idx[0], idx[1], :] *= 1.0 - alpha
33 | img[idx[0], idx[1], :] += alpha * col
34 | _WHITE = (255, 255, 255)
35 | if show_border:
36 | contours, _ = cv2.findContours(mask.copy(), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
37 | cv2.drawContours(img, contours, -1, _WHITE, border_thick, cv2.LINE_AA)
38 |
39 | return img.astype(np.uint8)
40 |
41 |
42 | def colormap(rgb=False):
43 | color_list = np.array(
44 | [
45 | 0.000, 0.447, 0.741,
46 | 0.850, 0.325, 0.098,
47 | 0.929, 0.694, 0.125,
48 | 0.494, 0.184, 0.556,
49 | 0.466, 0.674, 0.188,
50 | 0.301, 0.745, 0.933,
51 | 0.635, 0.078, 0.184,
52 | 0.300, 0.300, 0.300,
53 | 0.600, 0.600, 0.600,
54 | 1.000, 0.000, 0.000,
55 | 1.000, 0.500, 0.000,
56 | 0.749, 0.749, 0.000,
57 | 0.000, 1.000, 0.000,
58 | 0.000, 0.000, 1.000,
59 | 0.667, 0.000, 1.000,
60 | 0.333, 0.333, 0.000,
61 | 0.333, 0.667, 0.000,
62 | 0.333, 1.000, 0.000,
63 | 0.667, 0.333, 0.000,
64 | 0.667, 0.667, 0.000,
65 | 0.667, 1.000, 0.000,
66 | 1.000, 0.333, 0.000,
67 | 1.000, 0.667, 0.000,
68 | 1.000, 1.000, 0.000,
69 | 0.000, 0.333, 0.500,
70 | 0.000, 0.667, 0.500,
71 | 0.000, 1.000, 0.500,
72 | 0.333, 0.000, 0.500,
73 | 0.333, 0.333, 0.500,
74 | 0.333, 0.667, 0.500,
75 | 0.333, 1.000, 0.500,
76 | 0.667, 0.000, 0.500,
77 | 0.667, 0.333, 0.500,
78 | 0.667, 0.667, 0.500,
79 | 0.667, 1.000, 0.500,
80 | 1.000, 0.000, 0.500,
81 | 1.000, 0.333, 0.500,
82 | 1.000, 0.667, 0.500,
83 | 1.000, 1.000, 0.500,
84 | 0.000, 0.333, 1.000,
85 | 0.000, 0.667, 1.000,
86 | 0.000, 1.000, 1.000,
87 | 0.333, 0.000, 1.000,
88 | 0.333, 0.333, 1.000,
89 | 0.333, 0.667, 1.000,
90 | 0.333, 1.000, 1.000,
91 | 0.667, 0.000, 1.000,
92 | 0.667, 0.333, 1.000,
93 | 0.667, 0.667, 1.000,
94 | 0.667, 1.000, 1.000,
95 | 1.000, 0.000, 1.000,
96 | 1.000, 0.333, 1.000,
97 | 1.000, 0.667, 1.000,
98 | 0.167, 0.000, 0.000,
99 | 0.333, 0.000, 0.000,
100 | 0.500, 0.000, 0.000,
101 | 0.667, 0.000, 0.000,
102 | 0.833, 0.000, 0.000,
103 | 1.000, 0.000, 0.000,
104 | 0.000, 0.167, 0.000,
105 | 0.000, 0.333, 0.000,
106 | 0.000, 0.500, 0.000,
107 | 0.000, 0.667, 0.000,
108 | 0.000, 0.833, 0.000,
109 | 0.000, 1.000, 0.000,
110 | 0.000, 0.000, 0.167,
111 | 0.000, 0.000, 0.333,
112 | 0.000, 0.000, 0.500,
113 | 0.000, 0.000, 0.667,
114 | 0.000, 0.000, 0.833,
115 | 0.000, 0.000, 1.000,
116 | 0.000, 0.000, 0.000,
117 | 0.143, 0.143, 0.143,
118 | 0.286, 0.286, 0.286,
119 | 0.429, 0.429, 0.429,
120 | 0.571, 0.571, 0.571,
121 | 0.714, 0.714, 0.714,
122 | 0.857, 0.857, 0.857,
123 | 1.000, 1.000, 1.000
124 | ]
125 | ).astype(np.float32)
126 | color_list = color_list.reshape((-1, 3)) * 255
127 | if not rgb:
128 | color_list = color_list[:, ::-1]
129 | return color_list
130 |
131 |
132 | videos = [i_id.strip() for i_id in open(os.path.join('./data/DAVIS/', 'ImageSets', '2016', 'val.txt'))]
133 | train_videos = [i_id.strip() for i_id in open(os.path.join('./data/DAVIS/', 'ImageSets', '2016', 'train.txt'))]
134 | frame_count = []
135 | for video in train_videos:
136 | img_files = sorted(
137 | glob.glob(os.path.join('./data/DAVIS/', 'JPEGImages', '480p', video, '*.jpg')))
138 | frame_count.append(len(img_files))
139 | mean_frame_count = np.mean(frame_count)
140 |
141 | out_dir = './inst_prune'
142 | if os.path.exists(out_dir):
143 | shutil.rmtree(out_dir)
144 |
145 | for vid, video in enumerate(videos):
146 | def load_obj(name):
147 | with open('detection/' + name + '.pkl', 'rb') as f:
148 | return pickle.load(f)
149 | if not os.path.exists('detection/' + video + '.pkl'):
150 | print('no detection on:', video)
151 | continue
152 | detect_res = load_obj(video)
153 | frame_len = len(detect_res)
154 | bboxes_all = []
155 | for frame_info in detect_res:
156 | for instance in detect_res[frame_info]:
157 | bboxes_all.append(instance['bbox'])
158 |
159 | mean_bboxes = len(bboxes_all)/frame_len
160 | first_remove = -1
161 | if mean_bboxes > 3:
162 | color_list = colormap(rgb=True)
163 | size_bboxes = [(bbox[3] - bbox[1]) * (bbox[2] - bbox[0]) for bbox in bboxes_all]
164 | size_bboxes = sorted(size_bboxes)
165 | size_bboxes_target = size_bboxes[-frame_len]
166 |
167 | img_files = sorted(
168 | glob.glob(os.path.join('./data/DAVIS/', 'JPEGImages', '480p', video, '*.jpg')))
169 |
170 | for f, img_file in enumerate(img_files):
171 | im = cv2.imread(img_file, cv2.IMREAD_COLOR)
172 | frame_bboxes = []
173 | frame_masks = []
174 | for id, instance in enumerate(detect_res[f]):
175 | frame_bboxes.append(instance['bbox'])
176 | frame_masks.append(instance['mask'])
177 | bbox = instance['bbox']
178 | mask = instance['mask']
179 | cv2.rectangle(im, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (0, 255, 0), 3)
180 | im = vis_mask(im, mask, color_list[id % len(color_list), :3], alpha=0.4)
181 |
182 | size_bboxes = [(bbox[3]-bbox[1])*(bbox[2]-bbox[0])for bbox in frame_bboxes]
183 | score_bboxes = [bbox[4] for bbox in frame_bboxes]
184 | size_bboxes_ind = np.argsort(size_bboxes)
185 | size_bboxes = sorted(size_bboxes)
186 |
187 | target_box = frame_bboxes[size_bboxes_ind[-1]]
188 |
189 | im_mask = np.ones((im.shape[0], im.shape[1]))
190 |
191 | for bbox, mask in zip(frame_bboxes, frame_masks):
192 | static_object_count = 0
193 | if (bbox[3]-bbox[1])*(bbox[2]-bbox[0]) > 47000 or (bbox[3]-bbox[1])*(bbox[2]-bbox[0]) == size_bboxes[-1]:
194 | continue
195 | for i in range(len(bboxes_all)):
196 | if _IoU(bbox[:4], bboxes_all[i][:4]) > 0.6:
197 | static_object_count += 1
198 |
199 | if static_object_count > 0.4 * mean_frame_count:
200 | im_mask = im_mask * (1 - mask)
201 | cv2.putText(im, 'static', (bbox[0], bbox[1]), 2, 2, (0, 255, 0))
202 |
203 | if len(size_bboxes) > 1:
204 | if size_bboxes[-1] > 10000 and size_bboxes[-1] > 2*size_bboxes[-2] and \
205 | size_bboxes[-1] > size_bboxes_target and \
206 | (target_box[-1] == 0 or target_box[-1] == 2) and \
207 | (target_box[2]-target_box[0])/(target_box[3]-target_box[1]) < 3:
208 | suppress_small = True
209 | if first_remove == -1:
210 | first_remove = f
211 | if first_remove > 20:
212 | break
213 | else:
214 | suppress_small = False
215 |
216 | if suppress_small:
217 | for bbox, mask in zip(frame_bboxes, frame_masks):
218 | cx = (bbox[3]+bbox[1])/2
219 | cy = (bbox[2]+bbox[0])/2
220 | cx0 = (target_box[3] + target_box[1]) / 2
221 | cy0 = (target_box[2] + target_box[0]) / 2
222 | d_dist = abs(cx-cx0)+ abs(cy-cy0)
223 | if ((bbox[3]-bbox[1])*(bbox[2]-bbox[0]) < size_bboxes[-1]//3 or
224 | ((cy < 300 or cy > 600) and d_dist > 200)) and \
225 | _IoU(target_box[:4], bbox[:4]) <= 0.1 and \
226 | bbox[-1] == frame_bboxes[size_bboxes_ind[-1]][-1]:
227 | # print(cx)
228 |
229 | # cv2.rectangle(im, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (0, 0, 255), 5)
230 | # im_mask[int(bbox[1]):int(bbox[3]), int(bbox[0]):int(bbox[2])] = 0
231 | im_mask = im_mask*(1 - mask)
232 |
233 | result_dir = os.path.join(out_dir, video)
234 | if not os.path.exists(result_dir):
235 | os.makedirs(result_dir)
236 |
237 | cv2.imwrite(os.path.join(result_dir, img_file.split('/')[-1].split('.')[0] + '.png'), im_mask*255)
238 | im_mask = (im_mask*255).astype(np.uint8)
239 | im = cv2.vconcat((cv2.cvtColor(im_mask.copy(), cv2.COLOR_GRAY2BGR), im))
240 | im = cv2.resize(im, dsize=None, fx=0.5, fy=0.5)
241 | # cv2.imshow('mask', im_mask*255)
242 | # cv2.imshow(video, im)
243 | # cv2.waitKey(1)
244 | cv2.destroyAllWindows()
245 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import torch
4 | from torch.nn import functional as F
5 | import numpy as np
6 | import torch.backends.cudnn as cudnn
7 | import os
8 | import os.path as osp
9 | import glob
10 | import cv2
11 | from PIL import Image
12 | import pickle
13 | from networks.deeplabv3 import ResNetDeepLabv3
14 | from networks.nets import AnchorDiffNet, ConcatNet, InterFrameNet, IntraFrameNet
15 |
16 | import timeit
17 | from metrics.iou import get_iou
18 |
19 | from sklearn.metrics import precision_recall_curve
20 |
21 |
22 | start = timeit.default_timer()
23 |
24 | IMG_MEAN = np.array((104.00698793, 116.66876762, 122.67891434), dtype=np.float32)
25 | BACKBONE = 'ResNet101'
26 | BN_TYPE = 'sync'
27 | DATA_DIRECTORY = './data/DAVIS/'
28 | EMBEDDING_SIZE = 128
29 | THRESHOLD = 0.5 # the threshold over raw scores (not the output of a sigmoid function)
30 | PYRAMID_POOLING = 'deeplabv3'
31 | VISUALIZE = False # False True
32 | MS_MIRROR = False # False True
33 | INSTANCE_PRUNING = False # False True
34 | PARENT_DIR_WEIGHTS = './snapshots/'
35 | SAVE_MASK = True # False True
36 | SAVE_MASK_DIR = './pred_masks/'
37 | EVAL_SAL = False # False True
38 | SAVE_HEATMAP_DIR = './pred_heatmaps/'
39 | # MODEL = 'base'
40 | # MODEL = 'concat'
41 | # MODEL = 'intra'
42 | # MODEL = 'inter'
43 | MODEL = 'ad'
44 |
45 |
46 | def str2bool(v):
47 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
48 | return True
49 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
50 | return False
51 | else:
52 | raise argparse.ArgumentTypeError('Boolean value expected.')
53 |
54 |
55 | def bool2str(v):
56 | if v:
57 | return 'True'
58 | else:
59 | return 'False'
60 |
61 |
62 | def get_arguments():
63 | parser = argparse.ArgumentParser(description="Anchor Diffusion VOS Test")
64 | parser.add_argument("--backbone", type=str, default=BACKBONE,
65 | help="Feature encoder.")
66 | parser.add_argument("--bn-type", type=str, default=BN_TYPE,
67 | help="BatchNorm MODE, [old/sync].")
68 | parser.add_argument("--data-dir", type=str, default=DATA_DIRECTORY,
69 | help="Path to the data directory.")
70 | parser.add_argument("--embedding-size", type=int, default=EMBEDDING_SIZE,
71 | help="Number of dimensions along the channel axis of pixel embeddings.")
72 | parser.add_argument("--eval-sal", type=str2bool, default=bool2str(EVAL_SAL),
73 | help="Whether to report MAE and F-score.")
74 | parser.add_argument("--inst-prune", type=str2bool, default=bool2str(INSTANCE_PRUNING),
75 | help="Whether to post-process the results with instance pruning")
76 | parser.add_argument("--ms-mirror", type=str2bool, default=bool2str(MS_MIRROR),
77 | help="Whether to mirror and re-scale the input image.")
78 | parser.add_argument("--model", type=str, default=MODEL, help="Overall models.")
79 | parser.add_argument("--pyramid-pooling", type=str, default=PYRAMID_POOLING,
80 | help="Pyramid pooling methods.")
81 | parser.add_argument("--parent-dir-weights", type=str, default=PARENT_DIR_WEIGHTS,
82 | help="Parent directory of pre-trained weights")
83 | parser.add_argument("--save-heatmap-dir", type=str, default=SAVE_HEATMAP_DIR,
84 | help="Path to save the outputs of sigmoid for MAE and F-score evaluation.")
85 | parser.add_argument("--save-mask", type=str2bool, default=bool2str(SAVE_MASK),
86 | help="Whether to save the predicted masks.")
87 | parser.add_argument("--save-mask-dir", type=str, default=SAVE_MASK_DIR,
88 | help="Path to save the predicted masks.")
89 | parser.add_argument("--threshold", type=float, default=THRESHOLD,
90 | help="Threshold on the raw scores (the logits before sigmoid normalization).")
91 | parser.add_argument("--visualize", type=str2bool, default=bool2str(VISUALIZE),
92 | help="Whether to visualize the predicted masks during inference.")
93 | parser.add_argument("--video", type=str, default='',
94 | help="If non-empty, then run inference on the specified video.")
95 | return parser.parse_args()
96 |
97 |
98 | args = get_arguments()
99 |
100 |
101 | def main():
102 | cudnn.enabled = True
103 |
104 | if args.model == 'base':
105 | model = ResNetDeepLabv3(backbone=args.backbone)
106 | elif args.model == 'intra':
107 | model = IntraFrameNet(backbone=args.backbone, pyramid_pooling=args.pyramid_pooling,
108 | embedding=args.embedding_size, batch_mode='sync')
109 | elif args.model == 'inter':
110 | model = InterFrameNet(backbone=args.backbone, pyramid_pooling=args.pyramid_pooling,
111 | embedding=args.embedding_size, batch_mode='sync')
112 | elif args.model == 'concat':
113 | model = ConcatNet(backbone=args.backbone, pyramid_pooling=args.pyramid_pooling,
114 | embedding=args.embedding_size, batch_mode='sync')
115 | elif args.model == 'ad':
116 | model = AnchorDiffNet(backbone=args.backbone, pyramid_pooling=args.pyramid_pooling,
117 | embedding=args.embedding_size, batch_mode='sync')
118 |
119 | model.load_state_dict(torch.load(osp.join(args.parent_dir_weights, args.model+'.pth')))
120 | model.eval()
121 | model.float()
122 | model.cuda()
123 |
124 | with torch.no_grad():
125 | video_mean_iou_list = []
126 | model.eval()
127 | videos = [i_id.strip() for i_id in open(osp.join(args.data_dir, 'ImageSets', '2016', 'val.txt'))]
128 | if args.video and args.video in videos:
129 | videos = [args.video]
130 |
131 | for vid, video in enumerate(videos, start=1):
132 | curr_video_iou_list = []
133 | img_files = sorted(glob.glob(osp.join(args.data_dir, 'JPEGImages', '480p', video, '*.jpg')))
134 | ann_files = sorted(glob.glob(osp.join(args.data_dir, 'Annotations', '480p', video, '*.png')))
135 |
136 | if args.ms_mirror:
137 | resize_shape = [(857*0.75, 481*0.75), (857, 481), (857*1.5, 481*1.5)]
138 | resize_shape = [(int((s[0]-1)//8*8+1), int((s[1]-1)//8*8+1)) for s in resize_shape]
139 | mirror = True
140 | else:
141 | resize_shape = [(857, 481)]
142 | mirror = False
143 |
144 | reference_img = []
145 | for s in resize_shape:
146 | reference_img.append((np.asarray(cv2.resize(cv2.imread(img_files[0], cv2.IMREAD_COLOR), s),
147 | np.float32) - IMG_MEAN).transpose((2, 0, 1)))
148 | if mirror:
149 | for r in range(len(reference_img)):
150 | reference_img.append(reference_img[r][:, :, ::-1].copy())
151 | reference_img = [torch.from_numpy(np.expand_dims(r, axis=0)).cuda() for r in reference_img]
152 | reference_mask = np.array(Image.open(ann_files[0])) > 0
153 | reference_mask = torch.from_numpy(np.expand_dims(np.expand_dims(reference_mask.astype(np.float32),
154 | axis=0), axis=0)).cuda()
155 | H, W = reference_mask.size(2), reference_mask.size(3)
156 |
157 | if args.visualize:
158 | colors = np.random.randint(128, 255, size=(1, 3), dtype="uint8")
159 | colors = np.vstack([[0, 0, 0], colors]).astype("uint8")
160 |
161 | last_mask_num = 0
162 | last_mask = None
163 | last_mask_final = None
164 | kernel1 = np.ones((15, 15), np.uint8)
165 | kernel2 = np.ones((101, 101), np.uint8)
166 | kernel3 = np.ones((31, 31), np.uint8)
167 | predictions_all = []
168 | gt_all = []
169 |
170 | for f, (img_file, ann_file) in enumerate(zip(img_files, ann_files)):
171 | current_img = []
172 | for s in resize_shape:
173 | current_img.append((np.asarray(cv2.resize(
174 | cv2.imread(img_file, cv2.IMREAD_COLOR), s),
175 | np.float32) - IMG_MEAN).transpose((2, 0, 1)))
176 |
177 | if mirror:
178 | for c in range(len(current_img)):
179 | current_img.append(current_img[c][:, :, ::-1].copy())
180 |
181 | current_img = [torch.from_numpy(np.expand_dims(c, axis=0)).cuda() for c in current_img]
182 |
183 | current_mask = np.array(Image.open(ann_file)) > 0
184 | current_mask = torch.from_numpy(np.expand_dims(np.expand_dims(current_mask.astype(np.float32), axis=0), axis=0)).cuda()
185 |
186 | if args.model in ['base']:
187 | predictions = [model(cur) for ref, cur in zip(reference_img, current_img)]
188 | predictions = [F.interpolate(input=p[0], size=(H, W), mode='bilinear', align_corners=True) for p in predictions]
189 | elif args.model in ['intra']:
190 | predictions = [model(cur) for ref, cur in zip(reference_img, current_img)]
191 | predictions = [F.interpolate(input=p, size=(H, W), mode='bilinear', align_corners=True) for p in predictions]
192 | elif args.model in ['inter', 'concat', 'ad']:
193 | predictions = [model(ref, cur) for ref, cur in zip(reference_img, current_img)]
194 | predictions = [F.interpolate(input=p, size=(H, W), mode='bilinear', align_corners=True) for p in predictions]
195 |
196 | if mirror:
197 | for r in range(len(predictions)//2, len(predictions)):
198 | predictions[r] = torch.flip(predictions[r], [3])
199 | predictions = torch.mean(torch.stack(predictions, dim=0), 0)
200 |
201 | predictions_all.append(predictions.sigmoid().data.cpu().numpy()[0, 0].copy())
202 | gt_all.append(current_mask.data.cpu().numpy()[0, 0].astype(np.uint8).copy())
203 |
204 | if args.inst_prune:
205 | result_dir = os.path.join('inst_prune', video)
206 | if os.path.exists(os.path.join(result_dir, img_file.split('/')[-1].split('.')[0] + '.png')):
207 | detection_mask = np.array(
208 | Image.open(os.path.join(result_dir, img_file.split('/')[-1].split('.')[0] + '.png'))) > 0
209 | detection_mask = torch.from_numpy(
210 | np.expand_dims(np.expand_dims(detection_mask.astype(np.float32), axis=0), axis=0)).cuda()
211 | predictions = predictions * detection_mask
212 |
213 | process_now = (predictions > args.threshold).data.cpu().numpy().astype(np.uint8)[0, 0]
214 | if 100000 > process_now.sum() > 40000:
215 | last_mask_numpy = (predictions > args.threshold).data.cpu().numpy().astype(np.uint8)[0, 0]
216 | last_mask_numpy = cv2.morphologyEx(last_mask_numpy, cv2.MORPH_OPEN, kernel1)
217 | dilation = cv2.dilate(last_mask_numpy, kernel3, iterations=1)
218 | contours, _ = cv2.findContours(dilation, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
219 | cnt_area = [cv2.contourArea(cnt) for cnt in contours]
220 | if len(contours) > 1:
221 | contour = contours[np.argmax(cnt_area)]
222 | polygon = contour.reshape(-1, 2)
223 | x, y, w, h = cv2.boundingRect(polygon)
224 | x0, y0 = x, y
225 | x1 = x + w
226 | y1 = y + h
227 | mask_rect = torch.from_numpy(np.zeros_like(dilation).astype(np.float32)).cuda()
228 | mask_rect[y0:y1, x0:x1] = 1
229 | mask_rect = mask_rect.unsqueeze(0).unsqueeze(0)
230 | if np.max(cnt_area) > 30000:
231 | if last_mask_final is None or get_iou(last_mask_final, mask_rect, thresh=args.threshold) > 0.3:
232 | predictions = predictions * mask_rect
233 | last_mask_final = predictions.clone()
234 |
235 | if 100000 > last_mask_num > 5000:
236 | last_mask_numpy = (last_mask > args.threshold).data.cpu().numpy().astype(np.uint8)[0, 0]
237 | last_mask_numpy = cv2.morphologyEx(last_mask_numpy, cv2.MORPH_OPEN, kernel1)
238 | dilation = cv2.dilate(last_mask_numpy, kernel2, iterations=1)
239 | dilation = torch.from_numpy(dilation.astype(np.float32)).cuda()
240 |
241 | last_mask = predictions.clone()
242 | last_mask_num = (predictions > args.threshold).sum()
243 |
244 | predictions = predictions*dilation
245 | else:
246 | last_mask = predictions.clone()
247 | last_mask_num = (predictions > args.threshold).sum()
248 |
249 | iou_temp = get_iou(predictions, current_mask, thresh=args.threshold)
250 | if 0 < f < (len(ann_files)-1):
251 | curr_video_iou_list.append(iou_temp)
252 |
253 | if args.visualize:
254 | mask = colors[predictions.squeeze() > args.threshold]
255 | output = ((0.4 * cv2.imread(img_file)) + (0.6 * mask)).astype("uint8")
256 | cv2.putText(output, "%.3f" % (iou_temp.item()),
257 | (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2, cv2.LINE_AA)
258 |
259 | cv2.imshow(video, output)
260 | cv2.waitKey(1)
261 |
262 | suffix = args.ms_mirror*'ms_mirror'+(not args.ms_mirror)*'single'+args.inst_prune*'_prune'
263 | visual_path = osp.join('visualization', args.model + '_' + suffix, img_file.split('/')[-2])
264 | if not osp.exists(visual_path):
265 | os.makedirs(visual_path)
266 | cv2.imwrite(osp.join(visual_path, ann_file.split('/')[-1]), output)
267 |
268 | if args.save_mask:
269 | suffix = args.ms_mirror*'ms_mirror'+(not args.ms_mirror)*'single'+args.inst_prune*'_prune'
270 | if not osp.exists(osp.join(args.save_mask_dir, args.model, suffix, video)):
271 | os.makedirs(osp.join(args.save_mask_dir, args.model, suffix, video))
272 | cv2.imwrite(osp.join(args.save_mask_dir, args.model, suffix, video, ann_file.split('/')[-1]),
273 | (predictions.squeeze() > args.threshold).cpu().numpy())
274 |
275 | cv2.destroyAllWindows()
276 | video_mean_iou_list.append(sum(curr_video_iou_list)/len(curr_video_iou_list))
277 | print('{} {} {}'.format(vid, video, video_mean_iou_list[-1]))
278 |
279 | if args.eval_sal:
280 | if not osp.exists(args.save_heatmap_dir):
281 | os.makedirs(args.save_heatmap_dir)
282 | with open(args.save_heatmap_dir + video + '.pkl', 'wb') as f:
283 | pickle.dump({'pred': np.array(predictions_all), 'gt': np.array(gt_all)}, f, pickle.HIGHEST_PROTOCOL)
284 |
285 | mean_iou = sum(video_mean_iou_list)/len(video_mean_iou_list)
286 | print('mean_iou {}'.format(mean_iou))
287 | end = timeit.default_timer()
288 | print(end-start, 'seconds')
289 | # ==========================
290 | if args.eval_sal:
291 | pkl_files = glob.glob(args.save_heatmap_dir + '*.pkl')
292 | heatmap_gt = []
293 | heatmap_pred = []
294 | for i, pkl_file in enumerate(pkl_files):
295 | with open(pkl_file, 'rb') as f:
296 | info = pickle.load(f)
297 | heatmap_gt.append(np.array(info['gt'][1:-1]).flatten())
298 | heatmap_pred.append(np.array(info['pred'][1:-1]).flatten())
299 | heatmap_gt = np.hstack(heatmap_gt).flatten()
300 | heatmap_pred = np.hstack(heatmap_pred).flatten()
301 | precision, recall, _ = precision_recall_curve(heatmap_gt, heatmap_pred)
302 | Fmax = 2 * (precision * recall) / (precision + recall)
303 | print('MAE', np.mean(abs(heatmap_pred - heatmap_gt)))
304 | print('F_max', Fmax.max())
305 |
306 | n_sample = len(precision)//1000
307 | import scipy.io
308 | scipy.io.savemat('davis.mat', {'recall': recall[0::n_sample], 'precision': precision[0::n_sample]})
309 |
310 |
311 | if __name__ == '__main__':
312 | main()
313 |
--------------------------------------------------------------------------------
/inplace_abn/__init__.py:
--------------------------------------------------------------------------------
1 | from .bn import ABN, InPlaceABN, InPlaceABNSync
2 | from .functions import ACT_RELU, ACT_LEAKY_RELU, ACT_ELU, ACT_NONE
3 | from .misc import GlobalAvgPool2d
4 | from .residual import IdentityResidualBlock
5 | from .dense import DenseModule
6 |
--------------------------------------------------------------------------------
/inplace_abn/bn.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as functional
5 |
6 | try:
7 | from queue import Queue
8 | except ImportError:
9 | from Queue import Queue
10 |
11 | BASE_DIR = os.path.dirname(os.path.abspath(__file__))
12 | sys.path.append(BASE_DIR)
13 | sys.path.append(os.path.join(BASE_DIR, '../src'))
14 | from functions import *
15 |
16 |
17 | class ABN(nn.Module):
18 | """Activated Batch Normalization
19 |
20 | This gathers a `BatchNorm2d` and an activation function in a single module
21 | """
22 |
23 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu", slope=0.01):
24 | """Creates an Activated Batch Normalization module
25 |
26 | Parameters
27 | ----------
28 | num_features : int
29 | Number of feature channels in the input and output.
30 | eps : float
31 | Small constant to prevent numerical issues.
32 | momentum : float
33 | Momentum factor applied to compute running statistics as.
34 | affine : bool
35 | If `True` apply learned scale and shift transformation after normalization.
36 | activation : str
37 | Name of the activation functions, one of: `leaky_relu`, `elu` or `none`.
38 | slope : float
39 | Negative slope for the `leaky_relu` activation.
40 | """
41 | super(ABN, self).__init__()
42 | self.num_features = num_features
43 | self.affine = affine
44 | self.eps = eps
45 | self.momentum = momentum
46 | self.activation = activation
47 | self.slope = slope
48 | if self.affine:
49 | self.weight = nn.Parameter(torch.ones(num_features))
50 | self.bias = nn.Parameter(torch.zeros(num_features))
51 | else:
52 | self.register_parameter('weight', None)
53 | self.register_parameter('bias', None)
54 | self.register_buffer('running_mean', torch.zeros(num_features))
55 | self.register_buffer('running_var', torch.ones(num_features))
56 | self.reset_parameters()
57 |
58 | def reset_parameters(self):
59 | nn.init.constant_(self.running_mean, 0)
60 | nn.init.constant_(self.running_var, 1)
61 | if self.affine:
62 | nn.init.constant_(self.weight, 1)
63 | nn.init.constant_(self.bias, 0)
64 |
65 | def forward(self, x):
66 | x = functional.batch_norm(x, self.running_mean, self.running_var, self.weight, self.bias,
67 | self.training, self.momentum, self.eps)
68 |
69 | if self.activation == ACT_RELU:
70 | return functional.relu(x, inplace=True)
71 | elif self.activation == ACT_LEAKY_RELU:
72 | return functional.leaky_relu(x, negative_slope=self.slope, inplace=True)
73 | elif self.activation == ACT_ELU:
74 | return functional.elu(x, inplace=True)
75 | else:
76 | return x
77 |
78 | def __repr__(self):
79 | rep = '{name}({num_features}, eps={eps}, momentum={momentum},' \
80 | ' affine={affine}, activation={activation}'
81 | if self.activation == "leaky_relu":
82 | rep += ', slope={slope})'
83 | else:
84 | rep += ')'
85 | return rep.format(name=self.__class__.__name__, **self.__dict__)
86 |
87 |
88 | class InPlaceABN(ABN):
89 | """InPlace Activated Batch Normalization"""
90 |
91 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu", slope=0.01):
92 | """Creates an InPlace Activated Batch Normalization module
93 |
94 | Parameters
95 | ----------
96 | num_features : int
97 | Number of feature channels in the input and output.
98 | eps : float
99 | Small constant to prevent numerical issues.
100 | momentum : float
101 | Momentum factor applied to compute running statistics as.
102 | affine : bool
103 | If `True` apply learned scale and shift transformation after normalization.
104 | activation : str
105 | Name of the activation functions, one of: `leaky_relu`, `elu` or `none`.
106 | slope : float
107 | Negative slope for the `leaky_relu` activation.
108 | """
109 | super(InPlaceABN, self).__init__(num_features, eps, momentum, affine, activation, slope)
110 |
111 | def forward(self, x):
112 | return inplace_abn(x, self.weight, self.bias, self.running_mean, self.running_var,
113 | self.training, self.momentum, self.eps, self.activation, self.slope)
114 |
115 |
116 | class InPlaceABNSync(ABN):
117 | """InPlace Activated Batch Normalization with cross-GPU synchronization
118 |
119 | This assumes that it will be replicated across GPUs using the same mechanism as in `nn.DataParallel`.
120 | """
121 |
122 | def __init__(self, num_features, devices=None, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu",
123 | slope=0.01):
124 | """Creates a synchronized, InPlace Activated Batch Normalization module
125 |
126 | Parameters
127 | ----------
128 | num_features : int
129 | Number of feature channels in the input and output.
130 | devices : list of int or None
131 | IDs of the GPUs that will run the replicas of this module.
132 | eps : float
133 | Small constant to prevent numerical issues.
134 | momentum : float
135 | Momentum factor applied to compute running statistics as.
136 | affine : bool
137 | If `True` apply learned scale and shift transformation after normalization.
138 | activation : str
139 | Name of the activation functions, one of: `leaky_relu`, `elu` or `none`.
140 | slope : float
141 | Negative slope for the `leaky_relu` activation.
142 | """
143 | super(InPlaceABNSync, self).__init__(num_features, eps, momentum, affine, activation, slope)
144 | self.devices = devices if devices else list(range(torch.cuda.device_count()))
145 |
146 | # Initialize queues
147 | self.worker_ids = self.devices[1:]
148 | self.master_queue = Queue(len(self.worker_ids))
149 | self.worker_queues = [Queue(1) for _ in self.worker_ids]
150 |
151 | def forward(self, x):
152 | if x.get_device() == self.devices[0]:
153 | # Master mode
154 | extra = {
155 | "is_master": True,
156 | "master_queue": self.master_queue,
157 | "worker_queues": self.worker_queues,
158 | "worker_ids": self.worker_ids
159 | }
160 | else:
161 | # Worker mode
162 | extra = {
163 | "is_master": False,
164 | "master_queue": self.master_queue,
165 | "worker_queue": self.worker_queues[self.worker_ids.index(x.get_device())]
166 | }
167 |
168 | return inplace_abn_sync(x, self.weight, self.bias, self.running_mean, self.running_var,
169 | extra, self.training, self.momentum, self.eps, self.activation, self.slope)
170 |
171 | def __repr__(self):
172 | rep = '{name}({num_features}, eps={eps}, momentum={momentum},' \
173 | ' affine={affine}, devices={devices}, activation={activation}'
174 | if self.activation == "leaky_relu":
175 | rep += ', slope={slope})'
176 | else:
177 | rep += ')'
178 | return rep.format(name=self.__class__.__name__, **self.__dict__)
179 |
--------------------------------------------------------------------------------
/inplace_abn/functions.py:
--------------------------------------------------------------------------------
1 | from os import path
2 |
3 | import torch.autograd as autograd
4 | import torch.cuda.comm as comm
5 | from torch.autograd.function import once_differentiable
6 | from torch.utils.cpp_extension import load
7 |
8 | _src_path = path.join(path.dirname(path.abspath(__file__)), "src")
9 | _backend = load(name="inplace_abn",
10 | extra_cflags=["-O3"],
11 | sources=[path.join(_src_path, f) for f in [
12 | "inplace_abn.cpp",
13 | "inplace_abn_cpu.cpp",
14 | "inplace_abn_cuda.cu"
15 | ]],
16 | extra_cuda_cflags=["--expt-extended-lambda"])
17 |
18 | # Activation names
19 | ACT_RELU = "relu"
20 | ACT_LEAKY_RELU = "leaky_relu"
21 | ACT_ELU = "elu"
22 | ACT_NONE = "none"
23 |
24 |
25 | def _check(fn, *args, **kwargs):
26 | success = fn(*args, **kwargs)
27 | if not success:
28 | raise RuntimeError("CUDA Error encountered in {}".format(fn))
29 |
30 |
31 | def _broadcast_shape(x):
32 | out_size = []
33 | for i, s in enumerate(x.size()):
34 | if i != 1:
35 | out_size.append(1)
36 | else:
37 | out_size.append(s)
38 | return out_size
39 |
40 |
41 | def _reduce(x):
42 | if len(x.size()) == 2:
43 | return x.sum(dim=0)
44 | else:
45 | n, c = x.size()[0:2]
46 | return x.contiguous().view((n, c, -1)).sum(2).sum(0)
47 |
48 |
49 | def _count_samples(x):
50 | count = 1
51 | for i, s in enumerate(x.size()):
52 | if i != 1:
53 | count *= s
54 | return count
55 |
56 |
57 | def _act_forward(ctx, x):
58 | if ctx.activation == ACT_LEAKY_RELU:
59 | _backend.leaky_relu_forward(x, ctx.slope)
60 | elif ctx.activation == ACT_ELU:
61 | _backend.elu_forward(x)
62 | elif ctx.activation == ACT_NONE:
63 | pass
64 |
65 |
66 | def _act_backward(ctx, x, dx):
67 | if ctx.activation == ACT_LEAKY_RELU:
68 | _backend.leaky_relu_backward(x, dx, ctx.slope)
69 | elif ctx.activation == ACT_ELU:
70 | _backend.elu_backward(x, dx)
71 | elif ctx.activation == ACT_NONE:
72 | pass
73 |
74 |
75 | class InPlaceABN(autograd.Function):
76 | @staticmethod
77 | def forward(ctx, x, weight, bias, running_mean, running_var,
78 | training=True, momentum=0.1, eps=1e-05, activation=ACT_LEAKY_RELU, slope=0.01):
79 | # Save context
80 | ctx.training = training
81 | ctx.momentum = momentum
82 | ctx.eps = eps
83 | ctx.activation = activation
84 | ctx.slope = slope
85 | ctx.affine = weight is not None and bias is not None
86 |
87 | # Prepare inputs
88 | count = _count_samples(x)
89 | x = x.contiguous()
90 | weight = weight.contiguous() if ctx.affine else x.new_empty(0)
91 | bias = bias.contiguous() if ctx.affine else x.new_empty(0)
92 |
93 | if ctx.training:
94 | mean, var = _backend.mean_var(x)
95 |
96 | # Update running stats
97 | running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean)
98 | running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * var * count / (count - 1))
99 |
100 | # Mark in-place modified tensors
101 | ctx.mark_dirty(x, running_mean, running_var)
102 | else:
103 | mean, var = running_mean.contiguous(), running_var.contiguous()
104 | ctx.mark_dirty(x)
105 |
106 | # BN forward + activation
107 | _backend.forward(x, mean, var, weight, bias, ctx.affine, ctx.eps)
108 | _act_forward(ctx, x)
109 |
110 | # Output
111 | ctx.var = var
112 | ctx.save_for_backward(x, var, weight, bias)
113 | return x
114 |
115 | @staticmethod
116 | @once_differentiable
117 | def backward(ctx, dz):
118 | z, var, weight, bias = ctx.saved_tensors
119 | dz = dz.contiguous()
120 |
121 | # Undo activation
122 | _act_backward(ctx, z, dz)
123 |
124 | if ctx.training:
125 | edz, eydz = _backend.edz_eydz(z, dz, weight, bias, ctx.affine, ctx.eps)
126 | else:
127 | # TODO: implement simplified CUDA backward for inference mode
128 | edz = dz.new_zeros(dz.size(1))
129 | eydz = dz.new_zeros(dz.size(1))
130 |
131 | dx, dweight, dbias = _backend.backward(z, dz, var, weight, bias, edz, eydz, ctx.affine, ctx.eps)
132 | dweight = dweight if ctx.affine else None
133 | dbias = dbias if ctx.affine else None
134 |
135 | return dx, dweight, dbias, None, None, None, None, None, None, None
136 |
137 |
138 | class InPlaceABNSync(autograd.Function):
139 | @classmethod
140 | def forward(cls, ctx, x, weight, bias, running_mean, running_var,
141 | extra, training=True, momentum=0.1, eps=1e-05, activation=ACT_LEAKY_RELU, slope=0.01):
142 | # Save context
143 | cls._parse_extra(ctx, extra)
144 | ctx.training = training
145 | ctx.momentum = momentum
146 | ctx.eps = eps
147 | ctx.activation = activation
148 | ctx.slope = slope
149 | ctx.affine = weight is not None and bias is not None
150 |
151 | # Prepare inputs
152 | count = _count_samples(x) * (ctx.master_queue.maxsize + 1)
153 | x = x.contiguous()
154 | weight = weight.contiguous() if ctx.affine else x.new_empty(0)
155 | bias = bias.contiguous() if ctx.affine else x.new_empty(0)
156 |
157 | if ctx.training:
158 | mean, var = _backend.mean_var(x)
159 |
160 | if ctx.is_master:
161 | means, vars = [mean.unsqueeze(0)], [var.unsqueeze(0)]
162 | for _ in range(ctx.master_queue.maxsize):
163 | mean_w, var_w = ctx.master_queue.get()
164 | ctx.master_queue.task_done()
165 | means.append(mean_w.unsqueeze(0))
166 | vars.append(var_w.unsqueeze(0))
167 |
168 | means = comm.gather(means)
169 | vars = comm.gather(vars)
170 |
171 | mean = means.mean(0)
172 | var = (vars + (mean - means) ** 2).mean(0)
173 |
174 | tensors = comm.broadcast_coalesced((mean, var), [mean.get_device()] + ctx.worker_ids)
175 | for ts, queue in zip(tensors[1:], ctx.worker_queues):
176 | queue.put(ts)
177 | else:
178 | ctx.master_queue.put((mean, var))
179 | mean, var = ctx.worker_queue.get()
180 | ctx.worker_queue.task_done()
181 |
182 | # Update running stats
183 | running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean)
184 | running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * var * count / (count - 1))
185 |
186 | # Mark in-place modified tensors
187 | ctx.mark_dirty(x, running_mean, running_var)
188 | else:
189 | mean, var = running_mean.contiguous(), running_var.contiguous()
190 | ctx.mark_dirty(x)
191 |
192 | # BN forward + activation
193 | _backend.forward(x, mean, var, weight, bias, ctx.affine, ctx.eps)
194 | _act_forward(ctx, x)
195 |
196 | # Output
197 | ctx.var = var
198 | ctx.save_for_backward(x, var, weight, bias)
199 | return x
200 |
201 | @staticmethod
202 | @once_differentiable
203 | def backward(ctx, dz):
204 | z, var, weight, bias = ctx.saved_tensors
205 | dz = dz.contiguous()
206 |
207 | # Undo activation
208 | _act_backward(ctx, z, dz)
209 |
210 | if ctx.training:
211 | edz, eydz = _backend.edz_eydz(z, dz, weight, bias, ctx.affine, ctx.eps)
212 |
213 | if ctx.is_master:
214 | edzs, eydzs = [edz], [eydz]
215 | for _ in range(len(ctx.worker_queues)):
216 | edz_w, eydz_w = ctx.master_queue.get()
217 | ctx.master_queue.task_done()
218 | edzs.append(edz_w)
219 | eydzs.append(eydz_w)
220 |
221 | edz = comm.reduce_add(edzs) / (ctx.master_queue.maxsize + 1)
222 | eydz = comm.reduce_add(eydzs) / (ctx.master_queue.maxsize + 1)
223 |
224 | tensors = comm.broadcast_coalesced((edz, eydz), [edz.get_device()] + ctx.worker_ids)
225 | for ts, queue in zip(tensors[1:], ctx.worker_queues):
226 | queue.put(ts)
227 | else:
228 | ctx.master_queue.put((edz, eydz))
229 | edz, eydz = ctx.worker_queue.get()
230 | ctx.worker_queue.task_done()
231 | else:
232 | edz = dz.new_zeros(dz.size(1))
233 | eydz = dz.new_zeros(dz.size(1))
234 |
235 | dx, dweight, dbias = _backend.backward(z, dz, var, weight, bias, edz, eydz, ctx.affine, ctx.eps)
236 | dweight = dweight if ctx.affine else None
237 | dbias = dbias if ctx.affine else None
238 |
239 | return dx, dweight, dbias, None, None, None, None, None, None, None, None
240 |
241 | @staticmethod
242 | def _parse_extra(ctx, extra):
243 | ctx.is_master = extra["is_master"]
244 | if ctx.is_master:
245 | ctx.master_queue = extra["master_queue"]
246 | ctx.worker_queues = extra["worker_queues"]
247 | ctx.worker_ids = extra["worker_ids"]
248 | else:
249 | ctx.master_queue = extra["master_queue"]
250 | ctx.worker_queue = extra["worker_queue"]
251 |
252 |
253 | inplace_abn = InPlaceABN.apply
254 | inplace_abn_sync = InPlaceABNSync.apply
255 |
256 | __all__ = ["inplace_abn", "inplace_abn_sync", "ACT_RELU", "ACT_LEAKY_RELU", "ACT_ELU", "ACT_NONE"]
257 |
--------------------------------------------------------------------------------
/inplace_abn/src/common.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include
4 |
5 | /*
6 | * General settings
7 | */
8 | const int WARP_SIZE = 32;
9 | const int MAX_BLOCK_SIZE = 512;
10 |
11 | template
12 | struct Pair {
13 | T v1, v2;
14 | __device__ Pair() {}
15 | __device__ Pair(T _v1, T _v2) : v1(_v1), v2(_v2) {}
16 | __device__ Pair(T v) : v1(v), v2(v) {}
17 | __device__ Pair(int v) : v1(v), v2(v) {}
18 | __device__ Pair &operator+=(const Pair &a) {
19 | v1 += a.v1;
20 | v2 += a.v2;
21 | return *this;
22 | }
23 | };
24 |
25 | /*
26 | * Utility functions
27 | */
28 | template
29 | __device__ __forceinline__ T WARP_SHFL_XOR(T value, int laneMask, int width = warpSize,
30 | unsigned int mask = 0xffffffff) {
31 | #if CUDART_VERSION >= 9000
32 | return __shfl_xor_sync(mask, value, laneMask, width);
33 | #else
34 | return __shfl_xor(value, laneMask, width);
35 | #endif
36 | }
37 |
38 | __device__ __forceinline__ int getMSB(int val) { return 31 - __clz(val); }
39 |
40 | static int getNumThreads(int nElem) {
41 | int threadSizes[5] = {32, 64, 128, 256, MAX_BLOCK_SIZE};
42 | for (int i = 0; i != 5; ++i) {
43 | if (nElem <= threadSizes[i]) {
44 | return threadSizes[i];
45 | }
46 | }
47 | return MAX_BLOCK_SIZE;
48 | }
49 |
50 | template
51 | static __device__ __forceinline__ T warpSum(T val) {
52 | #if __CUDA_ARCH__ >= 300
53 | for (int i = 0; i < getMSB(WARP_SIZE); ++i) {
54 | val += WARP_SHFL_XOR(val, 1 << i, WARP_SIZE);
55 | }
56 | #else
57 | __shared__ T values[MAX_BLOCK_SIZE];
58 | values[threadIdx.x] = val;
59 | __threadfence_block();
60 | const int base = (threadIdx.x / WARP_SIZE) * WARP_SIZE;
61 | for (int i = 1; i < WARP_SIZE; i++) {
62 | val += values[base + ((i + threadIdx.x) % WARP_SIZE)];
63 | }
64 | #endif
65 | return val;
66 | }
67 |
68 | template
69 | static __device__ __forceinline__ Pair warpSum(Pair value) {
70 | value.v1 = warpSum(value.v1);
71 | value.v2 = warpSum(value.v2);
72 | return value;
73 | }
74 |
75 | template
76 | __device__ T reduce(Op op, int plane, int N, int C, int S) {
77 | T sum = (T)0;
78 | for (int batch = 0; batch < N; ++batch) {
79 | for (int x = threadIdx.x; x < S; x += blockDim.x) {
80 | sum += op(batch, plane, x);
81 | }
82 | }
83 |
84 | // sum over NumThreads within a warp
85 | sum = warpSum(sum);
86 |
87 | // 'transpose', and reduce within warp again
88 | __shared__ T shared[32];
89 | __syncthreads();
90 | if (threadIdx.x % WARP_SIZE == 0) {
91 | shared[threadIdx.x / WARP_SIZE] = sum;
92 | }
93 | if (threadIdx.x >= blockDim.x / WARP_SIZE && threadIdx.x < WARP_SIZE) {
94 | // zero out the other entries in shared
95 | shared[threadIdx.x] = (T)0;
96 | }
97 | __syncthreads();
98 | if (threadIdx.x / WARP_SIZE == 0) {
99 | sum = warpSum(shared[threadIdx.x]);
100 | if (threadIdx.x == 0) {
101 | shared[0] = sum;
102 | }
103 | }
104 | __syncthreads();
105 |
106 | // Everyone picks it up, should be broadcast into the whole gradInput
107 | return shared[0];
108 | }
--------------------------------------------------------------------------------
/inplace_abn/src/inplace_abn.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include
4 |
5 | #include "inplace_abn.h"
6 |
7 | std::vector mean_var(at::Tensor x) {
8 | if (x.is_cuda()) {
9 | return mean_var_cuda(x);
10 | } else {
11 | return mean_var_cpu(x);
12 | }
13 | }
14 |
15 | at::Tensor forward(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
16 | bool affine, float eps) {
17 | if (x.is_cuda()) {
18 | return forward_cuda(x, mean, var, weight, bias, affine, eps);
19 | } else {
20 | return forward_cpu(x, mean, var, weight, bias, affine, eps);
21 | }
22 | }
23 |
24 | std::vector edz_eydz(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
25 | bool affine, float eps) {
26 | if (z.is_cuda()) {
27 | return edz_eydz_cuda(z, dz, weight, bias, affine, eps);
28 | } else {
29 | return edz_eydz_cpu(z, dz, weight, bias, affine, eps);
30 | }
31 | }
32 |
33 | std::vector backward(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
34 | at::Tensor edz, at::Tensor eydz, bool affine, float eps) {
35 | if (z.is_cuda()) {
36 | return backward_cuda(z, dz, var, weight, bias, edz, eydz, affine, eps);
37 | } else {
38 | return backward_cpu(z, dz, var, weight, bias, edz, eydz, affine, eps);
39 | }
40 | }
41 |
42 | void leaky_relu_forward(at::Tensor z, float slope) {
43 | at::leaky_relu_(z, slope);
44 | }
45 |
46 | void leaky_relu_backward(at::Tensor z, at::Tensor dz, float slope) {
47 | if (z.is_cuda()) {
48 | return leaky_relu_backward_cuda(z, dz, slope);
49 | } else {
50 | return leaky_relu_backward_cpu(z, dz, slope);
51 | }
52 | }
53 |
54 | void elu_forward(at::Tensor z) {
55 | at::elu_(z);
56 | }
57 |
58 | void elu_backward(at::Tensor z, at::Tensor dz) {
59 | if (z.is_cuda()) {
60 | return elu_backward_cuda(z, dz);
61 | } else {
62 | return elu_backward_cpu(z, dz);
63 | }
64 | }
65 |
66 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
67 | m.def("mean_var", &mean_var, "Mean and variance computation");
68 | m.def("forward", &forward, "In-place forward computation");
69 | m.def("edz_eydz", &edz_eydz, "First part of backward computation");
70 | m.def("backward", &backward, "Second part of backward computation");
71 | m.def("leaky_relu_forward", &leaky_relu_forward, "Leaky relu forward computation");
72 | m.def("leaky_relu_backward", &leaky_relu_backward, "Leaky relu backward computation and inversion");
73 | m.def("elu_forward", &elu_forward, "Elu forward computation");
74 | m.def("elu_backward", &elu_backward, "Elu backward computation and inversion");
75 | }
--------------------------------------------------------------------------------
/inplace_abn/src/inplace_abn.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include
4 |
5 | #include
6 |
7 | std::vector mean_var_cpu(at::Tensor x);
8 | std::vector mean_var_cuda(at::Tensor x);
9 |
10 | at::Tensor forward_cpu(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
11 | bool affine, float eps);
12 | at::Tensor forward_cuda(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
13 | bool affine, float eps);
14 |
15 | std::vector edz_eydz_cpu(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
16 | bool affine, float eps);
17 | std::vector edz_eydz_cuda(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
18 | bool affine, float eps);
19 |
20 | std::vector backward_cpu(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
21 | at::Tensor edz, at::Tensor eydz, bool affine, float eps);
22 | std::vector backward_cuda(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
23 | at::Tensor edz, at::Tensor eydz, bool affine, float eps);
24 |
25 | void leaky_relu_backward_cpu(at::Tensor z, at::Tensor dz, float slope);
26 | void leaky_relu_backward_cuda(at::Tensor z, at::Tensor dz, float slope);
27 |
28 | void elu_backward_cpu(at::Tensor z, at::Tensor dz);
29 | void elu_backward_cuda(at::Tensor z, at::Tensor dz);
--------------------------------------------------------------------------------
/inplace_abn/src/inplace_abn_cpu.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include
4 |
5 | #include "inplace_abn.h"
6 |
7 | at::Tensor reduce_sum(at::Tensor x) {
8 | if (x.ndimension() == 2) {
9 | return x.sum(0);
10 | } else {
11 | auto x_view = x.view({x.size(0), x.size(1), -1});
12 | return x_view.sum(-1).sum(0);
13 | }
14 | }
15 |
16 | at::Tensor broadcast_to(at::Tensor v, at::Tensor x) {
17 | if (x.ndimension() == 2) {
18 | return v;
19 | } else {
20 | std::vector broadcast_size = {1, -1};
21 | for (int64_t i = 2; i < x.ndimension(); ++i)
22 | broadcast_size.push_back(1);
23 |
24 | return v.view(broadcast_size);
25 | }
26 | }
27 |
28 | int64_t count(at::Tensor x) {
29 | int64_t count = x.size(0);
30 | for (int64_t i = 2; i < x.ndimension(); ++i)
31 | count *= x.size(i);
32 |
33 | return count;
34 | }
35 |
36 | at::Tensor invert_affine(at::Tensor z, at::Tensor weight, at::Tensor bias, bool affine, float eps) {
37 | if (affine) {
38 | return (z - broadcast_to(bias, z)) / broadcast_to(at::abs(weight) + eps, z);
39 | } else {
40 | return z;
41 | }
42 | }
43 |
44 | std::vector mean_var_cpu(at::Tensor x) {
45 | auto num = count(x);
46 | auto mean = reduce_sum(x) / num;
47 | auto diff = x - broadcast_to(mean, x);
48 | auto var = reduce_sum(diff.pow(2)) / num;
49 |
50 | return {mean, var};
51 | }
52 |
53 | at::Tensor forward_cpu(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
54 | bool affine, float eps) {
55 | auto gamma = affine ? at::abs(weight) + eps : at::ones_like(var);
56 | auto mul = at::rsqrt(var + eps) * gamma;
57 |
58 | x.sub_(broadcast_to(mean, x));
59 | x.mul_(broadcast_to(mul, x));
60 | if (affine) x.add_(broadcast_to(bias, x));
61 |
62 | return x;
63 | }
64 |
65 | std::vector edz_eydz_cpu(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
66 | bool affine, float eps) {
67 | auto edz = reduce_sum(dz);
68 | auto y = invert_affine(z, weight, bias, affine, eps);
69 | auto eydz = reduce_sum(y * dz);
70 |
71 | return {edz, eydz};
72 | }
73 |
74 | std::vector backward_cpu(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
75 | at::Tensor edz, at::Tensor eydz, bool affine, float eps) {
76 | auto y = invert_affine(z, weight, bias, affine, eps);
77 | auto mul = affine ? at::rsqrt(var + eps) * (at::abs(weight) + eps) : at::rsqrt(var + eps);
78 |
79 | auto num = count(z);
80 | auto dx = (dz - broadcast_to(edz / num, dz) - y * broadcast_to(eydz / num, dz)) * broadcast_to(mul, dz);
81 |
82 | auto dweight = at::empty(z.type(), {0});
83 | auto dbias = at::empty(z.type(), {0});
84 | if (affine) {
85 | dweight = eydz * at::sign(weight);
86 | dbias = edz;
87 | }
88 |
89 | return {dx, dweight, dbias};
90 | }
91 |
92 | void leaky_relu_backward_cpu(at::Tensor z, at::Tensor dz, float slope) {
93 | AT_DISPATCH_FLOATING_TYPES(z.type(), "leaky_relu_backward_cpu", ([&] {
94 | int64_t count = z.numel();
95 | auto *_z = z.data();
96 | auto *_dz = dz.data();
97 |
98 | for (int64_t i = 0; i < count; ++i) {
99 | if (_z[i] < 0) {
100 | _z[i] *= 1 / slope;
101 | _dz[i] *= slope;
102 | }
103 | }
104 | }));
105 | }
106 |
107 | void elu_backward_cpu(at::Tensor z, at::Tensor dz) {
108 | AT_DISPATCH_FLOATING_TYPES(z.type(), "elu_backward_cpu", ([&] {
109 | int64_t count = z.numel();
110 | auto *_z = z.data();
111 | auto *_dz = dz.data();
112 |
113 | for (int64_t i = 0; i < count; ++i) {
114 | if (_z[i] < 0) {
115 | _z[i] = log1p(_z[i]);
116 | _dz[i] *= (_z[i] + 1.f);
117 | }
118 | }
119 | }));
120 | }
--------------------------------------------------------------------------------
/inplace_abn/src/inplace_abn_cuda.cu:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include
4 | #include
5 |
6 | #include
7 |
8 | #include "common.h"
9 | #include "inplace_abn.h"
10 |
11 | // Checks
12 | #ifndef AT_CHECK
13 | #define AT_CHECK AT_ASSERT
14 | #endif
15 | #define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
16 | #define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x " must be contiguous")
17 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
18 |
19 | // Utilities
20 | void get_dims(at::Tensor x, int64_t& num, int64_t& chn, int64_t& sp) {
21 | num = x.size(0);
22 | chn = x.size(1);
23 | sp = 1;
24 | for (int64_t i = 2; i < x.ndimension(); ++i)
25 | sp *= x.size(i);
26 | }
27 |
28 | // Operations for reduce
29 | template
30 | struct SumOp {
31 | __device__ SumOp(const T *t, int c, int s)
32 | : tensor(t), chn(c), sp(s) {}
33 | __device__ __forceinline__ T operator()(int batch, int plane, int n) {
34 | return tensor[(batch * chn + plane) * sp + n];
35 | }
36 | const T *tensor;
37 | const int chn;
38 | const int sp;
39 | };
40 |
41 | template
42 | struct VarOp {
43 | __device__ VarOp(T m, const T *t, int c, int s)
44 | : mean(m), tensor(t), chn(c), sp(s) {}
45 | __device__ __forceinline__ T operator()(int batch, int plane, int n) {
46 | T val = tensor[(batch * chn + plane) * sp + n];
47 | return (val - mean) * (val - mean);
48 | }
49 | const T mean;
50 | const T *tensor;
51 | const int chn;
52 | const int sp;
53 | };
54 |
55 | template
56 | struct GradOp {
57 | __device__ GradOp(T _weight, T _bias, const T *_z, const T *_dz, int c, int s)
58 | : weight(_weight), bias(_bias), z(_z), dz(_dz), chn(c), sp(s) {}
59 | __device__ __forceinline__ Pair operator()(int batch, int plane, int n) {
60 | T _y = (z[(batch * chn + plane) * sp + n] - bias) / weight;
61 | T _dz = dz[(batch * chn + plane) * sp + n];
62 | return Pair(_dz, _y * _dz);
63 | }
64 | const T weight;
65 | const T bias;
66 | const T *z;
67 | const T *dz;
68 | const int chn;
69 | const int sp;
70 | };
71 |
72 | /***********
73 | * mean_var
74 | ***********/
75 |
76 | template
77 | __global__ void mean_var_kernel(const T *x, T *mean, T *var, int num, int chn, int sp) {
78 | int plane = blockIdx.x;
79 | T norm = T(1) / T(num * sp);
80 |
81 | T _mean = reduce>(SumOp(x, chn, sp), plane, num, chn, sp) * norm;
82 | __syncthreads();
83 | T _var = reduce>(VarOp(_mean, x, chn, sp), plane, num, chn, sp) * norm;
84 |
85 | if (threadIdx.x == 0) {
86 | mean[plane] = _mean;
87 | var[plane] = _var;
88 | }
89 | }
90 |
91 | std::vector mean_var_cuda(at::Tensor x) {
92 | CHECK_INPUT(x);
93 |
94 | // Extract dimensions
95 | int64_t num, chn, sp;
96 | get_dims(x, num, chn, sp);
97 |
98 | // Prepare output tensors
99 | auto mean = at::empty(x.type(), {chn});
100 | auto var = at::empty(x.type(), {chn});
101 |
102 | // Run kernel
103 | dim3 blocks(chn);
104 | dim3 threads(getNumThreads(sp));
105 | AT_DISPATCH_FLOATING_TYPES(x.type(), "mean_var_cuda", ([&] {
106 | mean_var_kernel<<>>(
107 | x.data(),
108 | mean.data(),
109 | var.data(),
110 | num, chn, sp);
111 | }));
112 |
113 | return {mean, var};
114 | }
115 |
116 | /**********
117 | * forward
118 | **********/
119 |
120 | template
121 | __global__ void forward_kernel(T *x, const T *mean, const T *var, const T *weight, const T *bias,
122 | bool affine, float eps, int num, int chn, int sp) {
123 | int plane = blockIdx.x;
124 |
125 | T _mean = mean[plane];
126 | T _var = var[plane];
127 | T _weight = affine ? abs(weight[plane]) + eps : T(1);
128 | T _bias = affine ? bias[plane] : T(0);
129 |
130 | T mul = rsqrt(_var + eps) * _weight;
131 |
132 | for (int batch = 0; batch < num; ++batch) {
133 | for (int n = threadIdx.x; n < sp; n += blockDim.x) {
134 | T _x = x[(batch * chn + plane) * sp + n];
135 | T _y = (_x - _mean) * mul + _bias;
136 |
137 | x[(batch * chn + plane) * sp + n] = _y;
138 | }
139 | }
140 | }
141 |
142 | at::Tensor forward_cuda(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
143 | bool affine, float eps) {
144 | CHECK_INPUT(x);
145 | CHECK_INPUT(mean);
146 | CHECK_INPUT(var);
147 | CHECK_INPUT(weight);
148 | CHECK_INPUT(bias);
149 |
150 | // Extract dimensions
151 | int64_t num, chn, sp;
152 | get_dims(x, num, chn, sp);
153 |
154 | // Run kernel
155 | dim3 blocks(chn);
156 | dim3 threads(getNumThreads(sp));
157 | AT_DISPATCH_FLOATING_TYPES(x.type(), "forward_cuda", ([&] {
158 | forward_kernel<<>>(
159 | x.data(),
160 | mean.data(),
161 | var.data(),
162 | weight.data(),
163 | bias.data(),
164 | affine, eps, num, chn, sp);
165 | }));
166 |
167 | return x;
168 | }
169 |
170 | /***********
171 | * edz_eydz
172 | ***********/
173 |
174 | template
175 | __global__ void edz_eydz_kernel(const T *z, const T *dz, const T *weight, const T *bias,
176 | T *edz, T *eydz, bool affine, float eps, int num, int chn, int sp) {
177 | int plane = blockIdx.x;
178 |
179 | T _weight = affine ? abs(weight[plane]) + eps : 1.f;
180 | T _bias = affine ? bias[plane] : 0.f;
181 |
182 | Pair res = reduce, GradOp>(GradOp(_weight, _bias, z, dz, chn, sp), plane, num, chn, sp);
183 | __syncthreads();
184 |
185 | if (threadIdx.x == 0) {
186 | edz[plane] = res.v1;
187 | eydz[plane] = res.v2;
188 | }
189 | }
190 |
191 | std::vector edz_eydz_cuda(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
192 | bool affine, float eps) {
193 | CHECK_INPUT(z);
194 | CHECK_INPUT(dz);
195 | CHECK_INPUT(weight);
196 | CHECK_INPUT(bias);
197 |
198 | // Extract dimensions
199 | int64_t num, chn, sp;
200 | get_dims(z, num, chn, sp);
201 |
202 | auto edz = at::empty(z.type(), {chn});
203 | auto eydz = at::empty(z.type(), {chn});
204 |
205 | // Run kernel
206 | dim3 blocks(chn);
207 | dim3 threads(getNumThreads(sp));
208 | AT_DISPATCH_FLOATING_TYPES(z.type(), "edz_eydz_cuda", ([&] {
209 | edz_eydz_kernel<<>>(
210 | z.data(),
211 | dz.data(),
212 | weight.data(),
213 | bias.data(),
214 | edz.data(),
215 | eydz.data(),
216 | affine, eps, num, chn, sp);
217 | }));
218 |
219 | return {edz, eydz};
220 | }
221 |
222 | /***********
223 | * backward
224 | ***********/
225 |
226 | template
227 | __global__ void backward_kernel(const T *z, const T *dz, const T *var, const T *weight, const T *bias, const T *edz,
228 | const T *eydz, T *dx, T *dweight, T *dbias,
229 | bool affine, float eps, int num, int chn, int sp) {
230 | int plane = blockIdx.x;
231 |
232 | T _weight = affine ? abs(weight[plane]) + eps : 1.f;
233 | T _bias = affine ? bias[plane] : 0.f;
234 | T _var = var[plane];
235 | T _edz = edz[plane];
236 | T _eydz = eydz[plane];
237 |
238 | T _mul = _weight * rsqrt(_var + eps);
239 | T count = T(num * sp);
240 |
241 | for (int batch = 0; batch < num; ++batch) {
242 | for (int n = threadIdx.x; n < sp; n += blockDim.x) {
243 | T _dz = dz[(batch * chn + plane) * sp + n];
244 | T _y = (z[(batch * chn + plane) * sp + n] - _bias) / _weight;
245 |
246 | dx[(batch * chn + plane) * sp + n] = (_dz - _edz / count - _y * _eydz / count) * _mul;
247 | }
248 | }
249 |
250 | if (threadIdx.x == 0) {
251 | if (affine) {
252 | dweight[plane] = weight[plane] > 0 ? _eydz : -_eydz;
253 | dbias[plane] = _edz;
254 | }
255 | }
256 | }
257 |
258 | std::vector backward_cuda(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
259 | at::Tensor edz, at::Tensor eydz, bool affine, float eps) {
260 | CHECK_INPUT(z);
261 | CHECK_INPUT(dz);
262 | CHECK_INPUT(var);
263 | CHECK_INPUT(weight);
264 | CHECK_INPUT(bias);
265 | CHECK_INPUT(edz);
266 | CHECK_INPUT(eydz);
267 |
268 | // Extract dimensions
269 | int64_t num, chn, sp;
270 | get_dims(z, num, chn, sp);
271 |
272 | auto dx = at::zeros_like(z);
273 | auto dweight = at::zeros_like(weight);
274 | auto dbias = at::zeros_like(bias);
275 |
276 | // Run kernel
277 | dim3 blocks(chn);
278 | dim3 threads(getNumThreads(sp));
279 | AT_DISPATCH_FLOATING_TYPES(z.type(), "backward_cuda", ([&] {
280 | backward_kernel<<>>(
281 | z.data(),
282 | dz.data(),
283 | var.data(),
284 | weight.data(),
285 | bias.data(),
286 | edz.data(),
287 | eydz.data(),
288 | dx.data(),
289 | dweight.data(),
290 | dbias.data(),
291 | affine, eps, num, chn, sp);
292 | }));
293 |
294 | return {dx, dweight, dbias};
295 | }
296 |
297 | /**************
298 | * activations
299 | **************/
300 |
301 | template
302 | inline void leaky_relu_backward_impl(T *z, T *dz, float slope, int64_t count) {
303 | // Create thrust pointers
304 | thrust::device_ptr th_z = thrust::device_pointer_cast(z);
305 | thrust::device_ptr th_dz = thrust::device_pointer_cast(dz);
306 |
307 | thrust::transform_if(th_dz, th_dz + count, th_z, th_dz,
308 | [slope] __device__ (const T& dz) { return dz * slope; },
309 | [] __device__ (const T& z) { return z < 0; });
310 | thrust::transform_if(th_z, th_z + count, th_z,
311 | [slope] __device__ (const T& z) { return z / slope; },
312 | [] __device__ (const T& z) { return z < 0; });
313 | }
314 |
315 | void leaky_relu_backward_cuda(at::Tensor z, at::Tensor dz, float slope) {
316 | CHECK_INPUT(z);
317 | CHECK_INPUT(dz);
318 |
319 | int64_t count = z.numel();
320 |
321 | AT_DISPATCH_FLOATING_TYPES(z.type(), "leaky_relu_backward_cuda", ([&] {
322 | leaky_relu_backward_impl(z.data(), dz.data(), slope, count);
323 | }));
324 | }
325 |
326 | template
327 | inline void elu_backward_impl(T *z, T *dz, int64_t count) {
328 | // Create thrust pointers
329 | thrust::device_ptr th_z = thrust::device_pointer_cast(z);
330 | thrust::device_ptr th_dz = thrust::device_pointer_cast(dz);
331 |
332 | thrust::transform_if(th_dz, th_dz + count, th_z, th_z, th_dz,
333 | [] __device__ (const T& dz, const T& z) { return dz * (z + 1.); },
334 | [] __device__ (const T& z) { return z < 0; });
335 | thrust::transform_if(th_z, th_z + count, th_z,
336 | [] __device__ (const T& z) { return log1p(z); },
337 | [] __device__ (const T& z) { return z < 0; });
338 | }
339 |
340 | void elu_backward_cuda(at::Tensor z, at::Tensor dz) {
341 | CHECK_INPUT(z);
342 | CHECK_INPUT(dz);
343 |
344 | int64_t count = z.numel();
345 |
346 | AT_DISPATCH_FLOATING_TYPES(z.type(), "leaky_relu_backward_cuda", ([&] {
347 | elu_backward_impl(z.data(), dz.data(), count);
348 | }));
349 | }
350 |
--------------------------------------------------------------------------------
/metrics/iou.py:
--------------------------------------------------------------------------------
1 | def get_iou(preds, labels, thresh=0.5):
2 | preds, labels = preds.squeeze(), labels.squeeze()
3 | preds = preds > thresh
4 | mask_sum = (preds == 1) + (labels > 0)
5 | intersection = (mask_sum == 2).sum().float()
6 | union = (mask_sum > 0).sum().float()
7 |
8 | if union > 0:
9 | return intersection / union
10 |
11 | return 1.
12 |
--------------------------------------------------------------------------------
/networks/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yz93/anchor-diff-VOS/b6ee4fcc9eb1b85b514215badea9d10c158d73c0/networks/__init__.py
--------------------------------------------------------------------------------
/networks/deeplabv3.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from torch.nn import functional as F
3 | import torch
4 | affine_par = True
5 | import functools
6 | import os
7 | import sys
8 | BASE_DIR = os.path.dirname(os.path.abspath(__file__))
9 | sys.path.append(BASE_DIR)
10 | sys.path.append(os.path.join(BASE_DIR, '../inplace_abn'))
11 | from bn import InPlaceABNSync
12 |
13 | BatchNorm2d = functools.partial(InPlaceABNSync, activation='none')
14 |
15 | def conv3x3(in_planes, out_planes, stride=1):
16 | "3x3 convolution with padding"
17 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
18 | padding=1, bias=False)
19 |
20 |
21 | class Bottleneck(nn.Module):
22 | expansion = 4
23 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, fist_dilation=1, multi_grid=1):
24 | super(Bottleneck, self).__init__()
25 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
26 | self.bn1 = BatchNorm2d(planes)
27 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
28 | padding=dilation*multi_grid, dilation=dilation*multi_grid, bias=False)
29 | self.bn2 = BatchNorm2d(planes)
30 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
31 | self.bn3 = BatchNorm2d(planes * 4)
32 | self.relu = nn.ReLU(inplace=False)
33 | self.relu_inplace = nn.ReLU(inplace=True)
34 | self.downsample = downsample
35 | self.dilation = dilation
36 | self.stride = stride
37 |
38 | def forward(self, x):
39 | residual = x
40 |
41 | out = self.conv1(x)
42 | out = self.bn1(out)
43 | out = self.relu(out)
44 |
45 | out = self.conv2(out)
46 | out = self.bn2(out)
47 | out = self.relu(out)
48 |
49 | out = self.conv3(out)
50 | out = self.bn3(out)
51 |
52 | if self.downsample is not None:
53 | residual = self.downsample(x)
54 |
55 | out = out + residual
56 | out = self.relu_inplace(out)
57 |
58 | return out
59 |
60 | class ASPPModule(nn.Module):
61 | """
62 | Reference:
63 | Chen, Liang-Chieh, et al. *"Rethinking Atrous Convolution for Semantic Image Segmentation."*
64 | """
65 | # output stride 16: (6, 12, 18)
66 | # output stride 8: (12, 24, 36)
67 | def __init__(self, features, hidden_features=512, out_features=512, dilations=(12, 24, 36)):
68 | super(ASPPModule, self).__init__()
69 | self.conv1 = nn.Sequential(nn.Conv2d(features, hidden_features, kernel_size=1, bias=False),
70 | InPlaceABNSync(hidden_features))
71 | self.conv2 = nn.Sequential(nn.Conv2d(features, hidden_features, kernel_size=3, padding=dilations[0], dilation=dilations[0], bias=False),
72 | InPlaceABNSync(hidden_features))
73 | self.conv3 = nn.Sequential(nn.Conv2d(features, hidden_features, kernel_size=3, padding=dilations[1], dilation=dilations[1], bias=False),
74 | InPlaceABNSync(hidden_features))
75 | self.conv4 = nn.Sequential(nn.Conv2d(features, hidden_features, kernel_size=3, padding=dilations[2], dilation=dilations[2], bias=False),
76 | InPlaceABNSync(hidden_features))
77 | self.image_pooling = nn.Sequential(nn.AdaptiveAvgPool2d((1,1)),
78 | nn.Conv2d(features, hidden_features, kernel_size=1, bias=False),
79 | InPlaceABNSync(hidden_features))
80 |
81 | self.conv_bn_dropout = nn.Sequential(
82 | nn.Conv2d(hidden_features * 5, out_features, kernel_size=1, bias=False),
83 | InPlaceABNSync(out_features),
84 | nn.Dropout2d(0.1)
85 | )
86 |
87 | def forward(self, x):
88 | _, _, h, w = x.size()
89 |
90 | feat1 = self.conv1(x)
91 | feat2 = self.conv2(x)
92 | feat3 = self.conv3(x)
93 | feat4 = self.conv4(x)
94 |
95 | img_feat = F.interpolate(self.image_pooling(x), size=(h, w), mode='bilinear', align_corners=True)
96 | concat_feat = torch.cat((feat1, feat2, feat3, feat4, img_feat), 1)
97 |
98 | out = self.conv_bn_dropout(concat_feat)
99 | return out
100 |
101 | class ResNet(nn.Module):
102 | def __init__(self, block, layers, num_classes):
103 | self.inplanes = 128
104 | super(ResNet, self).__init__()
105 | self.conv1 = conv3x3(3, 64, stride=2)
106 | self.bn1 = BatchNorm2d(64)
107 | self.relu1 = nn.ReLU(inplace=False)
108 | self.conv2 = conv3x3(64, 64)
109 | self.bn2 = BatchNorm2d(64)
110 | self.relu2 = nn.ReLU(inplace=False)
111 | self.conv3 = conv3x3(64, 128)
112 | self.bn3 = BatchNorm2d(128)
113 | self.relu3 = nn.ReLU(inplace=False)
114 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
115 | self.relu = nn.ReLU(inplace=False)
116 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=True) # change
117 | self.layer1 = self._make_layer(block, 64, layers[0])
118 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
119 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2)
120 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4, multi_grid=(1,2,4))
121 |
122 | self.head = nn.Sequential(ASPPModule(2048),
123 | nn.Conv2d(512, num_classes, kernel_size=1)
124 | )
125 |
126 | self.dsn = nn.Sequential(
127 | nn.Conv2d(1024, 512, kernel_size=3, padding=1, bias=False),
128 | InPlaceABNSync(512),
129 | nn.Dropout2d(0.1),
130 | nn.Conv2d(512, num_classes, kernel_size=1)
131 | )
132 |
133 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, multi_grid=1):
134 | downsample = None
135 | if stride != 1 or self.inplanes != planes * block.expansion:
136 | downsample = nn.Sequential(
137 | nn.Conv2d(self.inplanes, planes * block.expansion,
138 | kernel_size=1, stride=stride, bias=False),
139 | BatchNorm2d(planes * block.expansion,affine = affine_par))
140 |
141 | layers = []
142 | generate_multi_grid = lambda index, grids: grids[index%len(grids)] if isinstance(grids, tuple) else 1
143 | layers.append(block(self.inplanes, planes, stride,dilation=dilation, downsample=downsample, multi_grid=generate_multi_grid(0, multi_grid)))
144 | self.inplanes = planes * block.expansion
145 | for i in range(1, blocks):
146 | layers.append(block(self.inplanes, planes, dilation=dilation, multi_grid=generate_multi_grid(i, multi_grid)))
147 |
148 | return nn.Sequential(*layers)
149 |
150 | def forward(self, x):
151 | x = self.relu1(self.bn1(self.conv1(x)))
152 | x = self.relu2(self.bn2(self.conv2(x)))
153 | x = self.relu3(self.bn3(self.conv3(x)))
154 | x = self.maxpool(x)
155 | x = self.layer1(x)
156 | x = self.layer2(x)
157 | x = self.layer3(x)
158 | x_dsn = self.dsn(x)
159 | x = self.layer4(x)
160 | x = self.head(x)
161 | return [x, x_dsn]
162 |
163 |
164 | def ResNetDeepLabv3(backbone='ResNet50', num_classes=1, batch_mode='sync'):
165 | global BatchNorm2d
166 | if batch_mode=='sync':
167 | BatchNorm2d = functools.partial(InPlaceABNSync, activation='none')
168 | elif batch_mode=='old':
169 | BatchNorm2d = torch.nn.BatchNorm2d
170 |
171 | if backbone=='ResNet50':
172 | model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes)
173 | elif backbone=='ResNet101':
174 | model = ResNet(Bottleneck, [3, 4, 23, 3], num_classes)
175 | else:
176 | raise RuntimeError('unknown backbone type')
177 | return model
178 |
179 |
--------------------------------------------------------------------------------
/networks/inits.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 |
4 | def uniform(size, tensor):
5 | bound = 1.0 / math.sqrt(size)
6 | if tensor is not None:
7 | tensor.data.uniform_(-bound, bound)
8 |
9 |
10 | def kaiming_uniform(tensor, fan, a):
11 | bound = math.sqrt(6 / ((1 + a**2) * fan))
12 | if tensor is not None:
13 | tensor.data.uniform_(-bound, bound)
14 |
15 |
16 | def glorot(tensor):
17 | stdv = math.sqrt(6.0 / (tensor.size(-2) + tensor.size(-1)))
18 | if tensor is not None:
19 | tensor.data.uniform_(-stdv, stdv)
20 |
21 |
22 | def zeros(tensor):
23 | if tensor is not None:
24 | tensor.data.fill_(0)
25 |
26 |
27 | def ones(tensor):
28 | if tensor is not None:
29 | tensor.data.fill_(1)
30 |
31 |
32 | def reset(nn):
33 | def _reset(item):
34 | if hasattr(item, 'reset_parameters'):
35 | item.reset_parameters()
36 |
37 | if nn is not None:
38 | if hasattr(nn, 'children') and len(list(nn.children())) > 0:
39 | for item in nn.children():
40 | _reset(item)
41 | else:
42 | _reset(nn)
43 |
--------------------------------------------------------------------------------
/networks/nets.py:
--------------------------------------------------------------------------------
1 | from networks.deeplabv3 import ResNetDeepLabv3
2 |
3 | from torch import nn
4 | from torch.nn import functional as F
5 | import numpy as np
6 | import torch
7 | import os
8 | import sys
9 |
10 | BASE_DIR = os.path.dirname(os.path.abspath(__file__))
11 | sys.path.append(BASE_DIR)
12 | sys.path.append(os.path.join(BASE_DIR, '../inplace_abn'))
13 | from bn import InPlaceABNSync
14 |
15 |
16 | class AnchorDiffNet(nn.Module):
17 | def __init__(self, backbone='ResNet50', pyramid_pooling='deeplabv3', embedding=128, batch_mode='sync'):
18 | super(AnchorDiffNet, self).__init__()
19 | if pyramid_pooling == 'deeplabv3':
20 | self.features = ResNetDeepLabv3(backbone, num_classes=embedding, batch_mode=batch_mode)
21 | elif pyramid_pooling == 'pspnet':
22 | raise RuntimeError('Pooling module not implemented')
23 | else:
24 | raise RuntimeError('Unknown pyramid pooling module')
25 | self.cls = nn.Sequential(
26 | nn.Conv2d(3 * embedding, embedding, kernel_size=1, stride=1, padding=0),
27 | InPlaceABNSync(embedding),
28 | nn.Dropout2d(0.10),
29 | nn.Conv2d(embedding, 1, kernel_size=1, stride=1, padding=0)
30 | )
31 |
32 | def forward(self, reference, current):
33 | ref_features = self.features(reference)[0]
34 | curr_features = self.features(current)[0]
35 | batch, channel, h, w = curr_features.shape
36 | M = h * w
37 | ref_features = ref_features.view(batch, channel, M).permute(0, 2, 1)
38 | curr_features = curr_features.view(batch, channel, M)
39 |
40 | p_0 = torch.matmul(ref_features, curr_features)
41 | p_0 = F.softmax((channel ** -.5) * p_0, dim=-1)
42 | p_1 = torch.matmul(curr_features.permute(0, 2, 1), curr_features)
43 | p_1 = F.softmax((channel ** -.5) * p_1, dim=-1)
44 | feats_0 = torch.matmul(p_0, curr_features.permute(0, 2, 1)).permute(0, 2, 1)
45 | feats_1 = torch.matmul(p_1, curr_features.permute(0, 2, 1)).permute(0, 2, 1)
46 | x = torch.cat([feats_0, feats_1, curr_features], dim=1).view(batch, 3 * channel, h, w)
47 | pred = self.cls(x)
48 |
49 | return pred
50 |
51 |
52 | class ConcatNet(nn.Module):
53 | def __init__(self, backbone='ResNet50', pyramid_pooling='deeplabv3', embedding=128, batch_mode='sync'):
54 | super(ConcatNet, self).__init__()
55 | if pyramid_pooling == 'deeplabv3':
56 | self.features = ResNetDeepLabv3(backbone, num_classes=embedding, batch_mode=batch_mode)
57 | elif pyramid_pooling == 'pspnet':
58 | raise RuntimeError('Pooling module not implemented')
59 | else:
60 | raise RuntimeError('Unknown pyramid pooling module')
61 | self.cls = nn.Sequential(
62 | nn.Conv2d(2*embedding, embedding, kernel_size=1, stride=1, padding=0),
63 | InPlaceABNSync(embedding),
64 | nn.Dropout2d(0.10),
65 | nn.Conv2d(embedding, 1, kernel_size=1, stride=1, padding=0)
66 | )
67 |
68 | def forward(self, reference, current):
69 | ref_features = self.features(reference)[0]
70 | curr_features = self.features(current)[0]
71 | x = torch.cat([ref_features, curr_features], dim=1)
72 | pred = self.cls(x)
73 |
74 | return pred
75 |
76 |
77 | class InterFrameNet(nn.Module):
78 | def __init__(self, backbone='ResNet50', pyramid_pooling='deeplabv3', embedding=128, batch_mode='sync'):
79 | super(InterFrameNet, self).__init__()
80 | if pyramid_pooling == 'deeplabv3':
81 | self.features = ResNetDeepLabv3(backbone, num_classes=embedding, batch_mode=batch_mode)
82 | elif pyramid_pooling == 'pspnet':
83 | raise RuntimeError('Pooling module not implemented')
84 | else:
85 | raise RuntimeError('Unknown pyramid pooling module')
86 | self.cls = nn.Sequential(
87 | nn.Conv2d(2 * embedding, embedding, kernel_size=1, stride=1, padding=0),
88 | InPlaceABNSync(embedding),
89 | nn.Dropout2d(0.10),
90 | nn.Conv2d(embedding, 1, kernel_size=1, stride=1, padding=0)
91 | )
92 |
93 | def forward(self, reference, current):
94 | ref_features = self.features(reference)[0]
95 | curr_features = self.features(current)[0]
96 | batch, channel, h, w = curr_features.shape
97 | M = h * w
98 | ref_features = ref_features.view(batch, channel, M).permute(0, 2, 1)
99 | curr_features = curr_features.view(batch, channel, M)
100 |
101 | p_0 = torch.matmul(ref_features, curr_features)
102 | p_0 = F.softmax((channel ** -.5) * p_0, dim=-1)
103 | feats_0 = torch.matmul(p_0, curr_features.permute(0, 2, 1)).permute(0, 2, 1)
104 | x = torch.cat([feats_0, curr_features], dim=1).view(batch, 2 * channel, h, w)
105 | pred = self.cls(x)
106 |
107 | return pred
108 |
109 |
110 | class IntraFrameNet(nn.Module):
111 | def __init__(self, backbone='ResNet50', pyramid_pooling='deeplabv3', embedding=128, batch_mode='sync'):
112 | super(IntraFrameNet, self).__init__()
113 | if pyramid_pooling == 'deeplabv3':
114 | self.features = ResNetDeepLabv3(backbone, num_classes=embedding, batch_mode=batch_mode)
115 | elif pyramid_pooling == 'pspnet':
116 | raise RuntimeError('Pooling module not implemented')
117 | else:
118 | raise RuntimeError('Unknown pyramid pooling module')
119 | self.cls = nn.Sequential(
120 | nn.Conv2d(2*embedding, embedding, kernel_size=1, stride=1, padding=0),
121 | InPlaceABNSync(embedding),
122 | nn.Dropout2d(0.10),
123 | nn.Conv2d(embedding, 1, kernel_size=1, stride=1, padding=0)
124 | )
125 |
126 | def forward(self, current):
127 | curr_features = self.features(current)[0]
128 | batch, channel, h, w = curr_features.shape
129 | M = h * w
130 | curr_features = curr_features.view(batch, channel, M)
131 |
132 | p_1 = torch.matmul(curr_features.permute(0, 2, 1), curr_features)
133 | p_1 = F.softmax((channel**-.5) * p_1, dim=-1)
134 | feats_1 = torch.matmul(p_1, curr_features.permute(0, 2, 1)).permute(0, 2, 1)
135 | x = torch.cat([feats_1, curr_features], dim=1).view(batch, 2*channel, h, w)
136 | pred = self.cls(x)
137 |
138 | return pred
139 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | matplotlib==2.2.3
2 | opencv-python==4.1.0.25
3 | scikit-learn==0.20.2
4 | torchvision==0.2.1
5 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yz93/anchor-diff-VOS/b6ee4fcc9eb1b85b514215badea9d10c158d73c0/utils/__init__.py
--------------------------------------------------------------------------------
/utils/parallel.py:
--------------------------------------------------------------------------------
1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 | ## Created by: Hang Zhang
3 | ## ECE Department, Rutgers University
4 | ## Email: zhang.hang@rutgers.edu
5 | ## Copyright (c) 2017
6 | ##
7 | ## This source code is licensed under the MIT-style license found in the
8 | ## LICENSE file in the root directory of this source tree
9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
10 |
11 | """Encoding Data Parallel"""
12 | import threading
13 | import functools
14 | import torch
15 | from torch.autograd import Variable, Function
16 | import torch.cuda.comm as comm
17 | from torch.nn.parallel.data_parallel import DataParallel
18 | from torch.nn.parallel.parallel_apply import get_a_var
19 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
20 |
21 | torch_ver = torch.__version__[:3]
22 |
23 | __all__ = ['allreduce', 'DataParallelModel', 'DataParallelCriterion',
24 | 'patch_replication_callback']
25 |
26 | def allreduce(*inputs):
27 | """Cross GPU all reduce autograd operation for calculate mean and
28 | variance in SyncBN.
29 | """
30 | return AllReduce.apply(*inputs)
31 |
32 | class AllReduce(Function):
33 | @staticmethod
34 | def forward(ctx, num_inputs, *inputs):
35 | ctx.num_inputs = num_inputs
36 | ctx.target_gpus = [inputs[i].get_device() for i in range(0, len(inputs), num_inputs)]
37 | inputs = [inputs[i:i + num_inputs]
38 | for i in range(0, len(inputs), num_inputs)]
39 | # sort before reduce sum
40 | inputs = sorted(inputs, key=lambda i: i[0].get_device())
41 | results = comm.reduce_add_coalesced(inputs, ctx.target_gpus[0])
42 | outputs = comm.broadcast_coalesced(results, ctx.target_gpus)
43 | return tuple([t for tensors in outputs for t in tensors])
44 |
45 | @staticmethod
46 | def backward(ctx, *inputs):
47 | inputs = [i.data for i in inputs]
48 | inputs = [inputs[i:i + ctx.num_inputs]
49 | for i in range(0, len(inputs), ctx.num_inputs)]
50 | results = comm.reduce_add_coalesced(inputs, ctx.target_gpus[0])
51 | outputs = comm.broadcast_coalesced(results, ctx.target_gpus)
52 | return (None,) + tuple([Variable(t) for tensors in outputs for t in tensors])
53 |
54 |
55 | class Reduce(Function):
56 | @staticmethod
57 | def forward(ctx, *inputs):
58 | ctx.target_gpus = [inputs[i].get_device() for i in range(len(inputs))]
59 | inputs = sorted(inputs, key=lambda i: i.get_device())
60 | return comm.reduce_add(inputs)
61 |
62 | @staticmethod
63 | def backward(ctx, gradOutput):
64 | return Broadcast.apply(ctx.target_gpus, gradOutput)
65 |
66 |
67 | class DataParallelModel(DataParallel):
68 | """Implements data parallelism at the module level.
69 |
70 | This container parallelizes the application of the given module by
71 | splitting the input across the specified devices by chunking in the
72 | batch dimension.
73 | In the forward pass, the module is replicated on each device,
74 | and each replica handles a portion of the input. During the backwards pass, gradients from each replica are summed into the original module.
75 | Note that the outputs are not gathered, please use compatible
76 | :class:`encoding.parallel.DataParallelCriterion`.
77 |
78 | The batch size should be larger than the number of GPUs used. It should
79 | also be an integer multiple of the number of GPUs so that each chunk is
80 | the same size (so that each GPU processes the same number of samples).
81 |
82 | Args:
83 | module: module to be parallelized
84 | device_ids: CUDA devices (default: all devices)
85 |
86 | Reference:
87 | Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi,
88 | Amit Agrawal. “Context Encoding for Semantic Segmentation.
89 | *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018*
90 |
91 | Example::
92 |
93 | >>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2])
94 | >>> y = net(x)
95 | """
96 | def gather(self, outputs, output_device):
97 | return outputs
98 |
99 | def replicate(self, module, device_ids):
100 | modules = super(DataParallelModel, self).replicate(module, device_ids)
101 | execute_replication_callbacks(modules)
102 | return modules
103 |
104 |
105 | class DataParallelCriterion(DataParallel):
106 | """
107 | Calculate loss in multiple-GPUs, which balance the memory usage for
108 | Semantic Segmentation.
109 |
110 | The targets are splitted across the specified devices by chunking in
111 | the batch dimension. Please use together with :class:`encoding.parallel.DataParallelModel`.
112 |
113 | Reference:
114 | Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi,
115 | Amit Agrawal. “Context Encoding for Semantic Segmentation.
116 | *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018*
117 |
118 | Example::
119 |
120 | >>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2])
121 | >>> criterion = encoding.nn.DataParallelCriterion(criterion, device_ids=[0, 1, 2])
122 | >>> y = net(x)
123 | >>> loss = criterion(y, target)
124 | """
125 | def forward(self, inputs, *targets, **kwargs):
126 | # input should be already scatterd
127 | # scattering the targets instead
128 | if not self.device_ids:
129 | return self.module(inputs, *targets, **kwargs)
130 | targets, kwargs = self.scatter(targets, kwargs, self.device_ids)
131 | if len(self.device_ids) == 1:
132 | return self.module(inputs, *targets[0], **kwargs[0])
133 | replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
134 | targets = tuple(targets_per_gpu[0] for targets_per_gpu in targets) # fix bug
135 | outputs = _criterion_parallel_apply(replicas, inputs, targets, kwargs)
136 | return Reduce.apply(*outputs) / len(outputs)
137 | #return self.gather(outputs, self.output_device).mean()
138 |
139 |
140 | def _criterion_parallel_apply(modules, inputs, targets, kwargs_tup=None, devices=None):
141 | assert len(modules) == len(inputs)
142 | assert len(targets) == len(inputs)
143 | if kwargs_tup:
144 | assert len(modules) == len(kwargs_tup)
145 | else:
146 | kwargs_tup = ({},) * len(modules)
147 | if devices is not None:
148 | assert len(modules) == len(devices)
149 | else:
150 | devices = [None] * len(modules)
151 |
152 | lock = threading.Lock()
153 | results = {}
154 | if torch_ver != "0.3":
155 | grad_enabled = torch.is_grad_enabled()
156 |
157 | def _worker(i, module, input, target, kwargs, device=None):
158 | if torch_ver != "0.3":
159 | torch.set_grad_enabled(grad_enabled)
160 | if device is None:
161 | device = get_a_var(input).get_device()
162 | try:
163 | with torch.cuda.device(device):
164 | output = module(input, target, **kwargs)
165 | # output = module(*(input + target), **kwargs)
166 | with lock:
167 | results[i] = output
168 | except Exception as e:
169 | with lock:
170 | results[i] = e
171 |
172 | if len(modules) > 1:
173 | threads = [threading.Thread(target=_worker,
174 | args=(i, module, input, target,
175 | kwargs, device),)
176 | for i, (module, input, target, kwargs, device) in
177 | enumerate(zip(modules, inputs, targets, kwargs_tup, devices))]
178 |
179 | for thread in threads:
180 | thread.start()
181 | for thread in threads:
182 | thread.join()
183 | else:
184 | _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0])
185 |
186 | outputs = []
187 | for i in range(len(inputs)):
188 | output = results[i]
189 | if isinstance(output, Exception):
190 | raise output
191 | outputs.append(output)
192 | return outputs
193 |
194 |
195 | ###########################################################################
196 | # Adapted from Synchronized-BatchNorm-PyTorch.
197 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
198 | #
199 | class CallbackContext(object):
200 | pass
201 |
202 |
203 | def execute_replication_callbacks(modules):
204 | """
205 | Execute an replication callback `__data_parallel_replicate__` on each module created
206 | by original replication.
207 |
208 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
209 |
210 | Note that, as all modules are isomorphism, we assign each sub-module with a context
211 | (shared among multiple copies of this module on different devices).
212 | Through this context, different copies can share some information.
213 |
214 | We guarantee that the callback on the master copy (the first copy) will be called ahead
215 | of calling the callback of any slave copies.
216 | """
217 | master_copy = modules[0]
218 | nr_modules = len(list(master_copy.modules()))
219 | ctxs = [CallbackContext() for _ in range(nr_modules)]
220 |
221 | for i, module in enumerate(modules):
222 | for j, m in enumerate(module.modules()):
223 | if hasattr(m, '__data_parallel_replicate__'):
224 | m.__data_parallel_replicate__(ctxs[j], i)
225 |
226 |
227 | def patch_replication_callback(data_parallel):
228 | """
229 | Monkey-patch an existing `DataParallel` object. Add the replication callback.
230 | Useful when you have customized `DataParallel` implementation.
231 |
232 | Examples:
233 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
234 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
235 | > patch_replication_callback(sync_bn)
236 | # this is equivalent to
237 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
238 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
239 | """
240 |
241 | assert isinstance(data_parallel, DataParallel)
242 |
243 | old_replicate = data_parallel.replicate
244 |
245 | @functools.wraps(old_replicate)
246 | def new_replicate(module, device_ids):
247 | modules = old_replicate(module, device_ids)
248 | execute_replication_callbacks(modules)
249 | return modules
250 |
251 | data_parallel.replicate = new_replicate
--------------------------------------------------------------------------------