├── .gitignore ├── README.md ├── SECURITY.md ├── dataset ├── __init__.py ├── annotation_centroids.npy ├── davis.py ├── split_trainval.py └── transforms.py ├── figure └── fig1.png ├── inference.py ├── lib ├── __init__.py ├── loss.py ├── predict.py └── utils.py ├── main.py ├── modeling ├── __init__.py ├── backbone │ ├── __init__.py │ └── resnet.py ├── network.py └── sync_batchnorm │ ├── __init__.py │ ├── batchnorm.py │ ├── comm.py │ ├── replicate.py │ └── unittest.py ├── test_epochs.py └── test_epochs.sh /.gitignore: -------------------------------------------------------------------------------- 1 | *.npy 2 | __pycache__ 3 | .vscode 4 | .idea 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## A Transductive Approach for Video Object Segmentation 2 | 3 | 4 | 5 | This repo contains the pytorch implementation for the CVPR 2020 paper [A Transductive Approach for Video Object Segmentation](https://arxiv.org/abs/2004.07193). 6 | 7 | ## Pretrained Models and Results 8 | 9 | We provide three pretrained models of ResNet50. They are trained from DAVIS 17 training set, combined DAVIS 17 training and validation set and YouTube-VOS training set. 10 | 11 | - [Davis-train](https://drive.google.com/open?id=1SWZ20zTHgOpha0MlF8iOdqEHFkALdZn7) 12 | - [Davis-trainval](https://drive.google.com/open?id=14Qm8UEQG-rYYepDYzKPTc1KqISzQkT95) 13 | - [Youtube-train](https://drive.google.com/open?id=1U6sX9EUpOvDRFyaqDVpsi3plTnI2Witp) 14 | 15 | Our pre-computed results can be downloaded [here](https://drive.google.com/open?id=1QdKaeoMU7KaEp0TIXZZNLdgm9n8IOQOj). 16 | 17 | Our results on DAVIS17 and YouTube-VOS: 18 | 19 | | Dataset | J | F | 20 | | -------------------- | :--- | ---- | 21 | | DAVIS17 validation | 69.9 | 74.7 | 22 | | DAVIS17 test-dev | 58.8 | 67.4 | 23 | | YouTube-VOS (seen) | 67.1 | 69.4 | 24 | | YouTube-VOS (unseen) | 63.0 | 71.6 | 25 | 26 | ## Usage 27 | 28 | - Install python3, pytorch >= 0.4, and PIL package. 29 | 30 | - Clone this repo: 31 | 32 | ```shell 33 | git clone https://github.com/microsoft/transductive-vos.pytorch 34 | ``` 35 | 36 | - Prepare DAVIS 17 train-val dataset: 37 | 38 | ```shell 39 | # first download the dataset 40 | cd /path-to-data-directory/ 41 | wget https://data.vision.ee.ethz.ch/csergi/share/davis/DAVIS-2017-trainval-480p.zip 42 | # unzip 43 | unzip DAVIS-2017-trainval-480p.zip 44 | # split train-val dataset 45 | python /VOS-Baseline/dataset/split_trainval.py -i ./DAVIS 46 | # clean up 47 | rm -rf ./DAVIS 48 | ``` 49 | 50 | Now, your data directory should be structured like this: 51 | 52 | ``` 53 | . 54 | |-- DAVIS_train 55 | |-- JPEGImages/480p/ 56 | |-- bear 57 | |-- ... 58 | |-- Annotations/480p/ 59 | |-- DAVIS_val 60 | |-- JPEGImages/480p/ 61 | |-- bike-packing 62 | |-- ... 63 | |-- Annotations/480p/ 64 | ``` 65 | 66 | - Training on DAVIS training set: 67 | 68 | ```shell 69 | python -m torch.distributed.launch --master_port 12347 --nproc_per_node=4 main.py --data /path-to-your-davis-directory/ 70 | ``` 71 | 72 | All the training parameters are set to our best setting to reproduce the ResNet50 model as default. In this setting you need to have 4 GPUs with 16 GB CUDA memory each. Feel free to contact the author on parameter settings if you want to train on a single or more GPUs. 73 | 74 | If you want to change some parameters, you can see comments in `main.py` or 75 | 76 | ```shell 77 | python main.py -h 78 | ``` 79 | 80 | - Inference on DAVIS validation set, 1 GPU with 12 GB CUDA memory is needed: 81 | 82 | ```shell 83 | python inference.py -r /path-to-pretrained-model -s /path-to-save-predictions 84 | ``` 85 | 86 | Same as above, all the inference parameters are set to our best setting on DAVIS validation set as default, which is able to reproduce our result with a J-mean of 0.699. The saved predictions can be directly evaluated by [DAVIS evaluation code](https://github.com/davisvideochallenge/davis2017-evaluation). 87 | 88 | ## Further Improvements 89 | This approach is simple with clean implementations, if you add few tiny tricks, the performance will be furhter improved. For exmaple, 90 | - If performing epoch test, i.e., selecting the best-performing epoch, you can further get ~1.5 points absolute performance improvements on DAVIS17 dataset. 91 | - Pretraining the model on other image datasets with mask annotation, such as semantic segmentation and salient object detection, may bring further improvements. 92 | - ... ... 93 | 94 | ## Contact 95 | 96 | For any questions, please feel free to reach 97 | 98 | ``` 99 | Yizhuo Zhang: criszhang004@gmail.com 100 | Zhirong Wu: xavibrowu@gmail.com 101 | ``` 102 | 103 | ## Citations 104 | ``` 105 | @inproceedings{zhang2020a, 106 | title={A Transductive Approach for Video Object Segmentation} 107 | author={Zhang, Yizhuo and Wu, Zhirong and Peng, Houwen and Lin, Stephen}, 108 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, 109 | year={2020} 110 | } 111 | ``` 112 | 113 | ## Contributing 114 | 115 | This project welcomes contributions and suggestions. Most contributions require you to agree to a 116 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us 117 | the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com. 118 | 119 | When you submit a pull request, a CLA bot will automatically determine whether you need to provide 120 | a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions 121 | provided by the bot. You will only need to do this once across all repos using our CLA. 122 | 123 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 124 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or 125 | contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. 126 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .davis import DavisTrain, DavisInference 2 | -------------------------------------------------------------------------------- /dataset/annotation_centroids.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/transductive-vos.pytorch/1312e43b59bb203e2575a339643a6d6f5c292a16/dataset/annotation_centroids.npy -------------------------------------------------------------------------------- /dataset/davis.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO 2 | from PIL import Image 3 | 4 | import torch 5 | import numpy as np 6 | from torchvision import datasets 7 | 8 | from dataset.transforms import * 9 | 10 | 11 | class DavisTrain(datasets.ImageFolder): 12 | def __init__(self, 13 | img_root, 14 | annotation_root, 15 | cropping=256, 16 | frame_num=10, 17 | transform=None, 18 | target_transform=None, 19 | color_jitter=False): 20 | super(DavisTrain, self).__init__(img_root, 21 | transform=transform, 22 | target_transform=target_transform) 23 | # img root and annotation root should have the same class_to_idx 24 | self.annotations = make_dataset(annotation_root, self.class_to_idx) 25 | self.cropping = cropping 26 | self.frame_num = frame_num 27 | self.color_jitter = color_jitter 28 | self.rgb_normalize = transforms.Compose([transforms.ToTensor(), 29 | transforms.Normalize( 30 | mean=[0.485, 0.456, 0.406], 31 | std=[0.229, 0.224, 0.225])]) 32 | # read all jpgs and annotations into mem to speed up training 33 | self.img_bytes = [] 34 | self.annotation_bytes = [] 35 | idx = 0 36 | for path, _ in self.imgs: 37 | with open(path, 'rb') as f: 38 | self.img_bytes.append(f.read()) 39 | idx += 1 40 | if idx % 500 == 0: 41 | print("%d images loaded." % idx) 42 | print("JPEGImages loaded: ", len(self.img_bytes)) 43 | idx = 0 44 | for path, _ in self.annotations: 45 | with open(path, 'rb') as f: 46 | self.annotation_bytes.append(f.read()) 47 | idx += 1 48 | if idx % 500 == 0: 49 | print("%d annotations loaded." % idx) 50 | print("Annotations loaded: ", len(self.annotation_bytes)) 51 | 52 | def __getitem__(self, index): 53 | img_output = [] 54 | annotation_output = [] 55 | 56 | # if index reaches end of dataset, get the last frames 57 | if index + self.frame_num > len(self.imgs): 58 | index = len(self.imgs) - self.frame_num 59 | while not self.__is_from_same_video__(index): 60 | index -= 1 61 | # get transform params 62 | if self.color_jitter: 63 | color_transform = FixedColorJitter(brightness=0.4, contrast=0.4, 64 | saturation=0.4, hue=0.4) 65 | crop_i, crop_j, th, tw = 0, 0, 0, 0 66 | h_flip = True if random.random() < 0.5 else False 67 | v_flip = True if random.random() < 0.5 else False 68 | for i in range(self.frame_num): 69 | path, video_index = self.imgs[index + i] 70 | img = Image.open(BytesIO(self.img_bytes[index + i])) 71 | img = img.convert('RGB') 72 | annotation = Image.open(BytesIO(self.annotation_bytes[index + i])) # (W, H), -P mode 73 | annotation = annotation.convert('RGB') 74 | 75 | if h_flip: 76 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 77 | annotation = annotation.transpose(Image.FLIP_LEFT_RIGHT) 78 | if v_flip: 79 | img = img.transpose(Image.FLIP_TOP_BOTTOM) 80 | annotation = annotation.transpose(Image.FLIP_TOP_BOTTOM) 81 | if i == 0: 82 | W, H = img.size 83 | crop_i, crop_j, th, tw = get_crop_params((W, H), self.cropping) 84 | 85 | # all images and annotations should cropped in the same way 86 | img_cropped = crop(img, crop_i, crop_j, th, tw) 87 | annotation_cropped = crop(annotation, crop_i, crop_j, th, tw) 88 | if self.color_jitter: 89 | img_cropped = color_transform(img_cropped) 90 | 91 | img_cropped = self.rgb_normalize(img_cropped).numpy() 92 | annotation_cropped = np.asarray(annotation_cropped).transpose((2, 0, 1)) 93 | img_output.append(img_cropped) 94 | annotation_output.append(annotation_cropped) 95 | 96 | img_output = torch.from_numpy(np.asarray(img_output)).float() 97 | annotation_output = torch.from_numpy(np.asarray(annotation_output)).float() 98 | return img_output, annotation_output, video_index 99 | 100 | def __is_from_same_video__(self, index): 101 | _, indexStart = self.imgs[index] 102 | _, indexEnd = self.imgs[index + self.frame_num - 1] 103 | return indexStart == indexEnd 104 | 105 | 106 | class DavisInference(datasets.ImageFolder): 107 | """ 108 | Load one frame at a time. 109 | Used for inference. 110 | """ 111 | 112 | def __init__(self, 113 | root, 114 | transform=None, 115 | target_transform=None): 116 | super(DavisInference, self).__init__(root, 117 | transform=transform, 118 | target_transform=target_transform) 119 | self.img_bytes = [] 120 | self.rgb_normalize = transforms.Compose([transforms.ToTensor(), 121 | transforms.Normalize( 122 | mean=[0.485, 0.456, 0.406], 123 | std=[0.229, 0.224, 0.225])]) 124 | for path, _ in self.imgs: 125 | with open(path, 'rb') as f: 126 | self.img_bytes.append(f.read()) 127 | print("Tracking folder JPEGImages loaded: ", len(self.img_bytes)) 128 | 129 | def __getitem__(self, index): 130 | path, video_index = self.imgs[index] 131 | img = Image.open(BytesIO(self.img_bytes[index])) 132 | img = img.convert('RGB') 133 | 134 | img_original = np.asarray(img) 135 | 136 | output = self.rgb_normalize(img_original) 137 | return output, video_index, img_original 138 | 139 | def __len__(self): 140 | return len(self.imgs) 141 | -------------------------------------------------------------------------------- /dataset/split_trainval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import shutil 4 | 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument('-i', type=str, required=True, 7 | help='path to DAVIS set') 8 | args = parser.parse_args() 9 | 10 | trainval_path = args.i 11 | 12 | train_list = os.path.join(trainval_path, 'ImageSets/2017/train.txt') 13 | train_list = open(train_list).readlines() 14 | for i in range(len(train_list)): 15 | train_list[i] = train_list[i].strip() 16 | 17 | val_list = os.path.join(trainval_path, 'ImageSets/2017/val.txt') 18 | val_list = open(val_list).readlines() 19 | for i in range(len(val_list)): 20 | val_list[i] = val_list[i].strip() 21 | 22 | full_img_path = os.path.join(trainval_path, 'JPEGImages/480p') 23 | full_annotation_path = os.path.join(trainval_path, 'Annotations/480p/') 24 | full_video_list = os.listdir(full_img_path) 25 | 26 | train_img = './DAVIS_train/JPEGImages/480p' 27 | train_annotations = './DAVIS_train/Annotations/480p' 28 | val_img = './DAVIS_val/JPEGImages/480p' 29 | val_annotations = './DAVIS_val/Annotations/480p' 30 | l = [train_img, train_annotations, val_img, val_annotations] 31 | for p in l: 32 | if not os.path.exists(p): 33 | os.makedirs(p) 34 | 35 | for video in full_video_list: 36 | src1 = os.path.join(full_annotation_path, video) 37 | src2 = os.path.join(full_img_path, video) 38 | if video in train_list: 39 | dest1 = train_annotations 40 | dest2 = train_img 41 | else: 42 | dest1 = val_annotations 43 | dest2 = val_img 44 | if not os.path.exists(dest1): 45 | os.makedirs(dest1) 46 | if not os.path.exists(dest2): 47 | os.makedirs(dest2) 48 | shutil.move(src1, dest1) 49 | shutil.move(src2, dest2) 50 | print('success') 51 | -------------------------------------------------------------------------------- /dataset/transforms.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numbers 4 | 5 | from torchvision import transforms 6 | 7 | 8 | def get_crop_params(img_size, output_size): 9 | """ input: 10 | - img_size : tuple of (w, h), original image size 11 | - output_size: desired output size, one int or tuple 12 | return: 13 | - i 14 | - j 15 | - th 16 | - tw 17 | """ 18 | w, h = img_size 19 | if isinstance(output_size, numbers.Number): 20 | th, tw = (output_size, output_size) 21 | else: 22 | th, tw = output_size 23 | if w == tw and h == th: 24 | return 0, 0, h, w 25 | 26 | i = random.randint(0, h - th) 27 | j = random.randint(0, w - tw) 28 | return i, j, th, tw 29 | 30 | 31 | def crop(img, i, j, h, w): 32 | """Crop the given PIL Image. 33 | Args: 34 | img (PIL Image): Image to be cropped. 35 | i: Upper pixel coordinate. 36 | j: Left pixel coordinate. 37 | h: Height of the cropped image. 38 | w: Width of the cropped image. 39 | Returns: 40 | PIL Image: Cropped image. 41 | """ 42 | return img.crop((j, i, j + w, i + h)) 43 | 44 | 45 | class FixedColorJitter(transforms.ColorJitter): 46 | """ 47 | Same ColorJitter class, only fixes the transform params once instantiated. 48 | """ 49 | 50 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): 51 | super(FixedColorJitter, self).__init__(brightness, contrast, saturation, hue) 52 | self.transform = self.get_params(self.brightness, self.contrast, 53 | self.saturation, self.hue) 54 | 55 | def __call__(self, img): 56 | return self.transform(img) 57 | 58 | 59 | def make_dataset(dir, class_to_idx): 60 | images = [] 61 | dir = os.path.expanduser(dir) 62 | for target in sorted(class_to_idx.keys()): 63 | d = os.path.join(dir, target) 64 | if not os.path.isdir(d): 65 | continue 66 | 67 | for root, _, fnames in sorted(os.walk(d)): 68 | for fname in sorted(fnames): 69 | path = os.path.join(root, fname) 70 | item = (path, class_to_idx[target]) 71 | images.append(item) 72 | return images 73 | -------------------------------------------------------------------------------- /figure/fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/transductive-vos.pytorch/1312e43b59bb203e2575a339643a6d6f5c292a16/figure/fig1.png -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | 5 | import torch 6 | import torch.nn as nn 7 | import numpy as np 8 | 9 | import dataset 10 | import modeling 11 | 12 | from lib.utils import AverageMeter, save_prediction, idx2onehot 13 | from lib.predict import predict, prepare_first_frame 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--ref_num', '-n', type=int, default=9, 17 | help='number of reference frames for inference') 18 | parser.add_argument('--dataset', '-ds', type=str, default='davis', 19 | help='name of dataset') 20 | parser.add_argument('--data', type=str, 21 | help='path to inference dataset') 22 | parser.add_argument('--resume', '-r', type=str, 23 | help='path to the resumed checkpoint') 24 | parser.add_argument('--model', type=str, default='resnet50', 25 | help='network architecture, resnet18, resnet50 or resnet101') 26 | parser.add_argument('--temperature', '-t', type=float, default=1.0, 27 | help='temperature parameter') 28 | parser.add_argument('--range', type=int, default=40, 29 | help='range of frames for inference') 30 | parser.add_argument('--sigma1', type=float, default=8.0, 31 | help='smaller sigma in the motion model for dense spatial weight') 32 | parser.add_argument('--sigma2', type=float, default=21.0, 33 | help='bigger sigma in the motion model for sparse spatial weight') 34 | parser.add_argument('--save', '-s', type=str, 35 | help='path to save predictions') 36 | 37 | device = torch.device("cuda") 38 | 39 | 40 | def main(): 41 | global args 42 | args = parser.parse_args() 43 | 44 | model = modeling.VOSNet(model=args.model) 45 | model = nn.DataParallel(model) 46 | model.cuda() 47 | 48 | if args.resume: 49 | if os.path.isfile(args.resume): 50 | print("=> loading checkpoint '{}'".format(args.resume)) 51 | checkpoint = torch.load(args.resume) 52 | model.load_state_dict(checkpoint['state_dict']) 53 | print("=> loaded checkpoint '{}'" 54 | .format(args.resume)) 55 | else: 56 | print("=> no checkpoint found at '{}'".format(args.resume)) 57 | model.eval() 58 | 59 | data_dir = os.path.join(args.data, 'DAVIS_val/JPEGImages/480p') 60 | inference_dataset = dataset.DavisInference(data_dir) 61 | inference_loader = torch.utils.data.DataLoader(inference_dataset, 62 | batch_size=1, 63 | shuffle=False, 64 | num_workers=8) 65 | inference(inference_loader, model, args) 66 | 67 | 68 | def inference(inference_loader, model, args): 69 | global pred_visualize, palette, d, feats_history, label_history, weight_dense, weight_sparse 70 | batch_time = AverageMeter() 71 | annotation_dir = os.path.join(args.data, 'DAVIS_val/Annotations/480p') 72 | annotation_list = sorted(os.listdir(annotation_dir)) 73 | 74 | last_video = 0 75 | frame_idx = 0 76 | with torch.no_grad(): 77 | for i, (input, curr_video, img_original) in enumerate(inference_loader): 78 | if curr_video != last_video: 79 | # save prediction 80 | pred_visualize = pred_visualize.cpu().numpy() 81 | for f in range(1, frame_idx): 82 | save_path = args.save 83 | save_name = str(f).zfill(5) 84 | video_name = annotation_list[last_video] 85 | save_prediction(np.asarray(pred_visualize[f - 1], dtype=np.int32), 86 | palette, save_path, save_name, video_name) 87 | 88 | frame_idx = 0 89 | print("End of video %d. Processing a new annotation..." % (last_video + 1)) 90 | if frame_idx == 0: 91 | input = input.to(device) 92 | with torch.no_grad(): 93 | feats_history = model(input) 94 | label_history, d, palette, weight_dense, weight_sparse = prepare_first_frame(curr_video, 95 | args.save, 96 | annotation_dir, 97 | args.sigma1, 98 | args.sigma2) 99 | frame_idx += 1 100 | last_video = curr_video 101 | continue 102 | (batch_size, num_channels, H, W) = input.shape 103 | input = input.to(device) 104 | 105 | start = time.time() 106 | features = model(input) 107 | (_, feature_dim, H_d, W_d) = features.shape 108 | prediction = predict(feats_history, 109 | features[0], 110 | label_history, 111 | weight_dense, 112 | weight_sparse, 113 | frame_idx, 114 | args 115 | ) 116 | # Store all frames' features 117 | new_label = idx2onehot(torch.argmax(prediction, 0), d).unsqueeze(1) 118 | label_history = torch.cat((label_history, new_label), 1) 119 | feats_history = torch.cat((feats_history, features), 0) 120 | 121 | last_video = curr_video 122 | frame_idx += 1 123 | 124 | # 1. upsample, 2. argmax 125 | prediction = torch.nn.functional.interpolate(prediction.view(1, d, H_d, W_d), 126 | size=(H, W), 127 | mode='bilinear', 128 | align_corners=False) 129 | prediction = torch.argmax(prediction, 1) # (1, H, W) 130 | 131 | if frame_idx == 2: 132 | pred_visualize = prediction 133 | else: 134 | pred_visualize = torch.cat((pred_visualize, prediction), 0) 135 | 136 | batch_time.update(time.time() - start) 137 | 138 | if i % 10 == 0: 139 | print('Validate: [{0}/{1}]\t' 140 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'.format( 141 | i, len(inference_loader), batch_time=batch_time)) 142 | # save last video's prediction 143 | pred_visualize = pred_visualize.cpu().numpy() 144 | for f in range(1, frame_idx): 145 | save_path = args.save 146 | save_name = str(f).zfill(5) 147 | video_name = annotation_list[last_video] 148 | save_prediction(np.asarray(pred_visualize[f - 1], dtype=np.int32), 149 | palette, save_path, save_name, video_name) 150 | print('Finished inference.') 151 | 152 | 153 | if __name__ == '__main__': 154 | main() 155 | -------------------------------------------------------------------------------- /lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/transductive-vos.pytorch/1312e43b59bb203e2575a339643a6d6f5c292a16/lib/__init__.py -------------------------------------------------------------------------------- /lib/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | def batch_get_similarity_matrix(ref, target): 6 | """ 7 | Get pixel-level similarity matrix. 8 | :param ref: (batchSize, num_ref, feature_dim, H, W) 9 | :param target: (batchSize, feature_dim, H, W) 10 | :return: (batchSize, num_ref*H*W, H*W) 11 | """ 12 | (batchSize, num_ref, feature_dim, H, W) = ref.shape 13 | ref = ref.permute(0, 1, 3, 4, 2).reshape(batchSize, -1, feature_dim) 14 | target = target.reshape(batchSize, feature_dim, -1) 15 | T = ref.bmm(target) 16 | return T 17 | 18 | 19 | def batch_global_predict(global_similarity, ref_label): 20 | """ 21 | Get global prediction. 22 | :param global_similarity: (batchSize, num_ref*H*W, H*W) 23 | :param ref_label: onehot form (batchSize, num_ref, d, H, W) 24 | :return: (batchSize, d, H, W) 25 | """ 26 | (batchSize, num_ref, d, H, W) = ref_label.shape 27 | ref_label = ref_label.transpose(1, 2).reshape(batchSize, d, -1) 28 | return ref_label.bmm(global_similarity).reshape(batchSize, d, H, W) 29 | 30 | 31 | class CrossEntropy(nn.Module): 32 | def __init__(self, temperature=1.0): 33 | super(CrossEntropy, self).__init__() 34 | self.temperature = temperature 35 | self.nllloss = nn.NLLLoss() 36 | 37 | def forward(self, ref, target, ref_label, target_label): 38 | """ 39 | let Nt = num of target pixels, Nr = num of ref pixels 40 | :param ref: (batchSize, num_ref, feature_dim, H, W) 41 | :param target: (batchSize, feature_dim, H, W) 42 | :param ref_label: label for reference pixels 43 | (batchSize, num_ref, d, H, W) 44 | :param target_label: label for target pixels (ground truth) 45 | (batchSize, H, W) 46 | """ 47 | global_similarity = batch_get_similarity_matrix(ref, target) 48 | 49 | global_similarity = global_similarity * self.temperature 50 | 51 | global_similarity = global_similarity.softmax(dim=1) 52 | 53 | prediction = batch_global_predict(global_similarity, ref_label) 54 | prediction = torch.log(prediction + 1e-14) 55 | loss = self.nllloss(prediction, target_label) 56 | 57 | return loss 58 | -------------------------------------------------------------------------------- /lib/predict.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | 5 | from PIL import Image 6 | 7 | from .utils import idx2onehot 8 | 9 | 10 | def predict(ref, 11 | target, 12 | ref_label, 13 | weight_dense, 14 | weight_sparse, 15 | frame_idx, 16 | args): 17 | """ 18 | The Predict Function. 19 | :param ref: (N, feature_dim, H, W) 20 | :param target: (feature_dim, H, W) 21 | :param ref_label: (d, N, H*W) 22 | :param weight_dense: (H*W, H*W) 23 | :param weight_sparse: (H*W, H*W) 24 | :param frame_idx: 25 | :param args: 26 | :return: (d, H, W) 27 | """ 28 | # sample frames from history features 29 | d = ref_label.shape[0] 30 | sample_idx = sample_frames(frame_idx, args.range, args.ref_num) 31 | ref_selected = ref.index_select(0, sample_idx) 32 | ref_label_selected = ref_label.index_select(1, sample_idx).view(d, -1) 33 | 34 | # get similarity matrix 35 | (num_ref, feature_dim, H, W) = ref_selected.shape 36 | ref_selected = ref_selected.permute(0, 2, 3, 1).reshape(-1, feature_dim) 37 | target = target.reshape(feature_dim, -1) 38 | global_similarity = ref_selected.mm(target) 39 | 40 | # temperature step 41 | global_similarity *= args.temperature 42 | 43 | # softmax 44 | global_similarity = global_similarity.softmax(dim=0) 45 | 46 | # spatial weight and motion model 47 | global_similarity = global_similarity.contiguous().view(num_ref, H * W, H * W) 48 | if frame_idx > 15: 49 | continuous_frame = 4 50 | # interval frames 51 | global_similarity[:-continuous_frame] *= weight_sparse 52 | # continuous frames 53 | global_similarity[-continuous_frame:] *= weight_dense 54 | else: 55 | global_similarity = global_similarity.mul(weight_dense) 56 | global_similarity = global_similarity.view(-1, H * W) 57 | 58 | # get prediction 59 | prediction = ref_label_selected.mm(global_similarity) 60 | return prediction 61 | 62 | 63 | def sample_frames(frame_idx, 64 | take_range, 65 | num_refs): 66 | if frame_idx <= num_refs: 67 | sample_idx = list(range(frame_idx)) 68 | else: 69 | dense_num = 4 - 1 70 | sparse_num = num_refs - dense_num 71 | target_idx = frame_idx 72 | ref_end = target_idx - dense_num - 1 73 | ref_start = max(ref_end - take_range, 0) 74 | sample_idx = np.linspace(ref_start, ref_end, sparse_num).astype(np.int).tolist() 75 | for j in range(dense_num): 76 | sample_idx.append(target_idx - dense_num + j) 77 | 78 | return torch.Tensor(sample_idx).long().cuda() 79 | 80 | 81 | def prepare_first_frame(curr_video, 82 | save_prediction, 83 | annotation_dir, 84 | sigma1=8, 85 | sigma2=21): 86 | annotation_list = sorted(os.listdir(annotation_dir)) 87 | first_annotation = Image.open(os.path.join(annotation_dir, annotation_list[curr_video], '00000.png')) 88 | (H, W) = np.asarray(first_annotation).shape 89 | H_d = int(np.ceil(H / 8)) 90 | W_d = int(np.ceil(W / 8)) 91 | palette = first_annotation.getpalette() 92 | label = np.asarray(first_annotation) 93 | d = np.max(label) + 1 94 | label = torch.Tensor(label).long().cuda() # (1, H, W) 95 | label_1hot = idx2onehot(label.view(-1), d).reshape(1, d, H, W) 96 | label_1hot = torch.nn.functional.interpolate(label_1hot, 97 | size=(H_d, W_d), 98 | mode='bilinear', 99 | align_corners=False) 100 | label_1hot = label_1hot.reshape(d, -1).unsqueeze(1) 101 | weight_dense = get_spatial_weight((H_d, W_d), sigma1) 102 | weight_sparse = get_spatial_weight((H_d, W_d), sigma2) 103 | 104 | if save_prediction is not None: 105 | if not os.path.exists(save_prediction): 106 | os.makedirs(save_prediction) 107 | save_path = os.path.join(save_prediction, annotation_list[curr_video]) 108 | if not os.path.exists(save_path): 109 | os.makedirs(save_path) 110 | first_annotation.save(os.path.join(save_path, '00000.png')) 111 | 112 | return label_1hot, d, palette, weight_dense, weight_sparse 113 | 114 | 115 | def get_spatial_weight(shape, sigma): 116 | """ 117 | Get soft spatial weights for similarity matrix. 118 | :param shape: (H, W) 119 | :param sigma: 120 | :return: (H*W, H*W) 121 | """ 122 | (H, W) = shape 123 | 124 | index_matrix = torch.arange(H * W, dtype=torch.long).reshape(H * W, 1).cuda() 125 | index_matrix = torch.cat((index_matrix / W, index_matrix % W), -1) # (H*W, 2) 126 | d = index_matrix - index_matrix.unsqueeze(1) # (H*W, H*W, 2) 127 | d = d.float().pow(2).sum(-1) # (H*W, H*W) 128 | w = (- d / sigma ** 2).exp() 129 | 130 | return w 131 | -------------------------------------------------------------------------------- /lib/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import functools 4 | import logging 5 | 6 | import torch 7 | 8 | from PIL import Image 9 | 10 | # so that calling setup_logger multiple times won't add many handlers 11 | @functools.lru_cache() 12 | def setup_logger( 13 | output=None, distributed_rank=0, *, color=True, name=None, abbrev_name=None 14 | ): 15 | """ 16 | Initialize the detectron2 logger and set its verbosity level to "INFO". 17 | 18 | Args: 19 | output (str): a file name or a directory to save log. If None, will not save log file. 20 | If ends with ".txt" or ".log", assumed to be a file name. 21 | Otherwise, logs will be saved to `output/log.txt`. 22 | name (str): the root module name of this logger 23 | 24 | Returns: 25 | logging.Logger: a logger 26 | """ 27 | logger = logging.getLogger(name) 28 | logger.setLevel(logging.DEBUG) 29 | logger.propagate = False 30 | 31 | if abbrev_name is None: 32 | abbrev_name = name 33 | 34 | plain_formatter = logging.Formatter( 35 | "[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%m/%d %H:%M:%S" 36 | ) 37 | # stdout logging: master only 38 | if distributed_rank == 0: 39 | ch = logging.StreamHandler(stream=sys.stdout) 40 | ch.setLevel(logging.DEBUG) 41 | formatter = plain_formatter 42 | ch.setFormatter(formatter) 43 | logger.addHandler(ch) 44 | 45 | # file logging: all workers 46 | if output is not None: 47 | if output.endswith(".txt") or output.endswith(".log"): 48 | filename = output 49 | else: 50 | filename = os.path.join(output, "log.txt") 51 | if distributed_rank > 0: 52 | filename = filename + f".rank{distributed_rank}" 53 | os.makedirs(os.path.dirname(filename), exist_ok=True) 54 | 55 | fh = logging.StreamHandler(_cached_log_stream(filename)) 56 | fh.setLevel(logging.DEBUG) 57 | fh.setFormatter(plain_formatter) 58 | logger.addHandler(fh) 59 | 60 | return logger 61 | 62 | # cache the opened file object, so that different calls to `setup_logger` 63 | # with the same file name can safely write to the same file. 64 | @functools.lru_cache(maxsize=None) 65 | def _cached_log_stream(filename): 66 | return open(filename, "a") 67 | class AverageMeter(object): 68 | def __init__(self): 69 | self.val = 0 70 | self.avg = 0 71 | self.sum = 0 72 | self.count = 0 73 | 74 | def reset(self): 75 | self.val = 0 76 | self.avg = 0 77 | self.sum = 0 78 | self.count = 0 79 | 80 | def update(self, val, n=1): 81 | self.val = val 82 | self.sum += val * n 83 | self.count += n 84 | self.avg = self.sum / self.count 85 | 86 | 87 | def save_prediction(prediction, palette, save_path, save_name, video_name): 88 | img = Image.fromarray(prediction) 89 | img = img.convert('L') 90 | img.putpalette(palette) 91 | img = img.convert('P') 92 | video_path = os.path.join(save_path, video_name) 93 | if not os.path.exists(video_path): 94 | os.makedirs(video_path) 95 | img.save('{}/{}.png'.format(video_path, save_name)) 96 | 97 | 98 | def rgb2class(img, centroids): 99 | """ 100 | Change rgb image array into class index. 101 | :param img: (batch_size, C, H, W) 102 | :param centroids: 103 | :return: (batch_size, H, W) 104 | """ 105 | (batch_size, C, H, W) = img.shape 106 | img = img.permute(0, 2, 3, 1).reshape(-1, C) 107 | class_idx = torch.argmin(torch.sqrt(torch.sum((img.unsqueeze(1) - centroids) ** 2, 2)), 1) 108 | class_idx = torch.reshape(class_idx, (batch_size, H, W)) 109 | return class_idx 110 | 111 | 112 | def idx2onehot(idx, d): 113 | """ input: 114 | - idx: (H*W) 115 | return: 116 | - one_hot: (d, H*W) 117 | """ 118 | n = idx.shape[0] 119 | one_hot = torch.zeros(d, n, device=torch.device("cuda")).scatter_(0, idx.view(1, -1), 1) 120 | 121 | return one_hot 122 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.distributed as dist 8 | import torch.utils.data.distributed 9 | from torch.nn.parallel import DistributedDataParallel 10 | import torch.backends.cudnn as cudnn 11 | import numpy as np 12 | 13 | import dataset 14 | import modeling 15 | 16 | from lib.loss import CrossEntropy 17 | from lib.utils import AverageMeter, rgb2class, setup_logger 18 | 19 | SCALE = 0.125 20 | 21 | os.environ['CUDA_VISIBLE_DEVICES'] = '4,5,6,7' 22 | def parse_options(): 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument('--frame_num', '-n', type=int, default=10, 25 | help='number of frames to train') 26 | parser.add_argument('--dataset', '-ds', type=str, default='davis', 27 | help='name of dataset') 28 | parser.add_argument('--data', type=str, 29 | help='path to dataset') 30 | parser.add_argument('--resume', '-r', type=str, 31 | help='path to the resumed checkpoint') 32 | parser.add_argument('--save_model', '-m', type=str, default='./checkpoints', 33 | help='directory to save checkpoints') 34 | parser.add_argument('--epochs', type=int, default=240, 35 | help='number of epochs') 36 | parser.add_argument('--model', type=str, default='resnet50', 37 | help='network architecture, resnet18, resnet50 or resnet101') 38 | parser.add_argument('--temperature', '-t', type=float, default=1.0, 39 | help='temperature parameter') 40 | parser.add_argument('--bs', type=int, default=16, 41 | help='batch size') 42 | parser.add_argument('--lr', type=float, default=0.02, 43 | help='initial learning rate') 44 | parser.add_argument('--wd', type=float, default=3e-4, 45 | help='weight decay') # weight decay 46 | parser.add_argument('--iter_size', type=int, default=1, 47 | help='iter size') 48 | parser.add_argument('--cj', action='store_true', 49 | help='use color jitter') 50 | parser.add_argument('--local_rank', type=int, default=0, 51 | help='default rank for dist') 52 | 53 | args = parser.parse_args() 54 | 55 | return args 56 | 57 | def main(args): 58 | 59 | # model = modeling.VOSNet(model=args.model).cuda() 60 | # model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 61 | model = modeling.VOSNet(model=args.model, sync_bn=True).cuda() 62 | model = DistributedDataParallel(model, device_ids=[args.local_rank], broadcast_buffers=False) 63 | 64 | criterion = CrossEntropy(temperature=args.temperature).cuda() 65 | 66 | optimizer = torch.optim.SGD(model.parameters(), 67 | lr=args.lr, 68 | momentum=0.9, 69 | nesterov=True, 70 | weight_decay=args.wd) 71 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 72 | args.epochs, 73 | eta_min=4e-5) 74 | if args.dataset == 'davis': 75 | train_dataset = dataset.DavisTrain(os.path.join(args.data, 'DAVIS_train/JPEGImages/480p'), 76 | os.path.join(args.data, 'DAVIS_train/Annotations/480p'), 77 | frame_num=args.frame_num, 78 | color_jitter=args.cj) 79 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 80 | train_loader = torch.utils.data.DataLoader(train_dataset, 81 | batch_size=args.bs // dist.get_world_size(), 82 | shuffle=False, 83 | sampler = train_sampler, 84 | pin_memory = True, 85 | num_workers=8 // dist.get_world_size(), 86 | drop_last=True) 87 | val_dataset = dataset.DavisTrain(os.path.join(args.data, 'DAVIS_val/JPEGImages/480p'), 88 | os.path.join(args.data, 'DAVIS_val/Annotations/480p'), 89 | frame_num=args.frame_num, 90 | color_jitter=args.cj) 91 | val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset) 92 | val_loader = torch.utils.data.DataLoader(val_dataset, 93 | batch_size=args.bs // dist.get_world_size(), 94 | shuffle=False, 95 | sampler = val_sampler, 96 | pin_memory = True, 97 | num_workers=8 // dist.get_world_size(), 98 | drop_last=True) 99 | else: 100 | raise NotImplementedError 101 | start_epoch = 0 102 | if args.resume: 103 | if os.path.isfile(args.resume): 104 | logger.info("=> loading checkpoint '{}'".format(args.resume)) 105 | checkpoint = torch.load(args.resume) 106 | args.start_epoch = checkpoint['epoch'] 107 | model.load_state_dict(checkpoint['state_dict']) 108 | optimizer.load_state_dict(checkpoint['optimizer']) 109 | scheduler.load_state_dict(checkpoint['scheduler']) 110 | logger.info("=> loaded checkpoint '{}' (epoch {})" 111 | .format(args.resume, checkpoint['epoch'])) 112 | else: 113 | logger.info("=> no checkpoint found at '{}'".format(args.resume)) 114 | 115 | for epoch in range(start_epoch, start_epoch + args.epochs): 116 | 117 | train_loss = train(train_loader, model, criterion, optimizer, epoch, args) 118 | 119 | with torch.no_grad(): 120 | val_loss = validate(val_loader, model, criterion, args) 121 | 122 | scheduler.step() 123 | 124 | if dist.get_rank() == 0: 125 | os.makedirs(args.save_model, exist_ok=True) 126 | checkpoint_name = 'checkpoint-epoch-{}.pth.tar'.format(epoch) 127 | save_path = os.path.join(args.save_model, checkpoint_name) 128 | torch.save({ 129 | 'epoch': epoch + 1, 130 | 'state_dict': model.state_dict(), 131 | 'optimizer': optimizer.state_dict(), 132 | 'scheduler': scheduler.state_dict(), 133 | }, save_path) 134 | 135 | 136 | def train(train_loader, model, criterion, optimizer, epoch, args): 137 | logger.info('Starting training epoch {}'.format(epoch)) 138 | 139 | centroids = np.load("./dataset/annotation_centroids.npy") 140 | centroids = torch.Tensor(centroids).float().cuda() 141 | 142 | batch_time = AverageMeter() 143 | data_time = AverageMeter() 144 | losses = AverageMeter() 145 | 146 | model.train() 147 | 148 | end = time.time() 149 | for i, (img_input, annotation_input, _) in enumerate(train_loader): 150 | data_time.update(time.time() - end) 151 | 152 | (batch_size, num_frames, num_channels, H, W) = img_input.shape 153 | annotation_input = annotation_input.reshape(-1, 3, H, W).cuda() 154 | annotation_input_downsample = torch.nn.functional.interpolate(annotation_input, 155 | scale_factor=SCALE, 156 | mode='bilinear', 157 | align_corners=False) 158 | H_d = annotation_input_downsample.shape[-2] 159 | W_d = annotation_input_downsample.shape[-1] 160 | 161 | annotation_input = rgb2class(annotation_input_downsample, centroids) 162 | annotation_input = annotation_input.reshape(batch_size, num_frames, H_d, W_d) 163 | 164 | img_input = img_input.reshape(-1, num_channels, H, W).cuda() 165 | 166 | features = model(img_input) 167 | feature_dim = features.shape[1] 168 | features = features.reshape(batch_size, num_frames, feature_dim, H_d, W_d) 169 | 170 | ref = features[:, 0:num_frames - 1, :, :, :] 171 | target = features[:, -1, :, :, :] 172 | ref_label = annotation_input[:, 0:num_frames - 1, :, :] 173 | target_label = annotation_input[:, -1, :, :] 174 | 175 | ref_label = torch.zeros(batch_size, num_frames - 1, centroids.shape[0], H_d, W_d).cuda().scatter_( 176 | 2, ref_label.unsqueeze(2), 1) 177 | 178 | loss = criterion(ref, target, ref_label, target_label) / args.iter_size 179 | loss.backward() 180 | 181 | losses.update(loss.item(), batch_size) 182 | 183 | if (i + 1) % args.iter_size == 0: 184 | optimizer.step() 185 | optimizer.zero_grad() 186 | 187 | batch_time.update(time.time() - end) 188 | end = time.time() 189 | 190 | if i % 25 == 0: 191 | logger.info('Epoch: [{0}][{1}/{2}]\t' 192 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 193 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 194 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format( 195 | epoch, i, len(train_loader), batch_time=batch_time, 196 | data_time=data_time, loss=losses)) 197 | 198 | logger.info('Finished training epoch {}'.format(epoch)) 199 | return losses.avg 200 | 201 | 202 | def validate(val_loader, model, criterion, args): 203 | logger.info('starting validation...') 204 | 205 | centroids = np.load("./dataset/annotation_centroids.npy") 206 | centroids = torch.Tensor(centroids).float().cuda() 207 | 208 | batch_time = AverageMeter() 209 | data_time = AverageMeter() 210 | losses = AverageMeter() 211 | 212 | model.eval() 213 | 214 | end = time.time() 215 | for i, (img_input, annotation_input, _) in enumerate(val_loader): 216 | 217 | data_time.update(time.time() - end) 218 | 219 | (batch_size, num_frames, num_channels, H, W) = img_input.shape 220 | 221 | annotation_input = annotation_input.reshape(-1, 3, H, W).cuda() 222 | annotation_input_downsample = torch.nn.functional.interpolate(annotation_input, 223 | scale_factor=SCALE, 224 | mode='bilinear', 225 | align_corners=False) 226 | H_d = annotation_input_downsample.shape[-2] 227 | W_d = annotation_input_downsample.shape[-1] 228 | 229 | annotation_input = rgb2class(annotation_input_downsample, centroids) 230 | annotation_input = annotation_input.reshape(batch_size, num_frames, H_d, W_d) 231 | 232 | img_input = img_input.reshape(-1, num_channels, H, W).cuda() 233 | 234 | features = model(img_input) 235 | feature_dim = features.shape[1] 236 | features = features.reshape(batch_size, num_frames, feature_dim, H_d, W_d) 237 | 238 | ref = features[:, 0:num_frames - 1, :, :, :] 239 | target = features[:, -1, :, :, :] 240 | ref_label = annotation_input[:, 0:num_frames - 1, :, :] 241 | target_label = annotation_input[:, -1, :, :] 242 | 243 | ref_label = torch.zeros(batch_size, num_frames - 1, centroids.shape[0], H_d, W_d).cuda().scatter_( 244 | 2, ref_label.unsqueeze(2), 1) 245 | 246 | loss = criterion(ref, target, ref_label, target_label) / args.iter_size 247 | 248 | losses.update(loss.item(), batch_size) 249 | 250 | batch_time.update(time.time() - end) 251 | end = time.time() 252 | 253 | if i % 25 == 0: 254 | logger.info('Validate: [{0}/{1}]\t' 255 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 256 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format( 257 | i, len(val_loader), batch_time=batch_time, loss=losses)) 258 | 259 | logger.info('Finished validation') 260 | return losses.avg 261 | 262 | 263 | if __name__ == '__main__': 264 | 265 | opt = parse_options() 266 | 267 | torch.cuda.set_device(opt.local_rank) 268 | torch.distributed.init_process_group(backend='nccl', init_method='env://') 269 | cudnn.benchmark = True 270 | 271 | logger = setup_logger(output=opt.save_model, distributed_rank=dist.get_rank(), name='vos') 272 | 273 | main(opt) 274 | -------------------------------------------------------------------------------- /modeling/__init__.py: -------------------------------------------------------------------------------- 1 | from .network import VOSNet 2 | -------------------------------------------------------------------------------- /modeling/backbone/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/transductive-vos.pytorch/1312e43b59bb203e2575a339643a6d6f5c292a16/modeling/backbone/__init__.py -------------------------------------------------------------------------------- /modeling/backbone/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | 5 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 6 | 'resnet152'] 7 | 8 | model_urls = { 9 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 10 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 11 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 12 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 13 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 14 | } 15 | 16 | 17 | def conv3x3(in_planes, out_planes, stride=1): 18 | """3x3 convolution with padding""" 19 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 20 | padding=1, bias=False) 21 | 22 | 23 | class BasicBlock(nn.Module): 24 | expansion = 1 25 | 26 | def __init__(self, inplanes, planes, stride=1, downsample=None, BatchNorm=None): 27 | super(BasicBlock, self).__init__() 28 | self.conv1 = conv3x3(inplanes, planes, stride) 29 | self.bn1 = BatchNorm(planes) 30 | self.relu = nn.ReLU(inplace=True) 31 | self.conv2 = conv3x3(planes, planes) 32 | self.bn2 = BatchNorm(planes) 33 | self.downsample = downsample 34 | self.stride = stride 35 | 36 | def forward(self, x): 37 | residual = x 38 | 39 | out = self.conv1(x) 40 | out = self.bn1(out) 41 | out = self.relu(out) 42 | 43 | out = self.conv2(out) 44 | out = self.bn2(out) 45 | 46 | if self.downsample is not None: 47 | residual = self.downsample(x) 48 | 49 | out += residual 50 | out = self.relu(out) 51 | 52 | return out 53 | 54 | 55 | class Bottleneck(nn.Module): 56 | expansion = 4 57 | 58 | def __init__(self, inplanes, planes, stride=1, downsample=None, BatchNorm=None): 59 | super(Bottleneck, self).__init__() 60 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 61 | self.bn1 = BatchNorm(planes) 62 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 63 | padding=1, bias=False) 64 | self.bn2 = BatchNorm(planes) 65 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 66 | self.bn3 = BatchNorm(planes * 4) 67 | self.relu = nn.ReLU(inplace=True) 68 | self.downsample = downsample 69 | self.stride = stride 70 | 71 | def forward(self, x): 72 | residual = x 73 | 74 | out = self.conv1(x) 75 | out = self.bn1(out) 76 | out = self.relu(out) 77 | 78 | out = self.conv2(out) 79 | out = self.bn2(out) 80 | out = self.relu(out) 81 | 82 | out = self.conv3(out) 83 | out = self.bn3(out) 84 | 85 | if self.downsample is not None: 86 | residual = self.downsample(x) 87 | 88 | out += residual 89 | out = self.relu(out) 90 | 91 | return out 92 | 93 | 94 | class ResNet(nn.Module): 95 | 96 | def __init__(self, block, layers, BatchNorm, num_classes=1000): 97 | self.inplanes = 64 98 | super(ResNet, self).__init__() 99 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 100 | bias=False) 101 | self.bn1 = BatchNorm(64) 102 | self.relu = nn.ReLU(inplace=True) 103 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 104 | self.layer1 = self._make_layer(block, 64, layers[0], BatchNorm) 105 | self.layer2 = self._make_layer(block, 128, layers[1], BatchNorm, stride=2) 106 | self.layer3 = self._make_layer(block, 256, layers[2], BatchNorm, stride=1) 107 | self.layer4 = self._make_layer(block, 256, layers[3], BatchNorm, stride=1) 108 | self.avgpool = nn.AvgPool2d(7, stride=1) 109 | self.fc = nn.Linear(512 * block.expansion, num_classes) 110 | 111 | for m in self.modules(): 112 | if isinstance(m, nn.Conv2d): 113 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 114 | m.weight.data.normal_(0, math.sqrt(2. / n)) 115 | elif isinstance(m, BatchNorm): 116 | m.weight.data.fill_(1) 117 | m.bias.data.zero_() 118 | 119 | def _make_layer(self, block, planes, blocks, BatchNorm, stride=1): 120 | downsample = None 121 | if stride != 1 or self.inplanes != planes * block.expansion: 122 | downsample = nn.Sequential( 123 | nn.Conv2d(self.inplanes, planes * block.expansion, 124 | kernel_size=1, stride=stride, bias=False), 125 | BatchNorm(planes * block.expansion), 126 | ) 127 | 128 | layers = [] 129 | layers.append(block(self.inplanes, planes, stride, downsample, BatchNorm=BatchNorm)) 130 | self.inplanes = planes * block.expansion 131 | for i in range(1, blocks): 132 | layers.append(block(self.inplanes, planes, BatchNorm=BatchNorm)) 133 | 134 | return nn.Sequential(*layers) 135 | 136 | def forward(self, x): 137 | x = self.conv1(x) 138 | x = self.bn1(x) 139 | x = self.relu(x) 140 | x = self.maxpool(x) 141 | 142 | x = self.layer1(x) 143 | x = self.layer2(x) 144 | x = self.layer3(x) 145 | x = self.layer4(x) 146 | 147 | x = self.avgpool(x) 148 | x = x.view(x.size(0), -1) 149 | x = self.fc(x) 150 | 151 | return x 152 | 153 | 154 | def resnet18(pretrained=False, BatchNorm=nn.BatchNorm2d, **kwargs): 155 | """Constructs a ResNet-18 model. 156 | Args: 157 | pretrained (bool): If True, returns a model pre-trained on ImageNet 158 | """ 159 | model = ResNet(BasicBlock, [2, 2, 2, 2], BatchNorm=BatchNorm, **kwargs) 160 | if pretrained: 161 | pretrained_dict = model_zoo.load_url(model_urls['resnet18']) 162 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if 163 | not (k.startswith('layer4') or k.startswith('fc'))} 164 | model_dict = model.state_dict() 165 | model_dict.update(pretrained_dict) 166 | model.load_state_dict(model_dict) 167 | 168 | return model 169 | 170 | 171 | def resnet34(pretrained=False, BatchNorm=nn.BatchNorm2d, **kwargs): 172 | """Constructs a ResNet-34 model. 173 | Args: 174 | pretrained (bool): If True, returns a model pre-trained on ImageNet 175 | """ 176 | model = ResNet(BasicBlock, [3, 4, 6, 3], BatchNorm=BatchNorm, **kwargs) 177 | if pretrained: 178 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 179 | return model 180 | 181 | 182 | def resnet50(pretrained=False, BatchNorm=nn.BatchNorm2d, **kwargs): 183 | """Constructs a ResNet-50 model. 184 | Args: 185 | pretrained (bool): If True, returns a model pre-trained on ImageNet 186 | """ 187 | model = ResNet(Bottleneck, [3, 4, 6, 3], BatchNorm=BatchNorm, **kwargs) 188 | if pretrained: 189 | pretrained_dict = model_zoo.load_url(model_urls['resnet50']) 190 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if 191 | not (k.startswith('layer4') or k.startswith('fc'))} 192 | model_dict = model.state_dict() 193 | model_dict.update(pretrained_dict) 194 | model.load_state_dict(model_dict) 195 | return model 196 | 197 | 198 | def resnet101(pretrained=False, BatchNorm=nn.BatchNorm2d, **kwargs): 199 | """Constructs a ResNet-101 model. 200 | Args: 201 | pretrained (bool): If True, returns a model pre-trained on ImageNet 202 | """ 203 | model = ResNet(Bottleneck, [3, 4, 23, 3], BatchNorm=BatchNorm, **kwargs) 204 | if pretrained: 205 | pretrained_dict = model_zoo.load_url(model_urls['resnet101']) 206 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if 207 | not (k.startswith('layer4') or k.startswith('fc'))} 208 | model_dict = model.state_dict() 209 | model_dict.update(pretrained_dict) 210 | model.load_state_dict(model_dict) 211 | return model 212 | 213 | 214 | def resnet152(pretrained=False, BatchNorm=nn.BatchNorm2d, **kwargs): 215 | """Constructs a ResNet-152 model. 216 | Args: 217 | pretrained (bool): If True, returns a model pre-trained on ImageNet 218 | """ 219 | model = ResNet(Bottleneck, [3, 8, 36, 3], BatchNorm=BatchNorm, **kwargs) 220 | if pretrained: 221 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 222 | return model 223 | -------------------------------------------------------------------------------- /modeling/network.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from modeling.backbone.resnet import resnet18, resnet50, resnet101 4 | from modeling.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d # additional codes 5 | 6 | 7 | class VOSNet(nn.Module): 8 | 9 | def __init__(self, 10 | model='resnet18', sync_bn=False): 11 | 12 | super(VOSNet, self).__init__() 13 | self.model = model 14 | 15 | # additional codes 16 | if sync_bn: 17 | print("Using SynchronizedBatchNorm2d.") 18 | BatchNorm = SynchronizedBatchNorm2d 19 | else: 20 | BatchNorm = nn.BatchNorm2d 21 | 22 | if model == 'resnet18': 23 | # resnet = resnet18(pretrained=True) 24 | resnet = resnet18(pretrained=True, BatchNorm=BatchNorm) # additional codes 25 | self.backbone = nn.Sequential(*list(resnet.children())[0:8]) 26 | elif model == 'resnet50': 27 | # resnet = resnet50(pretrained=True) 28 | resnet = resnet50(pretrained=True, BatchNorm=BatchNorm) # additional codes 29 | self.backbone = nn.Sequential(*list(resnet.children())[0:8]) 30 | self.adjust_dim = nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0, bias=False) 31 | self.bn256 = nn.BatchNorm2d(256) 32 | elif model == 'resnet101': 33 | # resnet = resnet101(pretrained=True) 34 | resnet = resnet101(pretrained=True, BatchNorm=BatchNorm) # additional codes 35 | self.backbone = nn.Sequential(*list(resnet.children())[0:8]) 36 | self.adjust_dim = nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0, bias=False) 37 | self.bn256 = nn.BatchNorm2d(256) 38 | else: 39 | raise NotImplementedError 40 | 41 | def forward(self, x): 42 | 43 | if self.model == 'resnet18': 44 | x = self.backbone(x) 45 | elif self.model == 'resnet50' or self.model == 'resnet101': 46 | x = self.backbone(x) 47 | x = self.adjust_dim(x) 48 | x = self.bn256(x) 49 | 50 | return x 51 | -------------------------------------------------------------------------------- /modeling/sync_batchnorm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .replicate import DataParallelWithCallback, patch_replication_callback -------------------------------------------------------------------------------- /modeling/sync_batchnorm/batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import collections 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | 16 | from torch.nn.modules.batchnorm import _BatchNorm 17 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast 18 | 19 | from .comm import SyncMaster 20 | 21 | __all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d'] 22 | 23 | 24 | def _sum_ft(tensor): 25 | """sum over the first and last dimention""" 26 | return tensor.sum(dim=0).sum(dim=-1) 27 | 28 | 29 | def _unsqueeze_ft(tensor): 30 | """add new dementions at the front and the tail""" 31 | return tensor.unsqueeze(0).unsqueeze(-1) 32 | 33 | 34 | _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) 35 | _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) 36 | 37 | 38 | class _SynchronizedBatchNorm(_BatchNorm): 39 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True): 40 | super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) 41 | 42 | self._sync_master = SyncMaster(self._data_parallel_master) 43 | 44 | self._is_parallel = False 45 | self._parallel_id = None 46 | self._slave_pipe = None 47 | 48 | def forward(self, input): 49 | # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. 50 | if not (self._is_parallel and self.training): 51 | return F.batch_norm( 52 | input, self.running_mean, self.running_var, self.weight, self.bias, 53 | self.training, self.momentum, self.eps) 54 | 55 | # Resize the input to (B, C, -1). 56 | input_shape = input.size() 57 | input = input.view(input.size(0), self.num_features, -1) 58 | 59 | # Compute the sum and square-sum. 60 | sum_size = input.size(0) * input.size(2) 61 | input_sum = _sum_ft(input) 62 | input_ssum = _sum_ft(input ** 2) 63 | 64 | # Reduce-and-broadcast the statistics. 65 | if self._parallel_id == 0: 66 | mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) 67 | else: 68 | mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) 69 | 70 | # Compute the output. 71 | if self.affine: 72 | # MJY:: Fuse the multiplication for speed. 73 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) 74 | else: 75 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) 76 | 77 | # Reshape it. 78 | return output.view(input_shape) 79 | 80 | def __data_parallel_replicate__(self, ctx, copy_id): 81 | self._is_parallel = True 82 | self._parallel_id = copy_id 83 | 84 | # parallel_id == 0 means master device. 85 | if self._parallel_id == 0: 86 | ctx.sync_master = self._sync_master 87 | else: 88 | self._slave_pipe = ctx.sync_master.register_slave(copy_id) 89 | 90 | def _data_parallel_master(self, intermediates): 91 | """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" 92 | 93 | # Always using same "device order" makes the ReduceAdd operation faster. 94 | # Thanks to:: Tete Xiao (http://tetexiao.com/) 95 | intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) 96 | 97 | to_reduce = [i[1][:2] for i in intermediates] 98 | to_reduce = [j for i in to_reduce for j in i] # flatten 99 | target_gpus = [i[1].sum.get_device() for i in intermediates] 100 | 101 | sum_size = sum([i[1].sum_size for i in intermediates]) 102 | sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) 103 | mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) 104 | 105 | broadcasted = Broadcast.apply(target_gpus, mean, inv_std) 106 | 107 | outputs = [] 108 | for i, rec in enumerate(intermediates): 109 | outputs.append((rec[0], _MasterMessage(*broadcasted[i * 2:i * 2 + 2]))) 110 | 111 | return outputs 112 | 113 | def _compute_mean_std(self, sum_, ssum, size): 114 | """Compute the mean and standard-deviation with sum and square-sum. This method 115 | also maintains the moving average on the master device.""" 116 | assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' 117 | mean = sum_ / size 118 | sumvar = ssum - sum_ * mean 119 | unbias_var = sumvar / (size - 1) 120 | bias_var = sumvar / size 121 | 122 | self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data 123 | self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data 124 | 125 | return mean, bias_var.clamp(self.eps) ** -0.5 126 | 127 | 128 | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): 129 | r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a 130 | mini-batch. 131 | .. math:: 132 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 133 | This module differs from the built-in PyTorch BatchNorm1d as the mean and 134 | standard-deviation are reduced across all devices during training. 135 | For example, when one uses `nn.DataParallel` to wrap the network during 136 | training, PyTorch's implementation normalize the tensor on each device using 137 | the statistics only on that device, which accelerated the computation and 138 | is also easy to implement, but the statistics might be inaccurate. 139 | Instead, in this synchronized version, the statistics will be computed 140 | over all training samples distributed on multiple devices. 141 | 142 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 143 | as the built-in PyTorch implementation. 144 | The mean and standard-deviation are calculated per-dimension over 145 | the mini-batches and gamma and beta are learnable parameter vectors 146 | of size C (where C is the input size). 147 | During training, this layer keeps a running estimate of its computed mean 148 | and variance. The running sum is kept with a default momentum of 0.1. 149 | During evaluation, this running mean/variance is used for normalization. 150 | Because the BatchNorm is done over the `C` dimension, computing statistics 151 | on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm 152 | Args: 153 | num_features: num_features from an expected input of size 154 | `batch_size x num_features [x width]` 155 | eps: a value added to the denominator for numerical stability. 156 | Default: 1e-5 157 | momentum: the value used for the running_mean and running_var 158 | computation. Default: 0.1 159 | affine: a boolean value that when set to ``True``, gives the layer learnable 160 | affine parameters. Default: ``True`` 161 | Shape: 162 | - Input: :math:`(N, C)` or :math:`(N, C, L)` 163 | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) 164 | Examples: 165 | >>> # With Learnable Parameters 166 | >>> m = SynchronizedBatchNorm1d(100) 167 | >>> # Without Learnable Parameters 168 | >>> m = SynchronizedBatchNorm1d(100, affine=False) 169 | >>> input = torch.autograd.Variable(torch.randn(20, 100)) 170 | >>> output = m(input) 171 | """ 172 | 173 | def _check_input_dim(self, input): 174 | if input.dim() != 2 and input.dim() != 3: 175 | raise ValueError('expected 2D or 3D input (got {}D input)' 176 | .format(input.dim())) 177 | super(SynchronizedBatchNorm1d, self)._check_input_dim(input) 178 | 179 | 180 | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): 181 | r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch 182 | of 3d inputs 183 | .. math:: 184 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 185 | This module differs from the built-in PyTorch BatchNorm2d as the mean and 186 | standard-deviation are reduced across all devices during training. 187 | For example, when one uses `nn.DataParallel` to wrap the network during 188 | training, PyTorch's implementation normalize the tensor on each device using 189 | the statistics only on that device, which accelerated the computation and 190 | is also easy to implement, but the statistics might be inaccurate. 191 | Instead, in this synchronized version, the statistics will be computed 192 | over all training samples distributed on multiple devices. 193 | 194 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 195 | as the built-in PyTorch implementation. 196 | The mean and standard-deviation are calculated per-dimension over 197 | the mini-batches and gamma and beta are learnable parameter vectors 198 | of size C (where C is the input size). 199 | During training, this layer keeps a running estimate of its computed mean 200 | and variance. The running sum is kept with a default momentum of 0.1. 201 | During evaluation, this running mean/variance is used for normalization. 202 | Because the BatchNorm is done over the `C` dimension, computing statistics 203 | on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm 204 | Args: 205 | num_features: num_features from an expected input of 206 | size batch_size x num_features x height x width 207 | eps: a value added to the denominator for numerical stability. 208 | Default: 1e-5 209 | momentum: the value used for the running_mean and running_var 210 | computation. Default: 0.1 211 | affine: a boolean value that when set to ``True``, gives the layer learnable 212 | affine parameters. Default: ``True`` 213 | Shape: 214 | - Input: :math:`(N, C, H, W)` 215 | - Output: :math:`(N, C, H, W)` (same shape as input) 216 | Examples: 217 | >>> # With Learnable Parameters 218 | >>> m = SynchronizedBatchNorm2d(100) 219 | >>> # Without Learnable Parameters 220 | >>> m = SynchronizedBatchNorm2d(100, affine=False) 221 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) 222 | >>> output = m(input) 223 | """ 224 | 225 | def _check_input_dim(self, input): 226 | if input.dim() != 4: 227 | raise ValueError('expected 4D input (got {}D input)' 228 | .format(input.dim())) 229 | super(SynchronizedBatchNorm2d, self)._check_input_dim(input) 230 | 231 | 232 | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): 233 | r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch 234 | of 4d inputs 235 | .. math:: 236 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 237 | This module differs from the built-in PyTorch BatchNorm3d as the mean and 238 | standard-deviation are reduced across all devices during training. 239 | For example, when one uses `nn.DataParallel` to wrap the network during 240 | training, PyTorch's implementation normalize the tensor on each device using 241 | the statistics only on that device, which accelerated the computation and 242 | is also easy to implement, but the statistics might be inaccurate. 243 | Instead, in this synchronized version, the statistics will be computed 244 | over all training samples distributed on multiple devices. 245 | 246 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 247 | as the built-in PyTorch implementation. 248 | The mean and standard-deviation are calculated per-dimension over 249 | the mini-batches and gamma and beta are learnable parameter vectors 250 | of size C (where C is the input size). 251 | During training, this layer keeps a running estimate of its computed mean 252 | and variance. The running sum is kept with a default momentum of 0.1. 253 | During evaluation, this running mean/variance is used for normalization. 254 | Because the BatchNorm is done over the `C` dimension, computing statistics 255 | on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm 256 | or Spatio-temporal BatchNorm 257 | Args: 258 | num_features: num_features from an expected input of 259 | size batch_size x num_features x depth x height x width 260 | eps: a value added to the denominator for numerical stability. 261 | Default: 1e-5 262 | momentum: the value used for the running_mean and running_var 263 | computation. Default: 0.1 264 | affine: a boolean value that when set to ``True``, gives the layer learnable 265 | affine parameters. Default: ``True`` 266 | Shape: 267 | - Input: :math:`(N, C, D, H, W)` 268 | - Output: :math:`(N, C, D, H, W)` (same shape as input) 269 | Examples: 270 | >>> # With Learnable Parameters 271 | >>> m = SynchronizedBatchNorm3d(100) 272 | >>> # Without Learnable Parameters 273 | >>> m = SynchronizedBatchNorm3d(100, affine=False) 274 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) 275 | >>> output = m(input) 276 | """ 277 | 278 | def _check_input_dim(self, input): 279 | if input.dim() != 5: 280 | raise ValueError('expected 5D input (got {}D input)' 281 | .format(input.dim())) 282 | super(SynchronizedBatchNorm3d, self)._check_input_dim(input) -------------------------------------------------------------------------------- /modeling/sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 59 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 60 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 61 | and passed to a registered callback. 62 | - After receiving the messages, the master device should gather the information and determine to message passed 63 | back to each slave devices. 64 | """ 65 | 66 | def __init__(self, master_callback): 67 | """ 68 | Args: 69 | master_callback: a callback to be invoked after having collected messages from slave devices. 70 | """ 71 | self._master_callback = master_callback 72 | self._queue = queue.Queue() 73 | self._registry = collections.OrderedDict() 74 | self._activated = False 75 | 76 | def __getstate__(self): 77 | return {'master_callback': self._master_callback} 78 | 79 | def __setstate__(self, state): 80 | self.__init__(state['master_callback']) 81 | 82 | def register_slave(self, identifier): 83 | """ 84 | Register an slave device. 85 | Args: 86 | identifier: an identifier, usually is the device id. 87 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 88 | """ 89 | if self._activated: 90 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 91 | self._activated = False 92 | self._registry.clear() 93 | future = FutureResult() 94 | self._registry[identifier] = _MasterRegistry(future) 95 | return SlavePipe(identifier, self._queue, future) 96 | 97 | def run_master(self, master_msg): 98 | """ 99 | Main entry for the master device in each forward pass. 100 | The messages were first collected from each devices (including the master device), and then 101 | an callback will be invoked to compute the message to be sent back to each devices 102 | (including the master device). 103 | Args: 104 | master_msg: the message that the master want to send to itself. This will be placed as the first 105 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 106 | Returns: the message to be sent back to the master device. 107 | """ 108 | self._activated = True 109 | 110 | intermediates = [(0, master_msg)] 111 | for i in range(self.nr_slaves): 112 | intermediates.append(self._queue.get()) 113 | 114 | results = self._master_callback(intermediates) 115 | assert results[0][0] == 0, 'The first result should belongs to the master.' 116 | 117 | for i, res in results: 118 | if i == 0: 119 | continue 120 | self._registry[i].result.put(res) 121 | 122 | for i in range(self.nr_slaves): 123 | assert self._queue.get() is True 124 | 125 | return results[0][1] 126 | 127 | @property 128 | def nr_slaves(self): 129 | return len(self._registry) 130 | -------------------------------------------------------------------------------- /modeling/sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 31 | Note that, as all modules are isomorphism, we assign each sub-module with a context 32 | (shared among multiple copies of this module on different devices). 33 | Through this context, different copies can share some information. 34 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 35 | of any slave copies. 36 | """ 37 | master_copy = modules[0] 38 | nr_modules = len(list(master_copy.modules())) 39 | ctxs = [CallbackContext() for _ in range(nr_modules)] 40 | 41 | for i, module in enumerate(modules): 42 | for j, m in enumerate(module.modules()): 43 | if hasattr(m, '__data_parallel_replicate__'): 44 | m.__data_parallel_replicate__(ctxs[j], i) 45 | 46 | 47 | class DataParallelWithCallback(DataParallel): 48 | """ 49 | Data Parallel with a replication callback. 50 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 51 | original `replicate` function. 52 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 53 | Examples: 54 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 55 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 56 | # sync_bn.__data_parallel_replicate__ will be invoked. 57 | """ 58 | 59 | def replicate(self, module, device_ids): 60 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 61 | execute_replication_callbacks(modules) 62 | return modules 63 | 64 | 65 | def patch_replication_callback(data_parallel): 66 | """ 67 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 68 | Useful when you have customized `DataParallel` implementation. 69 | Examples: 70 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 71 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 72 | > patch_replication_callback(sync_bn) 73 | # this is equivalent to 74 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 75 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 76 | """ 77 | 78 | assert isinstance(data_parallel, DataParallel) 79 | 80 | old_replicate = data_parallel.replicate 81 | 82 | @functools.wraps(old_replicate) 83 | def new_replicate(module, device_ids): 84 | modules = old_replicate(module, device_ids) 85 | execute_replication_callbacks(modules) 86 | return modules 87 | 88 | data_parallel.replicate = new_replicate -------------------------------------------------------------------------------- /modeling/sync_batchnorm/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | 13 | import numpy as np 14 | from torch.autograd import Variable 15 | 16 | 17 | def as_numpy(v): 18 | if isinstance(v, Variable): 19 | v = v.data 20 | return v.cpu().numpy() 21 | 22 | 23 | class TorchTestCase(unittest.TestCase): 24 | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3): 25 | npa, npb = as_numpy(a), as_numpy(b) 26 | self.assertTrue( 27 | np.allclose(npa, npb, atol=atol), 28 | 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) 29 | ) 30 | -------------------------------------------------------------------------------- /test_epochs.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Written by Zhipeng Zhang (zhangzhipeng2017@ia.ac.cn) 5 | # multi-gpu test for epochs 6 | # ------------------------------------------------------------------------------ 7 | 8 | import os 9 | import time 10 | import argparse 11 | from mpi4py import MPI 12 | 13 | 14 | parser = argparse.ArgumentParser(description='multi-gpu test all epochs') 15 | parser.add_argument('--start_epoch', default=160, type=int, help='start epoch') 16 | parser.add_argument('--end_epoch', default=240, type=int, help='end epoch') 17 | parser.add_argument('--gpu_nums', default=8, type=int, help='test start epoch') 18 | parser.add_argument('--threads', default=16, type=int) 19 | args = parser.parse_args() 20 | 21 | # init gpu and epochs 22 | comm = MPI.COMM_WORLD 23 | size = comm.Get_size() 24 | rank = comm.Get_rank() 25 | GPU_ID = rank % args.gpu_nums 26 | node_name = MPI.Get_processor_name() # get the name of the node 27 | os.environ['CUDA_VISIBLE_DEVICES'] = str(GPU_ID) 28 | print("node name: {}, GPU_ID: {}".format(node_name, GPU_ID)) 29 | time.sleep(rank * 5) 30 | 31 | # run test scripts -- two epoch for each thread 32 | for i in range(2): 33 | try: 34 | epoch_ID += args.threads // 2 * 5 # for 16 queue 35 | except: 36 | epoch_ID = 5 * rank % (args.end_epoch - args.start_epoch + 1) + args.start_epoch 37 | 38 | if epoch_ID > args.end_epoch: 39 | continue 40 | 41 | resume = 'checkpoints/checkpoint-epoch-{}.pth.tar'.format(epoch_ID) 42 | print('==> test {}th epoch'.format(epoch_ID)) 43 | 44 | save_path = os.path.join('results/ck{}'.format(epoch_ID)) 45 | os.system('python inference.py -r {} -s {}'.format(resume, save_path)) 46 | -------------------------------------------------------------------------------- /test_epochs.sh: -------------------------------------------------------------------------------- 1 | mpiexec -n 16 python test_epochs.py 2 | --------------------------------------------------------------------------------