├── .gitignore ├── Adversarial ├── ColorFool.py ├── Samples │ ├── ILSVRC2012_val_00003533_alexnet.png │ ├── ILSVRC2012_val_00003533_resnet18.png │ └── ILSVRC2012_val_00003533_resnet50.png ├── misc_functions.py └── script.sh ├── ColorFool.gif ├── Dataset └── ILSVRC2012_val_00003533.JPEG ├── License.txt ├── README.md ├── Sample_results ├── ILSVRC2012_val_00003533_alexnet.png ├── ILSVRC2012_val_00003533_resnet18.png └── ILSVRC2012_val_00003533_resnet50.png ├── Segmentation ├── SemanticMasks.py ├── data │ ├── ADE20K_object150_train.txt │ ├── ADE20K_object150_val.txt │ ├── color150.mat │ ├── object150_info.csv │ ├── train.odgt │ └── validation.odgt ├── dataset.py ├── lib │ ├── __init__.py │ ├── nn │ │ ├── __init__.py │ │ ├── modules │ │ │ ├── __init__.py │ │ │ ├── batchnorm.py │ │ │ ├── comm.py │ │ │ ├── replicate.py │ │ │ ├── tests │ │ │ │ ├── test_numeric_batchnorm.py │ │ │ │ └── test_sync_batchnorm.py │ │ │ └── unittest.py │ │ └── parallel │ │ │ ├── __init__.py │ │ │ └── data_parallel.py │ └── utils │ │ ├── __init__.py │ │ ├── data │ │ ├── __init__.py │ │ ├── dataloader.py │ │ ├── dataset.py │ │ ├── distributed.py │ │ └── sampler.py │ │ └── th.py ├── models │ ├── __init__.py │ ├── mobilenet.py │ ├── models.py │ ├── resnet.py │ └── resnext.py ├── script.sh └── utils.py ├── TutorialDemoColorFool ├── ColorFool.ipynb ├── Image │ └── ILSVRC2012_val_00003533.JPEG └── Masks │ ├── Person │ └── ILSVRC2012_val_00003533.JPEG │ ├── Sky │ └── ILSVRC2012_val_00003533.JPEG │ ├── Vegetation │ └── ILSVRC2012_val_00003533.JPEG │ └── Water │ └── ILSVRC2012_val_00003533.JPEG └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Latex files 2 | 3 | *.aux 4 | *.glo 5 | *.idx 6 | *.log 7 | *.toc 8 | *.ist 9 | *.acn 10 | *.acr 11 | *.alg 12 | *.bbl 13 | *.blg 14 | *.tui 15 | *.top 16 | *.tmp 17 | *.mp 18 | *.dvi 19 | *.glg 20 | *.gls 21 | *.ilg 22 | *.ind 23 | *.lof 24 | *.lot 25 | *.maf 26 | *.mtc 27 | *.mtc1 28 | *.out 29 | *.gz 30 | *.pyc 31 | 32 | # Mac IDE files 33 | *.swp 34 | *~ 35 | *(Autosaved).rtfd/ 36 | Backup[ ]of[ ]*.pages/ 37 | Backup[ ]of[ ]*.key/ 38 | Backup[ ]of[ ]*.numbers/ 39 | 40 | # Mac finder files and hidden folders 41 | .DS_Store 42 | *.fdb_latexmk 43 | 44 | *.fls 45 | 46 | *.sublime-workspace 47 | paper.pdf 48 | changelog.txt 49 | *.sublime-project 50 | -------------------------------------------------------------------------------- /Adversarial/ColorFool.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import cv2 4 | import scipy 5 | from skimage import io, color 6 | 7 | import torch 8 | from torch import nn 9 | from torch.autograd import Variable 10 | from torch.nn import functional as F 11 | 12 | import glob, os 13 | from os.path import join,isfile 14 | 15 | from os import listdir 16 | 17 | from PIL import Image 18 | from tqdm import tqdm 19 | from torchvision import models 20 | from numpy import pi 21 | from numpy import sin 22 | from numpy import zeros 23 | from numpy import r_ 24 | 25 | from scipy import signal 26 | from scipy import misc 27 | import torchvision.transforms as T 28 | from skimage.filters import rank 29 | from skimage.morphology import disk 30 | 31 | import argparse 32 | import pdb 33 | from copy import copy as copy 34 | 35 | 36 | from misc_functions import prepareImageMasks, initialise, createLogFiles, createDirectories 37 | 38 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 39 | 40 | selem = disk(20) 41 | 42 | # Normalization values for ImageNet 43 | trf = T.Compose([T.ToPILImage(), 44 | T.ToTensor(), 45 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 46 | 47 | class attack(): 48 | 49 | def __init__(self, model, args): 50 | 51 | self.model = model 52 | 53 | # Create the folder to export adversarial images if not exists 54 | self.adv_path = createDirectories(args) 55 | 56 | def generate(self, original_image, sky_mask, water_mask, green_mask, person_mask, img_name, org_class, args): 57 | 58 | misclassified=0 59 | maxTrials = 1000 60 | 61 | # Transfer the clea image from RGB to Lab color space 62 | original_image_lab=color.rgb2lab(original_image) 63 | 64 | # Start iteration 65 | for trial in range(maxTrials): 66 | 67 | X_lab = original_image_lab.copy() 68 | 69 | margin = 127 70 | mult = float(trial+1) / float(maxTrials) 71 | 72 | # Adversarial color perturbation for Water regions 73 | water_mask_binary = copy(water_mask) 74 | water_mask_binary[water_mask_binary>0] = 1 75 | water = X_lab[water_mask_binary == 1] 76 | if water.size != 0: 77 | a_min = water[:,1].min() 78 | a_max = np.clip(water[:,1].max(), a_min=None, a_max = 0) 79 | b_min = water[:,2].min() 80 | b_max = np.clip(water[:,2].max(), a_min=None, a_max = 0) 81 | a_blue = np.full((X_lab.shape[0], X_lab.shape[1]), np.random.uniform(mult*(-margin-a_min), mult*(-a_max), size=(1))) * water_mask 82 | b_blue = np.full((X_lab.shape[0], X_lab.shape[1]), np.random.uniform(mult*(-margin-b_min), mult*(-b_max), size=(1))) * water_mask 83 | else: 84 | a_blue = np.full((X_lab.shape[0], X_lab.shape[1]), 0.) 85 | b_blue = np.full((X_lab.shape[0], X_lab.shape[1]), 0.) 86 | 87 | # Adversarial color perturbation for Vegetation regions 88 | green_mask_binary = copy(green_mask) 89 | green_mask_binary[green_mask_binary>0] = 1 90 | green = X_lab[green_mask_binary == 1] 91 | if green.size != 0: 92 | a_min = green[:,1].min() 93 | a_max = np.clip(green[:,1].max(), a_min=None, a_max = 0) 94 | b_min = np.clip(green[:,2].min(), a_min=0, a_max = None) 95 | b_max = green[:,2].max() 96 | a_green = np.full((X_lab.shape[0], X_lab.shape[1]), np.random.uniform(mult*(-margin-a_min), mult*(-a_max), size=(1))) * green_mask 97 | b_green = np.full((X_lab.shape[0], X_lab.shape[1]), np.random.uniform(mult*(-b_min), mult*(margin-b_max), size=(1))) * green_mask 98 | else: 99 | a_green = np.full((X_lab.shape[0], X_lab.shape[1]), 0.) 100 | b_green = np.full((X_lab.shape[0], X_lab.shape[1]), 0.) 101 | 102 | # Adversarial color perturbation for Sky regions 103 | sky_mask_binary = copy(sky_mask) 104 | sky_mask_binary[sky_mask_binary>0] = 1 105 | sky = X_lab[sky_mask_binary == 1] 106 | if sky.size != 0: 107 | a_min = sky[:,1].min() 108 | a_max = np.clip(sky[:,1].max(), a_min=None, a_max = 0) 109 | b_min = sky[:,2].min() 110 | b_max = np.clip(sky[:,2].max(), a_min=None, a_max = 0) 111 | a_sky = np.full((X_lab.shape[0], X_lab.shape[1]), np.random.uniform(mult*(-margin-a_min), mult*(-a_max), size=(1))) * sky_mask 112 | b_sky = np.full((X_lab.shape[0], X_lab.shape[1]), np.random.uniform(mult*(-margin-b_min), mult*(-b_max), size=(1))) * sky_mask 113 | else: 114 | a_sky = np.full((X_lab.shape[0], X_lab.shape[1]), 0.) 115 | b_sky = np.full((X_lab.shape[0], X_lab.shape[1]), 0.) 116 | 117 | 118 | mask = (person_mask + water_mask + green_mask + sky_mask) 119 | mask[mask>1] = 1 120 | 121 | # Smooth boundaries between sensitive regions 122 | kernel = np.ones((5, 5), np.uint8) 123 | mask = cv2.blur(mask,(10,10)) 124 | 125 | # Adversarial color perturbation for non-sensitive regions 126 | random_mask = 1 - mask 127 | a_random = np.full((X_lab.shape[0],X_lab.shape[1]), np.random.uniform(mult*(-margin), mult*(margin), size=(1))) 128 | b_random = np.full((X_lab.shape[0],X_lab.shape[1]), np.random.uniform(mult*(-margin), mult*(margin), size=(1))) 129 | a_random_mask = a_random * random_mask 130 | b_random_mask = b_random * random_mask 131 | 132 | 133 | # Adversarialy perturb color (i.e. a and b channels in the Lab color space) of the clean image 134 | noise_mask = np.zeros((X_lab.shape), dtype=float) 135 | noise_mask[:,:,1] = a_blue + a_green + a_sky + a_random_mask 136 | noise_mask[:,:,2] = b_blue + b_green + b_sky + b_random_mask 137 | X_lab_mask = np.zeros((X_lab.shape), dtype=float) 138 | X_lab_mask [:,:,0] = X_lab [:,:,0] 139 | X_lab_mask [:,:,1] = np.clip(X_lab [:,:,1] + noise_mask[:,:,1], -margin, margin) 140 | X_lab_mask [:,:,2] = np.clip(X_lab [:,:,2] + noise_mask[:,:,2], -margin, margin) 141 | 142 | # Transfer from LAB to RGB 143 | X_rgb_mask = np.uint8(color.lab2rgb(X_lab_mask)*255.) 144 | 145 | # Predict the label of the adversarial image 146 | logit = model(trf(cv2.resize(X_rgb_mask, (224, 224), interpolation=cv2.INTER_LINEAR)).to(device).unsqueeze_(0)) 147 | h_x = F.softmax(logit).data.squeeze() 148 | probs, idx = h_x.sort(0, True) 149 | 150 | current_class = idx[0] 151 | current_class_prob = probs[0] 152 | org_class_prob = h_x[org_class] 153 | 154 | # Check if the generated adversarial image misleads the model 155 | if (current_class != org_class): 156 | misclassified=1 157 | break 158 | 159 | # Transfer the adversarial image from RGB to BGR to save with opencv 160 | X_bgr = X_rgb_mask[:, :, (2, 1, 0)] 161 | cv2.imwrite('{}/{}.png'.format(self.adv_path, img_name.split('.')[0]), X_bgr) 162 | return misclassified, trial, current_class, current_class_prob 163 | 164 | 165 | if __name__ == '__main__': 166 | 167 | # Parse arguments 168 | parser = argparse.ArgumentParser() 169 | parser.add_argument('--model', type=str, required=True) 170 | parser.add_argument('--dataset', type=str, default='../Dataset/') 171 | args = parser.parse_args() 172 | 173 | # Initialization. Load model under atack, path of the dataset and list of all clean images inside it 174 | model, image_list = initialise(args) 175 | 176 | # Log files to save numerical results 177 | f1, f1_name = createLogFiles(args) 178 | 179 | # Number of successful adversarial images 180 | misleads=0 181 | 182 | # Generate adversarial images for all clean images in the image_list 183 | NumImg=len(image_list) 184 | for idx in tqdm(range(NumImg)): 185 | 186 | # Load clean image and predict the lable using the model 187 | original_image, sky_mask, water_mask, grass_mask, person_mask, img_name, org_class, org_class_prob = prepareImageMasks(args, image_list, idx, model) 188 | 189 | f1 = open(f1_name, 'a+') 190 | 191 | # Perform the ColorFool attack 192 | LAB = attack(model, args) 193 | mislead, numTrials, current_class, current_class_prob= LAB.generate(original_image, sky_mask, water_mask, grass_mask, person_mask, img_name, org_class, args) 194 | misleads += mislead 195 | text = '{}\t{}\t{}\t{:.5f}\t{}\t{:.5f}\n'.format(img_name, numTrials+1, org_class, org_class_prob, current_class, current_class_prob) 196 | 197 | f1.write(text) 198 | f1.close() 199 | print('Success rate {:.1f}%'.format(100*float(misleads) / (NumImg)) ) 200 | -------------------------------------------------------------------------------- /Adversarial/Samples/ILSVRC2012_val_00003533_alexnet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smartcameras/ColorFool/a2ccae4db821e304fb54090c332545b71046ea22/Adversarial/Samples/ILSVRC2012_val_00003533_alexnet.png -------------------------------------------------------------------------------- /Adversarial/Samples/ILSVRC2012_val_00003533_resnet18.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smartcameras/ColorFool/a2ccae4db821e304fb54090c332545b71046ea22/Adversarial/Samples/ILSVRC2012_val_00003533_resnet18.png -------------------------------------------------------------------------------- /Adversarial/Samples/ILSVRC2012_val_00003533_resnet50.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smartcameras/ColorFool/a2ccae4db821e304fb54090c332545b71046ea22/Adversarial/Samples/ILSVRC2012_val_00003533_resnet50.png -------------------------------------------------------------------------------- /Adversarial/misc_functions.py: -------------------------------------------------------------------------------- 1 | import fnmatch 2 | import cv2 3 | import numpy as np 4 | from skimage import io, color 5 | import csv 6 | import os 7 | from os import listdir 8 | from os.path import isfile,join 9 | import torch 10 | import torchvision 11 | from torch.autograd import Variable 12 | from torchvision import models 13 | import torchvision.transforms as T 14 | from torch.nn import functional as F 15 | from torch.autograd import Variable as V 16 | import torch.nn as nn 17 | import scipy.sparse 18 | 19 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 20 | 21 | def initialise(args): 22 | 23 | image_list = [f for f in listdir(args.dataset) if isfile(join(args.dataset,f))] #Format of images for ImageNet is .JPEG, change for other dataset 24 | 25 | # Load model 26 | if args.model == 'resnet18': 27 | model = models.resnet18(pretrained=True) 28 | elif args.model == 'resnet50': 29 | model = models.resnet50(pretrained=True) 30 | elif args.model == 'alexnet': 31 | model = models.alexnet(pretrained=True) 32 | model.eval() 33 | model.to(device) 34 | return model, image_list 35 | 36 | 37 | def createLogFiles(args): 38 | log_path = 'Results/Logs/' 39 | 40 | if not os.path.exists(log_path): 41 | os.makedirs(log_path) 42 | f1_name = log_path+'log_{}.txt'.format(args.model) 43 | f1 = open(f1_name,"w") 44 | return f1, f1_name 45 | 46 | 47 | def createDirectories(args): 48 | main_path = 'Results/ColorFoolImgs/' 49 | adv_path = main_path+ 'adv_{}'.format(args.model) 50 | 51 | if not os.path.exists(adv_path): 52 | os.makedirs(adv_path) 53 | 54 | return adv_path 55 | 56 | #for ImageNet the mean and std are: 57 | mean = np.asarray([ 0.485, 0.456, 0.406 ]) 58 | std = np.asarray([ 0.229, 0.224, 0.225 ]) 59 | 60 | trf = T.Compose([T.ToPILImage(), 61 | T.ToTensor(), 62 | T.Normalize(mean=mean, std=std)]) 63 | 64 | def prepareImageMasks(args, image_list, index, model): 65 | 66 | # Paths to segmentation outputs done in the prior step 67 | sky_mask_path = '../Segmentation/SegmentationResults/sky/' 68 | water_mask_path = '../Segmentation/SegmentationResults/water/' 69 | grass_mask_path = '../Segmentation/SegmentationResults/grass/' 70 | person_mask_path = '../Segmentation/SegmentationResults/person/' 71 | 72 | # Read images 73 | img_name = image_list[index] 74 | 75 | # Load the clean image with its four corresponding masks that represent Sky, Person, Vegetation and Water 76 | original_image = cv2.imread(args.dataset+img_name, 1) 77 | person_mask = cv2.imread('{}.png'.format(person_mask_path+img_name.split('.')[0]), cv2.COLOR_BGR2GRAY) / 255. 78 | water_mask = cv2.imread('{}.png'.format(water_mask_path+img_name.split('.')[0]), cv2.COLOR_BGR2GRAY) / 255. 79 | grass_mask = cv2.imread('{}.png'.format(grass_mask_path+img_name.split('.')[0]), cv2.COLOR_BGR2GRAY) / 255. 80 | sky_mask = cv2.imread('{}.png'.format(sky_mask_path+img_name.split('.')[0]), cv2.COLOR_BGR2GRAY) / 255. 81 | 82 | 83 | # Have RGB images 84 | original_image = original_image[:, :, (2, 1, 0)] 85 | 86 | # Resize image to the input size of the model 87 | image = cv2.resize(original_image, (224, 224), interpolation=cv2.INTER_LINEAR) 88 | # forward pass 89 | logit = model.forward(trf(image).cuda().unsqueeze_(0)) 90 | h_x = F.softmax(logit).data.squeeze() 91 | probs, idx = h_x.sort(0, True) 92 | 93 | probs = np.array(probs.cpu()) 94 | idx = np.array(idx.cpu()) 95 | 96 | org_class= idx[0] 97 | org_class_prob = probs[0] 98 | 99 | return original_image, sky_mask, water_mask, grass_mask, person_mask, img_name, org_class, org_class_prob 100 | 101 | -------------------------------------------------------------------------------- /Adversarial/script.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | MODELS=(alexnet resnet18 resnet50) 4 | 5 | clear 6 | for model in "${MODELS[@]}" 7 | do 8 | 9 | echo ColorFool attacking $model 10 | python -W ignore ColorFool.py --model=$model 11 | 12 | done 13 | -------------------------------------------------------------------------------- /ColorFool.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smartcameras/ColorFool/a2ccae4db821e304fb54090c332545b71046ea22/ColorFool.gif -------------------------------------------------------------------------------- /Dataset/ILSVRC2012_val_00003533.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smartcameras/ColorFool/a2ccae4db821e304fb54090c332545b71046ea22/Dataset/ILSVRC2012_val_00003533.JPEG -------------------------------------------------------------------------------- /License.txt: -------------------------------------------------------------------------------- 1 | # License 2 | This work is licensed under the Creative Commons Attribution-NonCommercial 4.0 International License. To view a copy of this license, visit http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 3 | 4 | Creative Commons Legal Code 5 | Attribution-NonCommercial 4.0 International (CC BY-NC 4.0) 6 | 7 | Creative Commons Corporation (“Creative Commons”) is not a law firm and does not provide legal services or legal advice. Distribution of Creative Commons public licenses does not create a lawyer-client or other relationship. Creative Commons makes its licenses and related information available on an “as-is” basis. Creative Commons gives no warranties regarding its licenses, any material licensed under their terms and conditions, or any related information. Creative Commons disclaims all liability for damages resulting from their use to the fullest extent possible. 8 | Creative Commons Attribution-NonCommercial 4.0 International Public License 9 | By exercising the Licensed Rights (defined below), You accept and agree to be bound by the terms and conditions of this Creative Commons Attribution-NonCommercial 4.0 International Public License ("Public License"). To the extent this Public License may be interpreted as a contract, You are granted the Licensed Rights in consideration of Your acceptance of these terms and conditions, and the Licensor grants You such rights in consideration of benefits the Licensor receives from making the Licensed Material available under these terms and conditions. 10 | Section 1 – Definitions. 11 | a. Adapted Material means material subject to Copyright and Similar Rights that is derived from or based upon the Licensed Material and in which the Licensed Material is translated, altered, arranged, transformed, or otherwise modified in a manner requiring permission under the Copyright and Similar Rights held by the Licensor. For purposes of this Public License, where the Licensed Material is a musical work, performance, or sound recording, Adapted Material is always produced where the Licensed Material is synched in timed relation with a moving image. 12 | b. Adapter's License means the license You apply to Your Copyright and Similar Rights in Your contributions to Adapted Material in accordance with the terms and conditions of this Public License. 13 | c. Copyright and Similar Rights means copyright and/or similar rights closely related to copyright including, without limitation, performance, broadcast, sound recording, and Sui Generis Database Rights, without regard to how the rights are labeled or categorized. For purposes of this Public License, the rights specified in Section 2(b)(1)-(2) are not Copyright and Similar Rights. 14 | d. Effective Technological Measures means those measures that, in the absence of proper authority, may not be circumvented under laws fulfilling obligations under Article 11 of the WIPO Copyright Treaty adopted on December 20, 1996, and/or similar international agreements. 15 | e. Exceptions and Limitations means fair use, fair dealing, and/or any other exception or limitation to Copyright and Similar Rights that applies to Your use of the Licensed Material. 16 | f. Licensed Material means the artistic or literary work, database, or other material to which the Licensor applied this Public License. 17 | g. Licensed Rights means the rights granted to You subject to the terms and conditions of this Public License, which are limited to all Copyright and Similar Rights that apply to Your use of the Licensed Material and that the Licensor has authority to license. 18 | h. Licensor means the individual(s) or entity(ies) granting rights under this Public License. 19 | i. NonCommercial means not primarily intended for or directed towards commercial advantage or monetary compensation. For purposes of this Public License, the exchange of the Licensed Material for other material subject to Copyright and Similar Rights by digital file-sharing or similar means is NonCommercial provided there is no payment of monetary compensation in connection with the exchange. 20 | j. Share means to provide material to the public by any means or process that requires permission under the Licensed Rights, such as reproduction, public display, public performance, distribution, dissemination, communication, or importation, and to make material available to the public including in ways that members of the public may access the material from a place and at a time individually chosen by them. 21 | k. Sui Generis Database Rights means rights other than copyright resulting from Directive 96/9/EC of the European Parliament and of the Council of 11 March 1996 on the legal protection of databases, as amended and/or succeeded, as well as other essentially equivalent rights anywhere in the world. 22 | l. You means the individual or entity exercising the Licensed Rights under this Public License. Your has a corresponding meaning. 23 | Section 2 – Scope. 24 | a. License grant. 25 | 1. Subject to the terms and conditions of this Public License, the Licensor hereby grants You a worldwide, royalty-free, non-sublicensable, non-exclusive, irrevocable license to exercise the Licensed Rights in the Licensed Material to: 26 | A. reproduce and Share the Licensed Material, in whole or in part, for NonCommercial purposes only; and 27 | B. produce, reproduce, and Share Adapted Material for NonCommercial purposes only. 28 | 2. Exceptions and Limitations. For the avoidance of doubt, where Exceptions and Limitations apply to Your use, this Public License does not apply, and You do not need to comply with its terms and conditions. 29 | 3. Term. The term of this Public License is specified in Section 6(a). 30 | 4. Media and formats; technical modifications allowed. The Licensor authorizes You to exercise the Licensed Rights in all media and formats whether now known or hereafter created, and to make technical modifications necessary to do so. The Licensor waives and/or agrees not to assert any right or authority to forbid You from making technical modifications necessary to exercise the Licensed Rights, including technical modifications necessary to circumvent Effective Technological Measures. For purposes of this Public License, simply making modifications authorized by this Section 2(a)(4) never produces Adapted Material. 31 | 5. Downstream recipients. 32 | A. Offer from the Licensor – Licensed Material. Every recipient of the Licensed Material automatically receives an offer from the Licensor to exercise the Licensed Rights under the terms and conditions of this Public License. 33 | B. No downstream restrictions. You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, the Licensed Material if doing so restricts exercise of the Licensed Rights by any recipient of the Licensed Material. 34 | 6. No endorsement. Nothing in this Public License constitutes or may be construed as permission to assert or imply that You are, or that Your use of the Licensed Material is, connected with, or sponsored, endorsed, or granted official status by, the Licensor or others designated to receive attribution as provided in Section 3(a)(1)(A)(i). 35 | b. Other rights. 36 | 1. Moral rights, such as the right of integrity, are not licensed under this Public License, nor are publicity, privacy, and/or other similar personality rights; however, to the extent possible, the Licensor waives and/or agrees not to assert any such rights held by the Licensor to the limited extent necessary to allow You to exercise the Licensed Rights, but not otherwise. 37 | 2. Patent and trademark rights are not licensed under this Public License. 38 | 3. To the extent possible, the Licensor waives any right to collect royalties from You for the exercise of the Licensed Rights, whether directly or through a collecting society under any voluntary or waivable statutory or compulsory licensing scheme. In all other cases the Licensor expressly reserves any right to collect such royalties, including when the Licensed Material is used other than for NonCommercial purposes. 39 | Section 3 – License Conditions. 40 | Your exercise of the Licensed Rights is expressly made subject to the following conditions. 41 | a. Attribution. 42 | 1. If You Share the Licensed Material (including in modified form), You must: 43 | A. retain the following if it is supplied by the Licensor with the Licensed Material: 44 | i. identification of the creator(s) of the Licensed Material and any others designated to receive attribution, in any reasonable manner requested by the Licensor (including by pseudonym if designated); 45 | ii. a copyright notice; 46 | iii. a notice that refers to this Public License; 47 | iv. a notice that refers to the disclaimer of warranties; 48 | v. a URI or hyperlink to the Licensed Material to the extent reasonably practicable; 49 | B. indicate if You modified the Licensed Material and retain an indication of any previous modifications; and 50 | C. indicate the Licensed Material is licensed under this Public License, and include the text of, or the URI or hyperlink to, this Public License. 51 | 2. You may satisfy the conditions in Section 3(a)(1) in any reasonable manner based on the medium, means, and context in which You Share the Licensed Material. For example, it may be reasonable to satisfy the conditions by providing a URI or hyperlink to a resource that includes the required information. 52 | 3. If requested by the Licensor, You must remove any of the information required by Section 3(a)(1)(A) to the extent reasonably practicable. 53 | 4. If You Share Adapted Material You produce, the Adapter's License You apply must not prevent recipients of the Adapted Material from complying with this Public License. 54 | Section 4 – Sui Generis Database Rights. 55 | Where the Licensed Rights include Sui Generis Database Rights that apply to Your use of the Licensed Material: 56 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right to extract, reuse, reproduce, and Share all or a substantial portion of the contents of the database for NonCommercial purposes only; 57 | b. if You include all or a substantial portion of the database contents in a database in which You have Sui Generis Database Rights, then the database in which You have Sui Generis Database Rights (but not its individual contents) is Adapted Material; and 58 | c. You must comply with the conditions in Section 3(a) if You Share all or a substantial portion of the contents of the database. 59 | For the avoidance of doubt, this Section 4 supplements and does not replace Your obligations under this Public License where the Licensed Rights include other Copyright and Similar Rights. 60 | Section 5 – Disclaimer of Warranties and Limitation of Liability. 61 | a. Unless otherwise separately undertaken by the Licensor, to the extent possible, the Licensor offers the Licensed Material as-is and as-available, and makes no representations or warranties of any kind concerning the Licensed Material, whether express, implied, statutory, or other. This includes, without limitation, warranties of title, merchantability, fitness for a particular purpose, non-infringement, absence of latent or other defects, accuracy, or the presence or absence of errors, whether or not known or discoverable. Where disclaimers of warranties are not allowed in full or in part, this disclaimer may not apply to You. 62 | b. To the extent possible, in no event will the Licensor be liable to You on any legal theory (including, without limitation, negligence) or otherwise for any direct, special, indirect, incidental, consequential, punitive, exemplary, or other losses, costs, expenses, or damages arising out of this Public License or use of the Licensed Material, even if the Licensor has been advised of the possibility of such losses, costs, expenses, or damages. Where a limitation of liability is not allowed in full or in part, this limitation may not apply to You. 63 | c. The disclaimer of warranties and limitation of liability provided above shall be interpreted in a manner that, to the extent possible, most closely approximates an absolute disclaimer and waiver of all liability. 64 | Section 6 – Term and Termination. 65 | a. This Public License applies for the term of the Copyright and Similar Rights licensed here. However, if You fail to comply with this Public License, then Your rights under this Public License terminate automatically. 66 | b. Where Your right to use the Licensed Material has terminated under Section 6(a), it reinstates: 67 | 1. automatically as of the date the violation is cured, provided it is cured within 30 days of Your discovery of the violation; or 68 | 2. upon express reinstatement by the Licensor. 69 | For the avoidance of doubt, this Section 6(b) does not affect any right the Licensor may have to seek remedies for Your violations of this Public License. 70 | c. For the avoidance of doubt, the Licensor may also offer the Licensed Material under separate terms or conditions or stop distributing the Licensed Material at any time; however, doing so will not terminate this Public License. 71 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public License. 72 | Section 7 – Other Terms and Conditions. 73 | a. The Licensor shall not be bound by any additional or different terms or conditions communicated by You unless expressly agreed. 74 | b. Any arrangements, understandings, or agreements regarding the Licensed Material not stated herein are separate from and independent of the terms and conditions of this Public License. 75 | Section 8 – Interpretation. 76 | a. For the avoidance of doubt, this Public License does not, and shall not be interpreted to, reduce, limit, restrict, or impose conditions on any use of the Licensed Material that could lawfully be made without permission under this Public License. 77 | b. To the extent possible, if any provision of this Public License is deemed unenforceable, it shall be automatically reformed to the minimum extent necessary to make it enforceable. If the provision cannot be reformed, it shall be severed from this Public License without affecting the enforceability of the remaining terms and conditions. 78 | c. No term or condition of this Public License will be waived and no failure to comply consented to unless expressly agreed to by the Licensor. 79 | d. Nothing in this Public License constitutes or may be interpreted as a limitation upon, or waiver of, any privileges and immunities that apply to the Licensor or You, including from the legal processes of any jurisdiction or authority. 80 | 81 | 82 | Creative Commons is not a party to its public licenses. Notwithstanding, Creative Commons may elect to apply one of its public licenses to material it publishes and in those instances will be considered the “Licensor.” The text of the Creative Commons public licenses is dedicated to the public domain under the CC0 Public Domain Dedication. Except for the limited purpose of indicating that material is shared under a Creative Commons public license or as otherwise permitted by the Creative Commons policies published at creativecommons.org/policies, Creative Commons does not authorize the use of the trademark “Creative Commons” or any other trademark or logo of Creative Commons without its prior written consent including, without limitation, in connection with any unauthorized modifications to any of its public licenses or any other arrangements, understandings, or agreements concerning use of licensed material. For the avoidance of doubt, this paragraph does not form part of the public licenses. 83 | 84 | Creative Commons may be contacted at creativecommons.org. 85 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ColorFool 2 | 3 | This is the official repository of [ColorFool: Semantic Adversarial Colorization](https://arxiv.org/pdf/1911.10891.pdf), a work published in The IEEE Conference on Computer Vision and Pattern Recognition (CVPR), Seattle, Washington, USA, 14-19 June, 2020.
4 | 5 | ![Alt Text](ColorFool.gif) 6 | Example of results 7 | 8 | | Original Image | Attack AlexNet | Attack ResNet18 | Attack ResNet50 | 9 | |---|---|---|---| 10 | | ![Original Image](Dataset/ILSVRC2012_val_00003533.JPEG) | ![Attack AlexNet](Sample_results/ILSVRC2012_val_00003533_alexnet.png) |![Attack ResNet18](Sample_results/ILSVRC2012_val_00003533_resnet18.png) | ![Attack ResNet50](Sample_results/ILSVRC2012_val_00003533_resnet50.png) | 11 | 12 | 13 | ## Setup 14 | 1. Download source code from GitHub 15 | ``` 16 | git clone https://github.com/smartcameras/ColorFool.git 17 | ``` 18 | 2. Create [conda](https://docs.conda.io/en/latest/miniconda.html) virtual-environment 19 | ``` 20 | conda create --name ColorFool python=3.5.6 21 | ``` 22 | 3. Activate conda environment 23 | ``` 24 | source activate ColorFool 25 | ``` 26 | 4. Install requirements 27 | ``` 28 | pip install -r requirements.txt 29 | ``` 30 | 31 | 32 | ## Description 33 | The code works in two steps: 34 | 1. Identify image regions using semantic segmentation model 35 | 2. Generate adversarial images via perturbing color of semantic regions in the natural color range 36 | 37 | 38 | ### Semantic Segmentation 39 | 40 | 1. Go to Segmentation directory 41 | ``` 42 | cd Segmentation 43 | ``` 44 | 2. Download segmentation model (both encoder and decoder) from [here](https://drive.google.com/drive/folders/1FjZTweIsWWgxhXkzKHyIzEgBO5VTCe68) and locate in "models" directory. 45 | 46 | 47 | 3. Run the segmentation for all images within Dataset directory (requires GPU) 48 | ``` 49 | bash script.sh 50 | ``` 51 | 52 | The semantic regions of four categories will be saved in the Segmentation/SegmentationResults/$Dataset/ directory as a smooth mask the same size of the image with the same name as their corresponding original images 53 | 54 | ### Generate ColorFool Adversarial Images 55 | 56 | 1. Go to Adversarial directory 57 | ``` 58 | cd ../Adversarial 59 | ``` 60 | 2. In the script.sh set 61 | (i) the name of target models for attack, and (ii) the name of the dataset. 62 | The current implementation supports three classifiers (Resnet18, Resnet50 and Alexnet) trained with ImageNet. 63 | 3. Run ColorFool for all images within the Dataset directory (works in both GPU and CPU) 64 | ``` 65 | bash script.sh 66 | ``` 67 | 68 | ### Outputs 69 | * Adversarial Images saved with the same name as the clean images in Adversarial/Results/ColorFoolImgs directory; 70 | * Metadata with the following structure: filename, number of trials, predicted class of the clean image with its probablity and predicted class of the adversarial image with its probablity in Adversarial/Results/Logs directory. 71 | 72 | 73 | ## Authors 74 | * [Ali Shahin Shamsabadi](mailto:a.shahinshamsabadi@qmul.ac.uk) 75 | * [Ricardo Sanchez-Matilla](mailto:ricardo.sanchezmatilla@qmul.ac.uk) 76 | * [Andrea Cavallaro](mailto:a.cavallaro@qmul.ac.uk) 77 | 78 | 79 | ## References 80 | If you use our code, please cite the following paper: 81 | 82 | @InProceedings{shamsabadi2020colorfool, 83 | title = {ColorFool: Semantic Adversarial Colorization}, 84 | author = {Shamsabadi, Ali Shahin and Sanchez-Matilla, Ricardo and Cavallaro, Andrea}, 85 | booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 86 | year = {2020}, 87 | address = {Seattle, Washington, USA}, 88 | month = June 89 | } 90 | 91 | ## License 92 | This work is licensed under the Creative Commons Attribution-NonCommercial 4.0 International License. To view a copy of this license, visit http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 93 | -------------------------------------------------------------------------------- /Sample_results/ILSVRC2012_val_00003533_alexnet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smartcameras/ColorFool/a2ccae4db821e304fb54090c332545b71046ea22/Sample_results/ILSVRC2012_val_00003533_alexnet.png -------------------------------------------------------------------------------- /Sample_results/ILSVRC2012_val_00003533_resnet18.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smartcameras/ColorFool/a2ccae4db821e304fb54090c332545b71046ea22/Sample_results/ILSVRC2012_val_00003533_resnet18.png -------------------------------------------------------------------------------- /Sample_results/ILSVRC2012_val_00003533_resnet50.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smartcameras/ColorFool/a2ccae4db821e304fb54090c332545b71046ea22/Sample_results/ILSVRC2012_val_00003533_resnet50.png -------------------------------------------------------------------------------- /Segmentation/SemanticMasks.py: -------------------------------------------------------------------------------- 1 | # System libs 2 | import os 3 | import argparse 4 | from distutils.version import LooseVersion 5 | # Numerical libs 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | from scipy.io import loadmat 10 | # Our libs 11 | from dataset import TestDataset 12 | from models import ModelBuilder, SegmentationModule 13 | from utils import colorEncode, find_recursive 14 | from lib.nn import user_scattered_collate, async_copy_to 15 | from lib.utils import as_numpy 16 | import lib.utils.data as torchdata 17 | import cv2 18 | from tqdm import tqdm 19 | 20 | colors = loadmat('data/color150.mat')['colors'] 21 | 22 | 23 | def visualize_result(data, pred, pred_prob, args): 24 | (img, info) = data 25 | img_name = info.split('/')[-1] 26 | 27 | ### water mask: water, sea, swimming pool, waterfalls, lake and river 28 | water_mask = (pred == 21) 29 | sea_mask = (pred == 26) 30 | river_mask = (pred == 60) 31 | pool_mask = (pred == 109) 32 | fall_mask = (pred == 113) 33 | lake_mask = (pred == 128) 34 | water_mask = (water_mask | sea_mask | river_mask | pool_mask | fall_mask | lake_mask).astype(int) 35 | if args.mask_type=='smooth': 36 | water_mask = water_mask.astype(float) * pred_prob 37 | 38 | water_mask = water_mask * 255. 39 | cv2.imwrite('{}/water/{}.png' .format(args.result,img_name.split('.')[0]), water_mask) 40 | 41 | 42 | ### Sky mask 43 | sky_mask = (pred == 2).astype(int) 44 | if args.mask_type=='smooth': 45 | sky_mask = sky_mask.astype(float) * pred_prob 46 | sky_mask = sky_mask * 255. 47 | cv2.imwrite('{}/sky/{}.png' .format(args.result,img_name.split('.')[0]), sky_mask) 48 | 49 | 50 | ### Grass mask 51 | grass_mask = (pred == 9).astype(int) 52 | if args.mask_type=='smooth': 53 | grass_mask = grass_mask.astype(float) * pred_prob 54 | 55 | grass_mask = grass_mask * 255. 56 | cv2.imwrite('{}/grass/{}.png' .format(args.result,img_name.split('.')[0]), grass_mask) 57 | 58 | 59 | ### Person mask 60 | person_mask = (pred == 12).astype(int) 61 | if args.mask_type=='smooth': 62 | person_mask = person_mask.astype(float) * pred_prob 63 | person_mask = person_mask * 255. 64 | cv2.imwrite('{}/person/{}.png' .format(args.result,img_name.split('.')[0]), person_mask) 65 | 66 | 67 | def test(segmentation_module, loader, args): 68 | segmentation_module.eval() 69 | 70 | pbar = tqdm(total=len(loader)) 71 | for batch_data in loader: 72 | # process data 73 | batch_data = batch_data[0] 74 | segSize = (batch_data['img_ori'].shape[0], 75 | batch_data['img_ori'].shape[1]) 76 | img_resized_list = batch_data['img_data'] 77 | 78 | with torch.no_grad(): 79 | scores = torch.zeros(1, args.num_class, segSize[0], segSize[1]) 80 | 81 | for img in img_resized_list: 82 | feed_dict = batch_data.copy() 83 | feed_dict['img_data'] = img 84 | del feed_dict['img_ori'] 85 | del feed_dict['info'] 86 | 87 | # forward pass 88 | pred_tmp = segmentation_module(feed_dict, segSize=segSize) 89 | scores += (pred_tmp.cpu() / len(args.imgSize)) 90 | 91 | 92 | pred_prob, pred = torch.max(scores, dim=1) 93 | pred = as_numpy(pred.squeeze(0).cpu()) 94 | pred_prob = as_numpy(pred_prob.squeeze(0).cpu()) 95 | 96 | # visualization 97 | visualize_result((batch_data['img_ori'], batch_data['info']), pred, pred_prob, args) 98 | 99 | pbar.update(1) 100 | 101 | 102 | def main(args): 103 | 104 | # Network Builders 105 | builder = ModelBuilder() 106 | net_encoder = builder.build_encoder( 107 | arch=args.arch_encoder, 108 | fc_dim=args.fc_dim, 109 | weights=args.weights_encoder) 110 | net_decoder = builder.build_decoder( 111 | arch=args.arch_decoder, 112 | fc_dim=args.fc_dim, 113 | num_class=args.num_class, 114 | weights=args.weights_decoder, 115 | use_softmax=True) 116 | 117 | crit = nn.NLLLoss(ignore_index=-1) 118 | 119 | segmentation_module = SegmentationModule(net_encoder, net_decoder, crit) 120 | 121 | # Dataset and Loader 122 | if len(args.dataset) == 1 and os.path.isdir(args.dataset[0]): 123 | test_imgs = find_recursive(args.dataset[0], ext='.*') 124 | else: 125 | test_imgs = args.dataset 126 | 127 | list_test = [{'fpath_img': x} for x in test_imgs] 128 | dataset_test = TestDataset(list_test, args, max_sample=args.num_val) 129 | loader_test = torchdata.DataLoader( 130 | dataset_test, 131 | batch_size=args.batch_size, 132 | shuffle=False, 133 | collate_fn=user_scattered_collate, 134 | num_workers=5, 135 | drop_last=True) 136 | 137 | # Main loop 138 | test(segmentation_module, loader_test, args) 139 | 140 | print('Segmentation completed') 141 | 142 | 143 | if __name__ == '__main__': 144 | assert LooseVersion(torch.__version__) >= LooseVersion('0.4.0'), \ 145 | 'PyTorch>=0.4.0 is required' 146 | 147 | parser = argparse.ArgumentParser() 148 | # Path related arguments 149 | parser.add_argument('--dataset', required=True, nargs='+', type=str, 150 | help='a list of image paths, or a directory name') 151 | parser.add_argument('--model_path', required=True, 152 | help='folder to model path') 153 | parser.add_argument('--suffix', default='_epoch_20.pth', 154 | help="which snapshot to load") 155 | 156 | # Model related arguments 157 | parser.add_argument('--arch_encoder', default='resnet50dilated', 158 | help="architecture of net_encoder") 159 | parser.add_argument('--arch_decoder', default='ppm_deepsup', 160 | help="architecture of net_decoder") 161 | parser.add_argument('--fc_dim', default=2048, type=int, 162 | help='number of features between encoder and decoder') 163 | 164 | # Data related arguments 165 | parser.add_argument('--num_val', default=-1, type=int, 166 | help='number of images to evalutate') 167 | parser.add_argument('--num_class', default=150, type=int, 168 | help='number of classes') 169 | parser.add_argument('--batch_size', default=1, type=int, 170 | help='batchsize. current only supports 1') 171 | parser.add_argument('--imgSize', default=[300, 400, 500, 600], 172 | nargs='+', type=int, 173 | help='list of input image sizes.' 174 | 'for multiscale testing, e.g. 300 400 500') 175 | parser.add_argument('--imgMaxSize', default=1000, type=int, 176 | help='maximum input image size of long edge') 177 | parser.add_argument('--padding_constant', default=8, type=int, 178 | help='maxmimum downsampling rate of the network') 179 | parser.add_argument('--segm_downsampling_rate', default=8, type=int, 180 | help='downsampling rate of the segmentation label') 181 | 182 | # Misc arguments 183 | parser.add_argument('--result', default='.', 184 | help='folder to output visualization results') 185 | parser.add_argument('--mask_type', required=True, 186 | help='Type 0f mask: binary or smooth') 187 | parser.add_argument('--gpu', default=0, type=int, 188 | help='gpu id for evaluation') 189 | 190 | args = parser.parse_args() 191 | 192 | args.arch_encoder = args.arch_encoder.lower() 193 | args.arch_decoder = args.arch_decoder.lower() 194 | print("Input arguments:") 195 | for key, val in vars(args).items(): 196 | print("{:16} {}".format(key, val)) 197 | 198 | # absolute paths of model weights 199 | args.weights_encoder = os.path.join(args.model_path, 200 | 'encoder' + args.suffix) 201 | args.weights_decoder = os.path.join(args.model_path, 202 | 'decoder' + args.suffix) 203 | 204 | assert os.path.exists(args.weights_encoder) and \ 205 | os.path.exists(args.weights_encoder), 'checkpoint does not exitst!' 206 | 207 | if not os.path.isdir('{}/'.format(args.result)): 208 | os.makedirs('{}/'.format(args.result)) 209 | if not os.path.isdir('{}/sky/'.format(args.result)): 210 | os.makedirs('{}/sky/'.format(args.result)) 211 | if not os.path.isdir('{}/water/'.format(args.result)): 212 | os.makedirs('{}/water/'.format(args.result)) 213 | if not os.path.isdir('{}/grass/'.format(args.result)): 214 | os.makedirs('{}/grass/'.format(args.result)) 215 | if not os.path.isdir('{}/person/'.format(args.result)): 216 | os.makedirs('{}/person/'.format(args.result)) 217 | 218 | main(args) 219 | -------------------------------------------------------------------------------- /Segmentation/data/color150.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smartcameras/ColorFool/a2ccae4db821e304fb54090c332545b71046ea22/Segmentation/data/color150.mat -------------------------------------------------------------------------------- /Segmentation/data/object150_info.csv: -------------------------------------------------------------------------------- 1 | Idx,Ratio,Train,Val,Stuff,Name 2 | 1,0.1576,11664,1172,1,wall 3 | 2,0.1072,6046,612,1,building;edifice 4 | 3,0.0878,8265,796,1,sky 5 | 4,0.0621,9336,917,1,floor;flooring 6 | 5,0.0480,6678,641,0,tree 7 | 6,0.0450,6604,643,1,ceiling 8 | 7,0.0398,4023,408,1,road;route 9 | 8,0.0231,1906,199,0,bed 10 | 9,0.0198,4688,460,0,windowpane;window 11 | 10,0.0183,2423,225,1,grass 12 | 11,0.0181,2874,294,0,cabinet 13 | 12,0.0166,3068,310,1,sidewalk;pavement 14 | 13,0.0160,5075,526,0,person;individual;someone;somebody;mortal;soul 15 | 14,0.0151,1804,190,1,earth;ground 16 | 15,0.0118,6666,796,0,door;double;door 17 | 16,0.0110,4269,411,0,table 18 | 17,0.0109,1691,160,1,mountain;mount 19 | 18,0.0104,3999,441,0,plant;flora;plant;life 20 | 19,0.0104,2149,217,0,curtain;drape;drapery;mantle;pall 21 | 20,0.0103,3261,318,0,chair 22 | 21,0.0098,3164,306,0,car;auto;automobile;machine;motorcar 23 | 22,0.0074,709,75,1,water 24 | 23,0.0067,3296,315,0,painting;picture 25 | 24,0.0065,1191,106,0,sofa;couch;lounge 26 | 25,0.0061,1516,162,0,shelf 27 | 26,0.0060,667,69,1,house 28 | 27,0.0053,651,57,1,sea 29 | 28,0.0052,1847,224,0,mirror 30 | 29,0.0046,1158,128,1,rug;carpet;carpeting 31 | 30,0.0044,480,44,1,field 32 | 31,0.0044,1172,98,0,armchair 33 | 32,0.0044,1292,184,0,seat 34 | 33,0.0033,1386,138,0,fence;fencing 35 | 34,0.0031,698,61,0,desk 36 | 35,0.0030,781,73,0,rock;stone 37 | 36,0.0027,380,43,0,wardrobe;closet;press 38 | 37,0.0026,3089,302,0,lamp 39 | 38,0.0024,404,37,0,bathtub;bathing;tub;bath;tub 40 | 39,0.0024,804,99,0,railing;rail 41 | 40,0.0023,1453,153,0,cushion 42 | 41,0.0023,411,37,0,base;pedestal;stand 43 | 42,0.0022,1440,162,0,box 44 | 43,0.0022,800,77,0,column;pillar 45 | 44,0.0020,2650,298,0,signboard;sign 46 | 45,0.0019,549,46,0,chest;of;drawers;chest;bureau;dresser 47 | 46,0.0019,367,36,0,counter 48 | 47,0.0018,311,30,1,sand 49 | 48,0.0018,1181,122,0,sink 50 | 49,0.0018,287,23,1,skyscraper 51 | 50,0.0018,468,38,0,fireplace;hearth;open;fireplace 52 | 51,0.0018,402,43,0,refrigerator;icebox 53 | 52,0.0018,130,12,1,grandstand;covered;stand 54 | 53,0.0018,561,64,1,path 55 | 54,0.0017,880,102,0,stairs;steps 56 | 55,0.0017,86,12,1,runway 57 | 56,0.0017,172,11,0,case;display;case;showcase;vitrine 58 | 57,0.0017,198,18,0,pool;table;billiard;table;snooker;table 59 | 58,0.0017,930,109,0,pillow 60 | 59,0.0015,139,18,0,screen;door;screen 61 | 60,0.0015,564,52,1,stairway;staircase 62 | 61,0.0015,320,26,1,river 63 | 62,0.0015,261,29,1,bridge;span 64 | 63,0.0014,275,22,0,bookcase 65 | 64,0.0014,335,60,0,blind;screen 66 | 65,0.0014,792,75,0,coffee;table;cocktail;table 67 | 66,0.0014,395,49,0,toilet;can;commode;crapper;pot;potty;stool;throne 68 | 67,0.0014,1309,138,0,flower 69 | 68,0.0013,1112,113,0,book 70 | 69,0.0013,266,27,1,hill 71 | 70,0.0013,659,66,0,bench 72 | 71,0.0012,331,31,0,countertop 73 | 72,0.0012,531,56,0,stove;kitchen;stove;range;kitchen;range;cooking;stove 74 | 73,0.0012,369,36,0,palm;palm;tree 75 | 74,0.0012,144,9,0,kitchen;island 76 | 75,0.0011,265,29,0,computer;computing;machine;computing;device;data;processor;electronic;computer;information;processing;system 77 | 76,0.0010,324,33,0,swivel;chair 78 | 77,0.0009,304,27,0,boat 79 | 78,0.0009,170,20,0,bar 80 | 79,0.0009,68,6,0,arcade;machine 81 | 80,0.0009,65,8,1,hovel;hut;hutch;shack;shanty 82 | 81,0.0009,248,25,0,bus;autobus;coach;charabanc;double-decker;jitney;motorbus;motorcoach;omnibus;passenger;vehicle 83 | 82,0.0008,492,49,0,towel 84 | 83,0.0008,2510,269,0,light;light;source 85 | 84,0.0008,440,39,0,truck;motortruck 86 | 85,0.0008,147,18,1,tower 87 | 86,0.0008,583,56,0,chandelier;pendant;pendent 88 | 87,0.0007,533,61,0,awning;sunshade;sunblind 89 | 88,0.0007,1989,239,0,streetlight;street;lamp 90 | 89,0.0007,71,5,0,booth;cubicle;stall;kiosk 91 | 90,0.0007,618,53,0,television;television;receiver;television;set;tv;tv;set;idiot;box;boob;tube;telly;goggle;box 92 | 91,0.0007,135,12,0,airplane;aeroplane;plane 93 | 92,0.0007,83,5,1,dirt;track 94 | 93,0.0007,178,17,0,apparel;wearing;apparel;dress;clothes 95 | 94,0.0006,1003,104,0,pole 96 | 95,0.0006,182,12,1,land;ground;soil 97 | 96,0.0006,452,50,0,bannister;banister;balustrade;balusters;handrail 98 | 97,0.0006,42,6,1,escalator;moving;staircase;moving;stairway 99 | 98,0.0006,307,31,0,ottoman;pouf;pouffe;puff;hassock 100 | 99,0.0006,965,114,0,bottle 101 | 100,0.0006,117,13,0,buffet;counter;sideboard 102 | 101,0.0006,354,35,0,poster;posting;placard;notice;bill;card 103 | 102,0.0006,108,9,1,stage 104 | 103,0.0006,557,55,0,van 105 | 104,0.0006,52,4,0,ship 106 | 105,0.0005,99,5,0,fountain 107 | 106,0.0005,57,4,1,conveyer;belt;conveyor;belt;conveyer;conveyor;transporter 108 | 107,0.0005,292,31,0,canopy 109 | 108,0.0005,77,9,0,washer;automatic;washer;washing;machine 110 | 109,0.0005,340,38,0,plaything;toy 111 | 110,0.0005,66,3,1,swimming;pool;swimming;bath;natatorium 112 | 111,0.0005,465,49,0,stool 113 | 112,0.0005,50,4,0,barrel;cask 114 | 113,0.0005,622,75,0,basket;handbasket 115 | 114,0.0005,80,9,1,waterfall;falls 116 | 115,0.0005,59,3,0,tent;collapsible;shelter 117 | 116,0.0005,531,72,0,bag 118 | 117,0.0005,282,30,0,minibike;motorbike 119 | 118,0.0005,73,7,0,cradle 120 | 119,0.0005,435,44,0,oven 121 | 120,0.0005,136,25,0,ball 122 | 121,0.0005,116,24,0,food;solid;food 123 | 122,0.0004,266,31,0,step;stair 124 | 123,0.0004,58,12,0,tank;storage;tank 125 | 124,0.0004,418,83,0,trade;name;brand;name;brand;marque 126 | 125,0.0004,319,43,0,microwave;microwave;oven 127 | 126,0.0004,1193,139,0,pot;flowerpot 128 | 127,0.0004,97,23,0,animal;animate;being;beast;brute;creature;fauna 129 | 128,0.0004,347,36,0,bicycle;bike;wheel;cycle 130 | 129,0.0004,52,5,1,lake 131 | 130,0.0004,246,22,0,dishwasher;dish;washer;dishwashing;machine 132 | 131,0.0004,108,13,0,screen;silver;screen;projection;screen 133 | 132,0.0004,201,30,0,blanket;cover 134 | 133,0.0004,285,21,0,sculpture 135 | 134,0.0004,268,27,0,hood;exhaust;hood 136 | 135,0.0003,1020,108,0,sconce 137 | 136,0.0003,1282,122,0,vase 138 | 137,0.0003,528,65,0,traffic;light;traffic;signal;stoplight 139 | 138,0.0003,453,57,0,tray 140 | 139,0.0003,671,100,0,ashcan;trash;can;garbage;can;wastebin;ash;bin;ash-bin;ashbin;dustbin;trash;barrel;trash;bin 141 | 140,0.0003,397,44,0,fan 142 | 141,0.0003,92,8,1,pier;wharf;wharfage;dock 143 | 142,0.0003,228,18,0,crt;screen 144 | 143,0.0003,570,59,0,plate 145 | 144,0.0003,217,22,0,monitor;monitoring;device 146 | 145,0.0003,206,19,0,bulletin;board;notice;board 147 | 146,0.0003,130,14,0,shower 148 | 147,0.0003,178,28,0,radiator 149 | 148,0.0002,504,57,0,glass;drinking;glass 150 | 149,0.0002,775,96,0,clock 151 | 150,0.0002,421,56,0,flag 152 | -------------------------------------------------------------------------------- /Segmentation/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | import lib.utils.data as torchdata 5 | import cv2 6 | from torchvision import transforms 7 | import numpy as np 8 | 9 | 10 | class BaseDataset(torchdata.Dataset): 11 | def __init__(self, odgt, opt, **kwargs): 12 | # parse options 13 | self.imgSize = opt.imgSize 14 | self.imgMaxSize = opt.imgMaxSize 15 | 16 | # max down sampling rate of network to avoid rounding during conv or pooling 17 | self.padding_constant = opt.padding_constant 18 | 19 | # parse the input list 20 | self.parse_input_list(odgt, **kwargs) 21 | 22 | # mean and std 23 | self.normalize = transforms.Normalize( 24 | mean=[102.9801, 115.9465, 122.7717], 25 | std=[1., 1., 1.]) 26 | 27 | def parse_input_list(self, odgt, max_sample=-1, start_idx=-1, end_idx=-1): 28 | if isinstance(odgt, list): 29 | self.list_sample = odgt 30 | elif isinstance(odgt, str): 31 | self.list_sample = [json.loads(x.rstrip()) for x in open(odgt, 'r')] 32 | 33 | if max_sample > 0: 34 | self.list_sample = self.list_sample[0:max_sample] 35 | if start_idx >= 0 and end_idx >= 0: # divide file list 36 | self.list_sample = self.list_sample[start_idx:end_idx] 37 | 38 | self.num_sample = len(self.list_sample) 39 | assert self.num_sample > 0 40 | print('# samples: {}'.format(self.num_sample)) 41 | 42 | def img_transform(self, img): 43 | # image to float 44 | img = img.astype(np.float32) 45 | img = img.transpose((2, 0, 1)) 46 | img = self.normalize(torch.from_numpy(img.copy())) 47 | return img 48 | 49 | # Round x to the nearest multiple of p and x' >= x 50 | def round2nearest_multiple(self, x, p): 51 | return ((x - 1) // p + 1) * p 52 | 53 | 54 | class TrainDataset(BaseDataset): 55 | def __init__(self, odgt, opt, batch_per_gpu=1, **kwargs): 56 | super(TrainDataset, self).__init__(odgt, opt, **kwargs) 57 | self.root_dataset = opt.root_dataset 58 | self.random_flip = opt.random_flip 59 | # down sampling rate of segm labe 60 | self.segm_downsampling_rate = opt.segm_downsampling_rate 61 | self.batch_per_gpu = batch_per_gpu 62 | 63 | # classify images into two classes: 1. h > w and 2. h <= w 64 | self.batch_record_list = [[], []] 65 | 66 | # override dataset length when trainig with batch_per_gpu > 1 67 | self.cur_idx = 0 68 | self.if_shuffled = False 69 | 70 | def _get_sub_batch(self): 71 | while True: 72 | # get a sample record 73 | this_sample = self.list_sample[self.cur_idx] 74 | if this_sample['height'] > this_sample['width']: 75 | self.batch_record_list[0].append(this_sample) # h > w, go to 1st class 76 | else: 77 | self.batch_record_list[1].append(this_sample) # h <= w, go to 2nd class 78 | 79 | # update current sample pointer 80 | self.cur_idx += 1 81 | if self.cur_idx >= self.num_sample: 82 | self.cur_idx = 0 83 | np.random.shuffle(self.list_sample) 84 | 85 | if len(self.batch_record_list[0]) == self.batch_per_gpu: 86 | batch_records = self.batch_record_list[0] 87 | self.batch_record_list[0] = [] 88 | break 89 | elif len(self.batch_record_list[1]) == self.batch_per_gpu: 90 | batch_records = self.batch_record_list[1] 91 | self.batch_record_list[1] = [] 92 | break 93 | return batch_records 94 | 95 | def __getitem__(self, index): 96 | # NOTE: random shuffle for the first time. shuffle in __init__ is useless 97 | if not self.if_shuffled: 98 | np.random.shuffle(self.list_sample) 99 | self.if_shuffled = True 100 | 101 | # get sub-batch candidates 102 | batch_records = self._get_sub_batch() 103 | 104 | # resize all images' short edges to the chosen size 105 | if isinstance(self.imgSize, list): 106 | this_short_size = np.random.choice(self.imgSize) 107 | else: 108 | this_short_size = self.imgSize 109 | 110 | # calculate the BATCH's height and width 111 | # since we concat more than one samples, the batch's h and w shall be larger than EACH sample 112 | batch_resized_size = np.zeros((self.batch_per_gpu, 2), np.int32) 113 | for i in range(self.batch_per_gpu): 114 | img_height, img_width = batch_records[i]['height'], batch_records[i]['width'] 115 | this_scale = min( 116 | this_short_size / min(img_height, img_width), \ 117 | self.imgMaxSize / max(img_height, img_width)) 118 | img_resized_height, img_resized_width = img_height * this_scale, img_width * this_scale 119 | batch_resized_size[i, :] = img_resized_height, img_resized_width 120 | batch_resized_height = np.max(batch_resized_size[:, 0]) 121 | batch_resized_width = np.max(batch_resized_size[:, 1]) 122 | 123 | # Here we must pad both input image and segmentation map to size h' and w' so that p | h' and p | w' 124 | batch_resized_height = int(self.round2nearest_multiple(batch_resized_height, self.padding_constant)) 125 | batch_resized_width = int(self.round2nearest_multiple(batch_resized_width, self.padding_constant)) 126 | 127 | assert self.padding_constant >= self.segm_downsampling_rate,\ 128 | 'padding constant must be equal or large than segm downsamping rate' 129 | batch_images = torch.zeros(self.batch_per_gpu, 3, batch_resized_height, batch_resized_width) 130 | batch_segms = torch.zeros( 131 | self.batch_per_gpu, batch_resized_height // self.segm_downsampling_rate, \ 132 | batch_resized_width // self.segm_downsampling_rate).long() 133 | 134 | for i in range(self.batch_per_gpu): 135 | this_record = batch_records[i] 136 | 137 | # load image and label 138 | image_path = os.path.join(self.root_dataset, this_record['fpath_img']) 139 | segm_path = os.path.join(self.root_dataset, this_record['fpath_segm']) 140 | img = cv2.imread(image_path, cv2.IMREAD_COLOR) 141 | segm = cv2.imread(segm_path, cv2.IMREAD_GRAYSCALE) 142 | 143 | assert(img.ndim == 3) 144 | assert(segm.ndim == 2) 145 | assert(img.shape[0] == segm.shape[0]) 146 | assert(img.shape[1] == segm.shape[1]) 147 | 148 | if self.random_flip is True: 149 | random_flip = np.random.choice([0, 1]) 150 | if random_flip == 1: 151 | img = cv2.flip(img, 1) 152 | segm = cv2.flip(segm, 1) 153 | 154 | # note that each sample within a mini batch has different scale param 155 | img = cv2.resize(img, (batch_resized_size[i, 1], batch_resized_size[i, 0]), interpolation=cv2.INTER_LINEAR) 156 | segm = cv2.resize(segm, (batch_resized_size[i, 1], batch_resized_size[i, 0]), interpolation=cv2.INTER_NEAREST) 157 | 158 | # to avoid seg label misalignment 159 | segm_rounded_height = self.round2nearest_multiple(segm.shape[0], self.segm_downsampling_rate) 160 | segm_rounded_width = self.round2nearest_multiple(segm.shape[1], self.segm_downsampling_rate) 161 | segm_rounded = np.zeros((segm_rounded_height, segm_rounded_width), dtype='uint8') 162 | segm_rounded[:segm.shape[0], :segm.shape[1]] = segm 163 | 164 | segm = cv2.resize( 165 | segm_rounded, 166 | (segm_rounded.shape[1] // self.segm_downsampling_rate, \ 167 | segm_rounded.shape[0] // self.segm_downsampling_rate), \ 168 | interpolation=cv2.INTER_NEAREST) 169 | 170 | # image transform 171 | img = self.img_transform(img) 172 | 173 | batch_images[i][:, :img.shape[1], :img.shape[2]] = img 174 | batch_segms[i][:segm.shape[0], :segm.shape[1]] = torch.from_numpy(segm.astype(np.int)).long() 175 | 176 | batch_segms = batch_segms - 1 # label from -1 to 149 177 | output = dict() 178 | output['img_data'] = batch_images 179 | output['seg_label'] = batch_segms 180 | return output 181 | 182 | def __len__(self): 183 | return int(1e10) # It's a fake length due to the trick that every loader maintains its own list 184 | #return self.num_sampleclass 185 | 186 | 187 | class ValDataset(BaseDataset): 188 | def __init__(self, odgt, opt, **kwargs): 189 | super(ValDataset, self).__init__(odgt, opt, **kwargs) 190 | self.root_dataset = opt.root_dataset 191 | 192 | def __getitem__(self, index): 193 | this_record = self.list_sample[index] 194 | # load image and label 195 | image_path = os.path.join(self.root_dataset, this_record['fpath_img']) 196 | segm_path = os.path.join(self.root_dataset, this_record['fpath_segm']) 197 | img = cv2.imread(image_path, cv2.IMREAD_COLOR) 198 | segm = cv2.imread(segm_path, cv2.IMREAD_GRAYSCALE) 199 | 200 | ori_height, ori_width, _ = img.shape 201 | 202 | img_resized_list = [] 203 | for this_short_size in self.imgSize: 204 | # calculate target height and width 205 | scale = min(this_short_size / float(min(ori_height, ori_width)), 206 | self.imgMaxSize / float(max(ori_height, ori_width))) 207 | target_height, target_width = int(ori_height * scale), int(ori_width * scale) 208 | 209 | # to avoid rounding in network 210 | target_height = self.round2nearest_multiple(target_height, self.padding_constant) 211 | target_width = self.round2nearest_multiple(target_width, self.padding_constant) 212 | 213 | # resize 214 | img_resized = cv2.resize(img.copy(), (target_width, target_height)) 215 | 216 | # image transform 217 | img_resized = self.img_transform(img_resized) 218 | 219 | img_resized = torch.unsqueeze(img_resized, 0) 220 | img_resized_list.append(img_resized) 221 | 222 | segm = torch.from_numpy(segm.astype(np.int)).long() 223 | batch_segms = torch.unsqueeze(segm, 0) 224 | 225 | batch_segms = batch_segms - 1 # label from -1 to 149 226 | output = dict() 227 | output['img_ori'] = img.copy() 228 | output['img_data'] = [x.contiguous() for x in img_resized_list] 229 | output['seg_label'] = batch_segms.contiguous() 230 | output['info'] = this_record['fpath_img'] 231 | return output 232 | 233 | def __len__(self): 234 | return self.num_sample 235 | 236 | 237 | class TestDataset(BaseDataset): 238 | def __init__(self, odgt, opt, **kwargs): 239 | super(TestDataset, self).__init__(odgt, opt, **kwargs) 240 | 241 | def __getitem__(self, index): 242 | this_record = self.list_sample[index] 243 | # load image and label 244 | image_path = this_record['fpath_img'] 245 | img = cv2.imread(image_path, cv2.IMREAD_COLOR) 246 | 247 | ori_height, ori_width, _ = img.shape 248 | 249 | img_resized_list = [] 250 | for this_short_size in self.imgSize: 251 | # calculate target height and width 252 | scale = min(this_short_size / float(min(ori_height, ori_width)), 253 | self.imgMaxSize / float(max(ori_height, ori_width))) 254 | target_height, target_width = int(ori_height * scale), int(ori_width * scale) 255 | 256 | # to avoid rounding in network 257 | target_height = self.round2nearest_multiple(target_height, self.padding_constant) 258 | target_width = self.round2nearest_multiple(target_width, self.padding_constant) 259 | 260 | # resize 261 | img_resized = cv2.resize(img.copy(), (target_width, target_height)) 262 | 263 | # image transform 264 | img_resized = self.img_transform(img_resized) 265 | img_resized = torch.unsqueeze(img_resized, 0) 266 | img_resized_list.append(img_resized) 267 | 268 | output = dict() 269 | output['img_ori'] = img.copy() 270 | output['img_data'] = [x.contiguous() for x in img_resized_list] 271 | output['info'] = this_record['fpath_img'] 272 | return output 273 | 274 | def __len__(self): 275 | return self.num_sample 276 | -------------------------------------------------------------------------------- /Segmentation/lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smartcameras/ColorFool/a2ccae4db821e304fb54090c332545b71046ea22/Segmentation/lib/__init__.py -------------------------------------------------------------------------------- /Segmentation/lib/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from .modules import * 2 | from .parallel import UserScatteredDataParallel, user_scattered_collate, async_copy_to 3 | -------------------------------------------------------------------------------- /Segmentation/lib/nn/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .replicate import DataParallelWithCallback, patch_replication_callback 13 | -------------------------------------------------------------------------------- /Segmentation/lib/nn/modules/batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import collections 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | 16 | from torch.nn.modules.batchnorm import _BatchNorm 17 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast 18 | 19 | from .comm import SyncMaster 20 | 21 | __all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d'] 22 | 23 | 24 | def _sum_ft(tensor): 25 | """sum over the first and last dimention""" 26 | return tensor.sum(dim=0).sum(dim=-1) 27 | 28 | 29 | def _unsqueeze_ft(tensor): 30 | """add new dementions at the front and the tail""" 31 | return tensor.unsqueeze(0).unsqueeze(-1) 32 | 33 | 34 | _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) 35 | _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) 36 | 37 | 38 | class _SynchronizedBatchNorm(_BatchNorm): 39 | def __init__(self, num_features, eps=1e-5, momentum=0.001, affine=True): 40 | super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) 41 | 42 | self._sync_master = SyncMaster(self._data_parallel_master) 43 | 44 | self._is_parallel = False 45 | self._parallel_id = None 46 | self._slave_pipe = None 47 | 48 | # customed batch norm statistics 49 | self._moving_average_fraction = 1. - momentum 50 | self.register_buffer('_tmp_running_mean', torch.zeros(self.num_features)) 51 | self.register_buffer('_tmp_running_var', torch.ones(self.num_features)) 52 | self.register_buffer('_running_iter', torch.ones(1)) 53 | self._tmp_running_mean = self.running_mean.clone() * self._running_iter 54 | self._tmp_running_var = self.running_var.clone() * self._running_iter 55 | 56 | def forward(self, input): 57 | # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. 58 | if not (self._is_parallel and self.training): 59 | return F.batch_norm( 60 | input, self.running_mean, self.running_var, self.weight, self.bias, 61 | self.training, self.momentum, self.eps) 62 | 63 | # Resize the input to (B, C, -1). 64 | input_shape = input.size() 65 | input = input.view(input.size(0), self.num_features, -1) 66 | 67 | # Compute the sum and square-sum. 68 | sum_size = input.size(0) * input.size(2) 69 | input_sum = _sum_ft(input) 70 | input_ssum = _sum_ft(input ** 2) 71 | 72 | # Reduce-and-broadcast the statistics. 73 | if self._parallel_id == 0: 74 | mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) 75 | else: 76 | mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) 77 | 78 | # Compute the output. 79 | if self.affine: 80 | # MJY:: Fuse the multiplication for speed. 81 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) 82 | else: 83 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) 84 | 85 | # Reshape it. 86 | return output.view(input_shape) 87 | 88 | def __data_parallel_replicate__(self, ctx, copy_id): 89 | self._is_parallel = True 90 | self._parallel_id = copy_id 91 | 92 | # parallel_id == 0 means master device. 93 | if self._parallel_id == 0: 94 | ctx.sync_master = self._sync_master 95 | else: 96 | self._slave_pipe = ctx.sync_master.register_slave(copy_id) 97 | 98 | def _data_parallel_master(self, intermediates): 99 | """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" 100 | intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) 101 | 102 | to_reduce = [i[1][:2] for i in intermediates] 103 | to_reduce = [j for i in to_reduce for j in i] # flatten 104 | target_gpus = [i[1].sum.get_device() for i in intermediates] 105 | 106 | sum_size = sum([i[1].sum_size for i in intermediates]) 107 | sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) 108 | 109 | mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) 110 | 111 | broadcasted = Broadcast.apply(target_gpus, mean, inv_std) 112 | 113 | outputs = [] 114 | for i, rec in enumerate(intermediates): 115 | outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2]))) 116 | 117 | return outputs 118 | 119 | def _add_weighted(self, dest, delta, alpha=1, beta=1, bias=0): 120 | """return *dest* by `dest := dest*alpha + delta*beta + bias`""" 121 | return dest * alpha + delta * beta + bias 122 | 123 | def _compute_mean_std(self, sum_, ssum, size): 124 | """Compute the mean and standard-deviation with sum and square-sum. This method 125 | also maintains the moving average on the master device.""" 126 | assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' 127 | mean = sum_ / size 128 | sumvar = ssum - sum_ * mean 129 | unbias_var = sumvar / (size - 1) 130 | bias_var = sumvar / size 131 | 132 | self._tmp_running_mean = self._add_weighted(self._tmp_running_mean, mean.data, alpha=self._moving_average_fraction) 133 | self._tmp_running_var = self._add_weighted(self._tmp_running_var, unbias_var.data, alpha=self._moving_average_fraction) 134 | self._running_iter = self._add_weighted(self._running_iter, 1, alpha=self._moving_average_fraction) 135 | 136 | self.running_mean = self._tmp_running_mean / self._running_iter 137 | self.running_var = self._tmp_running_var / self._running_iter 138 | 139 | return mean, bias_var.clamp(self.eps) ** -0.5 140 | 141 | 142 | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): 143 | r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a 144 | mini-batch. 145 | 146 | .. math:: 147 | 148 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 149 | 150 | This module differs from the built-in PyTorch BatchNorm1d as the mean and 151 | standard-deviation are reduced across all devices during training. 152 | 153 | For example, when one uses `nn.DataParallel` to wrap the network during 154 | training, PyTorch's implementation normalize the tensor on each device using 155 | the statistics only on that device, which accelerated the computation and 156 | is also easy to implement, but the statistics might be inaccurate. 157 | Instead, in this synchronized version, the statistics will be computed 158 | over all training samples distributed on multiple devices. 159 | 160 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 161 | as the built-in PyTorch implementation. 162 | 163 | The mean and standard-deviation are calculated per-dimension over 164 | the mini-batches and gamma and beta are learnable parameter vectors 165 | of size C (where C is the input size). 166 | 167 | During training, this layer keeps a running estimate of its computed mean 168 | and variance. The running sum is kept with a default momentum of 0.1. 169 | 170 | During evaluation, this running mean/variance is used for normalization. 171 | 172 | Because the BatchNorm is done over the `C` dimension, computing statistics 173 | on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm 174 | 175 | Args: 176 | num_features: num_features from an expected input of size 177 | `batch_size x num_features [x width]` 178 | eps: a value added to the denominator for numerical stability. 179 | Default: 1e-5 180 | momentum: the value used for the running_mean and running_var 181 | computation. Default: 0.1 182 | affine: a boolean value that when set to ``True``, gives the layer learnable 183 | affine parameters. Default: ``True`` 184 | 185 | Shape: 186 | - Input: :math:`(N, C)` or :math:`(N, C, L)` 187 | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) 188 | 189 | Examples: 190 | >>> # With Learnable Parameters 191 | >>> m = SynchronizedBatchNorm1d(100) 192 | >>> # Without Learnable Parameters 193 | >>> m = SynchronizedBatchNorm1d(100, affine=False) 194 | >>> input = torch.autograd.Variable(torch.randn(20, 100)) 195 | >>> output = m(input) 196 | """ 197 | 198 | def _check_input_dim(self, input): 199 | if input.dim() != 2 and input.dim() != 3: 200 | raise ValueError('expected 2D or 3D input (got {}D input)' 201 | .format(input.dim())) 202 | super(SynchronizedBatchNorm1d, self)._check_input_dim(input) 203 | 204 | 205 | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): 206 | r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch 207 | of 3d inputs 208 | 209 | .. math:: 210 | 211 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 212 | 213 | This module differs from the built-in PyTorch BatchNorm2d as the mean and 214 | standard-deviation are reduced across all devices during training. 215 | 216 | For example, when one uses `nn.DataParallel` to wrap the network during 217 | training, PyTorch's implementation normalize the tensor on each device using 218 | the statistics only on that device, which accelerated the computation and 219 | is also easy to implement, but the statistics might be inaccurate. 220 | Instead, in this synchronized version, the statistics will be computed 221 | over all training samples distributed on multiple devices. 222 | 223 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 224 | as the built-in PyTorch implementation. 225 | 226 | The mean and standard-deviation are calculated per-dimension over 227 | the mini-batches and gamma and beta are learnable parameter vectors 228 | of size C (where C is the input size). 229 | 230 | During training, this layer keeps a running estimate of its computed mean 231 | and variance. The running sum is kept with a default momentum of 0.1. 232 | 233 | During evaluation, this running mean/variance is used for normalization. 234 | 235 | Because the BatchNorm is done over the `C` dimension, computing statistics 236 | on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm 237 | 238 | Args: 239 | num_features: num_features from an expected input of 240 | size batch_size x num_features x height x width 241 | eps: a value added to the denominator for numerical stability. 242 | Default: 1e-5 243 | momentum: the value used for the running_mean and running_var 244 | computation. Default: 0.1 245 | affine: a boolean value that when set to ``True``, gives the layer learnable 246 | affine parameters. Default: ``True`` 247 | 248 | Shape: 249 | - Input: :math:`(N, C, H, W)` 250 | - Output: :math:`(N, C, H, W)` (same shape as input) 251 | 252 | Examples: 253 | >>> # With Learnable Parameters 254 | >>> m = SynchronizedBatchNorm2d(100) 255 | >>> # Without Learnable Parameters 256 | >>> m = SynchronizedBatchNorm2d(100, affine=False) 257 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) 258 | >>> output = m(input) 259 | """ 260 | 261 | def _check_input_dim(self, input): 262 | if input.dim() != 4: 263 | raise ValueError('expected 4D input (got {}D input)' 264 | .format(input.dim())) 265 | super(SynchronizedBatchNorm2d, self)._check_input_dim(input) 266 | 267 | 268 | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): 269 | r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch 270 | of 4d inputs 271 | 272 | .. math:: 273 | 274 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 275 | 276 | This module differs from the built-in PyTorch BatchNorm3d as the mean and 277 | standard-deviation are reduced across all devices during training. 278 | 279 | For example, when one uses `nn.DataParallel` to wrap the network during 280 | training, PyTorch's implementation normalize the tensor on each device using 281 | the statistics only on that device, which accelerated the computation and 282 | is also easy to implement, but the statistics might be inaccurate. 283 | Instead, in this synchronized version, the statistics will be computed 284 | over all training samples distributed on multiple devices. 285 | 286 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 287 | as the built-in PyTorch implementation. 288 | 289 | The mean and standard-deviation are calculated per-dimension over 290 | the mini-batches and gamma and beta are learnable parameter vectors 291 | of size C (where C is the input size). 292 | 293 | During training, this layer keeps a running estimate of its computed mean 294 | and variance. The running sum is kept with a default momentum of 0.1. 295 | 296 | During evaluation, this running mean/variance is used for normalization. 297 | 298 | Because the BatchNorm is done over the `C` dimension, computing statistics 299 | on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm 300 | or Spatio-temporal BatchNorm 301 | 302 | Args: 303 | num_features: num_features from an expected input of 304 | size batch_size x num_features x depth x height x width 305 | eps: a value added to the denominator for numerical stability. 306 | Default: 1e-5 307 | momentum: the value used for the running_mean and running_var 308 | computation. Default: 0.1 309 | affine: a boolean value that when set to ``True``, gives the layer learnable 310 | affine parameters. Default: ``True`` 311 | 312 | Shape: 313 | - Input: :math:`(N, C, D, H, W)` 314 | - Output: :math:`(N, C, D, H, W)` (same shape as input) 315 | 316 | Examples: 317 | >>> # With Learnable Parameters 318 | >>> m = SynchronizedBatchNorm3d(100) 319 | >>> # Without Learnable Parameters 320 | >>> m = SynchronizedBatchNorm3d(100, affine=False) 321 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) 322 | >>> output = m(input) 323 | """ 324 | 325 | def _check_input_dim(self, input): 326 | if input.dim() != 5: 327 | raise ValueError('expected 5D input (got {}D input)' 328 | .format(input.dim())) 329 | super(SynchronizedBatchNorm3d, self)._check_input_dim(input) 330 | -------------------------------------------------------------------------------- /Segmentation/lib/nn/modules/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | 59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 62 | and passed to a registered callback. 63 | - After receiving the messages, the master device should gather the information and determine to message passed 64 | back to each slave devices. 65 | """ 66 | 67 | def __init__(self, master_callback): 68 | """ 69 | 70 | Args: 71 | master_callback: a callback to be invoked after having collected messages from slave devices. 72 | """ 73 | self._master_callback = master_callback 74 | self._queue = queue.Queue() 75 | self._registry = collections.OrderedDict() 76 | self._activated = False 77 | 78 | def register_slave(self, identifier): 79 | """ 80 | Register an slave device. 81 | 82 | Args: 83 | identifier: an identifier, usually is the device id. 84 | 85 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 86 | 87 | """ 88 | if self._activated: 89 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 90 | self._activated = False 91 | self._registry.clear() 92 | future = FutureResult() 93 | self._registry[identifier] = _MasterRegistry(future) 94 | return SlavePipe(identifier, self._queue, future) 95 | 96 | def run_master(self, master_msg): 97 | """ 98 | Main entry for the master device in each forward pass. 99 | The messages were first collected from each devices (including the master device), and then 100 | an callback will be invoked to compute the message to be sent back to each devices 101 | (including the master device). 102 | 103 | Args: 104 | master_msg: the message that the master want to send to itself. This will be placed as the first 105 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 106 | 107 | Returns: the message to be sent back to the master device. 108 | 109 | """ 110 | self._activated = True 111 | 112 | intermediates = [(0, master_msg)] 113 | for i in range(self.nr_slaves): 114 | intermediates.append(self._queue.get()) 115 | 116 | results = self._master_callback(intermediates) 117 | assert results[0][0] == 0, 'The first result should belongs to the master.' 118 | 119 | for i, res in results: 120 | if i == 0: 121 | continue 122 | self._registry[i].result.put(res) 123 | 124 | for i in range(self.nr_slaves): 125 | assert self._queue.get() is True 126 | 127 | return results[0][1] 128 | 129 | @property 130 | def nr_slaves(self): 131 | return len(self._registry) 132 | -------------------------------------------------------------------------------- /Segmentation/lib/nn/modules/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | 31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 32 | 33 | Note that, as all modules are isomorphism, we assign each sub-module with a context 34 | (shared among multiple copies of this module on different devices). 35 | Through this context, different copies can share some information. 36 | 37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 38 | of any slave copies. 39 | """ 40 | master_copy = modules[0] 41 | nr_modules = len(list(master_copy.modules())) 42 | ctxs = [CallbackContext() for _ in range(nr_modules)] 43 | 44 | for i, module in enumerate(modules): 45 | for j, m in enumerate(module.modules()): 46 | if hasattr(m, '__data_parallel_replicate__'): 47 | m.__data_parallel_replicate__(ctxs[j], i) 48 | 49 | 50 | class DataParallelWithCallback(DataParallel): 51 | """ 52 | Data Parallel with a replication callback. 53 | 54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 55 | original `replicate` function. 56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 57 | 58 | Examples: 59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 61 | # sync_bn.__data_parallel_replicate__ will be invoked. 62 | """ 63 | 64 | def replicate(self, module, device_ids): 65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 66 | execute_replication_callbacks(modules) 67 | return modules 68 | 69 | 70 | def patch_replication_callback(data_parallel): 71 | """ 72 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 73 | Useful when you have customized `DataParallel` implementation. 74 | 75 | Examples: 76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 78 | > patch_replication_callback(sync_bn) 79 | # this is equivalent to 80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 82 | """ 83 | 84 | assert isinstance(data_parallel, DataParallel) 85 | 86 | old_replicate = data_parallel.replicate 87 | 88 | @functools.wraps(old_replicate) 89 | def new_replicate(module, device_ids): 90 | modules = old_replicate(module, device_ids) 91 | execute_replication_callbacks(modules) 92 | return modules 93 | 94 | data_parallel.replicate = new_replicate 95 | -------------------------------------------------------------------------------- /Segmentation/lib/nn/modules/tests/test_numeric_batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : test_numeric_batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | 9 | import unittest 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torch.autograd import Variable 14 | 15 | from sync_batchnorm.unittest import TorchTestCase 16 | 17 | 18 | def handy_var(a, unbias=True): 19 | n = a.size(0) 20 | asum = a.sum(dim=0) 21 | as_sum = (a ** 2).sum(dim=0) # a square sum 22 | sumvar = as_sum - asum * asum / n 23 | if unbias: 24 | return sumvar / (n - 1) 25 | else: 26 | return sumvar / n 27 | 28 | 29 | class NumericTestCase(TorchTestCase): 30 | def testNumericBatchNorm(self): 31 | a = torch.rand(16, 10) 32 | bn = nn.BatchNorm2d(10, momentum=1, eps=1e-5, affine=False) 33 | bn.train() 34 | 35 | a_var1 = Variable(a, requires_grad=True) 36 | b_var1 = bn(a_var1) 37 | loss1 = b_var1.sum() 38 | loss1.backward() 39 | 40 | a_var2 = Variable(a, requires_grad=True) 41 | a_mean2 = a_var2.mean(dim=0, keepdim=True) 42 | a_std2 = torch.sqrt(handy_var(a_var2, unbias=False).clamp(min=1e-5)) 43 | # a_std2 = torch.sqrt(a_var2.var(dim=0, keepdim=True, unbiased=False) + 1e-5) 44 | b_var2 = (a_var2 - a_mean2) / a_std2 45 | loss2 = b_var2.sum() 46 | loss2.backward() 47 | 48 | self.assertTensorClose(bn.running_mean, a.mean(dim=0)) 49 | self.assertTensorClose(bn.running_var, handy_var(a)) 50 | self.assertTensorClose(a_var1.data, a_var2.data) 51 | self.assertTensorClose(b_var1.data, b_var2.data) 52 | self.assertTensorClose(a_var1.grad, a_var2.grad) 53 | 54 | 55 | if __name__ == '__main__': 56 | unittest.main() 57 | -------------------------------------------------------------------------------- /Segmentation/lib/nn/modules/tests/test_sync_batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : test_sync_batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | 9 | import unittest 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torch.autograd import Variable 14 | 15 | from sync_batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, DataParallelWithCallback 16 | from sync_batchnorm.unittest import TorchTestCase 17 | 18 | 19 | def handy_var(a, unbias=True): 20 | n = a.size(0) 21 | asum = a.sum(dim=0) 22 | as_sum = (a ** 2).sum(dim=0) # a square sum 23 | sumvar = as_sum - asum * asum / n 24 | if unbias: 25 | return sumvar / (n - 1) 26 | else: 27 | return sumvar / n 28 | 29 | 30 | def _find_bn(module): 31 | for m in module.modules(): 32 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, SynchronizedBatchNorm1d, SynchronizedBatchNorm2d)): 33 | return m 34 | 35 | 36 | class SyncTestCase(TorchTestCase): 37 | def _syncParameters(self, bn1, bn2): 38 | bn1.reset_parameters() 39 | bn2.reset_parameters() 40 | if bn1.affine and bn2.affine: 41 | bn2.weight.data.copy_(bn1.weight.data) 42 | bn2.bias.data.copy_(bn1.bias.data) 43 | 44 | def _checkBatchNormResult(self, bn1, bn2, input, is_train, cuda=False): 45 | """Check the forward and backward for the customized batch normalization.""" 46 | bn1.train(mode=is_train) 47 | bn2.train(mode=is_train) 48 | 49 | if cuda: 50 | input = input.cuda() 51 | 52 | self._syncParameters(_find_bn(bn1), _find_bn(bn2)) 53 | 54 | input1 = Variable(input, requires_grad=True) 55 | output1 = bn1(input1) 56 | output1.sum().backward() 57 | input2 = Variable(input, requires_grad=True) 58 | output2 = bn2(input2) 59 | output2.sum().backward() 60 | 61 | self.assertTensorClose(input1.data, input2.data) 62 | self.assertTensorClose(output1.data, output2.data) 63 | self.assertTensorClose(input1.grad, input2.grad) 64 | self.assertTensorClose(_find_bn(bn1).running_mean, _find_bn(bn2).running_mean) 65 | self.assertTensorClose(_find_bn(bn1).running_var, _find_bn(bn2).running_var) 66 | 67 | def testSyncBatchNormNormalTrain(self): 68 | bn = nn.BatchNorm1d(10) 69 | sync_bn = SynchronizedBatchNorm1d(10) 70 | 71 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True) 72 | 73 | def testSyncBatchNormNormalEval(self): 74 | bn = nn.BatchNorm1d(10) 75 | sync_bn = SynchronizedBatchNorm1d(10) 76 | 77 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False) 78 | 79 | def testSyncBatchNormSyncTrain(self): 80 | bn = nn.BatchNorm1d(10, eps=1e-5, affine=False) 81 | sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 82 | sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 83 | 84 | bn.cuda() 85 | sync_bn.cuda() 86 | 87 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True, cuda=True) 88 | 89 | def testSyncBatchNormSyncEval(self): 90 | bn = nn.BatchNorm1d(10, eps=1e-5, affine=False) 91 | sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 92 | sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 93 | 94 | bn.cuda() 95 | sync_bn.cuda() 96 | 97 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False, cuda=True) 98 | 99 | def testSyncBatchNorm2DSyncTrain(self): 100 | bn = nn.BatchNorm2d(10) 101 | sync_bn = SynchronizedBatchNorm2d(10) 102 | sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 103 | 104 | bn.cuda() 105 | sync_bn.cuda() 106 | 107 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10, 16, 16), True, cuda=True) 108 | 109 | 110 | if __name__ == '__main__': 111 | unittest.main() 112 | -------------------------------------------------------------------------------- /Segmentation/lib/nn/modules/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | 13 | import numpy as np 14 | from torch.autograd import Variable 15 | 16 | 17 | def as_numpy(v): 18 | if isinstance(v, Variable): 19 | v = v.data 20 | return v.cpu().numpy() 21 | 22 | 23 | class TorchTestCase(unittest.TestCase): 24 | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3): 25 | npa, npb = as_numpy(a), as_numpy(b) 26 | self.assertTrue( 27 | np.allclose(npa, npb, atol=atol), 28 | 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) 29 | ) 30 | -------------------------------------------------------------------------------- /Segmentation/lib/nn/parallel/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_parallel import UserScatteredDataParallel, user_scattered_collate, async_copy_to 2 | -------------------------------------------------------------------------------- /Segmentation/lib/nn/parallel/data_parallel.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf8 -*- 2 | 3 | import torch.cuda as cuda 4 | import torch.nn as nn 5 | import torch 6 | import collections 7 | from torch.nn.parallel._functions import Gather 8 | 9 | 10 | __all__ = ['UserScatteredDataParallel', 'user_scattered_collate', 'async_copy_to'] 11 | 12 | 13 | def async_copy_to(obj, dev, main_stream=None): 14 | if torch.is_tensor(obj): 15 | v = obj.cuda(dev, non_blocking=True) 16 | if main_stream is not None: 17 | v.data.record_stream(main_stream) 18 | return v 19 | elif isinstance(obj, collections.Mapping): 20 | return {k: async_copy_to(o, dev, main_stream) for k, o in obj.items()} 21 | elif isinstance(obj, collections.Sequence): 22 | return [async_copy_to(o, dev, main_stream) for o in obj] 23 | else: 24 | return obj 25 | 26 | 27 | def dict_gather(outputs, target_device, dim=0): 28 | """ 29 | Gathers variables from different GPUs on a specified device 30 | (-1 means the CPU), with dictionary support. 31 | """ 32 | def gather_map(outputs): 33 | out = outputs[0] 34 | if torch.is_tensor(out): 35 | # MJY(20180330) HACK:: force nr_dims > 0 36 | if out.dim() == 0: 37 | outputs = [o.unsqueeze(0) for o in outputs] 38 | return Gather.apply(target_device, dim, *outputs) 39 | elif out is None: 40 | return None 41 | elif isinstance(out, collections.Mapping): 42 | return {k: gather_map([o[k] for o in outputs]) for k in out} 43 | elif isinstance(out, collections.Sequence): 44 | return type(out)(map(gather_map, zip(*outputs))) 45 | return gather_map(outputs) 46 | 47 | 48 | class DictGatherDataParallel(nn.DataParallel): 49 | def gather(self, outputs, output_device): 50 | return dict_gather(outputs, output_device, dim=self.dim) 51 | 52 | 53 | class UserScatteredDataParallel(DictGatherDataParallel): 54 | def scatter(self, inputs, kwargs, device_ids): 55 | assert len(inputs) == 1 56 | inputs = inputs[0] 57 | inputs = _async_copy_stream(inputs, device_ids) 58 | inputs = [[i] for i in inputs] 59 | assert len(kwargs) == 0 60 | kwargs = [{} for _ in range(len(inputs))] 61 | 62 | return inputs, kwargs 63 | 64 | 65 | def user_scattered_collate(batch): 66 | return batch 67 | 68 | 69 | def _async_copy(inputs, device_ids): 70 | nr_devs = len(device_ids) 71 | assert type(inputs) in (tuple, list) 72 | assert len(inputs) == nr_devs 73 | 74 | outputs = [] 75 | for i, dev in zip(inputs, device_ids): 76 | with cuda.device(dev): 77 | outputs.append(async_copy_to(i, dev)) 78 | 79 | return tuple(outputs) 80 | 81 | 82 | def _async_copy_stream(inputs, device_ids): 83 | nr_devs = len(device_ids) 84 | assert type(inputs) in (tuple, list) 85 | assert len(inputs) == nr_devs 86 | 87 | outputs = [] 88 | streams = [_get_stream(d) for d in device_ids] 89 | for i, dev, stream in zip(inputs, device_ids, streams): 90 | with cuda.device(dev): 91 | main_stream = cuda.current_stream() 92 | with cuda.stream(stream): 93 | outputs.append(async_copy_to(i, dev, main_stream=main_stream)) 94 | main_stream.wait_stream(stream) 95 | 96 | return outputs 97 | 98 | 99 | """Adapted from: torch/nn/parallel/_functions.py""" 100 | # background streams used for copying 101 | _streams = None 102 | 103 | 104 | def _get_stream(device): 105 | """Gets a background stream for copying between CPU and GPU""" 106 | global _streams 107 | if device == -1: 108 | return None 109 | if _streams is None: 110 | _streams = [None] * cuda.device_count() 111 | if _streams[device] is None: _streams[device] = cuda.Stream(device) 112 | return _streams[device] 113 | -------------------------------------------------------------------------------- /Segmentation/lib/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .th import * 2 | -------------------------------------------------------------------------------- /Segmentation/lib/utils/data/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .dataset import Dataset, TensorDataset, ConcatDataset 3 | from .dataloader import DataLoader 4 | -------------------------------------------------------------------------------- /Segmentation/lib/utils/data/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.multiprocessing as multiprocessing 3 | from torch._C import _set_worker_signal_handlers, _set_worker_pids, \ 4 | _remove_worker_pids, _error_if_any_worker_fails 5 | from .sampler import SequentialSampler, RandomSampler, BatchSampler 6 | import signal 7 | import functools 8 | import collections 9 | import re 10 | import sys 11 | import threading 12 | import traceback 13 | from torch._six import string_classes, int_classes 14 | import numpy as np 15 | 16 | if sys.version_info[0] == 2: 17 | import Queue as queue 18 | else: 19 | import queue 20 | 21 | 22 | class ExceptionWrapper(object): 23 | r"Wraps an exception plus traceback to communicate across threads" 24 | 25 | def __init__(self, exc_info): 26 | self.exc_type = exc_info[0] 27 | self.exc_msg = "".join(traceback.format_exception(*exc_info)) 28 | 29 | 30 | _use_shared_memory = False 31 | """Whether to use shared memory in default_collate""" 32 | 33 | 34 | def _worker_loop(dataset, index_queue, data_queue, collate_fn, seed, init_fn, worker_id): 35 | global _use_shared_memory 36 | _use_shared_memory = True 37 | 38 | # Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal 39 | # module's handlers are executed after Python returns from C low-level 40 | # handlers, likely when the same fatal signal happened again already. 41 | # https://docs.python.org/3/library/signal.html Sec. 18.8.1.1 42 | _set_worker_signal_handlers() 43 | 44 | torch.set_num_threads(1) 45 | torch.manual_seed(seed) 46 | np.random.seed(seed) 47 | 48 | if init_fn is not None: 49 | init_fn(worker_id) 50 | 51 | while True: 52 | r = index_queue.get() 53 | if r is None: 54 | break 55 | idx, batch_indices = r 56 | try: 57 | samples = collate_fn([dataset[i] for i in batch_indices]) 58 | except Exception: 59 | data_queue.put((idx, ExceptionWrapper(sys.exc_info()))) 60 | else: 61 | data_queue.put((idx, samples)) 62 | 63 | 64 | def _worker_manager_loop(in_queue, out_queue, done_event, pin_memory, device_id): 65 | if pin_memory: 66 | torch.cuda.set_device(device_id) 67 | 68 | while True: 69 | try: 70 | r = in_queue.get() 71 | except Exception: 72 | if done_event.is_set(): 73 | return 74 | raise 75 | if r is None: 76 | break 77 | if isinstance(r[1], ExceptionWrapper): 78 | out_queue.put(r) 79 | continue 80 | idx, batch = r 81 | try: 82 | if pin_memory: 83 | batch = pin_memory_batch(batch) 84 | except Exception: 85 | out_queue.put((idx, ExceptionWrapper(sys.exc_info()))) 86 | else: 87 | out_queue.put((idx, batch)) 88 | 89 | numpy_type_map = { 90 | 'float64': torch.DoubleTensor, 91 | 'float32': torch.FloatTensor, 92 | 'float16': torch.HalfTensor, 93 | 'int64': torch.LongTensor, 94 | 'int32': torch.IntTensor, 95 | 'int16': torch.ShortTensor, 96 | 'int8': torch.CharTensor, 97 | 'uint8': torch.ByteTensor, 98 | } 99 | 100 | 101 | def default_collate(batch): 102 | "Puts each data field into a tensor with outer dimension batch size" 103 | 104 | error_msg = "batch must contain tensors, numbers, dicts or lists; found {}" 105 | elem_type = type(batch[0]) 106 | if torch.is_tensor(batch[0]): 107 | out = None 108 | if _use_shared_memory: 109 | # If we're in a background process, concatenate directly into a 110 | # shared memory tensor to avoid an extra copy 111 | numel = sum([x.numel() for x in batch]) 112 | storage = batch[0].storage()._new_shared(numel) 113 | out = batch[0].new(storage) 114 | return torch.stack(batch, 0, out=out) 115 | elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ 116 | and elem_type.__name__ != 'string_': 117 | elem = batch[0] 118 | if elem_type.__name__ == 'ndarray': 119 | # array of string classes and object 120 | if re.search('[SaUO]', elem.dtype.str) is not None: 121 | raise TypeError(error_msg.format(elem.dtype)) 122 | 123 | return torch.stack([torch.from_numpy(b) for b in batch], 0) 124 | if elem.shape == (): # scalars 125 | py_type = float if elem.dtype.name.startswith('float') else int 126 | return numpy_type_map[elem.dtype.name](list(map(py_type, batch))) 127 | elif isinstance(batch[0], int_classes): 128 | return torch.LongTensor(batch) 129 | elif isinstance(batch[0], float): 130 | return torch.DoubleTensor(batch) 131 | elif isinstance(batch[0], string_classes): 132 | return batch 133 | elif isinstance(batch[0], collections.Mapping): 134 | return {key: default_collate([d[key] for d in batch]) for key in batch[0]} 135 | elif isinstance(batch[0], collections.Sequence): 136 | transposed = zip(*batch) 137 | return [default_collate(samples) for samples in transposed] 138 | 139 | raise TypeError((error_msg.format(type(batch[0])))) 140 | 141 | 142 | def pin_memory_batch(batch): 143 | if torch.is_tensor(batch): 144 | return batch.pin_memory() 145 | elif isinstance(batch, string_classes): 146 | return batch 147 | elif isinstance(batch, collections.Mapping): 148 | return {k: pin_memory_batch(sample) for k, sample in batch.items()} 149 | elif isinstance(batch, collections.Sequence): 150 | return [pin_memory_batch(sample) for sample in batch] 151 | else: 152 | return batch 153 | 154 | 155 | _SIGCHLD_handler_set = False 156 | """Whether SIGCHLD handler is set for DataLoader worker failures. Only one 157 | handler needs to be set for all DataLoaders in a process.""" 158 | 159 | 160 | def _set_SIGCHLD_handler(): 161 | # Windows doesn't support SIGCHLD handler 162 | if sys.platform == 'win32': 163 | return 164 | # can't set signal in child threads 165 | if not isinstance(threading.current_thread(), threading._MainThread): 166 | return 167 | global _SIGCHLD_handler_set 168 | if _SIGCHLD_handler_set: 169 | return 170 | previous_handler = signal.getsignal(signal.SIGCHLD) 171 | if not callable(previous_handler): 172 | previous_handler = None 173 | 174 | def handler(signum, frame): 175 | # This following call uses `waitid` with WNOHANG from C side. Therefore, 176 | # Python can still get and update the process status successfully. 177 | _error_if_any_worker_fails() 178 | if previous_handler is not None: 179 | previous_handler(signum, frame) 180 | 181 | signal.signal(signal.SIGCHLD, handler) 182 | _SIGCHLD_handler_set = True 183 | 184 | 185 | class DataLoaderIter(object): 186 | "Iterates once over the DataLoader's dataset, as specified by the sampler" 187 | 188 | def __init__(self, loader): 189 | self.dataset = loader.dataset 190 | self.collate_fn = loader.collate_fn 191 | self.batch_sampler = loader.batch_sampler 192 | self.num_workers = loader.num_workers 193 | self.pin_memory = loader.pin_memory and torch.cuda.is_available() 194 | self.timeout = loader.timeout 195 | self.done_event = threading.Event() 196 | 197 | self.sample_iter = iter(self.batch_sampler) 198 | 199 | if self.num_workers > 0: 200 | self.worker_init_fn = loader.worker_init_fn 201 | self.index_queue = multiprocessing.SimpleQueue() 202 | self.worker_result_queue = multiprocessing.SimpleQueue() 203 | self.batches_outstanding = 0 204 | self.worker_pids_set = False 205 | self.shutdown = False 206 | self.send_idx = 0 207 | self.rcvd_idx = 0 208 | self.reorder_dict = {} 209 | 210 | base_seed = torch.LongTensor(1).random_(0, 2**31-1)[0] 211 | self.workers = [ 212 | multiprocessing.Process( 213 | target=_worker_loop, 214 | args=(self.dataset, self.index_queue, self.worker_result_queue, self.collate_fn, 215 | base_seed + i, self.worker_init_fn, i)) 216 | for i in range(self.num_workers)] 217 | 218 | if self.pin_memory or self.timeout > 0: 219 | self.data_queue = queue.Queue() 220 | if self.pin_memory: 221 | maybe_device_id = torch.cuda.current_device() 222 | else: 223 | # do not initialize cuda context if not necessary 224 | maybe_device_id = None 225 | self.worker_manager_thread = threading.Thread( 226 | target=_worker_manager_loop, 227 | args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory, 228 | maybe_device_id)) 229 | self.worker_manager_thread.daemon = True 230 | self.worker_manager_thread.start() 231 | else: 232 | self.data_queue = self.worker_result_queue 233 | 234 | for w in self.workers: 235 | w.daemon = True # ensure that the worker exits on process exit 236 | w.start() 237 | 238 | _set_worker_pids(id(self), tuple(w.pid for w in self.workers)) 239 | _set_SIGCHLD_handler() 240 | self.worker_pids_set = True 241 | 242 | # prime the prefetch loop 243 | for _ in range(2 * self.num_workers): 244 | self._put_indices() 245 | 246 | def __len__(self): 247 | return len(self.batch_sampler) 248 | 249 | def _get_batch(self): 250 | if self.timeout > 0: 251 | try: 252 | return self.data_queue.get(timeout=self.timeout) 253 | except queue.Empty: 254 | raise RuntimeError('DataLoader timed out after {} seconds'.format(self.timeout)) 255 | else: 256 | return self.data_queue.get() 257 | 258 | def __next__(self): 259 | if self.num_workers == 0: # same-process loading 260 | indices = next(self.sample_iter) # may raise StopIteration 261 | batch = self.collate_fn([self.dataset[i] for i in indices]) 262 | if self.pin_memory: 263 | batch = pin_memory_batch(batch) 264 | return batch 265 | 266 | # check if the next sample has already been generated 267 | if self.rcvd_idx in self.reorder_dict: 268 | batch = self.reorder_dict.pop(self.rcvd_idx) 269 | return self._process_next_batch(batch) 270 | 271 | if self.batches_outstanding == 0: 272 | self._shutdown_workers() 273 | raise StopIteration 274 | 275 | while True: 276 | assert (not self.shutdown and self.batches_outstanding > 0) 277 | idx, batch = self._get_batch() 278 | self.batches_outstanding -= 1 279 | if idx != self.rcvd_idx: 280 | # store out-of-order samples 281 | self.reorder_dict[idx] = batch 282 | continue 283 | return self._process_next_batch(batch) 284 | 285 | next = __next__ # Python 2 compatibility 286 | 287 | def __iter__(self): 288 | return self 289 | 290 | def _put_indices(self): 291 | assert self.batches_outstanding < 2 * self.num_workers 292 | indices = next(self.sample_iter, None) 293 | if indices is None: 294 | return 295 | self.index_queue.put((self.send_idx, indices)) 296 | self.batches_outstanding += 1 297 | self.send_idx += 1 298 | 299 | def _process_next_batch(self, batch): 300 | self.rcvd_idx += 1 301 | self._put_indices() 302 | if isinstance(batch, ExceptionWrapper): 303 | raise batch.exc_type(batch.exc_msg) 304 | return batch 305 | 306 | def __getstate__(self): 307 | # TODO: add limited pickling support for sharing an iterator 308 | # across multiple threads for HOGWILD. 309 | # Probably the best way to do this is by moving the sample pushing 310 | # to a separate thread and then just sharing the data queue 311 | # but signalling the end is tricky without a non-blocking API 312 | raise NotImplementedError("DataLoaderIterator cannot be pickled") 313 | 314 | def _shutdown_workers(self): 315 | try: 316 | if not self.shutdown: 317 | self.shutdown = True 318 | self.done_event.set() 319 | # if worker_manager_thread is waiting to put 320 | while not self.data_queue.empty(): 321 | self.data_queue.get() 322 | for _ in self.workers: 323 | self.index_queue.put(None) 324 | # done_event should be sufficient to exit worker_manager_thread, 325 | # but be safe here and put another None 326 | self.worker_result_queue.put(None) 327 | finally: 328 | # removes pids no matter what 329 | if self.worker_pids_set: 330 | _remove_worker_pids(id(self)) 331 | self.worker_pids_set = False 332 | 333 | def __del__(self): 334 | if self.num_workers > 0: 335 | self._shutdown_workers() 336 | 337 | 338 | class DataLoader(object): 339 | """ 340 | Data loader. Combines a dataset and a sampler, and provides 341 | single- or multi-process iterators over the dataset. 342 | 343 | Arguments: 344 | dataset (Dataset): dataset from which to load the data. 345 | batch_size (int, optional): how many samples per batch to load 346 | (default: 1). 347 | shuffle (bool, optional): set to ``True`` to have the data reshuffled 348 | at every epoch (default: False). 349 | sampler (Sampler, optional): defines the strategy to draw samples from 350 | the dataset. If specified, ``shuffle`` must be False. 351 | batch_sampler (Sampler, optional): like sampler, but returns a batch of 352 | indices at a time. Mutually exclusive with batch_size, shuffle, 353 | sampler, and drop_last. 354 | num_workers (int, optional): how many subprocesses to use for data 355 | loading. 0 means that the data will be loaded in the main process. 356 | (default: 0) 357 | collate_fn (callable, optional): merges a list of samples to form a mini-batch. 358 | pin_memory (bool, optional): If ``True``, the data loader will copy tensors 359 | into CUDA pinned memory before returning them. 360 | drop_last (bool, optional): set to ``True`` to drop the last incomplete batch, 361 | if the dataset size is not divisible by the batch size. If ``False`` and 362 | the size of dataset is not divisible by the batch size, then the last batch 363 | will be smaller. (default: False) 364 | timeout (numeric, optional): if positive, the timeout value for collecting a batch 365 | from workers. Should always be non-negative. (default: 0) 366 | worker_init_fn (callable, optional): If not None, this will be called on each 367 | worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as 368 | input, after seeding and before data loading. (default: None) 369 | 370 | .. note:: By default, each worker will have its PyTorch seed set to 371 | ``base_seed + worker_id``, where ``base_seed`` is a long generated 372 | by main process using its RNG. You may use ``torch.initial_seed()`` to access 373 | this value in :attr:`worker_init_fn`, which can be used to set other seeds 374 | (e.g. NumPy) before data loading. 375 | 376 | .. warning:: If ``spawn'' start method is used, :attr:`worker_init_fn` cannot be an 377 | unpicklable object, e.g., a lambda function. 378 | """ 379 | 380 | def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, 381 | num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False, 382 | timeout=0, worker_init_fn=None): 383 | self.dataset = dataset 384 | self.batch_size = batch_size 385 | self.num_workers = num_workers 386 | self.collate_fn = collate_fn 387 | self.pin_memory = pin_memory 388 | self.drop_last = drop_last 389 | self.timeout = timeout 390 | self.worker_init_fn = worker_init_fn 391 | 392 | if timeout < 0: 393 | raise ValueError('timeout option should be non-negative') 394 | 395 | if batch_sampler is not None: 396 | if batch_size > 1 or shuffle or sampler is not None or drop_last: 397 | raise ValueError('batch_sampler is mutually exclusive with ' 398 | 'batch_size, shuffle, sampler, and drop_last') 399 | 400 | if sampler is not None and shuffle: 401 | raise ValueError('sampler is mutually exclusive with shuffle') 402 | 403 | if self.num_workers < 0: 404 | raise ValueError('num_workers cannot be negative; ' 405 | 'use num_workers=0 to disable multiprocessing.') 406 | 407 | if batch_sampler is None: 408 | if sampler is None: 409 | if shuffle: 410 | sampler = RandomSampler(dataset) 411 | else: 412 | sampler = SequentialSampler(dataset) 413 | batch_sampler = BatchSampler(sampler, batch_size, drop_last) 414 | 415 | self.sampler = sampler 416 | self.batch_sampler = batch_sampler 417 | 418 | def __iter__(self): 419 | return DataLoaderIter(self) 420 | 421 | def __len__(self): 422 | return len(self.batch_sampler) 423 | -------------------------------------------------------------------------------- /Segmentation/lib/utils/data/dataset.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | import warnings 3 | 4 | from torch._utils import _accumulate 5 | from torch import randperm 6 | 7 | 8 | class Dataset(object): 9 | """An abstract class representing a Dataset. 10 | 11 | All other datasets should subclass it. All subclasses should override 12 | ``__len__``, that provides the size of the dataset, and ``__getitem__``, 13 | supporting integer indexing in range from 0 to len(self) exclusive. 14 | """ 15 | 16 | def __getitem__(self, index): 17 | raise NotImplementedError 18 | 19 | def __len__(self): 20 | raise NotImplementedError 21 | 22 | def __add__(self, other): 23 | return ConcatDataset([self, other]) 24 | 25 | 26 | class TensorDataset(Dataset): 27 | """Dataset wrapping data and target tensors. 28 | 29 | Each sample will be retrieved by indexing both tensors along the first 30 | dimension. 31 | 32 | Arguments: 33 | data_tensor (Tensor): contains sample data. 34 | target_tensor (Tensor): contains sample targets (labels). 35 | """ 36 | 37 | def __init__(self, data_tensor, target_tensor): 38 | assert data_tensor.size(0) == target_tensor.size(0) 39 | self.data_tensor = data_tensor 40 | self.target_tensor = target_tensor 41 | 42 | def __getitem__(self, index): 43 | return self.data_tensor[index], self.target_tensor[index] 44 | 45 | def __len__(self): 46 | return self.data_tensor.size(0) 47 | 48 | 49 | class ConcatDataset(Dataset): 50 | """ 51 | Dataset to concatenate multiple datasets. 52 | Purpose: useful to assemble different existing datasets, possibly 53 | large-scale datasets as the concatenation operation is done in an 54 | on-the-fly manner. 55 | 56 | Arguments: 57 | datasets (iterable): List of datasets to be concatenated 58 | """ 59 | 60 | @staticmethod 61 | def cumsum(sequence): 62 | r, s = [], 0 63 | for e in sequence: 64 | l = len(e) 65 | r.append(l + s) 66 | s += l 67 | return r 68 | 69 | def __init__(self, datasets): 70 | super(ConcatDataset, self).__init__() 71 | assert len(datasets) > 0, 'datasets should not be an empty iterable' 72 | self.datasets = list(datasets) 73 | self.cumulative_sizes = self.cumsum(self.datasets) 74 | 75 | def __len__(self): 76 | return self.cumulative_sizes[-1] 77 | 78 | def __getitem__(self, idx): 79 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) 80 | if dataset_idx == 0: 81 | sample_idx = idx 82 | else: 83 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] 84 | return self.datasets[dataset_idx][sample_idx] 85 | 86 | @property 87 | def cummulative_sizes(self): 88 | warnings.warn("cummulative_sizes attribute is renamed to " 89 | "cumulative_sizes", DeprecationWarning, stacklevel=2) 90 | return self.cumulative_sizes 91 | 92 | 93 | class Subset(Dataset): 94 | def __init__(self, dataset, indices): 95 | self.dataset = dataset 96 | self.indices = indices 97 | 98 | def __getitem__(self, idx): 99 | return self.dataset[self.indices[idx]] 100 | 101 | def __len__(self): 102 | return len(self.indices) 103 | 104 | 105 | def random_split(dataset, lengths): 106 | """ 107 | Randomly split a dataset into non-overlapping new datasets of given lengths 108 | ds 109 | 110 | Arguments: 111 | dataset (Dataset): Dataset to be split 112 | lengths (iterable): lengths of splits to be produced 113 | """ 114 | if sum(lengths) != len(dataset): 115 | raise ValueError("Sum of input lengths does not equal the length of the input dataset!") 116 | 117 | indices = randperm(sum(lengths)) 118 | return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)] 119 | -------------------------------------------------------------------------------- /Segmentation/lib/utils/data/distributed.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from .sampler import Sampler 4 | from torch.distributed import get_world_size, get_rank 5 | 6 | 7 | class DistributedSampler(Sampler): 8 | """Sampler that restricts data loading to a subset of the dataset. 9 | 10 | It is especially useful in conjunction with 11 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 12 | process can pass a DistributedSampler instance as a DataLoader sampler, 13 | and load a subset of the original dataset that is exclusive to it. 14 | 15 | .. note:: 16 | Dataset is assumed to be of constant size. 17 | 18 | Arguments: 19 | dataset: Dataset used for sampling. 20 | num_replicas (optional): Number of processes participating in 21 | distributed training. 22 | rank (optional): Rank of the current process within num_replicas. 23 | """ 24 | 25 | def __init__(self, dataset, num_replicas=None, rank=None): 26 | if num_replicas is None: 27 | num_replicas = get_world_size() 28 | if rank is None: 29 | rank = get_rank() 30 | self.dataset = dataset 31 | self.num_replicas = num_replicas 32 | self.rank = rank 33 | self.epoch = 0 34 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 35 | self.total_size = self.num_samples * self.num_replicas 36 | 37 | def __iter__(self): 38 | # deterministically shuffle based on epoch 39 | g = torch.Generator() 40 | g.manual_seed(self.epoch) 41 | indices = list(torch.randperm(len(self.dataset), generator=g)) 42 | 43 | # add extra samples to make it evenly divisible 44 | indices += indices[:(self.total_size - len(indices))] 45 | assert len(indices) == self.total_size 46 | 47 | # subsample 48 | offset = self.num_samples * self.rank 49 | indices = indices[offset:offset + self.num_samples] 50 | assert len(indices) == self.num_samples 51 | 52 | return iter(indices) 53 | 54 | def __len__(self): 55 | return self.num_samples 56 | 57 | def set_epoch(self, epoch): 58 | self.epoch = epoch 59 | -------------------------------------------------------------------------------- /Segmentation/lib/utils/data/sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Sampler(object): 5 | """Base class for all Samplers. 6 | 7 | Every Sampler subclass has to provide an __iter__ method, providing a way 8 | to iterate over indices of dataset elements, and a __len__ method that 9 | returns the length of the returned iterators. 10 | """ 11 | 12 | def __init__(self, data_source): 13 | pass 14 | 15 | def __iter__(self): 16 | raise NotImplementedError 17 | 18 | def __len__(self): 19 | raise NotImplementedError 20 | 21 | 22 | class SequentialSampler(Sampler): 23 | """Samples elements sequentially, always in the same order. 24 | 25 | Arguments: 26 | data_source (Dataset): dataset to sample from 27 | """ 28 | 29 | def __init__(self, data_source): 30 | self.data_source = data_source 31 | 32 | def __iter__(self): 33 | return iter(range(len(self.data_source))) 34 | 35 | def __len__(self): 36 | return len(self.data_source) 37 | 38 | 39 | class RandomSampler(Sampler): 40 | """Samples elements randomly, without replacement. 41 | 42 | Arguments: 43 | data_source (Dataset): dataset to sample from 44 | """ 45 | 46 | def __init__(self, data_source): 47 | self.data_source = data_source 48 | 49 | def __iter__(self): 50 | return iter(torch.randperm(len(self.data_source)).long()) 51 | 52 | def __len__(self): 53 | return len(self.data_source) 54 | 55 | 56 | class SubsetRandomSampler(Sampler): 57 | """Samples elements randomly from a given list of indices, without replacement. 58 | 59 | Arguments: 60 | indices (list): a list of indices 61 | """ 62 | 63 | def __init__(self, indices): 64 | self.indices = indices 65 | 66 | def __iter__(self): 67 | return (self.indices[i] for i in torch.randperm(len(self.indices))) 68 | 69 | def __len__(self): 70 | return len(self.indices) 71 | 72 | 73 | class WeightedRandomSampler(Sampler): 74 | """Samples elements from [0,..,len(weights)-1] with given probabilities (weights). 75 | 76 | Arguments: 77 | weights (list) : a list of weights, not necessary summing up to one 78 | num_samples (int): number of samples to draw 79 | replacement (bool): if ``True``, samples are drawn with replacement. 80 | If not, they are drawn without replacement, which means that when a 81 | sample index is drawn for a row, it cannot be drawn again for that row. 82 | """ 83 | 84 | def __init__(self, weights, num_samples, replacement=True): 85 | self.weights = torch.DoubleTensor(weights) 86 | self.num_samples = num_samples 87 | self.replacement = replacement 88 | 89 | def __iter__(self): 90 | return iter(torch.multinomial(self.weights, self.num_samples, self.replacement)) 91 | 92 | def __len__(self): 93 | return self.num_samples 94 | 95 | 96 | class BatchSampler(object): 97 | """Wraps another sampler to yield a mini-batch of indices. 98 | 99 | Args: 100 | sampler (Sampler): Base sampler. 101 | batch_size (int): Size of mini-batch. 102 | drop_last (bool): If ``True``, the sampler will drop the last batch if 103 | its size would be less than ``batch_size`` 104 | 105 | Example: 106 | >>> list(BatchSampler(range(10), batch_size=3, drop_last=False)) 107 | [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] 108 | >>> list(BatchSampler(range(10), batch_size=3, drop_last=True)) 109 | [[0, 1, 2], [3, 4, 5], [6, 7, 8]] 110 | """ 111 | 112 | def __init__(self, sampler, batch_size, drop_last): 113 | self.sampler = sampler 114 | self.batch_size = batch_size 115 | self.drop_last = drop_last 116 | 117 | def __iter__(self): 118 | batch = [] 119 | for idx in self.sampler: 120 | batch.append(idx) 121 | if len(batch) == self.batch_size: 122 | yield batch 123 | batch = [] 124 | if len(batch) > 0 and not self.drop_last: 125 | yield batch 126 | 127 | def __len__(self): 128 | if self.drop_last: 129 | return len(self.sampler) // self.batch_size 130 | else: 131 | return (len(self.sampler) + self.batch_size - 1) // self.batch_size 132 | -------------------------------------------------------------------------------- /Segmentation/lib/utils/th.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import numpy as np 4 | import collections 5 | 6 | __all__ = ['as_variable', 'as_numpy', 'mark_volatile'] 7 | 8 | def as_variable(obj): 9 | if isinstance(obj, Variable): 10 | return obj 11 | if isinstance(obj, collections.Sequence): 12 | return [as_variable(v) for v in obj] 13 | elif isinstance(obj, collections.Mapping): 14 | return {k: as_variable(v) for k, v in obj.items()} 15 | else: 16 | return Variable(obj) 17 | 18 | def as_numpy(obj): 19 | if isinstance(obj, collections.Sequence): 20 | return [as_numpy(v) for v in obj] 21 | elif isinstance(obj, collections.Mapping): 22 | return {k: as_numpy(v) for k, v in obj.items()} 23 | elif isinstance(obj, Variable): 24 | return obj.data.cpu().numpy() 25 | elif torch.is_tensor(obj): 26 | return obj.cpu().numpy() 27 | else: 28 | return np.array(obj) 29 | 30 | def mark_volatile(obj): 31 | if torch.is_tensor(obj): 32 | obj = Variable(obj) 33 | if isinstance(obj, Variable): 34 | obj.no_grad = True 35 | return obj 36 | elif isinstance(obj, collections.Mapping): 37 | return {k: mark_volatile(o) for k, o in obj.items()} 38 | elif isinstance(obj, collections.Sequence): 39 | return [mark_volatile(o) for o in obj] 40 | else: 41 | return obj 42 | -------------------------------------------------------------------------------- /Segmentation/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import ModelBuilder, SegmentationModule 2 | -------------------------------------------------------------------------------- /Segmentation/models/mobilenet.py: -------------------------------------------------------------------------------- 1 | """ 2 | This MobileNetV2 implementation is modified from the following repository: 3 | https://github.com/tonylins/pytorch-mobilenet-v2 4 | """ 5 | 6 | import os 7 | import sys 8 | import torch 9 | import torch.nn as nn 10 | import math 11 | from lib.nn import SynchronizedBatchNorm2d 12 | 13 | try: 14 | from urllib import urlretrieve 15 | except ImportError: 16 | from urllib.request import urlretrieve 17 | 18 | 19 | __all__ = ['mobilenetv2'] 20 | 21 | 22 | model_urls = { 23 | 'mobilenetv2': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/mobilenet_v2.pth.tar', 24 | } 25 | 26 | 27 | def conv_bn(inp, oup, stride): 28 | return nn.Sequential( 29 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 30 | SynchronizedBatchNorm2d(oup), 31 | nn.ReLU6(inplace=True) 32 | ) 33 | 34 | 35 | def conv_1x1_bn(inp, oup): 36 | return nn.Sequential( 37 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 38 | SynchronizedBatchNorm2d(oup), 39 | nn.ReLU6(inplace=True) 40 | ) 41 | 42 | 43 | class InvertedResidual(nn.Module): 44 | def __init__(self, inp, oup, stride, expand_ratio): 45 | super(InvertedResidual, self).__init__() 46 | self.stride = stride 47 | assert stride in [1, 2] 48 | 49 | hidden_dim = round(inp * expand_ratio) 50 | self.use_res_connect = self.stride == 1 and inp == oup 51 | 52 | if expand_ratio == 1: 53 | self.conv = nn.Sequential( 54 | # dw 55 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 56 | SynchronizedBatchNorm2d(hidden_dim), 57 | nn.ReLU6(inplace=True), 58 | # pw-linear 59 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 60 | SynchronizedBatchNorm2d(oup), 61 | ) 62 | else: 63 | self.conv = nn.Sequential( 64 | # pw 65 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), 66 | SynchronizedBatchNorm2d(hidden_dim), 67 | nn.ReLU6(inplace=True), 68 | # dw 69 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 70 | SynchronizedBatchNorm2d(hidden_dim), 71 | nn.ReLU6(inplace=True), 72 | # pw-linear 73 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 74 | SynchronizedBatchNorm2d(oup), 75 | ) 76 | 77 | def forward(self, x): 78 | if self.use_res_connect: 79 | return x + self.conv(x) 80 | else: 81 | return self.conv(x) 82 | 83 | 84 | class MobileNetV2(nn.Module): 85 | def __init__(self, n_class=1000, input_size=224, width_mult=1.): 86 | super(MobileNetV2, self).__init__() 87 | block = InvertedResidual 88 | input_channel = 32 89 | last_channel = 1280 90 | interverted_residual_setting = [ 91 | # t, c, n, s 92 | [1, 16, 1, 1], 93 | [6, 24, 2, 2], 94 | [6, 32, 3, 2], 95 | [6, 64, 4, 2], 96 | [6, 96, 3, 1], 97 | [6, 160, 3, 2], 98 | [6, 320, 1, 1], 99 | ] 100 | 101 | # building first layer 102 | assert input_size % 32 == 0 103 | input_channel = int(input_channel * width_mult) 104 | self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel 105 | self.features = [conv_bn(3, input_channel, 2)] 106 | # building inverted residual blocks 107 | for t, c, n, s in interverted_residual_setting: 108 | output_channel = int(c * width_mult) 109 | for i in range(n): 110 | if i == 0: 111 | self.features.append(block(input_channel, output_channel, s, expand_ratio=t)) 112 | else: 113 | self.features.append(block(input_channel, output_channel, 1, expand_ratio=t)) 114 | input_channel = output_channel 115 | # building last several layers 116 | self.features.append(conv_1x1_bn(input_channel, self.last_channel)) 117 | # make it nn.Sequential 118 | self.features = nn.Sequential(*self.features) 119 | 120 | # building classifier 121 | self.classifier = nn.Sequential( 122 | nn.Dropout(0.2), 123 | nn.Linear(self.last_channel, n_class), 124 | ) 125 | 126 | self._initialize_weights() 127 | 128 | def forward(self, x): 129 | x = self.features(x) 130 | x = x.mean(3).mean(2) 131 | x = self.classifier(x) 132 | return x 133 | 134 | def _initialize_weights(self): 135 | for m in self.modules(): 136 | if isinstance(m, nn.Conv2d): 137 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 138 | m.weight.data.normal_(0, math.sqrt(2. / n)) 139 | if m.bias is not None: 140 | m.bias.data.zero_() 141 | elif isinstance(m, SynchronizedBatchNorm2d): 142 | m.weight.data.fill_(1) 143 | m.bias.data.zero_() 144 | elif isinstance(m, nn.Linear): 145 | n = m.weight.size(1) 146 | m.weight.data.normal_(0, 0.01) 147 | m.bias.data.zero_() 148 | 149 | 150 | def mobilenetv2(pretrained=False, **kwargs): 151 | """Constructs a MobileNet_V2 model. 152 | 153 | Args: 154 | pretrained (bool): If True, returns a model pre-trained on ImageNet 155 | """ 156 | model = MobileNetV2(n_class=1000, **kwargs) 157 | if pretrained: 158 | model.load_state_dict(load_url(model_urls['mobilenetv2']), strict=False) 159 | return model 160 | 161 | 162 | def load_url(url, model_dir='./pretrained', map_location=None): 163 | if not os.path.exists(model_dir): 164 | os.makedirs(model_dir) 165 | filename = url.split('/')[-1] 166 | cached_file = os.path.join(model_dir, filename) 167 | if not os.path.exists(cached_file): 168 | sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) 169 | urlretrieve(url, cached_file) 170 | return torch.load(cached_file, map_location=map_location) 171 | 172 | -------------------------------------------------------------------------------- /Segmentation/models/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | from . import resnet, resnext, mobilenet 5 | from lib.nn import SynchronizedBatchNorm2d 6 | 7 | 8 | class SegmentationModuleBase(nn.Module): 9 | def __init__(self): 10 | super(SegmentationModuleBase, self).__init__() 11 | 12 | def pixel_acc(self, pred, label): 13 | _, preds = torch.max(pred, dim=1) 14 | valid = (label >= 0).long() 15 | acc_sum = torch.sum(valid * (preds == label).long()) 16 | pixel_sum = torch.sum(valid) 17 | acc = acc_sum.float() / (pixel_sum.float() + 1e-10) 18 | return acc 19 | 20 | 21 | class SegmentationModule(SegmentationModuleBase): 22 | def __init__(self, net_enc, net_dec, crit, deep_sup_scale=None): 23 | super(SegmentationModule, self).__init__() 24 | self.encoder = net_enc 25 | self.decoder = net_dec 26 | self.crit = crit 27 | self.deep_sup_scale = deep_sup_scale 28 | 29 | def forward(self, feed_dict, segSize=None): 30 | # training 31 | if segSize is None: 32 | if self.deep_sup_scale is not None: # use deep supervision technique 33 | (pred, pred_deepsup) = self.decoder(self.encoder(feed_dict['img_data'], return_feature_maps=True)) 34 | else: 35 | pred = self.decoder(self.encoder(feed_dict['img_data'], return_feature_maps=True)) 36 | 37 | loss = self.crit(pred, feed_dict['seg_label']) 38 | if self.deep_sup_scale is not None: 39 | loss_deepsup = self.crit(pred_deepsup, feed_dict['seg_label']) 40 | loss = loss + loss_deepsup * self.deep_sup_scale 41 | 42 | acc = self.pixel_acc(pred, feed_dict['seg_label']) 43 | return loss, acc 44 | # inference 45 | else: 46 | pred = self.decoder(self.encoder(feed_dict['img_data'], return_feature_maps=True), segSize=segSize) 47 | return pred 48 | 49 | 50 | def conv3x3(in_planes, out_planes, stride=1, has_bias=False): 51 | "3x3 convolution with padding" 52 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 53 | padding=1, bias=has_bias) 54 | 55 | 56 | def conv3x3_bn_relu(in_planes, out_planes, stride=1): 57 | return nn.Sequential( 58 | conv3x3(in_planes, out_planes, stride), 59 | SynchronizedBatchNorm2d(out_planes), 60 | nn.ReLU(inplace=True), 61 | ) 62 | 63 | 64 | class ModelBuilder(): 65 | # custom weights initialization 66 | def weights_init(self, m): 67 | classname = m.__class__.__name__ 68 | if classname.find('Conv') != -1: 69 | nn.init.kaiming_normal_(m.weight.data) 70 | elif classname.find('BatchNorm') != -1: 71 | m.weight.data.fill_(1.) 72 | m.bias.data.fill_(1e-4) 73 | #elif classname.find('Linear') != -1: 74 | # m.weight.data.normal_(0.0, 0.0001) 75 | 76 | def build_encoder(self, arch='resnet50dilated', fc_dim=512, weights=''): 77 | pretrained = True if len(weights) == 0 else False 78 | arch = arch.lower() 79 | if arch == 'mobilenetv2dilated': 80 | orig_mobilenet = mobilenet.__dict__['mobilenetv2'](pretrained=pretrained) 81 | net_encoder = MobileNetV2Dilated(orig_mobilenet, dilate_scale=8) 82 | elif arch == 'resnet18': 83 | orig_resnet = resnet.__dict__['resnet18'](pretrained=pretrained) 84 | net_encoder = Resnet(orig_resnet) 85 | elif arch == 'resnet18dilated': 86 | orig_resnet = resnet.__dict__['resnet18'](pretrained=pretrained) 87 | net_encoder = ResnetDilated(orig_resnet, dilate_scale=8) 88 | elif arch == 'resnet34': 89 | raise NotImplementedError 90 | orig_resnet = resnet.__dict__['resnet34'](pretrained=pretrained) 91 | net_encoder = Resnet(orig_resnet) 92 | elif arch == 'resnet34dilated': 93 | raise NotImplementedError 94 | orig_resnet = resnet.__dict__['resnet34'](pretrained=pretrained) 95 | net_encoder = ResnetDilated(orig_resnet, dilate_scale=8) 96 | elif arch == 'resnet50': 97 | orig_resnet = resnet.__dict__['resnet50'](pretrained=pretrained) 98 | net_encoder = Resnet(orig_resnet) 99 | elif arch == 'resnet50dilated': 100 | orig_resnet = resnet.__dict__['resnet50'](pretrained=pretrained) 101 | net_encoder = ResnetDilated(orig_resnet, dilate_scale=8) 102 | elif arch == 'resnet101': 103 | orig_resnet = resnet.__dict__['resnet101'](pretrained=pretrained) 104 | net_encoder = Resnet(orig_resnet) 105 | elif arch == 'resnet101dilated': 106 | orig_resnet = resnet.__dict__['resnet101'](pretrained=pretrained) 107 | net_encoder = ResnetDilated(orig_resnet, dilate_scale=8) 108 | elif arch == 'resnext101': 109 | orig_resnext = resnext.__dict__['resnext101'](pretrained=pretrained) 110 | net_encoder = Resnet(orig_resnext) # we can still use class Resnet 111 | else: 112 | raise Exception('Architecture undefined!') 113 | 114 | # net_encoder.apply(self.weights_init) 115 | if len(weights) > 0: 116 | print('Loading weights for net_encoder') 117 | net_encoder.load_state_dict( 118 | torch.load(weights, map_location=lambda storage, loc: storage), strict=False) 119 | return net_encoder 120 | 121 | def build_decoder(self, arch='ppm_deepsup', 122 | fc_dim=512, num_class=150, 123 | weights='', use_softmax=False): 124 | arch = arch.lower() 125 | if arch == 'c1_deepsup': 126 | net_decoder = C1DeepSup( 127 | num_class=num_class, 128 | fc_dim=fc_dim, 129 | use_softmax=use_softmax) 130 | elif arch == 'c1': 131 | net_decoder = C1( 132 | num_class=num_class, 133 | fc_dim=fc_dim, 134 | use_softmax=use_softmax) 135 | elif arch == 'ppm': 136 | net_decoder = PPM( 137 | num_class=num_class, 138 | fc_dim=fc_dim, 139 | use_softmax=use_softmax) 140 | elif arch == 'ppm_deepsup': 141 | net_decoder = PPMDeepsup( 142 | num_class=num_class, 143 | fc_dim=fc_dim, 144 | use_softmax=use_softmax) 145 | elif arch == 'upernet_lite': 146 | net_decoder = UPerNet( 147 | num_class=num_class, 148 | fc_dim=fc_dim, 149 | use_softmax=use_softmax, 150 | fpn_dim=256) 151 | elif arch == 'upernet': 152 | net_decoder = UPerNet( 153 | num_class=num_class, 154 | fc_dim=fc_dim, 155 | use_softmax=use_softmax, 156 | fpn_dim=512) 157 | else: 158 | raise Exception('Architecture undefined!') 159 | 160 | net_decoder.apply(self.weights_init) 161 | if len(weights) > 0: 162 | print('Loading weights for net_decoder') 163 | net_decoder.load_state_dict( 164 | torch.load(weights, map_location=lambda storage, loc: storage), strict=False) 165 | return net_decoder 166 | 167 | 168 | class Resnet(nn.Module): 169 | def __init__(self, orig_resnet): 170 | super(Resnet, self).__init__() 171 | 172 | # take pretrained resnet, except AvgPool and FC 173 | self.conv1 = orig_resnet.conv1 174 | self.bn1 = orig_resnet.bn1 175 | self.relu1 = orig_resnet.relu1 176 | self.conv2 = orig_resnet.conv2 177 | self.bn2 = orig_resnet.bn2 178 | self.relu2 = orig_resnet.relu2 179 | self.conv3 = orig_resnet.conv3 180 | self.bn3 = orig_resnet.bn3 181 | self.relu3 = orig_resnet.relu3 182 | self.maxpool = orig_resnet.maxpool 183 | self.layer1 = orig_resnet.layer1 184 | self.layer2 = orig_resnet.layer2 185 | self.layer3 = orig_resnet.layer3 186 | self.layer4 = orig_resnet.layer4 187 | 188 | def forward(self, x, return_feature_maps=False): 189 | conv_out = [] 190 | 191 | x = self.relu1(self.bn1(self.conv1(x))) 192 | x = self.relu2(self.bn2(self.conv2(x))) 193 | x = self.relu3(self.bn3(self.conv3(x))) 194 | x = self.maxpool(x) 195 | 196 | x = self.layer1(x); conv_out.append(x); 197 | x = self.layer2(x); conv_out.append(x); 198 | x = self.layer3(x); conv_out.append(x); 199 | x = self.layer4(x); conv_out.append(x); 200 | 201 | if return_feature_maps: 202 | return conv_out 203 | return [x] 204 | 205 | 206 | class ResnetDilated(nn.Module): 207 | def __init__(self, orig_resnet, dilate_scale=8): 208 | super(ResnetDilated, self).__init__() 209 | from functools import partial 210 | 211 | if dilate_scale == 8: 212 | orig_resnet.layer3.apply( 213 | partial(self._nostride_dilate, dilate=2)) 214 | orig_resnet.layer4.apply( 215 | partial(self._nostride_dilate, dilate=4)) 216 | elif dilate_scale == 16: 217 | orig_resnet.layer4.apply( 218 | partial(self._nostride_dilate, dilate=2)) 219 | 220 | # take pretrained resnet, except AvgPool and FC 221 | self.conv1 = orig_resnet.conv1 222 | self.bn1 = orig_resnet.bn1 223 | self.relu1 = orig_resnet.relu1 224 | self.conv2 = orig_resnet.conv2 225 | self.bn2 = orig_resnet.bn2 226 | self.relu2 = orig_resnet.relu2 227 | self.conv3 = orig_resnet.conv3 228 | self.bn3 = orig_resnet.bn3 229 | self.relu3 = orig_resnet.relu3 230 | self.maxpool = orig_resnet.maxpool 231 | self.layer1 = orig_resnet.layer1 232 | self.layer2 = orig_resnet.layer2 233 | self.layer3 = orig_resnet.layer3 234 | self.layer4 = orig_resnet.layer4 235 | 236 | def _nostride_dilate(self, m, dilate): 237 | classname = m.__class__.__name__ 238 | if classname.find('Conv') != -1: 239 | # the convolution with stride 240 | if m.stride == (2, 2): 241 | m.stride = (1, 1) 242 | if m.kernel_size == (3, 3): 243 | m.dilation = (dilate//2, dilate//2) 244 | m.padding = (dilate//2, dilate//2) 245 | # other convoluions 246 | else: 247 | if m.kernel_size == (3, 3): 248 | m.dilation = (dilate, dilate) 249 | m.padding = (dilate, dilate) 250 | 251 | def forward(self, x, return_feature_maps=False): 252 | conv_out = [] 253 | 254 | x = self.relu1(self.bn1(self.conv1(x))) 255 | x = self.relu2(self.bn2(self.conv2(x))) 256 | x = self.relu3(self.bn3(self.conv3(x))) 257 | x = self.maxpool(x) 258 | 259 | x = self.layer1(x); conv_out.append(x); 260 | x = self.layer2(x); conv_out.append(x); 261 | x = self.layer3(x); conv_out.append(x); 262 | x = self.layer4(x); conv_out.append(x); 263 | 264 | if return_feature_maps: 265 | return conv_out 266 | return [x] 267 | 268 | 269 | class MobileNetV2Dilated(nn.Module): 270 | def __init__(self, orig_net, dilate_scale=8): 271 | super(MobileNetV2Dilated, self).__init__() 272 | from functools import partial 273 | 274 | # take pretrained mobilenet features 275 | self.features = orig_net.features[:-1] 276 | 277 | self.total_idx = len(self.features) 278 | self.down_idx = [2, 4, 7, 14] 279 | 280 | if dilate_scale == 8: 281 | for i in range(self.down_idx[-2], self.down_idx[-1]): 282 | self.features[i].apply( 283 | partial(self._nostride_dilate, dilate=2) 284 | ) 285 | for i in range(self.down_idx[-1], self.total_idx): 286 | self.features[i].apply( 287 | partial(self._nostride_dilate, dilate=4) 288 | ) 289 | elif dilate_scale == 16: 290 | for i in range(self.down_idx[-1], self.total_idx): 291 | self.features[i].apply( 292 | partial(self._nostride_dilate, dilate=2) 293 | ) 294 | 295 | def _nostride_dilate(self, m, dilate): 296 | classname = m.__class__.__name__ 297 | if classname.find('Conv') != -1: 298 | # the convolution with stride 299 | if m.stride == (2, 2): 300 | m.stride = (1, 1) 301 | if m.kernel_size == (3, 3): 302 | m.dilation = (dilate//2, dilate//2) 303 | m.padding = (dilate//2, dilate//2) 304 | # other convoluions 305 | else: 306 | if m.kernel_size == (3, 3): 307 | m.dilation = (dilate, dilate) 308 | m.padding = (dilate, dilate) 309 | 310 | def forward(self, x, return_feature_maps=False): 311 | if return_feature_maps: 312 | conv_out = [] 313 | for i in range(self.total_idx): 314 | x = self.features[i](x) 315 | if i in self.down_idx: 316 | conv_out.append(x) 317 | conv_out.append(x) 318 | return conv_out 319 | 320 | else: 321 | return [self.features(x)] 322 | 323 | 324 | # last conv, deep supervision 325 | class C1DeepSup(nn.Module): 326 | def __init__(self, num_class=150, fc_dim=2048, use_softmax=False): 327 | super(C1DeepSup, self).__init__() 328 | self.use_softmax = use_softmax 329 | 330 | self.cbr = conv3x3_bn_relu(fc_dim, fc_dim // 4, 1) 331 | self.cbr_deepsup = conv3x3_bn_relu(fc_dim // 2, fc_dim // 4, 1) 332 | 333 | # last conv 334 | self.conv_last = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0) 335 | self.conv_last_deepsup = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0) 336 | 337 | def forward(self, conv_out, segSize=None): 338 | conv5 = conv_out[-1] 339 | 340 | x = self.cbr(conv5) 341 | x = self.conv_last(x) 342 | 343 | if self.use_softmax: # is True during inference 344 | x = nn.functional.interpolate( 345 | x, size=segSize, mode='bilinear', align_corners=False) 346 | x = nn.functional.softmax(x, dim=1) 347 | return x 348 | 349 | # deep sup 350 | conv4 = conv_out[-2] 351 | _ = self.cbr_deepsup(conv4) 352 | _ = self.conv_last_deepsup(_) 353 | 354 | x = nn.functional.log_softmax(x, dim=1) 355 | _ = nn.functional.log_softmax(_, dim=1) 356 | 357 | return (x, _) 358 | 359 | 360 | # last conv 361 | class C1(nn.Module): 362 | def __init__(self, num_class=150, fc_dim=2048, use_softmax=False): 363 | super(C1, self).__init__() 364 | self.use_softmax = use_softmax 365 | 366 | self.cbr = conv3x3_bn_relu(fc_dim, fc_dim // 4, 1) 367 | 368 | # last conv 369 | self.conv_last = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0) 370 | 371 | def forward(self, conv_out, segSize=None): 372 | conv5 = conv_out[-1] 373 | x = self.cbr(conv5) 374 | x = self.conv_last(x) 375 | 376 | if self.use_softmax: # is True during inference 377 | x = nn.functional.interpolate( 378 | x, size=segSize, mode='bilinear', align_corners=False) 379 | x = nn.functional.softmax(x, dim=1) 380 | else: 381 | x = nn.functional.log_softmax(x, dim=1) 382 | 383 | return x 384 | 385 | 386 | # pyramid pooling 387 | class PPM(nn.Module): 388 | def __init__(self, num_class=150, fc_dim=4096, 389 | use_softmax=False, pool_scales=(1, 2, 3, 6)): 390 | super(PPM, self).__init__() 391 | self.use_softmax = use_softmax 392 | 393 | self.ppm = [] 394 | for scale in pool_scales: 395 | self.ppm.append(nn.Sequential( 396 | nn.AdaptiveAvgPool2d(scale), 397 | nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False), 398 | SynchronizedBatchNorm2d(512), 399 | nn.ReLU(inplace=True) 400 | )) 401 | self.ppm = nn.ModuleList(self.ppm) 402 | 403 | self.conv_last = nn.Sequential( 404 | nn.Conv2d(fc_dim+len(pool_scales)*512, 512, 405 | kernel_size=3, padding=1, bias=False), 406 | SynchronizedBatchNorm2d(512), 407 | nn.ReLU(inplace=True), 408 | nn.Dropout2d(0.1), 409 | nn.Conv2d(512, num_class, kernel_size=1) 410 | ) 411 | 412 | def forward(self, conv_out, segSize=None): 413 | conv5 = conv_out[-1] 414 | 415 | input_size = conv5.size() 416 | ppm_out = [conv5] 417 | for pool_scale in self.ppm: 418 | ppm_out.append(nn.functional.interpolate( 419 | pool_scale(conv5), 420 | (input_size[2], input_size[3]), 421 | mode='bilinear', align_corners=False)) 422 | ppm_out = torch.cat(ppm_out, 1) 423 | 424 | x = self.conv_last(ppm_out) 425 | 426 | if self.use_softmax: # is True during inference 427 | x = nn.functional.interpolate( 428 | x, size=segSize, mode='bilinear', align_corners=False) 429 | x = nn.functional.softmax(x, dim=1) 430 | else: 431 | x = nn.functional.log_softmax(x, dim=1) 432 | return x 433 | 434 | 435 | # pyramid pooling, deep supervision 436 | class PPMDeepsup(nn.Module): 437 | def __init__(self, num_class=150, fc_dim=4096, 438 | use_softmax=False, pool_scales=(1, 2, 3, 6)): 439 | super(PPMDeepsup, self).__init__() 440 | self.use_softmax = use_softmax 441 | 442 | self.ppm = [] 443 | for scale in pool_scales: 444 | self.ppm.append(nn.Sequential( 445 | nn.AdaptiveAvgPool2d(scale), 446 | nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False), 447 | SynchronizedBatchNorm2d(512), 448 | nn.ReLU(inplace=True) 449 | )) 450 | self.ppm = nn.ModuleList(self.ppm) 451 | self.cbr_deepsup = conv3x3_bn_relu(fc_dim // 2, fc_dim // 4, 1) 452 | 453 | self.conv_last = nn.Sequential( 454 | nn.Conv2d(fc_dim+len(pool_scales)*512, 512, 455 | kernel_size=3, padding=1, bias=False), 456 | SynchronizedBatchNorm2d(512), 457 | nn.ReLU(inplace=True), 458 | nn.Dropout2d(0.1), 459 | nn.Conv2d(512, num_class, kernel_size=1) 460 | ) 461 | self.conv_last_deepsup = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0) 462 | self.dropout_deepsup = nn.Dropout2d(0.1) 463 | 464 | def forward(self, conv_out, segSize=None): 465 | conv5 = conv_out[-1] 466 | 467 | input_size = conv5.size() 468 | ppm_out = [conv5] 469 | for pool_scale in self.ppm: 470 | ppm_out.append(nn.functional.interpolate( 471 | pool_scale(conv5), 472 | (input_size[2], input_size[3]), 473 | mode='bilinear', align_corners=False)) 474 | ppm_out = torch.cat(ppm_out, 1) 475 | 476 | x = self.conv_last(ppm_out) 477 | 478 | if self.use_softmax: # is True during inference 479 | x = nn.functional.interpolate( 480 | x, size=segSize, mode='bilinear', align_corners=False) 481 | x = nn.functional.softmax(x, dim=1) 482 | return x 483 | 484 | # deep sup 485 | conv4 = conv_out[-2] 486 | _ = self.cbr_deepsup(conv4) 487 | _ = self.dropout_deepsup(_) 488 | _ = self.conv_last_deepsup(_) 489 | 490 | x = nn.functional.log_softmax(x, dim=1) 491 | _ = nn.functional.log_softmax(_, dim=1) 492 | 493 | return (x, _) 494 | 495 | 496 | # upernet 497 | class UPerNet(nn.Module): 498 | def __init__(self, num_class=150, fc_dim=4096, 499 | use_softmax=False, pool_scales=(1, 2, 3, 6), 500 | fpn_inplanes=(256, 512, 1024, 2048), fpn_dim=256): 501 | super(UPerNet, self).__init__() 502 | self.use_softmax = use_softmax 503 | 504 | # PPM Module 505 | self.ppm_pooling = [] 506 | self.ppm_conv = [] 507 | 508 | for scale in pool_scales: 509 | self.ppm_pooling.append(nn.AdaptiveAvgPool2d(scale)) 510 | self.ppm_conv.append(nn.Sequential( 511 | nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False), 512 | SynchronizedBatchNorm2d(512), 513 | nn.ReLU(inplace=True) 514 | )) 515 | self.ppm_pooling = nn.ModuleList(self.ppm_pooling) 516 | self.ppm_conv = nn.ModuleList(self.ppm_conv) 517 | self.ppm_last_conv = conv3x3_bn_relu(fc_dim + len(pool_scales)*512, fpn_dim, 1) 518 | 519 | # FPN Module 520 | self.fpn_in = [] 521 | for fpn_inplane in fpn_inplanes[:-1]: # skip the top layer 522 | self.fpn_in.append(nn.Sequential( 523 | nn.Conv2d(fpn_inplane, fpn_dim, kernel_size=1, bias=False), 524 | SynchronizedBatchNorm2d(fpn_dim), 525 | nn.ReLU(inplace=True) 526 | )) 527 | self.fpn_in = nn.ModuleList(self.fpn_in) 528 | 529 | self.fpn_out = [] 530 | for i in range(len(fpn_inplanes) - 1): # skip the top layer 531 | self.fpn_out.append(nn.Sequential( 532 | conv3x3_bn_relu(fpn_dim, fpn_dim, 1), 533 | )) 534 | self.fpn_out = nn.ModuleList(self.fpn_out) 535 | 536 | self.conv_last = nn.Sequential( 537 | conv3x3_bn_relu(len(fpn_inplanes) * fpn_dim, fpn_dim, 1), 538 | nn.Conv2d(fpn_dim, num_class, kernel_size=1) 539 | ) 540 | 541 | def forward(self, conv_out, segSize=None): 542 | conv5 = conv_out[-1] 543 | 544 | input_size = conv5.size() 545 | ppm_out = [conv5] 546 | for pool_scale, pool_conv in zip(self.ppm_pooling, self.ppm_conv): 547 | ppm_out.append(pool_conv(nn.functional.interpolate( 548 | pool_scale(conv5), 549 | (input_size[2], input_size[3]), 550 | mode='bilinear', align_corners=False))) 551 | ppm_out = torch.cat(ppm_out, 1) 552 | f = self.ppm_last_conv(ppm_out) 553 | 554 | fpn_feature_list = [f] 555 | for i in reversed(range(len(conv_out) - 1)): 556 | conv_x = conv_out[i] 557 | conv_x = self.fpn_in[i](conv_x) # lateral branch 558 | 559 | f = nn.functional.interpolate( 560 | f, size=conv_x.size()[2:], mode='bilinear', align_corners=False) # top-down branch 561 | f = conv_x + f 562 | 563 | fpn_feature_list.append(self.fpn_out[i](f)) 564 | 565 | fpn_feature_list.reverse() # [P2 - P5] 566 | output_size = fpn_feature_list[0].size()[2:] 567 | fusion_list = [fpn_feature_list[0]] 568 | for i in range(1, len(fpn_feature_list)): 569 | fusion_list.append(nn.functional.interpolate( 570 | fpn_feature_list[i], 571 | output_size, 572 | mode='bilinear', align_corners=False)) 573 | fusion_out = torch.cat(fusion_list, 1) 574 | x = self.conv_last(fusion_out) 575 | 576 | if self.use_softmax: # is True during inference 577 | x = nn.functional.interpolate( 578 | x, size=segSize, mode='bilinear', align_corners=False) 579 | x = nn.functional.softmax(x, dim=1) 580 | return x 581 | 582 | x = nn.functional.log_softmax(x, dim=1) 583 | 584 | return x 585 | -------------------------------------------------------------------------------- /Segmentation/models/resnet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import torch.nn as nn 5 | import math 6 | from lib.nn import SynchronizedBatchNorm2d 7 | 8 | try: 9 | from urllib import urlretrieve 10 | except ImportError: 11 | from urllib.request import urlretrieve 12 | 13 | 14 | __all__ = ['ResNet', 'resnet18', 'resnet50', 'resnet101'] # resnet101 is coming soon! 15 | 16 | 17 | model_urls = { 18 | 'resnet18': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet18-imagenet.pth', 19 | 'resnet50': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet50-imagenet.pth', 20 | 'resnet101': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet101-imagenet.pth' 21 | } 22 | 23 | 24 | def conv3x3(in_planes, out_planes, stride=1): 25 | "3x3 convolution with padding" 26 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 27 | padding=1, bias=False) 28 | 29 | 30 | class BasicBlock(nn.Module): 31 | expansion = 1 32 | 33 | def __init__(self, inplanes, planes, stride=1, downsample=None): 34 | super(BasicBlock, self).__init__() 35 | self.conv1 = conv3x3(inplanes, planes, stride) 36 | self.bn1 = SynchronizedBatchNorm2d(planes) 37 | self.relu = nn.ReLU(inplace=True) 38 | self.conv2 = conv3x3(planes, planes) 39 | self.bn2 = SynchronizedBatchNorm2d(planes) 40 | self.downsample = downsample 41 | self.stride = stride 42 | 43 | def forward(self, x): 44 | residual = x 45 | 46 | out = self.conv1(x) 47 | out = self.bn1(out) 48 | out = self.relu(out) 49 | 50 | out = self.conv2(out) 51 | out = self.bn2(out) 52 | 53 | if self.downsample is not None: 54 | residual = self.downsample(x) 55 | 56 | out += residual 57 | out = self.relu(out) 58 | 59 | return out 60 | 61 | 62 | class Bottleneck(nn.Module): 63 | expansion = 4 64 | 65 | def __init__(self, inplanes, planes, stride=1, downsample=None): 66 | super(Bottleneck, self).__init__() 67 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 68 | self.bn1 = SynchronizedBatchNorm2d(planes) 69 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 70 | padding=1, bias=False) 71 | self.bn2 = SynchronizedBatchNorm2d(planes) 72 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 73 | self.bn3 = SynchronizedBatchNorm2d(planes * 4) 74 | self.relu = nn.ReLU(inplace=True) 75 | self.downsample = downsample 76 | self.stride = stride 77 | 78 | def forward(self, x): 79 | residual = x 80 | 81 | out = self.conv1(x) 82 | out = self.bn1(out) 83 | out = self.relu(out) 84 | 85 | out = self.conv2(out) 86 | out = self.bn2(out) 87 | out = self.relu(out) 88 | 89 | out = self.conv3(out) 90 | out = self.bn3(out) 91 | 92 | if self.downsample is not None: 93 | residual = self.downsample(x) 94 | 95 | out += residual 96 | out = self.relu(out) 97 | 98 | return out 99 | 100 | 101 | class ResNet(nn.Module): 102 | 103 | def __init__(self, block, layers, num_classes=1000): 104 | self.inplanes = 128 105 | super(ResNet, self).__init__() 106 | self.conv1 = conv3x3(3, 64, stride=2) 107 | self.bn1 = SynchronizedBatchNorm2d(64) 108 | self.relu1 = nn.ReLU(inplace=True) 109 | self.conv2 = conv3x3(64, 64) 110 | self.bn2 = SynchronizedBatchNorm2d(64) 111 | self.relu2 = nn.ReLU(inplace=True) 112 | self.conv3 = conv3x3(64, 128) 113 | self.bn3 = SynchronizedBatchNorm2d(128) 114 | self.relu3 = nn.ReLU(inplace=True) 115 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 116 | 117 | self.layer1 = self._make_layer(block, 64, layers[0]) 118 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 119 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 120 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 121 | self.avgpool = nn.AvgPool2d(7, stride=1) 122 | self.fc = nn.Linear(512 * block.expansion, num_classes) 123 | 124 | for m in self.modules(): 125 | if isinstance(m, nn.Conv2d): 126 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 127 | m.weight.data.normal_(0, math.sqrt(2. / n)) 128 | elif isinstance(m, SynchronizedBatchNorm2d): 129 | m.weight.data.fill_(1) 130 | m.bias.data.zero_() 131 | 132 | def _make_layer(self, block, planes, blocks, stride=1): 133 | downsample = None 134 | if stride != 1 or self.inplanes != planes * block.expansion: 135 | downsample = nn.Sequential( 136 | nn.Conv2d(self.inplanes, planes * block.expansion, 137 | kernel_size=1, stride=stride, bias=False), 138 | SynchronizedBatchNorm2d(planes * block.expansion), 139 | ) 140 | 141 | layers = [] 142 | layers.append(block(self.inplanes, planes, stride, downsample)) 143 | self.inplanes = planes * block.expansion 144 | for i in range(1, blocks): 145 | layers.append(block(self.inplanes, planes)) 146 | 147 | return nn.Sequential(*layers) 148 | 149 | def forward(self, x): 150 | x = self.relu1(self.bn1(self.conv1(x))) 151 | x = self.relu2(self.bn2(self.conv2(x))) 152 | x = self.relu3(self.bn3(self.conv3(x))) 153 | x = self.maxpool(x) 154 | 155 | x = self.layer1(x) 156 | x = self.layer2(x) 157 | x = self.layer3(x) 158 | x = self.layer4(x) 159 | 160 | x = self.avgpool(x) 161 | x = x.view(x.size(0), -1) 162 | x = self.fc(x) 163 | 164 | return x 165 | 166 | def resnet18(pretrained=False, **kwargs): 167 | """Constructs a ResNet-18 model. 168 | 169 | Args: 170 | pretrained (bool): If True, returns a model pre-trained on ImageNet 171 | """ 172 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 173 | if pretrained: 174 | model.load_state_dict(load_url(model_urls['resnet18'])) 175 | return model 176 | 177 | ''' 178 | def resnet34(pretrained=False, **kwargs): 179 | """Constructs a ResNet-34 model. 180 | 181 | Args: 182 | pretrained (bool): If True, returns a model pre-trained on ImageNet 183 | """ 184 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 185 | if pretrained: 186 | model.load_state_dict(load_url(model_urls['resnet34'])) 187 | return model 188 | ''' 189 | 190 | def resnet50(pretrained=False, **kwargs): 191 | """Constructs a ResNet-50 model. 192 | 193 | Args: 194 | pretrained (bool): If True, returns a model pre-trained on ImageNet 195 | """ 196 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 197 | if pretrained: 198 | model.load_state_dict(load_url(model_urls['resnet50']), strict=False) 199 | return model 200 | 201 | 202 | def resnet101(pretrained=False, **kwargs): 203 | """Constructs a ResNet-101 model. 204 | 205 | Args: 206 | pretrained (bool): If True, returns a model pre-trained on ImageNet 207 | """ 208 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 209 | if pretrained: 210 | model.load_state_dict(load_url(model_urls['resnet101']), strict=False) 211 | return model 212 | 213 | # def resnet152(pretrained=False, **kwargs): 214 | # """Constructs a ResNet-152 model. 215 | # 216 | # Args: 217 | # pretrained (bool): If True, returns a model pre-trained on ImageNet 218 | # """ 219 | # model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 220 | # if pretrained: 221 | # model.load_state_dict(load_url(model_urls['resnet152'])) 222 | # return model 223 | 224 | def load_url(url, model_dir='./pretrained', map_location=None): 225 | if not os.path.exists(model_dir): 226 | os.makedirs(model_dir) 227 | filename = url.split('/')[-1] 228 | cached_file = os.path.join(model_dir, filename) 229 | if not os.path.exists(cached_file): 230 | sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) 231 | urlretrieve(url, cached_file) 232 | return torch.load(cached_file, map_location=map_location) 233 | -------------------------------------------------------------------------------- /Segmentation/models/resnext.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import torch.nn as nn 5 | import math 6 | from lib.nn import SynchronizedBatchNorm2d 7 | 8 | try: 9 | from urllib import urlretrieve 10 | except ImportError: 11 | from urllib.request import urlretrieve 12 | 13 | 14 | __all__ = ['ResNeXt', 'resnext101'] # support resnext 101 15 | 16 | 17 | model_urls = { 18 | #'resnext50': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnext50-imagenet.pth', 19 | 'resnext101': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnext101-imagenet.pth' 20 | } 21 | 22 | 23 | def conv3x3(in_planes, out_planes, stride=1): 24 | "3x3 convolution with padding" 25 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 26 | padding=1, bias=False) 27 | 28 | 29 | class GroupBottleneck(nn.Module): 30 | expansion = 2 31 | 32 | def __init__(self, inplanes, planes, stride=1, groups=1, downsample=None): 33 | super(GroupBottleneck, self).__init__() 34 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 35 | self.bn1 = SynchronizedBatchNorm2d(planes) 36 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 37 | padding=1, groups=groups, bias=False) 38 | self.bn2 = SynchronizedBatchNorm2d(planes) 39 | self.conv3 = nn.Conv2d(planes, planes * 2, kernel_size=1, bias=False) 40 | self.bn3 = SynchronizedBatchNorm2d(planes * 2) 41 | self.relu = nn.ReLU(inplace=True) 42 | self.downsample = downsample 43 | self.stride = stride 44 | 45 | def forward(self, x): 46 | residual = x 47 | 48 | out = self.conv1(x) 49 | out = self.bn1(out) 50 | out = self.relu(out) 51 | 52 | out = self.conv2(out) 53 | out = self.bn2(out) 54 | out = self.relu(out) 55 | 56 | out = self.conv3(out) 57 | out = self.bn3(out) 58 | 59 | if self.downsample is not None: 60 | residual = self.downsample(x) 61 | 62 | out += residual 63 | out = self.relu(out) 64 | 65 | return out 66 | 67 | 68 | class ResNeXt(nn.Module): 69 | 70 | def __init__(self, block, layers, groups=32, num_classes=1000): 71 | self.inplanes = 128 72 | super(ResNeXt, self).__init__() 73 | self.conv1 = conv3x3(3, 64, stride=2) 74 | self.bn1 = SynchronizedBatchNorm2d(64) 75 | self.relu1 = nn.ReLU(inplace=True) 76 | self.conv2 = conv3x3(64, 64) 77 | self.bn2 = SynchronizedBatchNorm2d(64) 78 | self.relu2 = nn.ReLU(inplace=True) 79 | self.conv3 = conv3x3(64, 128) 80 | self.bn3 = SynchronizedBatchNorm2d(128) 81 | self.relu3 = nn.ReLU(inplace=True) 82 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 83 | 84 | self.layer1 = self._make_layer(block, 128, layers[0], groups=groups) 85 | self.layer2 = self._make_layer(block, 256, layers[1], stride=2, groups=groups) 86 | self.layer3 = self._make_layer(block, 512, layers[2], stride=2, groups=groups) 87 | self.layer4 = self._make_layer(block, 1024, layers[3], stride=2, groups=groups) 88 | self.avgpool = nn.AvgPool2d(7, stride=1) 89 | self.fc = nn.Linear(1024 * block.expansion, num_classes) 90 | 91 | for m in self.modules(): 92 | if isinstance(m, nn.Conv2d): 93 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels // m.groups 94 | m.weight.data.normal_(0, math.sqrt(2. / n)) 95 | elif isinstance(m, SynchronizedBatchNorm2d): 96 | m.weight.data.fill_(1) 97 | m.bias.data.zero_() 98 | 99 | def _make_layer(self, block, planes, blocks, stride=1, groups=1): 100 | downsample = None 101 | if stride != 1 or self.inplanes != planes * block.expansion: 102 | downsample = nn.Sequential( 103 | nn.Conv2d(self.inplanes, planes * block.expansion, 104 | kernel_size=1, stride=stride, bias=False), 105 | SynchronizedBatchNorm2d(planes * block.expansion), 106 | ) 107 | 108 | layers = [] 109 | layers.append(block(self.inplanes, planes, stride, groups, downsample)) 110 | self.inplanes = planes * block.expansion 111 | for i in range(1, blocks): 112 | layers.append(block(self.inplanes, planes, groups=groups)) 113 | 114 | return nn.Sequential(*layers) 115 | 116 | def forward(self, x): 117 | x = self.relu1(self.bn1(self.conv1(x))) 118 | x = self.relu2(self.bn2(self.conv2(x))) 119 | x = self.relu3(self.bn3(self.conv3(x))) 120 | x = self.maxpool(x) 121 | 122 | x = self.layer1(x) 123 | x = self.layer2(x) 124 | x = self.layer3(x) 125 | x = self.layer4(x) 126 | 127 | x = self.avgpool(x) 128 | x = x.view(x.size(0), -1) 129 | x = self.fc(x) 130 | 131 | return x 132 | 133 | 134 | ''' 135 | def resnext50(pretrained=False, **kwargs): 136 | """Constructs a ResNet-50 model. 137 | 138 | Args: 139 | pretrained (bool): If True, returns a model pre-trained on Places 140 | """ 141 | model = ResNeXt(GroupBottleneck, [3, 4, 6, 3], **kwargs) 142 | if pretrained: 143 | model.load_state_dict(load_url(model_urls['resnext50']), strict=False) 144 | return model 145 | ''' 146 | 147 | 148 | def resnext101(pretrained=False, **kwargs): 149 | """Constructs a ResNet-101 model. 150 | 151 | Args: 152 | pretrained (bool): If True, returns a model pre-trained on Places 153 | """ 154 | model = ResNeXt(GroupBottleneck, [3, 4, 23, 3], **kwargs) 155 | if pretrained: 156 | model.load_state_dict(load_url(model_urls['resnext101']), strict=False) 157 | return model 158 | 159 | 160 | # def resnext152(pretrained=False, **kwargs): 161 | # """Constructs a ResNeXt-152 model. 162 | # 163 | # Args: 164 | # pretrained (bool): If True, returns a model pre-trained on Places 165 | # """ 166 | # model = ResNeXt(GroupBottleneck, [3, 8, 36, 3], **kwargs) 167 | # if pretrained: 168 | # model.load_state_dict(load_url(model_urls['resnext152'])) 169 | # return model 170 | 171 | 172 | def load_url(url, model_dir='./pretrained', map_location=None): 173 | if not os.path.exists(model_dir): 174 | os.makedirs(model_dir) 175 | filename = url.split('/')[-1] 176 | cached_file = os.path.join(model_dir, filename) 177 | if not os.path.exists(cached_file): 178 | sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) 179 | urlretrieve(url, cached_file) 180 | return torch.load(cached_file, map_location=map_location) 181 | -------------------------------------------------------------------------------- /Segmentation/script.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | clear 4 | 5 | # Path to images and results 6 | DATASET=../Dataset/ 7 | RESULT_PATH=./SegmentationResults/ 8 | 9 | # Segmentation model 10 | MODEL_PATH=models 11 | MASKTYPE=smooth 12 | 13 | # Inference 14 | python -u SemanticMasks.py \ 15 | --model_path $MODEL_PATH \ 16 | --dataset $DATASET \ 17 | --arch_encoder resnet50dilated \ 18 | --arch_decoder ppm_deepsup \ 19 | --fc_dim 2048 \ 20 | --result $RESULT_PATH \ 21 | --mask_type $MASKTYPE \ 22 | --gpu 0 23 | -------------------------------------------------------------------------------- /Segmentation/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import functools 4 | import fnmatch 5 | import numpy as np 6 | 7 | 8 | def find_recursive(root_dir, ext='.jpg'): 9 | files = [] 10 | for root, dirnames, filenames in os.walk(root_dir): 11 | for filename in fnmatch.filter(filenames, '*' + ext): 12 | files.append(os.path.join(root, filename)) 13 | return files 14 | 15 | 16 | class AverageMeter(object): 17 | """Computes and stores the average and current value""" 18 | def __init__(self): 19 | self.initialized = False 20 | self.val = None 21 | self.avg = None 22 | self.sum = None 23 | self.count = None 24 | 25 | def initialize(self, val, weight): 26 | self.val = val 27 | self.avg = val 28 | self.sum = val * weight 29 | self.count = weight 30 | self.initialized = True 31 | 32 | def update(self, val, weight=1): 33 | if not self.initialized: 34 | self.initialize(val, weight) 35 | else: 36 | self.add(val, weight) 37 | 38 | def add(self, val, weight): 39 | self.val = val 40 | self.sum += val * weight 41 | self.count += weight 42 | self.avg = self.sum / self.count 43 | 44 | def value(self): 45 | return self.val 46 | 47 | def average(self): 48 | return self.avg 49 | 50 | 51 | def unique(ar, return_index=False, return_inverse=False, return_counts=False): 52 | ar = np.asanyarray(ar).flatten() 53 | 54 | optional_indices = return_index or return_inverse 55 | optional_returns = optional_indices or return_counts 56 | 57 | if ar.size == 0: 58 | if not optional_returns: 59 | ret = ar 60 | else: 61 | ret = (ar,) 62 | if return_index: 63 | ret += (np.empty(0, np.bool),) 64 | if return_inverse: 65 | ret += (np.empty(0, np.bool),) 66 | if return_counts: 67 | ret += (np.empty(0, np.intp),) 68 | return ret 69 | if optional_indices: 70 | perm = ar.argsort(kind='mergesort' if return_index else 'quicksort') 71 | aux = ar[perm] 72 | else: 73 | ar.sort() 74 | aux = ar 75 | flag = np.concatenate(([True], aux[1:] != aux[:-1])) 76 | 77 | if not optional_returns: 78 | ret = aux[flag] 79 | else: 80 | ret = (aux[flag],) 81 | if return_index: 82 | ret += (perm[flag],) 83 | if return_inverse: 84 | iflag = np.cumsum(flag) - 1 85 | inv_idx = np.empty(ar.shape, dtype=np.intp) 86 | inv_idx[perm] = iflag 87 | ret += (inv_idx,) 88 | if return_counts: 89 | idx = np.concatenate(np.nonzero(flag) + ([ar.size],)) 90 | ret += (np.diff(idx),) 91 | return ret 92 | 93 | 94 | def colorEncode(labelmap, colors, mode='BGR'): 95 | labelmap = labelmap.astype('int') 96 | labelmap_rgb = np.zeros((labelmap.shape[0], labelmap.shape[1], 3), 97 | dtype=np.uint8) 98 | for label in unique(labelmap): 99 | if label < 0: 100 | continue 101 | labelmap_rgb += (labelmap == label)[:, :, np.newaxis] * \ 102 | np.tile(colors[label], 103 | (labelmap.shape[0], labelmap.shape[1], 1)) 104 | 105 | if mode == 'BGR': 106 | return labelmap_rgb[:, :, ::-1] 107 | else: 108 | return labelmap_rgb 109 | 110 | 111 | def accuracy(preds, label): 112 | valid = (label >= 0) 113 | acc_sum = (valid * (preds == label)).sum() 114 | valid_sum = valid.sum() 115 | acc = float(acc_sum) / (valid_sum + 1e-10) 116 | return acc, valid_sum 117 | 118 | 119 | def intersectionAndUnion(imPred, imLab, numClass): 120 | imPred = np.asarray(imPred).copy() 121 | imLab = np.asarray(imLab).copy() 122 | 123 | imPred += 1 124 | imLab += 1 125 | # Remove classes from unlabeled pixels in gt image. 126 | # We should not penalize detections in unlabeled portions of the image. 127 | imPred = imPred * (imLab > 0) 128 | 129 | # Compute area intersection: 130 | intersection = imPred * (imPred == imLab) 131 | (area_intersection, _) = np.histogram( 132 | intersection, bins=numClass, range=(1, numClass)) 133 | 134 | # Compute area union: 135 | (area_pred, _) = np.histogram(imPred, bins=numClass, range=(1, numClass)) 136 | (area_lab, _) = np.histogram(imLab, bins=numClass, range=(1, numClass)) 137 | area_union = area_pred + area_lab - area_intersection 138 | 139 | return (area_intersection, area_union) 140 | 141 | 142 | class NotSupportedCliException(Exception): 143 | pass 144 | 145 | 146 | def process_range(xpu, inp): 147 | start, end = map(int, inp) 148 | if start > end: 149 | end, start = start, end 150 | return map(lambda x: '{}{}'.format(xpu, x), range(start, end+1)) 151 | 152 | 153 | REGEX = [ 154 | (re.compile(r'^gpu(\d+)$'), lambda x: ['gpu%s' % x[0]]), 155 | (re.compile(r'^(\d+)$'), lambda x: ['gpu%s' % x[0]]), 156 | (re.compile(r'^gpu(\d+)-(?:gpu)?(\d+)$'), 157 | functools.partial(process_range, 'gpu')), 158 | (re.compile(r'^(\d+)-(\d+)$'), 159 | functools.partial(process_range, 'gpu')), 160 | ] 161 | 162 | 163 | def parse_devices(input_devices): 164 | 165 | """Parse user's devices input str to standard format. 166 | e.g. [gpu0, gpu1, ...] 167 | 168 | """ 169 | ret = [] 170 | for d in input_devices.split(','): 171 | for regex, func in REGEX: 172 | m = regex.match(d.lower().strip()) 173 | if m: 174 | tmp = func(m.groups()) 175 | # prevent duplicate 176 | for x in tmp: 177 | if x not in ret: 178 | ret.append(x) 179 | break 180 | else: 181 | raise NotSupportedCliException( 182 | 'Can not recognize device: "{}"'.format(d)) 183 | return ret 184 | -------------------------------------------------------------------------------- /TutorialDemoColorFool/Image/ILSVRC2012_val_00003533.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smartcameras/ColorFool/a2ccae4db821e304fb54090c332545b71046ea22/TutorialDemoColorFool/Image/ILSVRC2012_val_00003533.JPEG -------------------------------------------------------------------------------- /TutorialDemoColorFool/Masks/Person/ILSVRC2012_val_00003533.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smartcameras/ColorFool/a2ccae4db821e304fb54090c332545b71046ea22/TutorialDemoColorFool/Masks/Person/ILSVRC2012_val_00003533.JPEG -------------------------------------------------------------------------------- /TutorialDemoColorFool/Masks/Sky/ILSVRC2012_val_00003533.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smartcameras/ColorFool/a2ccae4db821e304fb54090c332545b71046ea22/TutorialDemoColorFool/Masks/Sky/ILSVRC2012_val_00003533.JPEG -------------------------------------------------------------------------------- /TutorialDemoColorFool/Masks/Vegetation/ILSVRC2012_val_00003533.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smartcameras/ColorFool/a2ccae4db821e304fb54090c332545b71046ea22/TutorialDemoColorFool/Masks/Vegetation/ILSVRC2012_val_00003533.JPEG -------------------------------------------------------------------------------- /TutorialDemoColorFool/Masks/Water/ILSVRC2012_val_00003533.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smartcameras/ColorFool/a2ccae4db821e304fb54090c332545b71046ea22/TutorialDemoColorFool/Masks/Water/ILSVRC2012_val_00003533.JPEG -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | torch 3 | torchvision 4 | opencv-python 5 | tqdm 6 | future 7 | scikit-image 8 | tensorboardX 9 | --------------------------------------------------------------------------------