├── README.md ├── bib ├── detection_filter.py ├── eval.py ├── inplace_abn ├── __init__.py ├── bn.py ├── functions.py └── src │ ├── common.h │ ├── inplace_abn.cpp │ ├── inplace_abn.h │ ├── inplace_abn_cpu.cpp │ └── inplace_abn_cuda.cu ├── metrics └── iou.py ├── networks ├── __init__.py ├── deeplabv3.py ├── inits.py └── nets.py ├── requirements.txt └── utils ├── __init__.py └── parallel.py /README.md: -------------------------------------------------------------------------------- 1 | # Anchor diffusion VOS 2 | 3 | 4 | This repository contains code for the paper 5 | 6 | **Anchor Diffusion for Unsupervised Video Object Segmentation**
7 | Zhao Yang\*, [Qiang Wang](http://www.robots.ox.ac.uk/~qwang/)\*, [Luca Bertinetto](http://www.robots.ox.ac.uk/~luca), [Weiming Hu](https://scholar.google.com/citations?user=Wl4tl4QAAAAJ&hl=en), [Song Bai](http://songbai.site/), [Philip H.S. Torr](http://www.robots.ox.ac.uk/~tvg/)
8 | **ICCV 2019** | **[PDF](https://arxiv.org/abs/1910.10895)** | **[BibTex](bib)**
9 | 10 | ## Setup 11 | Code tested for Ubuntu 16.04, Python 3.7, PyTorch 0.4.1, and CUDA 9.2. 12 | 13 | * Clone the repository and change to the new directory. 14 | ``` 15 | git clone https://github.com/yz93/anchor-diff-VOS-internal.git && cd anchor-diff-VOS 16 | ``` 17 | * Save the working directory to an environment variable for reference. 18 | ```shell 19 | export AnchorDiff=$PWD 20 | ``` 21 | * Set up a new conda environment. 22 | * For installing PyTorch 0.4.1 with different versions of CUDA, see [here](https://pytorch.org/get-started/previous-versions/#via-conda). 23 | ``` 24 | conda create -n anchordiff python=3.7 pytorch=0.4.1 cuda92 -c pytorch 25 | source activate anchordiff 26 | pip install -r requirements.txt 27 | ``` 28 | 29 | 30 | ## Data preparation 31 | - Download the data set 32 | ```shell 33 | cd $AnchorDiff 34 | wget https://data.vision.ee.ethz.ch/csergi/share/davis/DAVIS-2017-trainval-480p.zip 35 | unzip DAVIS-2017-trainval-480p.zip -d data 36 | ``` 37 | * Download [pre-trained weights](https://drive.google.com/file/d/1A6ozn2FT2Ef1dn7HGNgUb1WboWFN2NON/view?usp=sharing) (1.5G) to $AnchorDiff. 38 | ```shell 39 | cd $AnchorDiff 40 | unzip snapshots.zip -d snapshots 41 | ``` 42 | * (If you do not intend to apply instance pruning described in the paper, feel free to skip this.) Download the detection results that we have computed using [ExtremeNet](https://github.com/xingyizhou/ExtremeNet), 43 | and generate the pruning masks. 44 | ```shell 45 | cd $AnchorDiff 46 | wget www.robots.ox.ac.uk/~yz/detection.zip 47 | unzip detection.zip 48 | python detection_filter.py 49 | ``` 50 | 51 | ## Evaluation on [DAVIS 2016](https://davischallenge.org/davis2016/code.html) 52 | * Examples for evaluating mean IoU on the validation set with options, 53 | * *save-mask* (default 'True') for saving the predicted masks, 54 | * *ms-mirror* (default 'False') for multiple-scale and mirrored input (slow), 55 | * *inst-prune* (default 'False') for instance pruning, 56 | * *model* (default 'ad') specifying models in Table 1 of the paper, 57 | * *eval-sal* (default 'False') for computing saliency measures, MAE and F-score. 58 | ```shell 59 | cd $AnchorDiff 60 | python eval.py 61 | python eval.py --ms-mirror True --inst-prune True --eval-sal True 62 | ``` 63 | * Use the [benchmark tool](https://github.com/davisvideochallenge/davis-matlab) to evaluate the saved masks under more metrics. 64 | * [Pre-computed results](https://www.robots.ox.ac.uk/~yz/val_results.zip) 65 | 66 | ## License 67 | The [MIT License](https://choosealicense.com/licenses/mit/). 68 | -------------------------------------------------------------------------------- /bib: -------------------------------------------------------------------------------- 1 | @inproceedings{yang2019anchor, 2 | title={Anchor Diffusion for Unsupervised Video Object Segmentation}, 3 | author={Yang, Zhao and Wang, Qiang and Bertinetto, Luca and Bai, Song and Hu, Weiming and Torr, Philip H.S.}, 4 | booktitle={ICCV}, 5 | year={2019} 6 | } 7 | -------------------------------------------------------------------------------- /detection_filter.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import cv2 4 | import numpy as np 5 | import pickle 6 | import shutil 7 | 8 | 9 | def _IoU(rect1, rect2): 10 | def inter(rect1, rect2): 11 | x1 = max(rect1[0], rect2[0]) 12 | y1 = max(rect1[1], rect2[1]) 13 | x2 = min(rect1[2], rect2[2]) 14 | y2 = min(rect1[3], rect2[3]) 15 | return max(x2 - x1 + 1, 0) * max(y2 - y1 + 1, 0) * 1. 16 | 17 | def area(rect): 18 | x1, y1, x2, y2 = rect 19 | return (x2 - x1 + 1) * (y2 - y1 + 1) 20 | 21 | ii = inter(rect1, rect2) 22 | iou = ii / (area(rect1) + area(rect2) - ii) 23 | return iou 24 | 25 | 26 | def vis_mask(img, mask, col, alpha=0.4, show_border=True, border_thick=2): 27 | """Visualizes a single binary mask.""" 28 | 29 | img = img.astype(np.float32) 30 | idx = np.nonzero(mask) 31 | 32 | img[idx[0], idx[1], :] *= 1.0 - alpha 33 | img[idx[0], idx[1], :] += alpha * col 34 | _WHITE = (255, 255, 255) 35 | if show_border: 36 | contours, _ = cv2.findContours(mask.copy(), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE) 37 | cv2.drawContours(img, contours, -1, _WHITE, border_thick, cv2.LINE_AA) 38 | 39 | return img.astype(np.uint8) 40 | 41 | 42 | def colormap(rgb=False): 43 | color_list = np.array( 44 | [ 45 | 0.000, 0.447, 0.741, 46 | 0.850, 0.325, 0.098, 47 | 0.929, 0.694, 0.125, 48 | 0.494, 0.184, 0.556, 49 | 0.466, 0.674, 0.188, 50 | 0.301, 0.745, 0.933, 51 | 0.635, 0.078, 0.184, 52 | 0.300, 0.300, 0.300, 53 | 0.600, 0.600, 0.600, 54 | 1.000, 0.000, 0.000, 55 | 1.000, 0.500, 0.000, 56 | 0.749, 0.749, 0.000, 57 | 0.000, 1.000, 0.000, 58 | 0.000, 0.000, 1.000, 59 | 0.667, 0.000, 1.000, 60 | 0.333, 0.333, 0.000, 61 | 0.333, 0.667, 0.000, 62 | 0.333, 1.000, 0.000, 63 | 0.667, 0.333, 0.000, 64 | 0.667, 0.667, 0.000, 65 | 0.667, 1.000, 0.000, 66 | 1.000, 0.333, 0.000, 67 | 1.000, 0.667, 0.000, 68 | 1.000, 1.000, 0.000, 69 | 0.000, 0.333, 0.500, 70 | 0.000, 0.667, 0.500, 71 | 0.000, 1.000, 0.500, 72 | 0.333, 0.000, 0.500, 73 | 0.333, 0.333, 0.500, 74 | 0.333, 0.667, 0.500, 75 | 0.333, 1.000, 0.500, 76 | 0.667, 0.000, 0.500, 77 | 0.667, 0.333, 0.500, 78 | 0.667, 0.667, 0.500, 79 | 0.667, 1.000, 0.500, 80 | 1.000, 0.000, 0.500, 81 | 1.000, 0.333, 0.500, 82 | 1.000, 0.667, 0.500, 83 | 1.000, 1.000, 0.500, 84 | 0.000, 0.333, 1.000, 85 | 0.000, 0.667, 1.000, 86 | 0.000, 1.000, 1.000, 87 | 0.333, 0.000, 1.000, 88 | 0.333, 0.333, 1.000, 89 | 0.333, 0.667, 1.000, 90 | 0.333, 1.000, 1.000, 91 | 0.667, 0.000, 1.000, 92 | 0.667, 0.333, 1.000, 93 | 0.667, 0.667, 1.000, 94 | 0.667, 1.000, 1.000, 95 | 1.000, 0.000, 1.000, 96 | 1.000, 0.333, 1.000, 97 | 1.000, 0.667, 1.000, 98 | 0.167, 0.000, 0.000, 99 | 0.333, 0.000, 0.000, 100 | 0.500, 0.000, 0.000, 101 | 0.667, 0.000, 0.000, 102 | 0.833, 0.000, 0.000, 103 | 1.000, 0.000, 0.000, 104 | 0.000, 0.167, 0.000, 105 | 0.000, 0.333, 0.000, 106 | 0.000, 0.500, 0.000, 107 | 0.000, 0.667, 0.000, 108 | 0.000, 0.833, 0.000, 109 | 0.000, 1.000, 0.000, 110 | 0.000, 0.000, 0.167, 111 | 0.000, 0.000, 0.333, 112 | 0.000, 0.000, 0.500, 113 | 0.000, 0.000, 0.667, 114 | 0.000, 0.000, 0.833, 115 | 0.000, 0.000, 1.000, 116 | 0.000, 0.000, 0.000, 117 | 0.143, 0.143, 0.143, 118 | 0.286, 0.286, 0.286, 119 | 0.429, 0.429, 0.429, 120 | 0.571, 0.571, 0.571, 121 | 0.714, 0.714, 0.714, 122 | 0.857, 0.857, 0.857, 123 | 1.000, 1.000, 1.000 124 | ] 125 | ).astype(np.float32) 126 | color_list = color_list.reshape((-1, 3)) * 255 127 | if not rgb: 128 | color_list = color_list[:, ::-1] 129 | return color_list 130 | 131 | 132 | videos = [i_id.strip() for i_id in open(os.path.join('./data/DAVIS/', 'ImageSets', '2016', 'val.txt'))] 133 | train_videos = [i_id.strip() for i_id in open(os.path.join('./data/DAVIS/', 'ImageSets', '2016', 'train.txt'))] 134 | frame_count = [] 135 | for video in train_videos: 136 | img_files = sorted( 137 | glob.glob(os.path.join('./data/DAVIS/', 'JPEGImages', '480p', video, '*.jpg'))) 138 | frame_count.append(len(img_files)) 139 | mean_frame_count = np.mean(frame_count) 140 | 141 | out_dir = './inst_prune' 142 | if os.path.exists(out_dir): 143 | shutil.rmtree(out_dir) 144 | 145 | for vid, video in enumerate(videos): 146 | def load_obj(name): 147 | with open('detection/' + name + '.pkl', 'rb') as f: 148 | return pickle.load(f) 149 | if not os.path.exists('detection/' + video + '.pkl'): 150 | print('no detection on:', video) 151 | continue 152 | detect_res = load_obj(video) 153 | frame_len = len(detect_res) 154 | bboxes_all = [] 155 | for frame_info in detect_res: 156 | for instance in detect_res[frame_info]: 157 | bboxes_all.append(instance['bbox']) 158 | 159 | mean_bboxes = len(bboxes_all)/frame_len 160 | first_remove = -1 161 | if mean_bboxes > 3: 162 | color_list = colormap(rgb=True) 163 | size_bboxes = [(bbox[3] - bbox[1]) * (bbox[2] - bbox[0]) for bbox in bboxes_all] 164 | size_bboxes = sorted(size_bboxes) 165 | size_bboxes_target = size_bboxes[-frame_len] 166 | 167 | img_files = sorted( 168 | glob.glob(os.path.join('./data/DAVIS/', 'JPEGImages', '480p', video, '*.jpg'))) 169 | 170 | for f, img_file in enumerate(img_files): 171 | im = cv2.imread(img_file, cv2.IMREAD_COLOR) 172 | frame_bboxes = [] 173 | frame_masks = [] 174 | for id, instance in enumerate(detect_res[f]): 175 | frame_bboxes.append(instance['bbox']) 176 | frame_masks.append(instance['mask']) 177 | bbox = instance['bbox'] 178 | mask = instance['mask'] 179 | cv2.rectangle(im, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (0, 255, 0), 3) 180 | im = vis_mask(im, mask, color_list[id % len(color_list), :3], alpha=0.4) 181 | 182 | size_bboxes = [(bbox[3]-bbox[1])*(bbox[2]-bbox[0])for bbox in frame_bboxes] 183 | score_bboxes = [bbox[4] for bbox in frame_bboxes] 184 | size_bboxes_ind = np.argsort(size_bboxes) 185 | size_bboxes = sorted(size_bboxes) 186 | 187 | target_box = frame_bboxes[size_bboxes_ind[-1]] 188 | 189 | im_mask = np.ones((im.shape[0], im.shape[1])) 190 | 191 | for bbox, mask in zip(frame_bboxes, frame_masks): 192 | static_object_count = 0 193 | if (bbox[3]-bbox[1])*(bbox[2]-bbox[0]) > 47000 or (bbox[3]-bbox[1])*(bbox[2]-bbox[0]) == size_bboxes[-1]: 194 | continue 195 | for i in range(len(bboxes_all)): 196 | if _IoU(bbox[:4], bboxes_all[i][:4]) > 0.6: 197 | static_object_count += 1 198 | 199 | if static_object_count > 0.4 * mean_frame_count: 200 | im_mask = im_mask * (1 - mask) 201 | cv2.putText(im, 'static', (bbox[0], bbox[1]), 2, 2, (0, 255, 0)) 202 | 203 | if len(size_bboxes) > 1: 204 | if size_bboxes[-1] > 10000 and size_bboxes[-1] > 2*size_bboxes[-2] and \ 205 | size_bboxes[-1] > size_bboxes_target and \ 206 | (target_box[-1] == 0 or target_box[-1] == 2) and \ 207 | (target_box[2]-target_box[0])/(target_box[3]-target_box[1]) < 3: 208 | suppress_small = True 209 | if first_remove == -1: 210 | first_remove = f 211 | if first_remove > 20: 212 | break 213 | else: 214 | suppress_small = False 215 | 216 | if suppress_small: 217 | for bbox, mask in zip(frame_bboxes, frame_masks): 218 | cx = (bbox[3]+bbox[1])/2 219 | cy = (bbox[2]+bbox[0])/2 220 | cx0 = (target_box[3] + target_box[1]) / 2 221 | cy0 = (target_box[2] + target_box[0]) / 2 222 | d_dist = abs(cx-cx0)+ abs(cy-cy0) 223 | if ((bbox[3]-bbox[1])*(bbox[2]-bbox[0]) < size_bboxes[-1]//3 or 224 | ((cy < 300 or cy > 600) and d_dist > 200)) and \ 225 | _IoU(target_box[:4], bbox[:4]) <= 0.1 and \ 226 | bbox[-1] == frame_bboxes[size_bboxes_ind[-1]][-1]: 227 | # print(cx) 228 | 229 | # cv2.rectangle(im, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (0, 0, 255), 5) 230 | # im_mask[int(bbox[1]):int(bbox[3]), int(bbox[0]):int(bbox[2])] = 0 231 | im_mask = im_mask*(1 - mask) 232 | 233 | result_dir = os.path.join(out_dir, video) 234 | if not os.path.exists(result_dir): 235 | os.makedirs(result_dir) 236 | 237 | cv2.imwrite(os.path.join(result_dir, img_file.split('/')[-1].split('.')[0] + '.png'), im_mask*255) 238 | im_mask = (im_mask*255).astype(np.uint8) 239 | im = cv2.vconcat((cv2.cvtColor(im_mask.copy(), cv2.COLOR_GRAY2BGR), im)) 240 | im = cv2.resize(im, dsize=None, fx=0.5, fy=0.5) 241 | # cv2.imshow('mask', im_mask*255) 242 | # cv2.imshow(video, im) 243 | # cv2.waitKey(1) 244 | cv2.destroyAllWindows() 245 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | from torch.nn import functional as F 5 | import numpy as np 6 | import torch.backends.cudnn as cudnn 7 | import os 8 | import os.path as osp 9 | import glob 10 | import cv2 11 | from PIL import Image 12 | import pickle 13 | from networks.deeplabv3 import ResNetDeepLabv3 14 | from networks.nets import AnchorDiffNet, ConcatNet, InterFrameNet, IntraFrameNet 15 | 16 | import timeit 17 | from metrics.iou import get_iou 18 | 19 | from sklearn.metrics import precision_recall_curve 20 | 21 | 22 | start = timeit.default_timer() 23 | 24 | IMG_MEAN = np.array((104.00698793, 116.66876762, 122.67891434), dtype=np.float32) 25 | BACKBONE = 'ResNet101' 26 | BN_TYPE = 'sync' 27 | DATA_DIRECTORY = './data/DAVIS/' 28 | EMBEDDING_SIZE = 128 29 | THRESHOLD = 0.5 # the threshold over raw scores (not the output of a sigmoid function) 30 | PYRAMID_POOLING = 'deeplabv3' 31 | VISUALIZE = False # False True 32 | MS_MIRROR = False # False True 33 | INSTANCE_PRUNING = False # False True 34 | PARENT_DIR_WEIGHTS = './snapshots/' 35 | SAVE_MASK = True # False True 36 | SAVE_MASK_DIR = './pred_masks/' 37 | EVAL_SAL = False # False True 38 | SAVE_HEATMAP_DIR = './pred_heatmaps/' 39 | # MODEL = 'base' 40 | # MODEL = 'concat' 41 | # MODEL = 'intra' 42 | # MODEL = 'inter' 43 | MODEL = 'ad' 44 | 45 | 46 | def str2bool(v): 47 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 48 | return True 49 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 50 | return False 51 | else: 52 | raise argparse.ArgumentTypeError('Boolean value expected.') 53 | 54 | 55 | def bool2str(v): 56 | if v: 57 | return 'True' 58 | else: 59 | return 'False' 60 | 61 | 62 | def get_arguments(): 63 | parser = argparse.ArgumentParser(description="Anchor Diffusion VOS Test") 64 | parser.add_argument("--backbone", type=str, default=BACKBONE, 65 | help="Feature encoder.") 66 | parser.add_argument("--bn-type", type=str, default=BN_TYPE, 67 | help="BatchNorm MODE, [old/sync].") 68 | parser.add_argument("--data-dir", type=str, default=DATA_DIRECTORY, 69 | help="Path to the data directory.") 70 | parser.add_argument("--embedding-size", type=int, default=EMBEDDING_SIZE, 71 | help="Number of dimensions along the channel axis of pixel embeddings.") 72 | parser.add_argument("--eval-sal", type=str2bool, default=bool2str(EVAL_SAL), 73 | help="Whether to report MAE and F-score.") 74 | parser.add_argument("--inst-prune", type=str2bool, default=bool2str(INSTANCE_PRUNING), 75 | help="Whether to post-process the results with instance pruning") 76 | parser.add_argument("--ms-mirror", type=str2bool, default=bool2str(MS_MIRROR), 77 | help="Whether to mirror and re-scale the input image.") 78 | parser.add_argument("--model", type=str, default=MODEL, help="Overall models.") 79 | parser.add_argument("--pyramid-pooling", type=str, default=PYRAMID_POOLING, 80 | help="Pyramid pooling methods.") 81 | parser.add_argument("--parent-dir-weights", type=str, default=PARENT_DIR_WEIGHTS, 82 | help="Parent directory of pre-trained weights") 83 | parser.add_argument("--save-heatmap-dir", type=str, default=SAVE_HEATMAP_DIR, 84 | help="Path to save the outputs of sigmoid for MAE and F-score evaluation.") 85 | parser.add_argument("--save-mask", type=str2bool, default=bool2str(SAVE_MASK), 86 | help="Whether to save the predicted masks.") 87 | parser.add_argument("--save-mask-dir", type=str, default=SAVE_MASK_DIR, 88 | help="Path to save the predicted masks.") 89 | parser.add_argument("--threshold", type=float, default=THRESHOLD, 90 | help="Threshold on the raw scores (the logits before sigmoid normalization).") 91 | parser.add_argument("--visualize", type=str2bool, default=bool2str(VISUALIZE), 92 | help="Whether to visualize the predicted masks during inference.") 93 | parser.add_argument("--video", type=str, default='', 94 | help="If non-empty, then run inference on the specified video.") 95 | return parser.parse_args() 96 | 97 | 98 | args = get_arguments() 99 | 100 | 101 | def main(): 102 | cudnn.enabled = True 103 | 104 | if args.model == 'base': 105 | model = ResNetDeepLabv3(backbone=args.backbone) 106 | elif args.model == 'intra': 107 | model = IntraFrameNet(backbone=args.backbone, pyramid_pooling=args.pyramid_pooling, 108 | embedding=args.embedding_size, batch_mode='sync') 109 | elif args.model == 'inter': 110 | model = InterFrameNet(backbone=args.backbone, pyramid_pooling=args.pyramid_pooling, 111 | embedding=args.embedding_size, batch_mode='sync') 112 | elif args.model == 'concat': 113 | model = ConcatNet(backbone=args.backbone, pyramid_pooling=args.pyramid_pooling, 114 | embedding=args.embedding_size, batch_mode='sync') 115 | elif args.model == 'ad': 116 | model = AnchorDiffNet(backbone=args.backbone, pyramid_pooling=args.pyramid_pooling, 117 | embedding=args.embedding_size, batch_mode='sync') 118 | 119 | model.load_state_dict(torch.load(osp.join(args.parent_dir_weights, args.model+'.pth'))) 120 | model.eval() 121 | model.float() 122 | model.cuda() 123 | 124 | with torch.no_grad(): 125 | video_mean_iou_list = [] 126 | model.eval() 127 | videos = [i_id.strip() for i_id in open(osp.join(args.data_dir, 'ImageSets', '2016', 'val.txt'))] 128 | if args.video and args.video in videos: 129 | videos = [args.video] 130 | 131 | for vid, video in enumerate(videos, start=1): 132 | curr_video_iou_list = [] 133 | img_files = sorted(glob.glob(osp.join(args.data_dir, 'JPEGImages', '480p', video, '*.jpg'))) 134 | ann_files = sorted(glob.glob(osp.join(args.data_dir, 'Annotations', '480p', video, '*.png'))) 135 | 136 | if args.ms_mirror: 137 | resize_shape = [(857*0.75, 481*0.75), (857, 481), (857*1.5, 481*1.5)] 138 | resize_shape = [(int((s[0]-1)//8*8+1), int((s[1]-1)//8*8+1)) for s in resize_shape] 139 | mirror = True 140 | else: 141 | resize_shape = [(857, 481)] 142 | mirror = False 143 | 144 | reference_img = [] 145 | for s in resize_shape: 146 | reference_img.append((np.asarray(cv2.resize(cv2.imread(img_files[0], cv2.IMREAD_COLOR), s), 147 | np.float32) - IMG_MEAN).transpose((2, 0, 1))) 148 | if mirror: 149 | for r in range(len(reference_img)): 150 | reference_img.append(reference_img[r][:, :, ::-1].copy()) 151 | reference_img = [torch.from_numpy(np.expand_dims(r, axis=0)).cuda() for r in reference_img] 152 | reference_mask = np.array(Image.open(ann_files[0])) > 0 153 | reference_mask = torch.from_numpy(np.expand_dims(np.expand_dims(reference_mask.astype(np.float32), 154 | axis=0), axis=0)).cuda() 155 | H, W = reference_mask.size(2), reference_mask.size(3) 156 | 157 | if args.visualize: 158 | colors = np.random.randint(128, 255, size=(1, 3), dtype="uint8") 159 | colors = np.vstack([[0, 0, 0], colors]).astype("uint8") 160 | 161 | last_mask_num = 0 162 | last_mask = None 163 | last_mask_final = None 164 | kernel1 = np.ones((15, 15), np.uint8) 165 | kernel2 = np.ones((101, 101), np.uint8) 166 | kernel3 = np.ones((31, 31), np.uint8) 167 | predictions_all = [] 168 | gt_all = [] 169 | 170 | for f, (img_file, ann_file) in enumerate(zip(img_files, ann_files)): 171 | current_img = [] 172 | for s in resize_shape: 173 | current_img.append((np.asarray(cv2.resize( 174 | cv2.imread(img_file, cv2.IMREAD_COLOR), s), 175 | np.float32) - IMG_MEAN).transpose((2, 0, 1))) 176 | 177 | if mirror: 178 | for c in range(len(current_img)): 179 | current_img.append(current_img[c][:, :, ::-1].copy()) 180 | 181 | current_img = [torch.from_numpy(np.expand_dims(c, axis=0)).cuda() for c in current_img] 182 | 183 | current_mask = np.array(Image.open(ann_file)) > 0 184 | current_mask = torch.from_numpy(np.expand_dims(np.expand_dims(current_mask.astype(np.float32), axis=0), axis=0)).cuda() 185 | 186 | if args.model in ['base']: 187 | predictions = [model(cur) for ref, cur in zip(reference_img, current_img)] 188 | predictions = [F.interpolate(input=p[0], size=(H, W), mode='bilinear', align_corners=True) for p in predictions] 189 | elif args.model in ['intra']: 190 | predictions = [model(cur) for ref, cur in zip(reference_img, current_img)] 191 | predictions = [F.interpolate(input=p, size=(H, W), mode='bilinear', align_corners=True) for p in predictions] 192 | elif args.model in ['inter', 'concat', 'ad']: 193 | predictions = [model(ref, cur) for ref, cur in zip(reference_img, current_img)] 194 | predictions = [F.interpolate(input=p, size=(H, W), mode='bilinear', align_corners=True) for p in predictions] 195 | 196 | if mirror: 197 | for r in range(len(predictions)//2, len(predictions)): 198 | predictions[r] = torch.flip(predictions[r], [3]) 199 | predictions = torch.mean(torch.stack(predictions, dim=0), 0) 200 | 201 | predictions_all.append(predictions.sigmoid().data.cpu().numpy()[0, 0].copy()) 202 | gt_all.append(current_mask.data.cpu().numpy()[0, 0].astype(np.uint8).copy()) 203 | 204 | if args.inst_prune: 205 | result_dir = os.path.join('inst_prune', video) 206 | if os.path.exists(os.path.join(result_dir, img_file.split('/')[-1].split('.')[0] + '.png')): 207 | detection_mask = np.array( 208 | Image.open(os.path.join(result_dir, img_file.split('/')[-1].split('.')[0] + '.png'))) > 0 209 | detection_mask = torch.from_numpy( 210 | np.expand_dims(np.expand_dims(detection_mask.astype(np.float32), axis=0), axis=0)).cuda() 211 | predictions = predictions * detection_mask 212 | 213 | process_now = (predictions > args.threshold).data.cpu().numpy().astype(np.uint8)[0, 0] 214 | if 100000 > process_now.sum() > 40000: 215 | last_mask_numpy = (predictions > args.threshold).data.cpu().numpy().astype(np.uint8)[0, 0] 216 | last_mask_numpy = cv2.morphologyEx(last_mask_numpy, cv2.MORPH_OPEN, kernel1) 217 | dilation = cv2.dilate(last_mask_numpy, kernel3, iterations=1) 218 | contours, _ = cv2.findContours(dilation, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) 219 | cnt_area = [cv2.contourArea(cnt) for cnt in contours] 220 | if len(contours) > 1: 221 | contour = contours[np.argmax(cnt_area)] 222 | polygon = contour.reshape(-1, 2) 223 | x, y, w, h = cv2.boundingRect(polygon) 224 | x0, y0 = x, y 225 | x1 = x + w 226 | y1 = y + h 227 | mask_rect = torch.from_numpy(np.zeros_like(dilation).astype(np.float32)).cuda() 228 | mask_rect[y0:y1, x0:x1] = 1 229 | mask_rect = mask_rect.unsqueeze(0).unsqueeze(0) 230 | if np.max(cnt_area) > 30000: 231 | if last_mask_final is None or get_iou(last_mask_final, mask_rect, thresh=args.threshold) > 0.3: 232 | predictions = predictions * mask_rect 233 | last_mask_final = predictions.clone() 234 | 235 | if 100000 > last_mask_num > 5000: 236 | last_mask_numpy = (last_mask > args.threshold).data.cpu().numpy().astype(np.uint8)[0, 0] 237 | last_mask_numpy = cv2.morphologyEx(last_mask_numpy, cv2.MORPH_OPEN, kernel1) 238 | dilation = cv2.dilate(last_mask_numpy, kernel2, iterations=1) 239 | dilation = torch.from_numpy(dilation.astype(np.float32)).cuda() 240 | 241 | last_mask = predictions.clone() 242 | last_mask_num = (predictions > args.threshold).sum() 243 | 244 | predictions = predictions*dilation 245 | else: 246 | last_mask = predictions.clone() 247 | last_mask_num = (predictions > args.threshold).sum() 248 | 249 | iou_temp = get_iou(predictions, current_mask, thresh=args.threshold) 250 | if 0 < f < (len(ann_files)-1): 251 | curr_video_iou_list.append(iou_temp) 252 | 253 | if args.visualize: 254 | mask = colors[predictions.squeeze() > args.threshold] 255 | output = ((0.4 * cv2.imread(img_file)) + (0.6 * mask)).astype("uint8") 256 | cv2.putText(output, "%.3f" % (iou_temp.item()), 257 | (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2, cv2.LINE_AA) 258 | 259 | cv2.imshow(video, output) 260 | cv2.waitKey(1) 261 | 262 | suffix = args.ms_mirror*'ms_mirror'+(not args.ms_mirror)*'single'+args.inst_prune*'_prune' 263 | visual_path = osp.join('visualization', args.model + '_' + suffix, img_file.split('/')[-2]) 264 | if not osp.exists(visual_path): 265 | os.makedirs(visual_path) 266 | cv2.imwrite(osp.join(visual_path, ann_file.split('/')[-1]), output) 267 | 268 | if args.save_mask: 269 | suffix = args.ms_mirror*'ms_mirror'+(not args.ms_mirror)*'single'+args.inst_prune*'_prune' 270 | if not osp.exists(osp.join(args.save_mask_dir, args.model, suffix, video)): 271 | os.makedirs(osp.join(args.save_mask_dir, args.model, suffix, video)) 272 | cv2.imwrite(osp.join(args.save_mask_dir, args.model, suffix, video, ann_file.split('/')[-1]), 273 | (predictions.squeeze() > args.threshold).cpu().numpy()) 274 | 275 | cv2.destroyAllWindows() 276 | video_mean_iou_list.append(sum(curr_video_iou_list)/len(curr_video_iou_list)) 277 | print('{} {} {}'.format(vid, video, video_mean_iou_list[-1])) 278 | 279 | if args.eval_sal: 280 | if not osp.exists(args.save_heatmap_dir): 281 | os.makedirs(args.save_heatmap_dir) 282 | with open(args.save_heatmap_dir + video + '.pkl', 'wb') as f: 283 | pickle.dump({'pred': np.array(predictions_all), 'gt': np.array(gt_all)}, f, pickle.HIGHEST_PROTOCOL) 284 | 285 | mean_iou = sum(video_mean_iou_list)/len(video_mean_iou_list) 286 | print('mean_iou {}'.format(mean_iou)) 287 | end = timeit.default_timer() 288 | print(end-start, 'seconds') 289 | # ========================== 290 | if args.eval_sal: 291 | pkl_files = glob.glob(args.save_heatmap_dir + '*.pkl') 292 | heatmap_gt = [] 293 | heatmap_pred = [] 294 | for i, pkl_file in enumerate(pkl_files): 295 | with open(pkl_file, 'rb') as f: 296 | info = pickle.load(f) 297 | heatmap_gt.append(np.array(info['gt'][1:-1]).flatten()) 298 | heatmap_pred.append(np.array(info['pred'][1:-1]).flatten()) 299 | heatmap_gt = np.hstack(heatmap_gt).flatten() 300 | heatmap_pred = np.hstack(heatmap_pred).flatten() 301 | precision, recall, _ = precision_recall_curve(heatmap_gt, heatmap_pred) 302 | Fmax = 2 * (precision * recall) / (precision + recall) 303 | print('MAE', np.mean(abs(heatmap_pred - heatmap_gt))) 304 | print('F_max', Fmax.max()) 305 | 306 | n_sample = len(precision)//1000 307 | import scipy.io 308 | scipy.io.savemat('davis.mat', {'recall': recall[0::n_sample], 'precision': precision[0::n_sample]}) 309 | 310 | 311 | if __name__ == '__main__': 312 | main() 313 | -------------------------------------------------------------------------------- /inplace_abn/__init__.py: -------------------------------------------------------------------------------- 1 | from .bn import ABN, InPlaceABN, InPlaceABNSync 2 | from .functions import ACT_RELU, ACT_LEAKY_RELU, ACT_ELU, ACT_NONE 3 | from .misc import GlobalAvgPool2d 4 | from .residual import IdentityResidualBlock 5 | from .dense import DenseModule 6 | -------------------------------------------------------------------------------- /inplace_abn/bn.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as functional 5 | 6 | try: 7 | from queue import Queue 8 | except ImportError: 9 | from Queue import Queue 10 | 11 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 12 | sys.path.append(BASE_DIR) 13 | sys.path.append(os.path.join(BASE_DIR, '../src')) 14 | from functions import * 15 | 16 | 17 | class ABN(nn.Module): 18 | """Activated Batch Normalization 19 | 20 | This gathers a `BatchNorm2d` and an activation function in a single module 21 | """ 22 | 23 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu", slope=0.01): 24 | """Creates an Activated Batch Normalization module 25 | 26 | Parameters 27 | ---------- 28 | num_features : int 29 | Number of feature channels in the input and output. 30 | eps : float 31 | Small constant to prevent numerical issues. 32 | momentum : float 33 | Momentum factor applied to compute running statistics as. 34 | affine : bool 35 | If `True` apply learned scale and shift transformation after normalization. 36 | activation : str 37 | Name of the activation functions, one of: `leaky_relu`, `elu` or `none`. 38 | slope : float 39 | Negative slope for the `leaky_relu` activation. 40 | """ 41 | super(ABN, self).__init__() 42 | self.num_features = num_features 43 | self.affine = affine 44 | self.eps = eps 45 | self.momentum = momentum 46 | self.activation = activation 47 | self.slope = slope 48 | if self.affine: 49 | self.weight = nn.Parameter(torch.ones(num_features)) 50 | self.bias = nn.Parameter(torch.zeros(num_features)) 51 | else: 52 | self.register_parameter('weight', None) 53 | self.register_parameter('bias', None) 54 | self.register_buffer('running_mean', torch.zeros(num_features)) 55 | self.register_buffer('running_var', torch.ones(num_features)) 56 | self.reset_parameters() 57 | 58 | def reset_parameters(self): 59 | nn.init.constant_(self.running_mean, 0) 60 | nn.init.constant_(self.running_var, 1) 61 | if self.affine: 62 | nn.init.constant_(self.weight, 1) 63 | nn.init.constant_(self.bias, 0) 64 | 65 | def forward(self, x): 66 | x = functional.batch_norm(x, self.running_mean, self.running_var, self.weight, self.bias, 67 | self.training, self.momentum, self.eps) 68 | 69 | if self.activation == ACT_RELU: 70 | return functional.relu(x, inplace=True) 71 | elif self.activation == ACT_LEAKY_RELU: 72 | return functional.leaky_relu(x, negative_slope=self.slope, inplace=True) 73 | elif self.activation == ACT_ELU: 74 | return functional.elu(x, inplace=True) 75 | else: 76 | return x 77 | 78 | def __repr__(self): 79 | rep = '{name}({num_features}, eps={eps}, momentum={momentum},' \ 80 | ' affine={affine}, activation={activation}' 81 | if self.activation == "leaky_relu": 82 | rep += ', slope={slope})' 83 | else: 84 | rep += ')' 85 | return rep.format(name=self.__class__.__name__, **self.__dict__) 86 | 87 | 88 | class InPlaceABN(ABN): 89 | """InPlace Activated Batch Normalization""" 90 | 91 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu", slope=0.01): 92 | """Creates an InPlace Activated Batch Normalization module 93 | 94 | Parameters 95 | ---------- 96 | num_features : int 97 | Number of feature channels in the input and output. 98 | eps : float 99 | Small constant to prevent numerical issues. 100 | momentum : float 101 | Momentum factor applied to compute running statistics as. 102 | affine : bool 103 | If `True` apply learned scale and shift transformation after normalization. 104 | activation : str 105 | Name of the activation functions, one of: `leaky_relu`, `elu` or `none`. 106 | slope : float 107 | Negative slope for the `leaky_relu` activation. 108 | """ 109 | super(InPlaceABN, self).__init__(num_features, eps, momentum, affine, activation, slope) 110 | 111 | def forward(self, x): 112 | return inplace_abn(x, self.weight, self.bias, self.running_mean, self.running_var, 113 | self.training, self.momentum, self.eps, self.activation, self.slope) 114 | 115 | 116 | class InPlaceABNSync(ABN): 117 | """InPlace Activated Batch Normalization with cross-GPU synchronization 118 | 119 | This assumes that it will be replicated across GPUs using the same mechanism as in `nn.DataParallel`. 120 | """ 121 | 122 | def __init__(self, num_features, devices=None, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu", 123 | slope=0.01): 124 | """Creates a synchronized, InPlace Activated Batch Normalization module 125 | 126 | Parameters 127 | ---------- 128 | num_features : int 129 | Number of feature channels in the input and output. 130 | devices : list of int or None 131 | IDs of the GPUs that will run the replicas of this module. 132 | eps : float 133 | Small constant to prevent numerical issues. 134 | momentum : float 135 | Momentum factor applied to compute running statistics as. 136 | affine : bool 137 | If `True` apply learned scale and shift transformation after normalization. 138 | activation : str 139 | Name of the activation functions, one of: `leaky_relu`, `elu` or `none`. 140 | slope : float 141 | Negative slope for the `leaky_relu` activation. 142 | """ 143 | super(InPlaceABNSync, self).__init__(num_features, eps, momentum, affine, activation, slope) 144 | self.devices = devices if devices else list(range(torch.cuda.device_count())) 145 | 146 | # Initialize queues 147 | self.worker_ids = self.devices[1:] 148 | self.master_queue = Queue(len(self.worker_ids)) 149 | self.worker_queues = [Queue(1) for _ in self.worker_ids] 150 | 151 | def forward(self, x): 152 | if x.get_device() == self.devices[0]: 153 | # Master mode 154 | extra = { 155 | "is_master": True, 156 | "master_queue": self.master_queue, 157 | "worker_queues": self.worker_queues, 158 | "worker_ids": self.worker_ids 159 | } 160 | else: 161 | # Worker mode 162 | extra = { 163 | "is_master": False, 164 | "master_queue": self.master_queue, 165 | "worker_queue": self.worker_queues[self.worker_ids.index(x.get_device())] 166 | } 167 | 168 | return inplace_abn_sync(x, self.weight, self.bias, self.running_mean, self.running_var, 169 | extra, self.training, self.momentum, self.eps, self.activation, self.slope) 170 | 171 | def __repr__(self): 172 | rep = '{name}({num_features}, eps={eps}, momentum={momentum},' \ 173 | ' affine={affine}, devices={devices}, activation={activation}' 174 | if self.activation == "leaky_relu": 175 | rep += ', slope={slope})' 176 | else: 177 | rep += ')' 178 | return rep.format(name=self.__class__.__name__, **self.__dict__) 179 | -------------------------------------------------------------------------------- /inplace_abn/functions.py: -------------------------------------------------------------------------------- 1 | from os import path 2 | 3 | import torch.autograd as autograd 4 | import torch.cuda.comm as comm 5 | from torch.autograd.function import once_differentiable 6 | from torch.utils.cpp_extension import load 7 | 8 | _src_path = path.join(path.dirname(path.abspath(__file__)), "src") 9 | _backend = load(name="inplace_abn", 10 | extra_cflags=["-O3"], 11 | sources=[path.join(_src_path, f) for f in [ 12 | "inplace_abn.cpp", 13 | "inplace_abn_cpu.cpp", 14 | "inplace_abn_cuda.cu" 15 | ]], 16 | extra_cuda_cflags=["--expt-extended-lambda"]) 17 | 18 | # Activation names 19 | ACT_RELU = "relu" 20 | ACT_LEAKY_RELU = "leaky_relu" 21 | ACT_ELU = "elu" 22 | ACT_NONE = "none" 23 | 24 | 25 | def _check(fn, *args, **kwargs): 26 | success = fn(*args, **kwargs) 27 | if not success: 28 | raise RuntimeError("CUDA Error encountered in {}".format(fn)) 29 | 30 | 31 | def _broadcast_shape(x): 32 | out_size = [] 33 | for i, s in enumerate(x.size()): 34 | if i != 1: 35 | out_size.append(1) 36 | else: 37 | out_size.append(s) 38 | return out_size 39 | 40 | 41 | def _reduce(x): 42 | if len(x.size()) == 2: 43 | return x.sum(dim=0) 44 | else: 45 | n, c = x.size()[0:2] 46 | return x.contiguous().view((n, c, -1)).sum(2).sum(0) 47 | 48 | 49 | def _count_samples(x): 50 | count = 1 51 | for i, s in enumerate(x.size()): 52 | if i != 1: 53 | count *= s 54 | return count 55 | 56 | 57 | def _act_forward(ctx, x): 58 | if ctx.activation == ACT_LEAKY_RELU: 59 | _backend.leaky_relu_forward(x, ctx.slope) 60 | elif ctx.activation == ACT_ELU: 61 | _backend.elu_forward(x) 62 | elif ctx.activation == ACT_NONE: 63 | pass 64 | 65 | 66 | def _act_backward(ctx, x, dx): 67 | if ctx.activation == ACT_LEAKY_RELU: 68 | _backend.leaky_relu_backward(x, dx, ctx.slope) 69 | elif ctx.activation == ACT_ELU: 70 | _backend.elu_backward(x, dx) 71 | elif ctx.activation == ACT_NONE: 72 | pass 73 | 74 | 75 | class InPlaceABN(autograd.Function): 76 | @staticmethod 77 | def forward(ctx, x, weight, bias, running_mean, running_var, 78 | training=True, momentum=0.1, eps=1e-05, activation=ACT_LEAKY_RELU, slope=0.01): 79 | # Save context 80 | ctx.training = training 81 | ctx.momentum = momentum 82 | ctx.eps = eps 83 | ctx.activation = activation 84 | ctx.slope = slope 85 | ctx.affine = weight is not None and bias is not None 86 | 87 | # Prepare inputs 88 | count = _count_samples(x) 89 | x = x.contiguous() 90 | weight = weight.contiguous() if ctx.affine else x.new_empty(0) 91 | bias = bias.contiguous() if ctx.affine else x.new_empty(0) 92 | 93 | if ctx.training: 94 | mean, var = _backend.mean_var(x) 95 | 96 | # Update running stats 97 | running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean) 98 | running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * var * count / (count - 1)) 99 | 100 | # Mark in-place modified tensors 101 | ctx.mark_dirty(x, running_mean, running_var) 102 | else: 103 | mean, var = running_mean.contiguous(), running_var.contiguous() 104 | ctx.mark_dirty(x) 105 | 106 | # BN forward + activation 107 | _backend.forward(x, mean, var, weight, bias, ctx.affine, ctx.eps) 108 | _act_forward(ctx, x) 109 | 110 | # Output 111 | ctx.var = var 112 | ctx.save_for_backward(x, var, weight, bias) 113 | return x 114 | 115 | @staticmethod 116 | @once_differentiable 117 | def backward(ctx, dz): 118 | z, var, weight, bias = ctx.saved_tensors 119 | dz = dz.contiguous() 120 | 121 | # Undo activation 122 | _act_backward(ctx, z, dz) 123 | 124 | if ctx.training: 125 | edz, eydz = _backend.edz_eydz(z, dz, weight, bias, ctx.affine, ctx.eps) 126 | else: 127 | # TODO: implement simplified CUDA backward for inference mode 128 | edz = dz.new_zeros(dz.size(1)) 129 | eydz = dz.new_zeros(dz.size(1)) 130 | 131 | dx, dweight, dbias = _backend.backward(z, dz, var, weight, bias, edz, eydz, ctx.affine, ctx.eps) 132 | dweight = dweight if ctx.affine else None 133 | dbias = dbias if ctx.affine else None 134 | 135 | return dx, dweight, dbias, None, None, None, None, None, None, None 136 | 137 | 138 | class InPlaceABNSync(autograd.Function): 139 | @classmethod 140 | def forward(cls, ctx, x, weight, bias, running_mean, running_var, 141 | extra, training=True, momentum=0.1, eps=1e-05, activation=ACT_LEAKY_RELU, slope=0.01): 142 | # Save context 143 | cls._parse_extra(ctx, extra) 144 | ctx.training = training 145 | ctx.momentum = momentum 146 | ctx.eps = eps 147 | ctx.activation = activation 148 | ctx.slope = slope 149 | ctx.affine = weight is not None and bias is not None 150 | 151 | # Prepare inputs 152 | count = _count_samples(x) * (ctx.master_queue.maxsize + 1) 153 | x = x.contiguous() 154 | weight = weight.contiguous() if ctx.affine else x.new_empty(0) 155 | bias = bias.contiguous() if ctx.affine else x.new_empty(0) 156 | 157 | if ctx.training: 158 | mean, var = _backend.mean_var(x) 159 | 160 | if ctx.is_master: 161 | means, vars = [mean.unsqueeze(0)], [var.unsqueeze(0)] 162 | for _ in range(ctx.master_queue.maxsize): 163 | mean_w, var_w = ctx.master_queue.get() 164 | ctx.master_queue.task_done() 165 | means.append(mean_w.unsqueeze(0)) 166 | vars.append(var_w.unsqueeze(0)) 167 | 168 | means = comm.gather(means) 169 | vars = comm.gather(vars) 170 | 171 | mean = means.mean(0) 172 | var = (vars + (mean - means) ** 2).mean(0) 173 | 174 | tensors = comm.broadcast_coalesced((mean, var), [mean.get_device()] + ctx.worker_ids) 175 | for ts, queue in zip(tensors[1:], ctx.worker_queues): 176 | queue.put(ts) 177 | else: 178 | ctx.master_queue.put((mean, var)) 179 | mean, var = ctx.worker_queue.get() 180 | ctx.worker_queue.task_done() 181 | 182 | # Update running stats 183 | running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean) 184 | running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * var * count / (count - 1)) 185 | 186 | # Mark in-place modified tensors 187 | ctx.mark_dirty(x, running_mean, running_var) 188 | else: 189 | mean, var = running_mean.contiguous(), running_var.contiguous() 190 | ctx.mark_dirty(x) 191 | 192 | # BN forward + activation 193 | _backend.forward(x, mean, var, weight, bias, ctx.affine, ctx.eps) 194 | _act_forward(ctx, x) 195 | 196 | # Output 197 | ctx.var = var 198 | ctx.save_for_backward(x, var, weight, bias) 199 | return x 200 | 201 | @staticmethod 202 | @once_differentiable 203 | def backward(ctx, dz): 204 | z, var, weight, bias = ctx.saved_tensors 205 | dz = dz.contiguous() 206 | 207 | # Undo activation 208 | _act_backward(ctx, z, dz) 209 | 210 | if ctx.training: 211 | edz, eydz = _backend.edz_eydz(z, dz, weight, bias, ctx.affine, ctx.eps) 212 | 213 | if ctx.is_master: 214 | edzs, eydzs = [edz], [eydz] 215 | for _ in range(len(ctx.worker_queues)): 216 | edz_w, eydz_w = ctx.master_queue.get() 217 | ctx.master_queue.task_done() 218 | edzs.append(edz_w) 219 | eydzs.append(eydz_w) 220 | 221 | edz = comm.reduce_add(edzs) / (ctx.master_queue.maxsize + 1) 222 | eydz = comm.reduce_add(eydzs) / (ctx.master_queue.maxsize + 1) 223 | 224 | tensors = comm.broadcast_coalesced((edz, eydz), [edz.get_device()] + ctx.worker_ids) 225 | for ts, queue in zip(tensors[1:], ctx.worker_queues): 226 | queue.put(ts) 227 | else: 228 | ctx.master_queue.put((edz, eydz)) 229 | edz, eydz = ctx.worker_queue.get() 230 | ctx.worker_queue.task_done() 231 | else: 232 | edz = dz.new_zeros(dz.size(1)) 233 | eydz = dz.new_zeros(dz.size(1)) 234 | 235 | dx, dweight, dbias = _backend.backward(z, dz, var, weight, bias, edz, eydz, ctx.affine, ctx.eps) 236 | dweight = dweight if ctx.affine else None 237 | dbias = dbias if ctx.affine else None 238 | 239 | return dx, dweight, dbias, None, None, None, None, None, None, None, None 240 | 241 | @staticmethod 242 | def _parse_extra(ctx, extra): 243 | ctx.is_master = extra["is_master"] 244 | if ctx.is_master: 245 | ctx.master_queue = extra["master_queue"] 246 | ctx.worker_queues = extra["worker_queues"] 247 | ctx.worker_ids = extra["worker_ids"] 248 | else: 249 | ctx.master_queue = extra["master_queue"] 250 | ctx.worker_queue = extra["worker_queue"] 251 | 252 | 253 | inplace_abn = InPlaceABN.apply 254 | inplace_abn_sync = InPlaceABNSync.apply 255 | 256 | __all__ = ["inplace_abn", "inplace_abn_sync", "ACT_RELU", "ACT_LEAKY_RELU", "ACT_ELU", "ACT_NONE"] 257 | -------------------------------------------------------------------------------- /inplace_abn/src/common.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | /* 6 | * General settings 7 | */ 8 | const int WARP_SIZE = 32; 9 | const int MAX_BLOCK_SIZE = 512; 10 | 11 | template 12 | struct Pair { 13 | T v1, v2; 14 | __device__ Pair() {} 15 | __device__ Pair(T _v1, T _v2) : v1(_v1), v2(_v2) {} 16 | __device__ Pair(T v) : v1(v), v2(v) {} 17 | __device__ Pair(int v) : v1(v), v2(v) {} 18 | __device__ Pair &operator+=(const Pair &a) { 19 | v1 += a.v1; 20 | v2 += a.v2; 21 | return *this; 22 | } 23 | }; 24 | 25 | /* 26 | * Utility functions 27 | */ 28 | template 29 | __device__ __forceinline__ T WARP_SHFL_XOR(T value, int laneMask, int width = warpSize, 30 | unsigned int mask = 0xffffffff) { 31 | #if CUDART_VERSION >= 9000 32 | return __shfl_xor_sync(mask, value, laneMask, width); 33 | #else 34 | return __shfl_xor(value, laneMask, width); 35 | #endif 36 | } 37 | 38 | __device__ __forceinline__ int getMSB(int val) { return 31 - __clz(val); } 39 | 40 | static int getNumThreads(int nElem) { 41 | int threadSizes[5] = {32, 64, 128, 256, MAX_BLOCK_SIZE}; 42 | for (int i = 0; i != 5; ++i) { 43 | if (nElem <= threadSizes[i]) { 44 | return threadSizes[i]; 45 | } 46 | } 47 | return MAX_BLOCK_SIZE; 48 | } 49 | 50 | template 51 | static __device__ __forceinline__ T warpSum(T val) { 52 | #if __CUDA_ARCH__ >= 300 53 | for (int i = 0; i < getMSB(WARP_SIZE); ++i) { 54 | val += WARP_SHFL_XOR(val, 1 << i, WARP_SIZE); 55 | } 56 | #else 57 | __shared__ T values[MAX_BLOCK_SIZE]; 58 | values[threadIdx.x] = val; 59 | __threadfence_block(); 60 | const int base = (threadIdx.x / WARP_SIZE) * WARP_SIZE; 61 | for (int i = 1; i < WARP_SIZE; i++) { 62 | val += values[base + ((i + threadIdx.x) % WARP_SIZE)]; 63 | } 64 | #endif 65 | return val; 66 | } 67 | 68 | template 69 | static __device__ __forceinline__ Pair warpSum(Pair value) { 70 | value.v1 = warpSum(value.v1); 71 | value.v2 = warpSum(value.v2); 72 | return value; 73 | } 74 | 75 | template 76 | __device__ T reduce(Op op, int plane, int N, int C, int S) { 77 | T sum = (T)0; 78 | for (int batch = 0; batch < N; ++batch) { 79 | for (int x = threadIdx.x; x < S; x += blockDim.x) { 80 | sum += op(batch, plane, x); 81 | } 82 | } 83 | 84 | // sum over NumThreads within a warp 85 | sum = warpSum(sum); 86 | 87 | // 'transpose', and reduce within warp again 88 | __shared__ T shared[32]; 89 | __syncthreads(); 90 | if (threadIdx.x % WARP_SIZE == 0) { 91 | shared[threadIdx.x / WARP_SIZE] = sum; 92 | } 93 | if (threadIdx.x >= blockDim.x / WARP_SIZE && threadIdx.x < WARP_SIZE) { 94 | // zero out the other entries in shared 95 | shared[threadIdx.x] = (T)0; 96 | } 97 | __syncthreads(); 98 | if (threadIdx.x / WARP_SIZE == 0) { 99 | sum = warpSum(shared[threadIdx.x]); 100 | if (threadIdx.x == 0) { 101 | shared[0] = sum; 102 | } 103 | } 104 | __syncthreads(); 105 | 106 | // Everyone picks it up, should be broadcast into the whole gradInput 107 | return shared[0]; 108 | } -------------------------------------------------------------------------------- /inplace_abn/src/inplace_abn.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | #include "inplace_abn.h" 6 | 7 | std::vector mean_var(at::Tensor x) { 8 | if (x.is_cuda()) { 9 | return mean_var_cuda(x); 10 | } else { 11 | return mean_var_cpu(x); 12 | } 13 | } 14 | 15 | at::Tensor forward(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias, 16 | bool affine, float eps) { 17 | if (x.is_cuda()) { 18 | return forward_cuda(x, mean, var, weight, bias, affine, eps); 19 | } else { 20 | return forward_cpu(x, mean, var, weight, bias, affine, eps); 21 | } 22 | } 23 | 24 | std::vector edz_eydz(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias, 25 | bool affine, float eps) { 26 | if (z.is_cuda()) { 27 | return edz_eydz_cuda(z, dz, weight, bias, affine, eps); 28 | } else { 29 | return edz_eydz_cpu(z, dz, weight, bias, affine, eps); 30 | } 31 | } 32 | 33 | std::vector backward(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias, 34 | at::Tensor edz, at::Tensor eydz, bool affine, float eps) { 35 | if (z.is_cuda()) { 36 | return backward_cuda(z, dz, var, weight, bias, edz, eydz, affine, eps); 37 | } else { 38 | return backward_cpu(z, dz, var, weight, bias, edz, eydz, affine, eps); 39 | } 40 | } 41 | 42 | void leaky_relu_forward(at::Tensor z, float slope) { 43 | at::leaky_relu_(z, slope); 44 | } 45 | 46 | void leaky_relu_backward(at::Tensor z, at::Tensor dz, float slope) { 47 | if (z.is_cuda()) { 48 | return leaky_relu_backward_cuda(z, dz, slope); 49 | } else { 50 | return leaky_relu_backward_cpu(z, dz, slope); 51 | } 52 | } 53 | 54 | void elu_forward(at::Tensor z) { 55 | at::elu_(z); 56 | } 57 | 58 | void elu_backward(at::Tensor z, at::Tensor dz) { 59 | if (z.is_cuda()) { 60 | return elu_backward_cuda(z, dz); 61 | } else { 62 | return elu_backward_cpu(z, dz); 63 | } 64 | } 65 | 66 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 67 | m.def("mean_var", &mean_var, "Mean and variance computation"); 68 | m.def("forward", &forward, "In-place forward computation"); 69 | m.def("edz_eydz", &edz_eydz, "First part of backward computation"); 70 | m.def("backward", &backward, "Second part of backward computation"); 71 | m.def("leaky_relu_forward", &leaky_relu_forward, "Leaky relu forward computation"); 72 | m.def("leaky_relu_backward", &leaky_relu_backward, "Leaky relu backward computation and inversion"); 73 | m.def("elu_forward", &elu_forward, "Elu forward computation"); 74 | m.def("elu_backward", &elu_backward, "Elu backward computation and inversion"); 75 | } -------------------------------------------------------------------------------- /inplace_abn/src/inplace_abn.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include 6 | 7 | std::vector mean_var_cpu(at::Tensor x); 8 | std::vector mean_var_cuda(at::Tensor x); 9 | 10 | at::Tensor forward_cpu(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias, 11 | bool affine, float eps); 12 | at::Tensor forward_cuda(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias, 13 | bool affine, float eps); 14 | 15 | std::vector edz_eydz_cpu(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias, 16 | bool affine, float eps); 17 | std::vector edz_eydz_cuda(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias, 18 | bool affine, float eps); 19 | 20 | std::vector backward_cpu(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias, 21 | at::Tensor edz, at::Tensor eydz, bool affine, float eps); 22 | std::vector backward_cuda(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias, 23 | at::Tensor edz, at::Tensor eydz, bool affine, float eps); 24 | 25 | void leaky_relu_backward_cpu(at::Tensor z, at::Tensor dz, float slope); 26 | void leaky_relu_backward_cuda(at::Tensor z, at::Tensor dz, float slope); 27 | 28 | void elu_backward_cpu(at::Tensor z, at::Tensor dz); 29 | void elu_backward_cuda(at::Tensor z, at::Tensor dz); -------------------------------------------------------------------------------- /inplace_abn/src/inplace_abn_cpu.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | #include "inplace_abn.h" 6 | 7 | at::Tensor reduce_sum(at::Tensor x) { 8 | if (x.ndimension() == 2) { 9 | return x.sum(0); 10 | } else { 11 | auto x_view = x.view({x.size(0), x.size(1), -1}); 12 | return x_view.sum(-1).sum(0); 13 | } 14 | } 15 | 16 | at::Tensor broadcast_to(at::Tensor v, at::Tensor x) { 17 | if (x.ndimension() == 2) { 18 | return v; 19 | } else { 20 | std::vector broadcast_size = {1, -1}; 21 | for (int64_t i = 2; i < x.ndimension(); ++i) 22 | broadcast_size.push_back(1); 23 | 24 | return v.view(broadcast_size); 25 | } 26 | } 27 | 28 | int64_t count(at::Tensor x) { 29 | int64_t count = x.size(0); 30 | for (int64_t i = 2; i < x.ndimension(); ++i) 31 | count *= x.size(i); 32 | 33 | return count; 34 | } 35 | 36 | at::Tensor invert_affine(at::Tensor z, at::Tensor weight, at::Tensor bias, bool affine, float eps) { 37 | if (affine) { 38 | return (z - broadcast_to(bias, z)) / broadcast_to(at::abs(weight) + eps, z); 39 | } else { 40 | return z; 41 | } 42 | } 43 | 44 | std::vector mean_var_cpu(at::Tensor x) { 45 | auto num = count(x); 46 | auto mean = reduce_sum(x) / num; 47 | auto diff = x - broadcast_to(mean, x); 48 | auto var = reduce_sum(diff.pow(2)) / num; 49 | 50 | return {mean, var}; 51 | } 52 | 53 | at::Tensor forward_cpu(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias, 54 | bool affine, float eps) { 55 | auto gamma = affine ? at::abs(weight) + eps : at::ones_like(var); 56 | auto mul = at::rsqrt(var + eps) * gamma; 57 | 58 | x.sub_(broadcast_to(mean, x)); 59 | x.mul_(broadcast_to(mul, x)); 60 | if (affine) x.add_(broadcast_to(bias, x)); 61 | 62 | return x; 63 | } 64 | 65 | std::vector edz_eydz_cpu(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias, 66 | bool affine, float eps) { 67 | auto edz = reduce_sum(dz); 68 | auto y = invert_affine(z, weight, bias, affine, eps); 69 | auto eydz = reduce_sum(y * dz); 70 | 71 | return {edz, eydz}; 72 | } 73 | 74 | std::vector backward_cpu(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias, 75 | at::Tensor edz, at::Tensor eydz, bool affine, float eps) { 76 | auto y = invert_affine(z, weight, bias, affine, eps); 77 | auto mul = affine ? at::rsqrt(var + eps) * (at::abs(weight) + eps) : at::rsqrt(var + eps); 78 | 79 | auto num = count(z); 80 | auto dx = (dz - broadcast_to(edz / num, dz) - y * broadcast_to(eydz / num, dz)) * broadcast_to(mul, dz); 81 | 82 | auto dweight = at::empty(z.type(), {0}); 83 | auto dbias = at::empty(z.type(), {0}); 84 | if (affine) { 85 | dweight = eydz * at::sign(weight); 86 | dbias = edz; 87 | } 88 | 89 | return {dx, dweight, dbias}; 90 | } 91 | 92 | void leaky_relu_backward_cpu(at::Tensor z, at::Tensor dz, float slope) { 93 | AT_DISPATCH_FLOATING_TYPES(z.type(), "leaky_relu_backward_cpu", ([&] { 94 | int64_t count = z.numel(); 95 | auto *_z = z.data(); 96 | auto *_dz = dz.data(); 97 | 98 | for (int64_t i = 0; i < count; ++i) { 99 | if (_z[i] < 0) { 100 | _z[i] *= 1 / slope; 101 | _dz[i] *= slope; 102 | } 103 | } 104 | })); 105 | } 106 | 107 | void elu_backward_cpu(at::Tensor z, at::Tensor dz) { 108 | AT_DISPATCH_FLOATING_TYPES(z.type(), "elu_backward_cpu", ([&] { 109 | int64_t count = z.numel(); 110 | auto *_z = z.data(); 111 | auto *_dz = dz.data(); 112 | 113 | for (int64_t i = 0; i < count; ++i) { 114 | if (_z[i] < 0) { 115 | _z[i] = log1p(_z[i]); 116 | _dz[i] *= (_z[i] + 1.f); 117 | } 118 | } 119 | })); 120 | } -------------------------------------------------------------------------------- /inplace_abn/src/inplace_abn_cuda.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | 6 | #include 7 | 8 | #include "common.h" 9 | #include "inplace_abn.h" 10 | 11 | // Checks 12 | #ifndef AT_CHECK 13 | #define AT_CHECK AT_ASSERT 14 | #endif 15 | #define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 16 | #define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x " must be contiguous") 17 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 18 | 19 | // Utilities 20 | void get_dims(at::Tensor x, int64_t& num, int64_t& chn, int64_t& sp) { 21 | num = x.size(0); 22 | chn = x.size(1); 23 | sp = 1; 24 | for (int64_t i = 2; i < x.ndimension(); ++i) 25 | sp *= x.size(i); 26 | } 27 | 28 | // Operations for reduce 29 | template 30 | struct SumOp { 31 | __device__ SumOp(const T *t, int c, int s) 32 | : tensor(t), chn(c), sp(s) {} 33 | __device__ __forceinline__ T operator()(int batch, int plane, int n) { 34 | return tensor[(batch * chn + plane) * sp + n]; 35 | } 36 | const T *tensor; 37 | const int chn; 38 | const int sp; 39 | }; 40 | 41 | template 42 | struct VarOp { 43 | __device__ VarOp(T m, const T *t, int c, int s) 44 | : mean(m), tensor(t), chn(c), sp(s) {} 45 | __device__ __forceinline__ T operator()(int batch, int plane, int n) { 46 | T val = tensor[(batch * chn + plane) * sp + n]; 47 | return (val - mean) * (val - mean); 48 | } 49 | const T mean; 50 | const T *tensor; 51 | const int chn; 52 | const int sp; 53 | }; 54 | 55 | template 56 | struct GradOp { 57 | __device__ GradOp(T _weight, T _bias, const T *_z, const T *_dz, int c, int s) 58 | : weight(_weight), bias(_bias), z(_z), dz(_dz), chn(c), sp(s) {} 59 | __device__ __forceinline__ Pair operator()(int batch, int plane, int n) { 60 | T _y = (z[(batch * chn + plane) * sp + n] - bias) / weight; 61 | T _dz = dz[(batch * chn + plane) * sp + n]; 62 | return Pair(_dz, _y * _dz); 63 | } 64 | const T weight; 65 | const T bias; 66 | const T *z; 67 | const T *dz; 68 | const int chn; 69 | const int sp; 70 | }; 71 | 72 | /*********** 73 | * mean_var 74 | ***********/ 75 | 76 | template 77 | __global__ void mean_var_kernel(const T *x, T *mean, T *var, int num, int chn, int sp) { 78 | int plane = blockIdx.x; 79 | T norm = T(1) / T(num * sp); 80 | 81 | T _mean = reduce>(SumOp(x, chn, sp), plane, num, chn, sp) * norm; 82 | __syncthreads(); 83 | T _var = reduce>(VarOp(_mean, x, chn, sp), plane, num, chn, sp) * norm; 84 | 85 | if (threadIdx.x == 0) { 86 | mean[plane] = _mean; 87 | var[plane] = _var; 88 | } 89 | } 90 | 91 | std::vector mean_var_cuda(at::Tensor x) { 92 | CHECK_INPUT(x); 93 | 94 | // Extract dimensions 95 | int64_t num, chn, sp; 96 | get_dims(x, num, chn, sp); 97 | 98 | // Prepare output tensors 99 | auto mean = at::empty(x.type(), {chn}); 100 | auto var = at::empty(x.type(), {chn}); 101 | 102 | // Run kernel 103 | dim3 blocks(chn); 104 | dim3 threads(getNumThreads(sp)); 105 | AT_DISPATCH_FLOATING_TYPES(x.type(), "mean_var_cuda", ([&] { 106 | mean_var_kernel<<>>( 107 | x.data(), 108 | mean.data(), 109 | var.data(), 110 | num, chn, sp); 111 | })); 112 | 113 | return {mean, var}; 114 | } 115 | 116 | /********** 117 | * forward 118 | **********/ 119 | 120 | template 121 | __global__ void forward_kernel(T *x, const T *mean, const T *var, const T *weight, const T *bias, 122 | bool affine, float eps, int num, int chn, int sp) { 123 | int plane = blockIdx.x; 124 | 125 | T _mean = mean[plane]; 126 | T _var = var[plane]; 127 | T _weight = affine ? abs(weight[plane]) + eps : T(1); 128 | T _bias = affine ? bias[plane] : T(0); 129 | 130 | T mul = rsqrt(_var + eps) * _weight; 131 | 132 | for (int batch = 0; batch < num; ++batch) { 133 | for (int n = threadIdx.x; n < sp; n += blockDim.x) { 134 | T _x = x[(batch * chn + plane) * sp + n]; 135 | T _y = (_x - _mean) * mul + _bias; 136 | 137 | x[(batch * chn + plane) * sp + n] = _y; 138 | } 139 | } 140 | } 141 | 142 | at::Tensor forward_cuda(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias, 143 | bool affine, float eps) { 144 | CHECK_INPUT(x); 145 | CHECK_INPUT(mean); 146 | CHECK_INPUT(var); 147 | CHECK_INPUT(weight); 148 | CHECK_INPUT(bias); 149 | 150 | // Extract dimensions 151 | int64_t num, chn, sp; 152 | get_dims(x, num, chn, sp); 153 | 154 | // Run kernel 155 | dim3 blocks(chn); 156 | dim3 threads(getNumThreads(sp)); 157 | AT_DISPATCH_FLOATING_TYPES(x.type(), "forward_cuda", ([&] { 158 | forward_kernel<<>>( 159 | x.data(), 160 | mean.data(), 161 | var.data(), 162 | weight.data(), 163 | bias.data(), 164 | affine, eps, num, chn, sp); 165 | })); 166 | 167 | return x; 168 | } 169 | 170 | /*********** 171 | * edz_eydz 172 | ***********/ 173 | 174 | template 175 | __global__ void edz_eydz_kernel(const T *z, const T *dz, const T *weight, const T *bias, 176 | T *edz, T *eydz, bool affine, float eps, int num, int chn, int sp) { 177 | int plane = blockIdx.x; 178 | 179 | T _weight = affine ? abs(weight[plane]) + eps : 1.f; 180 | T _bias = affine ? bias[plane] : 0.f; 181 | 182 | Pair res = reduce, GradOp>(GradOp(_weight, _bias, z, dz, chn, sp), plane, num, chn, sp); 183 | __syncthreads(); 184 | 185 | if (threadIdx.x == 0) { 186 | edz[plane] = res.v1; 187 | eydz[plane] = res.v2; 188 | } 189 | } 190 | 191 | std::vector edz_eydz_cuda(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias, 192 | bool affine, float eps) { 193 | CHECK_INPUT(z); 194 | CHECK_INPUT(dz); 195 | CHECK_INPUT(weight); 196 | CHECK_INPUT(bias); 197 | 198 | // Extract dimensions 199 | int64_t num, chn, sp; 200 | get_dims(z, num, chn, sp); 201 | 202 | auto edz = at::empty(z.type(), {chn}); 203 | auto eydz = at::empty(z.type(), {chn}); 204 | 205 | // Run kernel 206 | dim3 blocks(chn); 207 | dim3 threads(getNumThreads(sp)); 208 | AT_DISPATCH_FLOATING_TYPES(z.type(), "edz_eydz_cuda", ([&] { 209 | edz_eydz_kernel<<>>( 210 | z.data(), 211 | dz.data(), 212 | weight.data(), 213 | bias.data(), 214 | edz.data(), 215 | eydz.data(), 216 | affine, eps, num, chn, sp); 217 | })); 218 | 219 | return {edz, eydz}; 220 | } 221 | 222 | /*********** 223 | * backward 224 | ***********/ 225 | 226 | template 227 | __global__ void backward_kernel(const T *z, const T *dz, const T *var, const T *weight, const T *bias, const T *edz, 228 | const T *eydz, T *dx, T *dweight, T *dbias, 229 | bool affine, float eps, int num, int chn, int sp) { 230 | int plane = blockIdx.x; 231 | 232 | T _weight = affine ? abs(weight[plane]) + eps : 1.f; 233 | T _bias = affine ? bias[plane] : 0.f; 234 | T _var = var[plane]; 235 | T _edz = edz[plane]; 236 | T _eydz = eydz[plane]; 237 | 238 | T _mul = _weight * rsqrt(_var + eps); 239 | T count = T(num * sp); 240 | 241 | for (int batch = 0; batch < num; ++batch) { 242 | for (int n = threadIdx.x; n < sp; n += blockDim.x) { 243 | T _dz = dz[(batch * chn + plane) * sp + n]; 244 | T _y = (z[(batch * chn + plane) * sp + n] - _bias) / _weight; 245 | 246 | dx[(batch * chn + plane) * sp + n] = (_dz - _edz / count - _y * _eydz / count) * _mul; 247 | } 248 | } 249 | 250 | if (threadIdx.x == 0) { 251 | if (affine) { 252 | dweight[plane] = weight[plane] > 0 ? _eydz : -_eydz; 253 | dbias[plane] = _edz; 254 | } 255 | } 256 | } 257 | 258 | std::vector backward_cuda(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias, 259 | at::Tensor edz, at::Tensor eydz, bool affine, float eps) { 260 | CHECK_INPUT(z); 261 | CHECK_INPUT(dz); 262 | CHECK_INPUT(var); 263 | CHECK_INPUT(weight); 264 | CHECK_INPUT(bias); 265 | CHECK_INPUT(edz); 266 | CHECK_INPUT(eydz); 267 | 268 | // Extract dimensions 269 | int64_t num, chn, sp; 270 | get_dims(z, num, chn, sp); 271 | 272 | auto dx = at::zeros_like(z); 273 | auto dweight = at::zeros_like(weight); 274 | auto dbias = at::zeros_like(bias); 275 | 276 | // Run kernel 277 | dim3 blocks(chn); 278 | dim3 threads(getNumThreads(sp)); 279 | AT_DISPATCH_FLOATING_TYPES(z.type(), "backward_cuda", ([&] { 280 | backward_kernel<<>>( 281 | z.data(), 282 | dz.data(), 283 | var.data(), 284 | weight.data(), 285 | bias.data(), 286 | edz.data(), 287 | eydz.data(), 288 | dx.data(), 289 | dweight.data(), 290 | dbias.data(), 291 | affine, eps, num, chn, sp); 292 | })); 293 | 294 | return {dx, dweight, dbias}; 295 | } 296 | 297 | /************** 298 | * activations 299 | **************/ 300 | 301 | template 302 | inline void leaky_relu_backward_impl(T *z, T *dz, float slope, int64_t count) { 303 | // Create thrust pointers 304 | thrust::device_ptr th_z = thrust::device_pointer_cast(z); 305 | thrust::device_ptr th_dz = thrust::device_pointer_cast(dz); 306 | 307 | thrust::transform_if(th_dz, th_dz + count, th_z, th_dz, 308 | [slope] __device__ (const T& dz) { return dz * slope; }, 309 | [] __device__ (const T& z) { return z < 0; }); 310 | thrust::transform_if(th_z, th_z + count, th_z, 311 | [slope] __device__ (const T& z) { return z / slope; }, 312 | [] __device__ (const T& z) { return z < 0; }); 313 | } 314 | 315 | void leaky_relu_backward_cuda(at::Tensor z, at::Tensor dz, float slope) { 316 | CHECK_INPUT(z); 317 | CHECK_INPUT(dz); 318 | 319 | int64_t count = z.numel(); 320 | 321 | AT_DISPATCH_FLOATING_TYPES(z.type(), "leaky_relu_backward_cuda", ([&] { 322 | leaky_relu_backward_impl(z.data(), dz.data(), slope, count); 323 | })); 324 | } 325 | 326 | template 327 | inline void elu_backward_impl(T *z, T *dz, int64_t count) { 328 | // Create thrust pointers 329 | thrust::device_ptr th_z = thrust::device_pointer_cast(z); 330 | thrust::device_ptr th_dz = thrust::device_pointer_cast(dz); 331 | 332 | thrust::transform_if(th_dz, th_dz + count, th_z, th_z, th_dz, 333 | [] __device__ (const T& dz, const T& z) { return dz * (z + 1.); }, 334 | [] __device__ (const T& z) { return z < 0; }); 335 | thrust::transform_if(th_z, th_z + count, th_z, 336 | [] __device__ (const T& z) { return log1p(z); }, 337 | [] __device__ (const T& z) { return z < 0; }); 338 | } 339 | 340 | void elu_backward_cuda(at::Tensor z, at::Tensor dz) { 341 | CHECK_INPUT(z); 342 | CHECK_INPUT(dz); 343 | 344 | int64_t count = z.numel(); 345 | 346 | AT_DISPATCH_FLOATING_TYPES(z.type(), "leaky_relu_backward_cuda", ([&] { 347 | elu_backward_impl(z.data(), dz.data(), count); 348 | })); 349 | } 350 | -------------------------------------------------------------------------------- /metrics/iou.py: -------------------------------------------------------------------------------- 1 | def get_iou(preds, labels, thresh=0.5): 2 | preds, labels = preds.squeeze(), labels.squeeze() 3 | preds = preds > thresh 4 | mask_sum = (preds == 1) + (labels > 0) 5 | intersection = (mask_sum == 2).sum().float() 6 | union = (mask_sum > 0).sum().float() 7 | 8 | if union > 0: 9 | return intersection / union 10 | 11 | return 1. 12 | -------------------------------------------------------------------------------- /networks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yz93/anchor-diff-VOS/b6ee4fcc9eb1b85b514215badea9d10c158d73c0/networks/__init__.py -------------------------------------------------------------------------------- /networks/deeplabv3.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.nn import functional as F 3 | import torch 4 | affine_par = True 5 | import functools 6 | import os 7 | import sys 8 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 9 | sys.path.append(BASE_DIR) 10 | sys.path.append(os.path.join(BASE_DIR, '../inplace_abn')) 11 | from bn import InPlaceABNSync 12 | 13 | BatchNorm2d = functools.partial(InPlaceABNSync, activation='none') 14 | 15 | def conv3x3(in_planes, out_planes, stride=1): 16 | "3x3 convolution with padding" 17 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 18 | padding=1, bias=False) 19 | 20 | 21 | class Bottleneck(nn.Module): 22 | expansion = 4 23 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, fist_dilation=1, multi_grid=1): 24 | super(Bottleneck, self).__init__() 25 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 26 | self.bn1 = BatchNorm2d(planes) 27 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 28 | padding=dilation*multi_grid, dilation=dilation*multi_grid, bias=False) 29 | self.bn2 = BatchNorm2d(planes) 30 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 31 | self.bn3 = BatchNorm2d(planes * 4) 32 | self.relu = nn.ReLU(inplace=False) 33 | self.relu_inplace = nn.ReLU(inplace=True) 34 | self.downsample = downsample 35 | self.dilation = dilation 36 | self.stride = stride 37 | 38 | def forward(self, x): 39 | residual = x 40 | 41 | out = self.conv1(x) 42 | out = self.bn1(out) 43 | out = self.relu(out) 44 | 45 | out = self.conv2(out) 46 | out = self.bn2(out) 47 | out = self.relu(out) 48 | 49 | out = self.conv3(out) 50 | out = self.bn3(out) 51 | 52 | if self.downsample is not None: 53 | residual = self.downsample(x) 54 | 55 | out = out + residual 56 | out = self.relu_inplace(out) 57 | 58 | return out 59 | 60 | class ASPPModule(nn.Module): 61 | """ 62 | Reference: 63 | Chen, Liang-Chieh, et al. *"Rethinking Atrous Convolution for Semantic Image Segmentation."* 64 | """ 65 | # output stride 16: (6, 12, 18) 66 | # output stride 8: (12, 24, 36) 67 | def __init__(self, features, hidden_features=512, out_features=512, dilations=(12, 24, 36)): 68 | super(ASPPModule, self).__init__() 69 | self.conv1 = nn.Sequential(nn.Conv2d(features, hidden_features, kernel_size=1, bias=False), 70 | InPlaceABNSync(hidden_features)) 71 | self.conv2 = nn.Sequential(nn.Conv2d(features, hidden_features, kernel_size=3, padding=dilations[0], dilation=dilations[0], bias=False), 72 | InPlaceABNSync(hidden_features)) 73 | self.conv3 = nn.Sequential(nn.Conv2d(features, hidden_features, kernel_size=3, padding=dilations[1], dilation=dilations[1], bias=False), 74 | InPlaceABNSync(hidden_features)) 75 | self.conv4 = nn.Sequential(nn.Conv2d(features, hidden_features, kernel_size=3, padding=dilations[2], dilation=dilations[2], bias=False), 76 | InPlaceABNSync(hidden_features)) 77 | self.image_pooling = nn.Sequential(nn.AdaptiveAvgPool2d((1,1)), 78 | nn.Conv2d(features, hidden_features, kernel_size=1, bias=False), 79 | InPlaceABNSync(hidden_features)) 80 | 81 | self.conv_bn_dropout = nn.Sequential( 82 | nn.Conv2d(hidden_features * 5, out_features, kernel_size=1, bias=False), 83 | InPlaceABNSync(out_features), 84 | nn.Dropout2d(0.1) 85 | ) 86 | 87 | def forward(self, x): 88 | _, _, h, w = x.size() 89 | 90 | feat1 = self.conv1(x) 91 | feat2 = self.conv2(x) 92 | feat3 = self.conv3(x) 93 | feat4 = self.conv4(x) 94 | 95 | img_feat = F.interpolate(self.image_pooling(x), size=(h, w), mode='bilinear', align_corners=True) 96 | concat_feat = torch.cat((feat1, feat2, feat3, feat4, img_feat), 1) 97 | 98 | out = self.conv_bn_dropout(concat_feat) 99 | return out 100 | 101 | class ResNet(nn.Module): 102 | def __init__(self, block, layers, num_classes): 103 | self.inplanes = 128 104 | super(ResNet, self).__init__() 105 | self.conv1 = conv3x3(3, 64, stride=2) 106 | self.bn1 = BatchNorm2d(64) 107 | self.relu1 = nn.ReLU(inplace=False) 108 | self.conv2 = conv3x3(64, 64) 109 | self.bn2 = BatchNorm2d(64) 110 | self.relu2 = nn.ReLU(inplace=False) 111 | self.conv3 = conv3x3(64, 128) 112 | self.bn3 = BatchNorm2d(128) 113 | self.relu3 = nn.ReLU(inplace=False) 114 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 115 | self.relu = nn.ReLU(inplace=False) 116 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=True) # change 117 | self.layer1 = self._make_layer(block, 64, layers[0]) 118 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 119 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2) 120 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4, multi_grid=(1,2,4)) 121 | 122 | self.head = nn.Sequential(ASPPModule(2048), 123 | nn.Conv2d(512, num_classes, kernel_size=1) 124 | ) 125 | 126 | self.dsn = nn.Sequential( 127 | nn.Conv2d(1024, 512, kernel_size=3, padding=1, bias=False), 128 | InPlaceABNSync(512), 129 | nn.Dropout2d(0.1), 130 | nn.Conv2d(512, num_classes, kernel_size=1) 131 | ) 132 | 133 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, multi_grid=1): 134 | downsample = None 135 | if stride != 1 or self.inplanes != planes * block.expansion: 136 | downsample = nn.Sequential( 137 | nn.Conv2d(self.inplanes, planes * block.expansion, 138 | kernel_size=1, stride=stride, bias=False), 139 | BatchNorm2d(planes * block.expansion,affine = affine_par)) 140 | 141 | layers = [] 142 | generate_multi_grid = lambda index, grids: grids[index%len(grids)] if isinstance(grids, tuple) else 1 143 | layers.append(block(self.inplanes, planes, stride,dilation=dilation, downsample=downsample, multi_grid=generate_multi_grid(0, multi_grid))) 144 | self.inplanes = planes * block.expansion 145 | for i in range(1, blocks): 146 | layers.append(block(self.inplanes, planes, dilation=dilation, multi_grid=generate_multi_grid(i, multi_grid))) 147 | 148 | return nn.Sequential(*layers) 149 | 150 | def forward(self, x): 151 | x = self.relu1(self.bn1(self.conv1(x))) 152 | x = self.relu2(self.bn2(self.conv2(x))) 153 | x = self.relu3(self.bn3(self.conv3(x))) 154 | x = self.maxpool(x) 155 | x = self.layer1(x) 156 | x = self.layer2(x) 157 | x = self.layer3(x) 158 | x_dsn = self.dsn(x) 159 | x = self.layer4(x) 160 | x = self.head(x) 161 | return [x, x_dsn] 162 | 163 | 164 | def ResNetDeepLabv3(backbone='ResNet50', num_classes=1, batch_mode='sync'): 165 | global BatchNorm2d 166 | if batch_mode=='sync': 167 | BatchNorm2d = functools.partial(InPlaceABNSync, activation='none') 168 | elif batch_mode=='old': 169 | BatchNorm2d = torch.nn.BatchNorm2d 170 | 171 | if backbone=='ResNet50': 172 | model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes) 173 | elif backbone=='ResNet101': 174 | model = ResNet(Bottleneck, [3, 4, 23, 3], num_classes) 175 | else: 176 | raise RuntimeError('unknown backbone type') 177 | return model 178 | 179 | -------------------------------------------------------------------------------- /networks/inits.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | 4 | def uniform(size, tensor): 5 | bound = 1.0 / math.sqrt(size) 6 | if tensor is not None: 7 | tensor.data.uniform_(-bound, bound) 8 | 9 | 10 | def kaiming_uniform(tensor, fan, a): 11 | bound = math.sqrt(6 / ((1 + a**2) * fan)) 12 | if tensor is not None: 13 | tensor.data.uniform_(-bound, bound) 14 | 15 | 16 | def glorot(tensor): 17 | stdv = math.sqrt(6.0 / (tensor.size(-2) + tensor.size(-1))) 18 | if tensor is not None: 19 | tensor.data.uniform_(-stdv, stdv) 20 | 21 | 22 | def zeros(tensor): 23 | if tensor is not None: 24 | tensor.data.fill_(0) 25 | 26 | 27 | def ones(tensor): 28 | if tensor is not None: 29 | tensor.data.fill_(1) 30 | 31 | 32 | def reset(nn): 33 | def _reset(item): 34 | if hasattr(item, 'reset_parameters'): 35 | item.reset_parameters() 36 | 37 | if nn is not None: 38 | if hasattr(nn, 'children') and len(list(nn.children())) > 0: 39 | for item in nn.children(): 40 | _reset(item) 41 | else: 42 | _reset(nn) 43 | -------------------------------------------------------------------------------- /networks/nets.py: -------------------------------------------------------------------------------- 1 | from networks.deeplabv3 import ResNetDeepLabv3 2 | 3 | from torch import nn 4 | from torch.nn import functional as F 5 | import numpy as np 6 | import torch 7 | import os 8 | import sys 9 | 10 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 11 | sys.path.append(BASE_DIR) 12 | sys.path.append(os.path.join(BASE_DIR, '../inplace_abn')) 13 | from bn import InPlaceABNSync 14 | 15 | 16 | class AnchorDiffNet(nn.Module): 17 | def __init__(self, backbone='ResNet50', pyramid_pooling='deeplabv3', embedding=128, batch_mode='sync'): 18 | super(AnchorDiffNet, self).__init__() 19 | if pyramid_pooling == 'deeplabv3': 20 | self.features = ResNetDeepLabv3(backbone, num_classes=embedding, batch_mode=batch_mode) 21 | elif pyramid_pooling == 'pspnet': 22 | raise RuntimeError('Pooling module not implemented') 23 | else: 24 | raise RuntimeError('Unknown pyramid pooling module') 25 | self.cls = nn.Sequential( 26 | nn.Conv2d(3 * embedding, embedding, kernel_size=1, stride=1, padding=0), 27 | InPlaceABNSync(embedding), 28 | nn.Dropout2d(0.10), 29 | nn.Conv2d(embedding, 1, kernel_size=1, stride=1, padding=0) 30 | ) 31 | 32 | def forward(self, reference, current): 33 | ref_features = self.features(reference)[0] 34 | curr_features = self.features(current)[0] 35 | batch, channel, h, w = curr_features.shape 36 | M = h * w 37 | ref_features = ref_features.view(batch, channel, M).permute(0, 2, 1) 38 | curr_features = curr_features.view(batch, channel, M) 39 | 40 | p_0 = torch.matmul(ref_features, curr_features) 41 | p_0 = F.softmax((channel ** -.5) * p_0, dim=-1) 42 | p_1 = torch.matmul(curr_features.permute(0, 2, 1), curr_features) 43 | p_1 = F.softmax((channel ** -.5) * p_1, dim=-1) 44 | feats_0 = torch.matmul(p_0, curr_features.permute(0, 2, 1)).permute(0, 2, 1) 45 | feats_1 = torch.matmul(p_1, curr_features.permute(0, 2, 1)).permute(0, 2, 1) 46 | x = torch.cat([feats_0, feats_1, curr_features], dim=1).view(batch, 3 * channel, h, w) 47 | pred = self.cls(x) 48 | 49 | return pred 50 | 51 | 52 | class ConcatNet(nn.Module): 53 | def __init__(self, backbone='ResNet50', pyramid_pooling='deeplabv3', embedding=128, batch_mode='sync'): 54 | super(ConcatNet, self).__init__() 55 | if pyramid_pooling == 'deeplabv3': 56 | self.features = ResNetDeepLabv3(backbone, num_classes=embedding, batch_mode=batch_mode) 57 | elif pyramid_pooling == 'pspnet': 58 | raise RuntimeError('Pooling module not implemented') 59 | else: 60 | raise RuntimeError('Unknown pyramid pooling module') 61 | self.cls = nn.Sequential( 62 | nn.Conv2d(2*embedding, embedding, kernel_size=1, stride=1, padding=0), 63 | InPlaceABNSync(embedding), 64 | nn.Dropout2d(0.10), 65 | nn.Conv2d(embedding, 1, kernel_size=1, stride=1, padding=0) 66 | ) 67 | 68 | def forward(self, reference, current): 69 | ref_features = self.features(reference)[0] 70 | curr_features = self.features(current)[0] 71 | x = torch.cat([ref_features, curr_features], dim=1) 72 | pred = self.cls(x) 73 | 74 | return pred 75 | 76 | 77 | class InterFrameNet(nn.Module): 78 | def __init__(self, backbone='ResNet50', pyramid_pooling='deeplabv3', embedding=128, batch_mode='sync'): 79 | super(InterFrameNet, self).__init__() 80 | if pyramid_pooling == 'deeplabv3': 81 | self.features = ResNetDeepLabv3(backbone, num_classes=embedding, batch_mode=batch_mode) 82 | elif pyramid_pooling == 'pspnet': 83 | raise RuntimeError('Pooling module not implemented') 84 | else: 85 | raise RuntimeError('Unknown pyramid pooling module') 86 | self.cls = nn.Sequential( 87 | nn.Conv2d(2 * embedding, embedding, kernel_size=1, stride=1, padding=0), 88 | InPlaceABNSync(embedding), 89 | nn.Dropout2d(0.10), 90 | nn.Conv2d(embedding, 1, kernel_size=1, stride=1, padding=0) 91 | ) 92 | 93 | def forward(self, reference, current): 94 | ref_features = self.features(reference)[0] 95 | curr_features = self.features(current)[0] 96 | batch, channel, h, w = curr_features.shape 97 | M = h * w 98 | ref_features = ref_features.view(batch, channel, M).permute(0, 2, 1) 99 | curr_features = curr_features.view(batch, channel, M) 100 | 101 | p_0 = torch.matmul(ref_features, curr_features) 102 | p_0 = F.softmax((channel ** -.5) * p_0, dim=-1) 103 | feats_0 = torch.matmul(p_0, curr_features.permute(0, 2, 1)).permute(0, 2, 1) 104 | x = torch.cat([feats_0, curr_features], dim=1).view(batch, 2 * channel, h, w) 105 | pred = self.cls(x) 106 | 107 | return pred 108 | 109 | 110 | class IntraFrameNet(nn.Module): 111 | def __init__(self, backbone='ResNet50', pyramid_pooling='deeplabv3', embedding=128, batch_mode='sync'): 112 | super(IntraFrameNet, self).__init__() 113 | if pyramid_pooling == 'deeplabv3': 114 | self.features = ResNetDeepLabv3(backbone, num_classes=embedding, batch_mode=batch_mode) 115 | elif pyramid_pooling == 'pspnet': 116 | raise RuntimeError('Pooling module not implemented') 117 | else: 118 | raise RuntimeError('Unknown pyramid pooling module') 119 | self.cls = nn.Sequential( 120 | nn.Conv2d(2*embedding, embedding, kernel_size=1, stride=1, padding=0), 121 | InPlaceABNSync(embedding), 122 | nn.Dropout2d(0.10), 123 | nn.Conv2d(embedding, 1, kernel_size=1, stride=1, padding=0) 124 | ) 125 | 126 | def forward(self, current): 127 | curr_features = self.features(current)[0] 128 | batch, channel, h, w = curr_features.shape 129 | M = h * w 130 | curr_features = curr_features.view(batch, channel, M) 131 | 132 | p_1 = torch.matmul(curr_features.permute(0, 2, 1), curr_features) 133 | p_1 = F.softmax((channel**-.5) * p_1, dim=-1) 134 | feats_1 = torch.matmul(p_1, curr_features.permute(0, 2, 1)).permute(0, 2, 1) 135 | x = torch.cat([feats_1, curr_features], dim=1).view(batch, 2*channel, h, w) 136 | pred = self.cls(x) 137 | 138 | return pred 139 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==2.2.3 2 | opencv-python==4.1.0.25 3 | scikit-learn==0.20.2 4 | torchvision==0.2.1 5 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yz93/anchor-diff-VOS/b6ee4fcc9eb1b85b514215badea9d10c158d73c0/utils/__init__.py -------------------------------------------------------------------------------- /utils/parallel.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Hang Zhang 3 | ## ECE Department, Rutgers University 4 | ## Email: zhang.hang@rutgers.edu 5 | ## Copyright (c) 2017 6 | ## 7 | ## This source code is licensed under the MIT-style license found in the 8 | ## LICENSE file in the root directory of this source tree 9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | 11 | """Encoding Data Parallel""" 12 | import threading 13 | import functools 14 | import torch 15 | from torch.autograd import Variable, Function 16 | import torch.cuda.comm as comm 17 | from torch.nn.parallel.data_parallel import DataParallel 18 | from torch.nn.parallel.parallel_apply import get_a_var 19 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast 20 | 21 | torch_ver = torch.__version__[:3] 22 | 23 | __all__ = ['allreduce', 'DataParallelModel', 'DataParallelCriterion', 24 | 'patch_replication_callback'] 25 | 26 | def allreduce(*inputs): 27 | """Cross GPU all reduce autograd operation for calculate mean and 28 | variance in SyncBN. 29 | """ 30 | return AllReduce.apply(*inputs) 31 | 32 | class AllReduce(Function): 33 | @staticmethod 34 | def forward(ctx, num_inputs, *inputs): 35 | ctx.num_inputs = num_inputs 36 | ctx.target_gpus = [inputs[i].get_device() for i in range(0, len(inputs), num_inputs)] 37 | inputs = [inputs[i:i + num_inputs] 38 | for i in range(0, len(inputs), num_inputs)] 39 | # sort before reduce sum 40 | inputs = sorted(inputs, key=lambda i: i[0].get_device()) 41 | results = comm.reduce_add_coalesced(inputs, ctx.target_gpus[0]) 42 | outputs = comm.broadcast_coalesced(results, ctx.target_gpus) 43 | return tuple([t for tensors in outputs for t in tensors]) 44 | 45 | @staticmethod 46 | def backward(ctx, *inputs): 47 | inputs = [i.data for i in inputs] 48 | inputs = [inputs[i:i + ctx.num_inputs] 49 | for i in range(0, len(inputs), ctx.num_inputs)] 50 | results = comm.reduce_add_coalesced(inputs, ctx.target_gpus[0]) 51 | outputs = comm.broadcast_coalesced(results, ctx.target_gpus) 52 | return (None,) + tuple([Variable(t) for tensors in outputs for t in tensors]) 53 | 54 | 55 | class Reduce(Function): 56 | @staticmethod 57 | def forward(ctx, *inputs): 58 | ctx.target_gpus = [inputs[i].get_device() for i in range(len(inputs))] 59 | inputs = sorted(inputs, key=lambda i: i.get_device()) 60 | return comm.reduce_add(inputs) 61 | 62 | @staticmethod 63 | def backward(ctx, gradOutput): 64 | return Broadcast.apply(ctx.target_gpus, gradOutput) 65 | 66 | 67 | class DataParallelModel(DataParallel): 68 | """Implements data parallelism at the module level. 69 | 70 | This container parallelizes the application of the given module by 71 | splitting the input across the specified devices by chunking in the 72 | batch dimension. 73 | In the forward pass, the module is replicated on each device, 74 | and each replica handles a portion of the input. During the backwards pass, gradients from each replica are summed into the original module. 75 | Note that the outputs are not gathered, please use compatible 76 | :class:`encoding.parallel.DataParallelCriterion`. 77 | 78 | The batch size should be larger than the number of GPUs used. It should 79 | also be an integer multiple of the number of GPUs so that each chunk is 80 | the same size (so that each GPU processes the same number of samples). 81 | 82 | Args: 83 | module: module to be parallelized 84 | device_ids: CUDA devices (default: all devices) 85 | 86 | Reference: 87 | Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, 88 | Amit Agrawal. “Context Encoding for Semantic Segmentation. 89 | *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018* 90 | 91 | Example:: 92 | 93 | >>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2]) 94 | >>> y = net(x) 95 | """ 96 | def gather(self, outputs, output_device): 97 | return outputs 98 | 99 | def replicate(self, module, device_ids): 100 | modules = super(DataParallelModel, self).replicate(module, device_ids) 101 | execute_replication_callbacks(modules) 102 | return modules 103 | 104 | 105 | class DataParallelCriterion(DataParallel): 106 | """ 107 | Calculate loss in multiple-GPUs, which balance the memory usage for 108 | Semantic Segmentation. 109 | 110 | The targets are splitted across the specified devices by chunking in 111 | the batch dimension. Please use together with :class:`encoding.parallel.DataParallelModel`. 112 | 113 | Reference: 114 | Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, 115 | Amit Agrawal. “Context Encoding for Semantic Segmentation. 116 | *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018* 117 | 118 | Example:: 119 | 120 | >>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2]) 121 | >>> criterion = encoding.nn.DataParallelCriterion(criterion, device_ids=[0, 1, 2]) 122 | >>> y = net(x) 123 | >>> loss = criterion(y, target) 124 | """ 125 | def forward(self, inputs, *targets, **kwargs): 126 | # input should be already scatterd 127 | # scattering the targets instead 128 | if not self.device_ids: 129 | return self.module(inputs, *targets, **kwargs) 130 | targets, kwargs = self.scatter(targets, kwargs, self.device_ids) 131 | if len(self.device_ids) == 1: 132 | return self.module(inputs, *targets[0], **kwargs[0]) 133 | replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) 134 | targets = tuple(targets_per_gpu[0] for targets_per_gpu in targets) # fix bug 135 | outputs = _criterion_parallel_apply(replicas, inputs, targets, kwargs) 136 | return Reduce.apply(*outputs) / len(outputs) 137 | #return self.gather(outputs, self.output_device).mean() 138 | 139 | 140 | def _criterion_parallel_apply(modules, inputs, targets, kwargs_tup=None, devices=None): 141 | assert len(modules) == len(inputs) 142 | assert len(targets) == len(inputs) 143 | if kwargs_tup: 144 | assert len(modules) == len(kwargs_tup) 145 | else: 146 | kwargs_tup = ({},) * len(modules) 147 | if devices is not None: 148 | assert len(modules) == len(devices) 149 | else: 150 | devices = [None] * len(modules) 151 | 152 | lock = threading.Lock() 153 | results = {} 154 | if torch_ver != "0.3": 155 | grad_enabled = torch.is_grad_enabled() 156 | 157 | def _worker(i, module, input, target, kwargs, device=None): 158 | if torch_ver != "0.3": 159 | torch.set_grad_enabled(grad_enabled) 160 | if device is None: 161 | device = get_a_var(input).get_device() 162 | try: 163 | with torch.cuda.device(device): 164 | output = module(input, target, **kwargs) 165 | # output = module(*(input + target), **kwargs) 166 | with lock: 167 | results[i] = output 168 | except Exception as e: 169 | with lock: 170 | results[i] = e 171 | 172 | if len(modules) > 1: 173 | threads = [threading.Thread(target=_worker, 174 | args=(i, module, input, target, 175 | kwargs, device),) 176 | for i, (module, input, target, kwargs, device) in 177 | enumerate(zip(modules, inputs, targets, kwargs_tup, devices))] 178 | 179 | for thread in threads: 180 | thread.start() 181 | for thread in threads: 182 | thread.join() 183 | else: 184 | _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0]) 185 | 186 | outputs = [] 187 | for i in range(len(inputs)): 188 | output = results[i] 189 | if isinstance(output, Exception): 190 | raise output 191 | outputs.append(output) 192 | return outputs 193 | 194 | 195 | ########################################################################### 196 | # Adapted from Synchronized-BatchNorm-PyTorch. 197 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 198 | # 199 | class CallbackContext(object): 200 | pass 201 | 202 | 203 | def execute_replication_callbacks(modules): 204 | """ 205 | Execute an replication callback `__data_parallel_replicate__` on each module created 206 | by original replication. 207 | 208 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 209 | 210 | Note that, as all modules are isomorphism, we assign each sub-module with a context 211 | (shared among multiple copies of this module on different devices). 212 | Through this context, different copies can share some information. 213 | 214 | We guarantee that the callback on the master copy (the first copy) will be called ahead 215 | of calling the callback of any slave copies. 216 | """ 217 | master_copy = modules[0] 218 | nr_modules = len(list(master_copy.modules())) 219 | ctxs = [CallbackContext() for _ in range(nr_modules)] 220 | 221 | for i, module in enumerate(modules): 222 | for j, m in enumerate(module.modules()): 223 | if hasattr(m, '__data_parallel_replicate__'): 224 | m.__data_parallel_replicate__(ctxs[j], i) 225 | 226 | 227 | def patch_replication_callback(data_parallel): 228 | """ 229 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 230 | Useful when you have customized `DataParallel` implementation. 231 | 232 | Examples: 233 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 234 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 235 | > patch_replication_callback(sync_bn) 236 | # this is equivalent to 237 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 238 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 239 | """ 240 | 241 | assert isinstance(data_parallel, DataParallel) 242 | 243 | old_replicate = data_parallel.replicate 244 | 245 | @functools.wraps(old_replicate) 246 | def new_replicate(module, device_ids): 247 | modules = old_replicate(module, device_ids) 248 | execute_replication_callbacks(modules) 249 | return modules 250 | 251 | data_parallel.replicate = new_replicate --------------------------------------------------------------------------------