├── .gitignore
├── Adversarial
├── ColorFool.py
├── Samples
│ ├── ILSVRC2012_val_00003533_alexnet.png
│ ├── ILSVRC2012_val_00003533_resnet18.png
│ └── ILSVRC2012_val_00003533_resnet50.png
├── misc_functions.py
└── script.sh
├── ColorFool.gif
├── Dataset
└── ILSVRC2012_val_00003533.JPEG
├── License.txt
├── README.md
├── Sample_results
├── ILSVRC2012_val_00003533_alexnet.png
├── ILSVRC2012_val_00003533_resnet18.png
└── ILSVRC2012_val_00003533_resnet50.png
├── Segmentation
├── SemanticMasks.py
├── data
│ ├── ADE20K_object150_train.txt
│ ├── ADE20K_object150_val.txt
│ ├── color150.mat
│ ├── object150_info.csv
│ ├── train.odgt
│ └── validation.odgt
├── dataset.py
├── lib
│ ├── __init__.py
│ ├── nn
│ │ ├── __init__.py
│ │ ├── modules
│ │ │ ├── __init__.py
│ │ │ ├── batchnorm.py
│ │ │ ├── comm.py
│ │ │ ├── replicate.py
│ │ │ ├── tests
│ │ │ │ ├── test_numeric_batchnorm.py
│ │ │ │ └── test_sync_batchnorm.py
│ │ │ └── unittest.py
│ │ └── parallel
│ │ │ ├── __init__.py
│ │ │ └── data_parallel.py
│ └── utils
│ │ ├── __init__.py
│ │ ├── data
│ │ ├── __init__.py
│ │ ├── dataloader.py
│ │ ├── dataset.py
│ │ ├── distributed.py
│ │ └── sampler.py
│ │ └── th.py
├── models
│ ├── __init__.py
│ ├── mobilenet.py
│ ├── models.py
│ ├── resnet.py
│ └── resnext.py
├── script.sh
└── utils.py
├── TutorialDemoColorFool
├── ColorFool.ipynb
├── Image
│ └── ILSVRC2012_val_00003533.JPEG
└── Masks
│ ├── Person
│ └── ILSVRC2012_val_00003533.JPEG
│ ├── Sky
│ └── ILSVRC2012_val_00003533.JPEG
│ ├── Vegetation
│ └── ILSVRC2012_val_00003533.JPEG
│ └── Water
│ └── ILSVRC2012_val_00003533.JPEG
└── requirements.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | # Latex files
2 |
3 | *.aux
4 | *.glo
5 | *.idx
6 | *.log
7 | *.toc
8 | *.ist
9 | *.acn
10 | *.acr
11 | *.alg
12 | *.bbl
13 | *.blg
14 | *.tui
15 | *.top
16 | *.tmp
17 | *.mp
18 | *.dvi
19 | *.glg
20 | *.gls
21 | *.ilg
22 | *.ind
23 | *.lof
24 | *.lot
25 | *.maf
26 | *.mtc
27 | *.mtc1
28 | *.out
29 | *.gz
30 | *.pyc
31 |
32 | # Mac IDE files
33 | *.swp
34 | *~
35 | *(Autosaved).rtfd/
36 | Backup[ ]of[ ]*.pages/
37 | Backup[ ]of[ ]*.key/
38 | Backup[ ]of[ ]*.numbers/
39 |
40 | # Mac finder files and hidden folders
41 | .DS_Store
42 | *.fdb_latexmk
43 |
44 | *.fls
45 |
46 | *.sublime-workspace
47 | paper.pdf
48 | changelog.txt
49 | *.sublime-project
50 |
--------------------------------------------------------------------------------
/Adversarial/ColorFool.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import cv2
4 | import scipy
5 | from skimage import io, color
6 |
7 | import torch
8 | from torch import nn
9 | from torch.autograd import Variable
10 | from torch.nn import functional as F
11 |
12 | import glob, os
13 | from os.path import join,isfile
14 |
15 | from os import listdir
16 |
17 | from PIL import Image
18 | from tqdm import tqdm
19 | from torchvision import models
20 | from numpy import pi
21 | from numpy import sin
22 | from numpy import zeros
23 | from numpy import r_
24 |
25 | from scipy import signal
26 | from scipy import misc
27 | import torchvision.transforms as T
28 | from skimage.filters import rank
29 | from skimage.morphology import disk
30 |
31 | import argparse
32 | import pdb
33 | from copy import copy as copy
34 |
35 |
36 | from misc_functions import prepareImageMasks, initialise, createLogFiles, createDirectories
37 |
38 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
39 |
40 | selem = disk(20)
41 |
42 | # Normalization values for ImageNet
43 | trf = T.Compose([T.ToPILImage(),
44 | T.ToTensor(),
45 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
46 |
47 | class attack():
48 |
49 | def __init__(self, model, args):
50 |
51 | self.model = model
52 |
53 | # Create the folder to export adversarial images if not exists
54 | self.adv_path = createDirectories(args)
55 |
56 | def generate(self, original_image, sky_mask, water_mask, green_mask, person_mask, img_name, org_class, args):
57 |
58 | misclassified=0
59 | maxTrials = 1000
60 |
61 | # Transfer the clea image from RGB to Lab color space
62 | original_image_lab=color.rgb2lab(original_image)
63 |
64 | # Start iteration
65 | for trial in range(maxTrials):
66 |
67 | X_lab = original_image_lab.copy()
68 |
69 | margin = 127
70 | mult = float(trial+1) / float(maxTrials)
71 |
72 | # Adversarial color perturbation for Water regions
73 | water_mask_binary = copy(water_mask)
74 | water_mask_binary[water_mask_binary>0] = 1
75 | water = X_lab[water_mask_binary == 1]
76 | if water.size != 0:
77 | a_min = water[:,1].min()
78 | a_max = np.clip(water[:,1].max(), a_min=None, a_max = 0)
79 | b_min = water[:,2].min()
80 | b_max = np.clip(water[:,2].max(), a_min=None, a_max = 0)
81 | a_blue = np.full((X_lab.shape[0], X_lab.shape[1]), np.random.uniform(mult*(-margin-a_min), mult*(-a_max), size=(1))) * water_mask
82 | b_blue = np.full((X_lab.shape[0], X_lab.shape[1]), np.random.uniform(mult*(-margin-b_min), mult*(-b_max), size=(1))) * water_mask
83 | else:
84 | a_blue = np.full((X_lab.shape[0], X_lab.shape[1]), 0.)
85 | b_blue = np.full((X_lab.shape[0], X_lab.shape[1]), 0.)
86 |
87 | # Adversarial color perturbation for Vegetation regions
88 | green_mask_binary = copy(green_mask)
89 | green_mask_binary[green_mask_binary>0] = 1
90 | green = X_lab[green_mask_binary == 1]
91 | if green.size != 0:
92 | a_min = green[:,1].min()
93 | a_max = np.clip(green[:,1].max(), a_min=None, a_max = 0)
94 | b_min = np.clip(green[:,2].min(), a_min=0, a_max = None)
95 | b_max = green[:,2].max()
96 | a_green = np.full((X_lab.shape[0], X_lab.shape[1]), np.random.uniform(mult*(-margin-a_min), mult*(-a_max), size=(1))) * green_mask
97 | b_green = np.full((X_lab.shape[0], X_lab.shape[1]), np.random.uniform(mult*(-b_min), mult*(margin-b_max), size=(1))) * green_mask
98 | else:
99 | a_green = np.full((X_lab.shape[0], X_lab.shape[1]), 0.)
100 | b_green = np.full((X_lab.shape[0], X_lab.shape[1]), 0.)
101 |
102 | # Adversarial color perturbation for Sky regions
103 | sky_mask_binary = copy(sky_mask)
104 | sky_mask_binary[sky_mask_binary>0] = 1
105 | sky = X_lab[sky_mask_binary == 1]
106 | if sky.size != 0:
107 | a_min = sky[:,1].min()
108 | a_max = np.clip(sky[:,1].max(), a_min=None, a_max = 0)
109 | b_min = sky[:,2].min()
110 | b_max = np.clip(sky[:,2].max(), a_min=None, a_max = 0)
111 | a_sky = np.full((X_lab.shape[0], X_lab.shape[1]), np.random.uniform(mult*(-margin-a_min), mult*(-a_max), size=(1))) * sky_mask
112 | b_sky = np.full((X_lab.shape[0], X_lab.shape[1]), np.random.uniform(mult*(-margin-b_min), mult*(-b_max), size=(1))) * sky_mask
113 | else:
114 | a_sky = np.full((X_lab.shape[0], X_lab.shape[1]), 0.)
115 | b_sky = np.full((X_lab.shape[0], X_lab.shape[1]), 0.)
116 |
117 |
118 | mask = (person_mask + water_mask + green_mask + sky_mask)
119 | mask[mask>1] = 1
120 |
121 | # Smooth boundaries between sensitive regions
122 | kernel = np.ones((5, 5), np.uint8)
123 | mask = cv2.blur(mask,(10,10))
124 |
125 | # Adversarial color perturbation for non-sensitive regions
126 | random_mask = 1 - mask
127 | a_random = np.full((X_lab.shape[0],X_lab.shape[1]), np.random.uniform(mult*(-margin), mult*(margin), size=(1)))
128 | b_random = np.full((X_lab.shape[0],X_lab.shape[1]), np.random.uniform(mult*(-margin), mult*(margin), size=(1)))
129 | a_random_mask = a_random * random_mask
130 | b_random_mask = b_random * random_mask
131 |
132 |
133 | # Adversarialy perturb color (i.e. a and b channels in the Lab color space) of the clean image
134 | noise_mask = np.zeros((X_lab.shape), dtype=float)
135 | noise_mask[:,:,1] = a_blue + a_green + a_sky + a_random_mask
136 | noise_mask[:,:,2] = b_blue + b_green + b_sky + b_random_mask
137 | X_lab_mask = np.zeros((X_lab.shape), dtype=float)
138 | X_lab_mask [:,:,0] = X_lab [:,:,0]
139 | X_lab_mask [:,:,1] = np.clip(X_lab [:,:,1] + noise_mask[:,:,1], -margin, margin)
140 | X_lab_mask [:,:,2] = np.clip(X_lab [:,:,2] + noise_mask[:,:,2], -margin, margin)
141 |
142 | # Transfer from LAB to RGB
143 | X_rgb_mask = np.uint8(color.lab2rgb(X_lab_mask)*255.)
144 |
145 | # Predict the label of the adversarial image
146 | logit = model(trf(cv2.resize(X_rgb_mask, (224, 224), interpolation=cv2.INTER_LINEAR)).to(device).unsqueeze_(0))
147 | h_x = F.softmax(logit).data.squeeze()
148 | probs, idx = h_x.sort(0, True)
149 |
150 | current_class = idx[0]
151 | current_class_prob = probs[0]
152 | org_class_prob = h_x[org_class]
153 |
154 | # Check if the generated adversarial image misleads the model
155 | if (current_class != org_class):
156 | misclassified=1
157 | break
158 |
159 | # Transfer the adversarial image from RGB to BGR to save with opencv
160 | X_bgr = X_rgb_mask[:, :, (2, 1, 0)]
161 | cv2.imwrite('{}/{}.png'.format(self.adv_path, img_name.split('.')[0]), X_bgr)
162 | return misclassified, trial, current_class, current_class_prob
163 |
164 |
165 | if __name__ == '__main__':
166 |
167 | # Parse arguments
168 | parser = argparse.ArgumentParser()
169 | parser.add_argument('--model', type=str, required=True)
170 | parser.add_argument('--dataset', type=str, default='../Dataset/')
171 | args = parser.parse_args()
172 |
173 | # Initialization. Load model under atack, path of the dataset and list of all clean images inside it
174 | model, image_list = initialise(args)
175 |
176 | # Log files to save numerical results
177 | f1, f1_name = createLogFiles(args)
178 |
179 | # Number of successful adversarial images
180 | misleads=0
181 |
182 | # Generate adversarial images for all clean images in the image_list
183 | NumImg=len(image_list)
184 | for idx in tqdm(range(NumImg)):
185 |
186 | # Load clean image and predict the lable using the model
187 | original_image, sky_mask, water_mask, grass_mask, person_mask, img_name, org_class, org_class_prob = prepareImageMasks(args, image_list, idx, model)
188 |
189 | f1 = open(f1_name, 'a+')
190 |
191 | # Perform the ColorFool attack
192 | LAB = attack(model, args)
193 | mislead, numTrials, current_class, current_class_prob= LAB.generate(original_image, sky_mask, water_mask, grass_mask, person_mask, img_name, org_class, args)
194 | misleads += mislead
195 | text = '{}\t{}\t{}\t{:.5f}\t{}\t{:.5f}\n'.format(img_name, numTrials+1, org_class, org_class_prob, current_class, current_class_prob)
196 |
197 | f1.write(text)
198 | f1.close()
199 | print('Success rate {:.1f}%'.format(100*float(misleads) / (NumImg)) )
200 |
--------------------------------------------------------------------------------
/Adversarial/Samples/ILSVRC2012_val_00003533_alexnet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/smartcameras/ColorFool/a2ccae4db821e304fb54090c332545b71046ea22/Adversarial/Samples/ILSVRC2012_val_00003533_alexnet.png
--------------------------------------------------------------------------------
/Adversarial/Samples/ILSVRC2012_val_00003533_resnet18.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/smartcameras/ColorFool/a2ccae4db821e304fb54090c332545b71046ea22/Adversarial/Samples/ILSVRC2012_val_00003533_resnet18.png
--------------------------------------------------------------------------------
/Adversarial/Samples/ILSVRC2012_val_00003533_resnet50.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/smartcameras/ColorFool/a2ccae4db821e304fb54090c332545b71046ea22/Adversarial/Samples/ILSVRC2012_val_00003533_resnet50.png
--------------------------------------------------------------------------------
/Adversarial/misc_functions.py:
--------------------------------------------------------------------------------
1 | import fnmatch
2 | import cv2
3 | import numpy as np
4 | from skimage import io, color
5 | import csv
6 | import os
7 | from os import listdir
8 | from os.path import isfile,join
9 | import torch
10 | import torchvision
11 | from torch.autograd import Variable
12 | from torchvision import models
13 | import torchvision.transforms as T
14 | from torch.nn import functional as F
15 | from torch.autograd import Variable as V
16 | import torch.nn as nn
17 | import scipy.sparse
18 |
19 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
20 |
21 | def initialise(args):
22 |
23 | image_list = [f for f in listdir(args.dataset) if isfile(join(args.dataset,f))] #Format of images for ImageNet is .JPEG, change for other dataset
24 |
25 | # Load model
26 | if args.model == 'resnet18':
27 | model = models.resnet18(pretrained=True)
28 | elif args.model == 'resnet50':
29 | model = models.resnet50(pretrained=True)
30 | elif args.model == 'alexnet':
31 | model = models.alexnet(pretrained=True)
32 | model.eval()
33 | model.to(device)
34 | return model, image_list
35 |
36 |
37 | def createLogFiles(args):
38 | log_path = 'Results/Logs/'
39 |
40 | if not os.path.exists(log_path):
41 | os.makedirs(log_path)
42 | f1_name = log_path+'log_{}.txt'.format(args.model)
43 | f1 = open(f1_name,"w")
44 | return f1, f1_name
45 |
46 |
47 | def createDirectories(args):
48 | main_path = 'Results/ColorFoolImgs/'
49 | adv_path = main_path+ 'adv_{}'.format(args.model)
50 |
51 | if not os.path.exists(adv_path):
52 | os.makedirs(adv_path)
53 |
54 | return adv_path
55 |
56 | #for ImageNet the mean and std are:
57 | mean = np.asarray([ 0.485, 0.456, 0.406 ])
58 | std = np.asarray([ 0.229, 0.224, 0.225 ])
59 |
60 | trf = T.Compose([T.ToPILImage(),
61 | T.ToTensor(),
62 | T.Normalize(mean=mean, std=std)])
63 |
64 | def prepareImageMasks(args, image_list, index, model):
65 |
66 | # Paths to segmentation outputs done in the prior step
67 | sky_mask_path = '../Segmentation/SegmentationResults/sky/'
68 | water_mask_path = '../Segmentation/SegmentationResults/water/'
69 | grass_mask_path = '../Segmentation/SegmentationResults/grass/'
70 | person_mask_path = '../Segmentation/SegmentationResults/person/'
71 |
72 | # Read images
73 | img_name = image_list[index]
74 |
75 | # Load the clean image with its four corresponding masks that represent Sky, Person, Vegetation and Water
76 | original_image = cv2.imread(args.dataset+img_name, 1)
77 | person_mask = cv2.imread('{}.png'.format(person_mask_path+img_name.split('.')[0]), cv2.COLOR_BGR2GRAY) / 255.
78 | water_mask = cv2.imread('{}.png'.format(water_mask_path+img_name.split('.')[0]), cv2.COLOR_BGR2GRAY) / 255.
79 | grass_mask = cv2.imread('{}.png'.format(grass_mask_path+img_name.split('.')[0]), cv2.COLOR_BGR2GRAY) / 255.
80 | sky_mask = cv2.imread('{}.png'.format(sky_mask_path+img_name.split('.')[0]), cv2.COLOR_BGR2GRAY) / 255.
81 |
82 |
83 | # Have RGB images
84 | original_image = original_image[:, :, (2, 1, 0)]
85 |
86 | # Resize image to the input size of the model
87 | image = cv2.resize(original_image, (224, 224), interpolation=cv2.INTER_LINEAR)
88 | # forward pass
89 | logit = model.forward(trf(image).cuda().unsqueeze_(0))
90 | h_x = F.softmax(logit).data.squeeze()
91 | probs, idx = h_x.sort(0, True)
92 |
93 | probs = np.array(probs.cpu())
94 | idx = np.array(idx.cpu())
95 |
96 | org_class= idx[0]
97 | org_class_prob = probs[0]
98 |
99 | return original_image, sky_mask, water_mask, grass_mask, person_mask, img_name, org_class, org_class_prob
100 |
101 |
--------------------------------------------------------------------------------
/Adversarial/script.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | MODELS=(alexnet resnet18 resnet50)
4 |
5 | clear
6 | for model in "${MODELS[@]}"
7 | do
8 |
9 | echo ColorFool attacking $model
10 | python -W ignore ColorFool.py --model=$model
11 |
12 | done
13 |
--------------------------------------------------------------------------------
/ColorFool.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/smartcameras/ColorFool/a2ccae4db821e304fb54090c332545b71046ea22/ColorFool.gif
--------------------------------------------------------------------------------
/Dataset/ILSVRC2012_val_00003533.JPEG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/smartcameras/ColorFool/a2ccae4db821e304fb54090c332545b71046ea22/Dataset/ILSVRC2012_val_00003533.JPEG
--------------------------------------------------------------------------------
/License.txt:
--------------------------------------------------------------------------------
1 | # License
2 | This work is licensed under the Creative Commons Attribution-NonCommercial 4.0 International License. To view a copy of this license, visit http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
3 |
4 | Creative Commons Legal Code
5 | Attribution-NonCommercial 4.0 International (CC BY-NC 4.0)
6 |
7 | Creative Commons Corporation (“Creative Commons”) is not a law firm and does not provide legal services or legal advice. Distribution of Creative Commons public licenses does not create a lawyer-client or other relationship. Creative Commons makes its licenses and related information available on an “as-is” basis. Creative Commons gives no warranties regarding its licenses, any material licensed under their terms and conditions, or any related information. Creative Commons disclaims all liability for damages resulting from their use to the fullest extent possible.
8 | Creative Commons Attribution-NonCommercial 4.0 International Public License
9 | By exercising the Licensed Rights (defined below), You accept and agree to be bound by the terms and conditions of this Creative Commons Attribution-NonCommercial 4.0 International Public License ("Public License"). To the extent this Public License may be interpreted as a contract, You are granted the Licensed Rights in consideration of Your acceptance of these terms and conditions, and the Licensor grants You such rights in consideration of benefits the Licensor receives from making the Licensed Material available under these terms and conditions.
10 | Section 1 – Definitions.
11 | a. Adapted Material means material subject to Copyright and Similar Rights that is derived from or based upon the Licensed Material and in which the Licensed Material is translated, altered, arranged, transformed, or otherwise modified in a manner requiring permission under the Copyright and Similar Rights held by the Licensor. For purposes of this Public License, where the Licensed Material is a musical work, performance, or sound recording, Adapted Material is always produced where the Licensed Material is synched in timed relation with a moving image.
12 | b. Adapter's License means the license You apply to Your Copyright and Similar Rights in Your contributions to Adapted Material in accordance with the terms and conditions of this Public License.
13 | c. Copyright and Similar Rights means copyright and/or similar rights closely related to copyright including, without limitation, performance, broadcast, sound recording, and Sui Generis Database Rights, without regard to how the rights are labeled or categorized. For purposes of this Public License, the rights specified in Section 2(b)(1)-(2) are not Copyright and Similar Rights.
14 | d. Effective Technological Measures means those measures that, in the absence of proper authority, may not be circumvented under laws fulfilling obligations under Article 11 of the WIPO Copyright Treaty adopted on December 20, 1996, and/or similar international agreements.
15 | e. Exceptions and Limitations means fair use, fair dealing, and/or any other exception or limitation to Copyright and Similar Rights that applies to Your use of the Licensed Material.
16 | f. Licensed Material means the artistic or literary work, database, or other material to which the Licensor applied this Public License.
17 | g. Licensed Rights means the rights granted to You subject to the terms and conditions of this Public License, which are limited to all Copyright and Similar Rights that apply to Your use of the Licensed Material and that the Licensor has authority to license.
18 | h. Licensor means the individual(s) or entity(ies) granting rights under this Public License.
19 | i. NonCommercial means not primarily intended for or directed towards commercial advantage or monetary compensation. For purposes of this Public License, the exchange of the Licensed Material for other material subject to Copyright and Similar Rights by digital file-sharing or similar means is NonCommercial provided there is no payment of monetary compensation in connection with the exchange.
20 | j. Share means to provide material to the public by any means or process that requires permission under the Licensed Rights, such as reproduction, public display, public performance, distribution, dissemination, communication, or importation, and to make material available to the public including in ways that members of the public may access the material from a place and at a time individually chosen by them.
21 | k. Sui Generis Database Rights means rights other than copyright resulting from Directive 96/9/EC of the European Parliament and of the Council of 11 March 1996 on the legal protection of databases, as amended and/or succeeded, as well as other essentially equivalent rights anywhere in the world.
22 | l. You means the individual or entity exercising the Licensed Rights under this Public License. Your has a corresponding meaning.
23 | Section 2 – Scope.
24 | a. License grant.
25 | 1. Subject to the terms and conditions of this Public License, the Licensor hereby grants You a worldwide, royalty-free, non-sublicensable, non-exclusive, irrevocable license to exercise the Licensed Rights in the Licensed Material to:
26 | A. reproduce and Share the Licensed Material, in whole or in part, for NonCommercial purposes only; and
27 | B. produce, reproduce, and Share Adapted Material for NonCommercial purposes only.
28 | 2. Exceptions and Limitations. For the avoidance of doubt, where Exceptions and Limitations apply to Your use, this Public License does not apply, and You do not need to comply with its terms and conditions.
29 | 3. Term. The term of this Public License is specified in Section 6(a).
30 | 4. Media and formats; technical modifications allowed. The Licensor authorizes You to exercise the Licensed Rights in all media and formats whether now known or hereafter created, and to make technical modifications necessary to do so. The Licensor waives and/or agrees not to assert any right or authority to forbid You from making technical modifications necessary to exercise the Licensed Rights, including technical modifications necessary to circumvent Effective Technological Measures. For purposes of this Public License, simply making modifications authorized by this Section 2(a)(4) never produces Adapted Material.
31 | 5. Downstream recipients.
32 | A. Offer from the Licensor – Licensed Material. Every recipient of the Licensed Material automatically receives an offer from the Licensor to exercise the Licensed Rights under the terms and conditions of this Public License.
33 | B. No downstream restrictions. You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, the Licensed Material if doing so restricts exercise of the Licensed Rights by any recipient of the Licensed Material.
34 | 6. No endorsement. Nothing in this Public License constitutes or may be construed as permission to assert or imply that You are, or that Your use of the Licensed Material is, connected with, or sponsored, endorsed, or granted official status by, the Licensor or others designated to receive attribution as provided in Section 3(a)(1)(A)(i).
35 | b. Other rights.
36 | 1. Moral rights, such as the right of integrity, are not licensed under this Public License, nor are publicity, privacy, and/or other similar personality rights; however, to the extent possible, the Licensor waives and/or agrees not to assert any such rights held by the Licensor to the limited extent necessary to allow You to exercise the Licensed Rights, but not otherwise.
37 | 2. Patent and trademark rights are not licensed under this Public License.
38 | 3. To the extent possible, the Licensor waives any right to collect royalties from You for the exercise of the Licensed Rights, whether directly or through a collecting society under any voluntary or waivable statutory or compulsory licensing scheme. In all other cases the Licensor expressly reserves any right to collect such royalties, including when the Licensed Material is used other than for NonCommercial purposes.
39 | Section 3 – License Conditions.
40 | Your exercise of the Licensed Rights is expressly made subject to the following conditions.
41 | a. Attribution.
42 | 1. If You Share the Licensed Material (including in modified form), You must:
43 | A. retain the following if it is supplied by the Licensor with the Licensed Material:
44 | i. identification of the creator(s) of the Licensed Material and any others designated to receive attribution, in any reasonable manner requested by the Licensor (including by pseudonym if designated);
45 | ii. a copyright notice;
46 | iii. a notice that refers to this Public License;
47 | iv. a notice that refers to the disclaimer of warranties;
48 | v. a URI or hyperlink to the Licensed Material to the extent reasonably practicable;
49 | B. indicate if You modified the Licensed Material and retain an indication of any previous modifications; and
50 | C. indicate the Licensed Material is licensed under this Public License, and include the text of, or the URI or hyperlink to, this Public License.
51 | 2. You may satisfy the conditions in Section 3(a)(1) in any reasonable manner based on the medium, means, and context in which You Share the Licensed Material. For example, it may be reasonable to satisfy the conditions by providing a URI or hyperlink to a resource that includes the required information.
52 | 3. If requested by the Licensor, You must remove any of the information required by Section 3(a)(1)(A) to the extent reasonably practicable.
53 | 4. If You Share Adapted Material You produce, the Adapter's License You apply must not prevent recipients of the Adapted Material from complying with this Public License.
54 | Section 4 – Sui Generis Database Rights.
55 | Where the Licensed Rights include Sui Generis Database Rights that apply to Your use of the Licensed Material:
56 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right to extract, reuse, reproduce, and Share all or a substantial portion of the contents of the database for NonCommercial purposes only;
57 | b. if You include all or a substantial portion of the database contents in a database in which You have Sui Generis Database Rights, then the database in which You have Sui Generis Database Rights (but not its individual contents) is Adapted Material; and
58 | c. You must comply with the conditions in Section 3(a) if You Share all or a substantial portion of the contents of the database.
59 | For the avoidance of doubt, this Section 4 supplements and does not replace Your obligations under this Public License where the Licensed Rights include other Copyright and Similar Rights.
60 | Section 5 – Disclaimer of Warranties and Limitation of Liability.
61 | a. Unless otherwise separately undertaken by the Licensor, to the extent possible, the Licensor offers the Licensed Material as-is and as-available, and makes no representations or warranties of any kind concerning the Licensed Material, whether express, implied, statutory, or other. This includes, without limitation, warranties of title, merchantability, fitness for a particular purpose, non-infringement, absence of latent or other defects, accuracy, or the presence or absence of errors, whether or not known or discoverable. Where disclaimers of warranties are not allowed in full or in part, this disclaimer may not apply to You.
62 | b. To the extent possible, in no event will the Licensor be liable to You on any legal theory (including, without limitation, negligence) or otherwise for any direct, special, indirect, incidental, consequential, punitive, exemplary, or other losses, costs, expenses, or damages arising out of this Public License or use of the Licensed Material, even if the Licensor has been advised of the possibility of such losses, costs, expenses, or damages. Where a limitation of liability is not allowed in full or in part, this limitation may not apply to You.
63 | c. The disclaimer of warranties and limitation of liability provided above shall be interpreted in a manner that, to the extent possible, most closely approximates an absolute disclaimer and waiver of all liability.
64 | Section 6 – Term and Termination.
65 | a. This Public License applies for the term of the Copyright and Similar Rights licensed here. However, if You fail to comply with this Public License, then Your rights under this Public License terminate automatically.
66 | b. Where Your right to use the Licensed Material has terminated under Section 6(a), it reinstates:
67 | 1. automatically as of the date the violation is cured, provided it is cured within 30 days of Your discovery of the violation; or
68 | 2. upon express reinstatement by the Licensor.
69 | For the avoidance of doubt, this Section 6(b) does not affect any right the Licensor may have to seek remedies for Your violations of this Public License.
70 | c. For the avoidance of doubt, the Licensor may also offer the Licensed Material under separate terms or conditions or stop distributing the Licensed Material at any time; however, doing so will not terminate this Public License.
71 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public License.
72 | Section 7 – Other Terms and Conditions.
73 | a. The Licensor shall not be bound by any additional or different terms or conditions communicated by You unless expressly agreed.
74 | b. Any arrangements, understandings, or agreements regarding the Licensed Material not stated herein are separate from and independent of the terms and conditions of this Public License.
75 | Section 8 – Interpretation.
76 | a. For the avoidance of doubt, this Public License does not, and shall not be interpreted to, reduce, limit, restrict, or impose conditions on any use of the Licensed Material that could lawfully be made without permission under this Public License.
77 | b. To the extent possible, if any provision of this Public License is deemed unenforceable, it shall be automatically reformed to the minimum extent necessary to make it enforceable. If the provision cannot be reformed, it shall be severed from this Public License without affecting the enforceability of the remaining terms and conditions.
78 | c. No term or condition of this Public License will be waived and no failure to comply consented to unless expressly agreed to by the Licensor.
79 | d. Nothing in this Public License constitutes or may be interpreted as a limitation upon, or waiver of, any privileges and immunities that apply to the Licensor or You, including from the legal processes of any jurisdiction or authority.
80 |
81 |
82 | Creative Commons is not a party to its public licenses. Notwithstanding, Creative Commons may elect to apply one of its public licenses to material it publishes and in those instances will be considered the “Licensor.” The text of the Creative Commons public licenses is dedicated to the public domain under the CC0 Public Domain Dedication. Except for the limited purpose of indicating that material is shared under a Creative Commons public license or as otherwise permitted by the Creative Commons policies published at creativecommons.org/policies, Creative Commons does not authorize the use of the trademark “Creative Commons” or any other trademark or logo of Creative Commons without its prior written consent including, without limitation, in connection with any unauthorized modifications to any of its public licenses or any other arrangements, understandings, or agreements concerning use of licensed material. For the avoidance of doubt, this paragraph does not form part of the public licenses.
83 |
84 | Creative Commons may be contacted at creativecommons.org.
85 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ColorFool
2 |
3 | This is the official repository of [ColorFool: Semantic Adversarial Colorization](https://arxiv.org/pdf/1911.10891.pdf), a work published in The IEEE Conference on Computer Vision and Pattern Recognition (CVPR), Seattle, Washington, USA, 14-19 June, 2020.
4 |
5 | 
6 | Example of results
7 |
8 | | Original Image | Attack AlexNet | Attack ResNet18 | Attack ResNet50 |
9 | |---|---|---|---|
10 | |  |  | |  |
11 |
12 |
13 | ## Setup
14 | 1. Download source code from GitHub
15 | ```
16 | git clone https://github.com/smartcameras/ColorFool.git
17 | ```
18 | 2. Create [conda](https://docs.conda.io/en/latest/miniconda.html) virtual-environment
19 | ```
20 | conda create --name ColorFool python=3.5.6
21 | ```
22 | 3. Activate conda environment
23 | ```
24 | source activate ColorFool
25 | ```
26 | 4. Install requirements
27 | ```
28 | pip install -r requirements.txt
29 | ```
30 |
31 |
32 | ## Description
33 | The code works in two steps:
34 | 1. Identify image regions using semantic segmentation model
35 | 2. Generate adversarial images via perturbing color of semantic regions in the natural color range
36 |
37 |
38 | ### Semantic Segmentation
39 |
40 | 1. Go to Segmentation directory
41 | ```
42 | cd Segmentation
43 | ```
44 | 2. Download segmentation model (both encoder and decoder) from [here](https://drive.google.com/drive/folders/1FjZTweIsWWgxhXkzKHyIzEgBO5VTCe68) and locate in "models" directory.
45 |
46 |
47 | 3. Run the segmentation for all images within Dataset directory (requires GPU)
48 | ```
49 | bash script.sh
50 | ```
51 |
52 | The semantic regions of four categories will be saved in the Segmentation/SegmentationResults/$Dataset/ directory as a smooth mask the same size of the image with the same name as their corresponding original images
53 |
54 | ### Generate ColorFool Adversarial Images
55 |
56 | 1. Go to Adversarial directory
57 | ```
58 | cd ../Adversarial
59 | ```
60 | 2. In the script.sh set
61 | (i) the name of target models for attack, and (ii) the name of the dataset.
62 | The current implementation supports three classifiers (Resnet18, Resnet50 and Alexnet) trained with ImageNet.
63 | 3. Run ColorFool for all images within the Dataset directory (works in both GPU and CPU)
64 | ```
65 | bash script.sh
66 | ```
67 |
68 | ### Outputs
69 | * Adversarial Images saved with the same name as the clean images in Adversarial/Results/ColorFoolImgs directory;
70 | * Metadata with the following structure: filename, number of trials, predicted class of the clean image with its probablity and predicted class of the adversarial image with its probablity in Adversarial/Results/Logs directory.
71 |
72 |
73 | ## Authors
74 | * [Ali Shahin Shamsabadi](mailto:a.shahinshamsabadi@qmul.ac.uk)
75 | * [Ricardo Sanchez-Matilla](mailto:ricardo.sanchezmatilla@qmul.ac.uk)
76 | * [Andrea Cavallaro](mailto:a.cavallaro@qmul.ac.uk)
77 |
78 |
79 | ## References
80 | If you use our code, please cite the following paper:
81 |
82 | @InProceedings{shamsabadi2020colorfool,
83 | title = {ColorFool: Semantic Adversarial Colorization},
84 | author = {Shamsabadi, Ali Shahin and Sanchez-Matilla, Ricardo and Cavallaro, Andrea},
85 | booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
86 | year = {2020},
87 | address = {Seattle, Washington, USA},
88 | month = June
89 | }
90 |
91 | ## License
92 | This work is licensed under the Creative Commons Attribution-NonCommercial 4.0 International License. To view a copy of this license, visit http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
93 |
--------------------------------------------------------------------------------
/Sample_results/ILSVRC2012_val_00003533_alexnet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/smartcameras/ColorFool/a2ccae4db821e304fb54090c332545b71046ea22/Sample_results/ILSVRC2012_val_00003533_alexnet.png
--------------------------------------------------------------------------------
/Sample_results/ILSVRC2012_val_00003533_resnet18.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/smartcameras/ColorFool/a2ccae4db821e304fb54090c332545b71046ea22/Sample_results/ILSVRC2012_val_00003533_resnet18.png
--------------------------------------------------------------------------------
/Sample_results/ILSVRC2012_val_00003533_resnet50.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/smartcameras/ColorFool/a2ccae4db821e304fb54090c332545b71046ea22/Sample_results/ILSVRC2012_val_00003533_resnet50.png
--------------------------------------------------------------------------------
/Segmentation/SemanticMasks.py:
--------------------------------------------------------------------------------
1 | # System libs
2 | import os
3 | import argparse
4 | from distutils.version import LooseVersion
5 | # Numerical libs
6 | import numpy as np
7 | import torch
8 | import torch.nn as nn
9 | from scipy.io import loadmat
10 | # Our libs
11 | from dataset import TestDataset
12 | from models import ModelBuilder, SegmentationModule
13 | from utils import colorEncode, find_recursive
14 | from lib.nn import user_scattered_collate, async_copy_to
15 | from lib.utils import as_numpy
16 | import lib.utils.data as torchdata
17 | import cv2
18 | from tqdm import tqdm
19 |
20 | colors = loadmat('data/color150.mat')['colors']
21 |
22 |
23 | def visualize_result(data, pred, pred_prob, args):
24 | (img, info) = data
25 | img_name = info.split('/')[-1]
26 |
27 | ### water mask: water, sea, swimming pool, waterfalls, lake and river
28 | water_mask = (pred == 21)
29 | sea_mask = (pred == 26)
30 | river_mask = (pred == 60)
31 | pool_mask = (pred == 109)
32 | fall_mask = (pred == 113)
33 | lake_mask = (pred == 128)
34 | water_mask = (water_mask | sea_mask | river_mask | pool_mask | fall_mask | lake_mask).astype(int)
35 | if args.mask_type=='smooth':
36 | water_mask = water_mask.astype(float) * pred_prob
37 |
38 | water_mask = water_mask * 255.
39 | cv2.imwrite('{}/water/{}.png' .format(args.result,img_name.split('.')[0]), water_mask)
40 |
41 |
42 | ### Sky mask
43 | sky_mask = (pred == 2).astype(int)
44 | if args.mask_type=='smooth':
45 | sky_mask = sky_mask.astype(float) * pred_prob
46 | sky_mask = sky_mask * 255.
47 | cv2.imwrite('{}/sky/{}.png' .format(args.result,img_name.split('.')[0]), sky_mask)
48 |
49 |
50 | ### Grass mask
51 | grass_mask = (pred == 9).astype(int)
52 | if args.mask_type=='smooth':
53 | grass_mask = grass_mask.astype(float) * pred_prob
54 |
55 | grass_mask = grass_mask * 255.
56 | cv2.imwrite('{}/grass/{}.png' .format(args.result,img_name.split('.')[0]), grass_mask)
57 |
58 |
59 | ### Person mask
60 | person_mask = (pred == 12).astype(int)
61 | if args.mask_type=='smooth':
62 | person_mask = person_mask.astype(float) * pred_prob
63 | person_mask = person_mask * 255.
64 | cv2.imwrite('{}/person/{}.png' .format(args.result,img_name.split('.')[0]), person_mask)
65 |
66 |
67 | def test(segmentation_module, loader, args):
68 | segmentation_module.eval()
69 |
70 | pbar = tqdm(total=len(loader))
71 | for batch_data in loader:
72 | # process data
73 | batch_data = batch_data[0]
74 | segSize = (batch_data['img_ori'].shape[0],
75 | batch_data['img_ori'].shape[1])
76 | img_resized_list = batch_data['img_data']
77 |
78 | with torch.no_grad():
79 | scores = torch.zeros(1, args.num_class, segSize[0], segSize[1])
80 |
81 | for img in img_resized_list:
82 | feed_dict = batch_data.copy()
83 | feed_dict['img_data'] = img
84 | del feed_dict['img_ori']
85 | del feed_dict['info']
86 |
87 | # forward pass
88 | pred_tmp = segmentation_module(feed_dict, segSize=segSize)
89 | scores += (pred_tmp.cpu() / len(args.imgSize))
90 |
91 |
92 | pred_prob, pred = torch.max(scores, dim=1)
93 | pred = as_numpy(pred.squeeze(0).cpu())
94 | pred_prob = as_numpy(pred_prob.squeeze(0).cpu())
95 |
96 | # visualization
97 | visualize_result((batch_data['img_ori'], batch_data['info']), pred, pred_prob, args)
98 |
99 | pbar.update(1)
100 |
101 |
102 | def main(args):
103 |
104 | # Network Builders
105 | builder = ModelBuilder()
106 | net_encoder = builder.build_encoder(
107 | arch=args.arch_encoder,
108 | fc_dim=args.fc_dim,
109 | weights=args.weights_encoder)
110 | net_decoder = builder.build_decoder(
111 | arch=args.arch_decoder,
112 | fc_dim=args.fc_dim,
113 | num_class=args.num_class,
114 | weights=args.weights_decoder,
115 | use_softmax=True)
116 |
117 | crit = nn.NLLLoss(ignore_index=-1)
118 |
119 | segmentation_module = SegmentationModule(net_encoder, net_decoder, crit)
120 |
121 | # Dataset and Loader
122 | if len(args.dataset) == 1 and os.path.isdir(args.dataset[0]):
123 | test_imgs = find_recursive(args.dataset[0], ext='.*')
124 | else:
125 | test_imgs = args.dataset
126 |
127 | list_test = [{'fpath_img': x} for x in test_imgs]
128 | dataset_test = TestDataset(list_test, args, max_sample=args.num_val)
129 | loader_test = torchdata.DataLoader(
130 | dataset_test,
131 | batch_size=args.batch_size,
132 | shuffle=False,
133 | collate_fn=user_scattered_collate,
134 | num_workers=5,
135 | drop_last=True)
136 |
137 | # Main loop
138 | test(segmentation_module, loader_test, args)
139 |
140 | print('Segmentation completed')
141 |
142 |
143 | if __name__ == '__main__':
144 | assert LooseVersion(torch.__version__) >= LooseVersion('0.4.0'), \
145 | 'PyTorch>=0.4.0 is required'
146 |
147 | parser = argparse.ArgumentParser()
148 | # Path related arguments
149 | parser.add_argument('--dataset', required=True, nargs='+', type=str,
150 | help='a list of image paths, or a directory name')
151 | parser.add_argument('--model_path', required=True,
152 | help='folder to model path')
153 | parser.add_argument('--suffix', default='_epoch_20.pth',
154 | help="which snapshot to load")
155 |
156 | # Model related arguments
157 | parser.add_argument('--arch_encoder', default='resnet50dilated',
158 | help="architecture of net_encoder")
159 | parser.add_argument('--arch_decoder', default='ppm_deepsup',
160 | help="architecture of net_decoder")
161 | parser.add_argument('--fc_dim', default=2048, type=int,
162 | help='number of features between encoder and decoder')
163 |
164 | # Data related arguments
165 | parser.add_argument('--num_val', default=-1, type=int,
166 | help='number of images to evalutate')
167 | parser.add_argument('--num_class', default=150, type=int,
168 | help='number of classes')
169 | parser.add_argument('--batch_size', default=1, type=int,
170 | help='batchsize. current only supports 1')
171 | parser.add_argument('--imgSize', default=[300, 400, 500, 600],
172 | nargs='+', type=int,
173 | help='list of input image sizes.'
174 | 'for multiscale testing, e.g. 300 400 500')
175 | parser.add_argument('--imgMaxSize', default=1000, type=int,
176 | help='maximum input image size of long edge')
177 | parser.add_argument('--padding_constant', default=8, type=int,
178 | help='maxmimum downsampling rate of the network')
179 | parser.add_argument('--segm_downsampling_rate', default=8, type=int,
180 | help='downsampling rate of the segmentation label')
181 |
182 | # Misc arguments
183 | parser.add_argument('--result', default='.',
184 | help='folder to output visualization results')
185 | parser.add_argument('--mask_type', required=True,
186 | help='Type 0f mask: binary or smooth')
187 | parser.add_argument('--gpu', default=0, type=int,
188 | help='gpu id for evaluation')
189 |
190 | args = parser.parse_args()
191 |
192 | args.arch_encoder = args.arch_encoder.lower()
193 | args.arch_decoder = args.arch_decoder.lower()
194 | print("Input arguments:")
195 | for key, val in vars(args).items():
196 | print("{:16} {}".format(key, val))
197 |
198 | # absolute paths of model weights
199 | args.weights_encoder = os.path.join(args.model_path,
200 | 'encoder' + args.suffix)
201 | args.weights_decoder = os.path.join(args.model_path,
202 | 'decoder' + args.suffix)
203 |
204 | assert os.path.exists(args.weights_encoder) and \
205 | os.path.exists(args.weights_encoder), 'checkpoint does not exitst!'
206 |
207 | if not os.path.isdir('{}/'.format(args.result)):
208 | os.makedirs('{}/'.format(args.result))
209 | if not os.path.isdir('{}/sky/'.format(args.result)):
210 | os.makedirs('{}/sky/'.format(args.result))
211 | if not os.path.isdir('{}/water/'.format(args.result)):
212 | os.makedirs('{}/water/'.format(args.result))
213 | if not os.path.isdir('{}/grass/'.format(args.result)):
214 | os.makedirs('{}/grass/'.format(args.result))
215 | if not os.path.isdir('{}/person/'.format(args.result)):
216 | os.makedirs('{}/person/'.format(args.result))
217 |
218 | main(args)
219 |
--------------------------------------------------------------------------------
/Segmentation/data/color150.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/smartcameras/ColorFool/a2ccae4db821e304fb54090c332545b71046ea22/Segmentation/data/color150.mat
--------------------------------------------------------------------------------
/Segmentation/data/object150_info.csv:
--------------------------------------------------------------------------------
1 | Idx,Ratio,Train,Val,Stuff,Name
2 | 1,0.1576,11664,1172,1,wall
3 | 2,0.1072,6046,612,1,building;edifice
4 | 3,0.0878,8265,796,1,sky
5 | 4,0.0621,9336,917,1,floor;flooring
6 | 5,0.0480,6678,641,0,tree
7 | 6,0.0450,6604,643,1,ceiling
8 | 7,0.0398,4023,408,1,road;route
9 | 8,0.0231,1906,199,0,bed
10 | 9,0.0198,4688,460,0,windowpane;window
11 | 10,0.0183,2423,225,1,grass
12 | 11,0.0181,2874,294,0,cabinet
13 | 12,0.0166,3068,310,1,sidewalk;pavement
14 | 13,0.0160,5075,526,0,person;individual;someone;somebody;mortal;soul
15 | 14,0.0151,1804,190,1,earth;ground
16 | 15,0.0118,6666,796,0,door;double;door
17 | 16,0.0110,4269,411,0,table
18 | 17,0.0109,1691,160,1,mountain;mount
19 | 18,0.0104,3999,441,0,plant;flora;plant;life
20 | 19,0.0104,2149,217,0,curtain;drape;drapery;mantle;pall
21 | 20,0.0103,3261,318,0,chair
22 | 21,0.0098,3164,306,0,car;auto;automobile;machine;motorcar
23 | 22,0.0074,709,75,1,water
24 | 23,0.0067,3296,315,0,painting;picture
25 | 24,0.0065,1191,106,0,sofa;couch;lounge
26 | 25,0.0061,1516,162,0,shelf
27 | 26,0.0060,667,69,1,house
28 | 27,0.0053,651,57,1,sea
29 | 28,0.0052,1847,224,0,mirror
30 | 29,0.0046,1158,128,1,rug;carpet;carpeting
31 | 30,0.0044,480,44,1,field
32 | 31,0.0044,1172,98,0,armchair
33 | 32,0.0044,1292,184,0,seat
34 | 33,0.0033,1386,138,0,fence;fencing
35 | 34,0.0031,698,61,0,desk
36 | 35,0.0030,781,73,0,rock;stone
37 | 36,0.0027,380,43,0,wardrobe;closet;press
38 | 37,0.0026,3089,302,0,lamp
39 | 38,0.0024,404,37,0,bathtub;bathing;tub;bath;tub
40 | 39,0.0024,804,99,0,railing;rail
41 | 40,0.0023,1453,153,0,cushion
42 | 41,0.0023,411,37,0,base;pedestal;stand
43 | 42,0.0022,1440,162,0,box
44 | 43,0.0022,800,77,0,column;pillar
45 | 44,0.0020,2650,298,0,signboard;sign
46 | 45,0.0019,549,46,0,chest;of;drawers;chest;bureau;dresser
47 | 46,0.0019,367,36,0,counter
48 | 47,0.0018,311,30,1,sand
49 | 48,0.0018,1181,122,0,sink
50 | 49,0.0018,287,23,1,skyscraper
51 | 50,0.0018,468,38,0,fireplace;hearth;open;fireplace
52 | 51,0.0018,402,43,0,refrigerator;icebox
53 | 52,0.0018,130,12,1,grandstand;covered;stand
54 | 53,0.0018,561,64,1,path
55 | 54,0.0017,880,102,0,stairs;steps
56 | 55,0.0017,86,12,1,runway
57 | 56,0.0017,172,11,0,case;display;case;showcase;vitrine
58 | 57,0.0017,198,18,0,pool;table;billiard;table;snooker;table
59 | 58,0.0017,930,109,0,pillow
60 | 59,0.0015,139,18,0,screen;door;screen
61 | 60,0.0015,564,52,1,stairway;staircase
62 | 61,0.0015,320,26,1,river
63 | 62,0.0015,261,29,1,bridge;span
64 | 63,0.0014,275,22,0,bookcase
65 | 64,0.0014,335,60,0,blind;screen
66 | 65,0.0014,792,75,0,coffee;table;cocktail;table
67 | 66,0.0014,395,49,0,toilet;can;commode;crapper;pot;potty;stool;throne
68 | 67,0.0014,1309,138,0,flower
69 | 68,0.0013,1112,113,0,book
70 | 69,0.0013,266,27,1,hill
71 | 70,0.0013,659,66,0,bench
72 | 71,0.0012,331,31,0,countertop
73 | 72,0.0012,531,56,0,stove;kitchen;stove;range;kitchen;range;cooking;stove
74 | 73,0.0012,369,36,0,palm;palm;tree
75 | 74,0.0012,144,9,0,kitchen;island
76 | 75,0.0011,265,29,0,computer;computing;machine;computing;device;data;processor;electronic;computer;information;processing;system
77 | 76,0.0010,324,33,0,swivel;chair
78 | 77,0.0009,304,27,0,boat
79 | 78,0.0009,170,20,0,bar
80 | 79,0.0009,68,6,0,arcade;machine
81 | 80,0.0009,65,8,1,hovel;hut;hutch;shack;shanty
82 | 81,0.0009,248,25,0,bus;autobus;coach;charabanc;double-decker;jitney;motorbus;motorcoach;omnibus;passenger;vehicle
83 | 82,0.0008,492,49,0,towel
84 | 83,0.0008,2510,269,0,light;light;source
85 | 84,0.0008,440,39,0,truck;motortruck
86 | 85,0.0008,147,18,1,tower
87 | 86,0.0008,583,56,0,chandelier;pendant;pendent
88 | 87,0.0007,533,61,0,awning;sunshade;sunblind
89 | 88,0.0007,1989,239,0,streetlight;street;lamp
90 | 89,0.0007,71,5,0,booth;cubicle;stall;kiosk
91 | 90,0.0007,618,53,0,television;television;receiver;television;set;tv;tv;set;idiot;box;boob;tube;telly;goggle;box
92 | 91,0.0007,135,12,0,airplane;aeroplane;plane
93 | 92,0.0007,83,5,1,dirt;track
94 | 93,0.0007,178,17,0,apparel;wearing;apparel;dress;clothes
95 | 94,0.0006,1003,104,0,pole
96 | 95,0.0006,182,12,1,land;ground;soil
97 | 96,0.0006,452,50,0,bannister;banister;balustrade;balusters;handrail
98 | 97,0.0006,42,6,1,escalator;moving;staircase;moving;stairway
99 | 98,0.0006,307,31,0,ottoman;pouf;pouffe;puff;hassock
100 | 99,0.0006,965,114,0,bottle
101 | 100,0.0006,117,13,0,buffet;counter;sideboard
102 | 101,0.0006,354,35,0,poster;posting;placard;notice;bill;card
103 | 102,0.0006,108,9,1,stage
104 | 103,0.0006,557,55,0,van
105 | 104,0.0006,52,4,0,ship
106 | 105,0.0005,99,5,0,fountain
107 | 106,0.0005,57,4,1,conveyer;belt;conveyor;belt;conveyer;conveyor;transporter
108 | 107,0.0005,292,31,0,canopy
109 | 108,0.0005,77,9,0,washer;automatic;washer;washing;machine
110 | 109,0.0005,340,38,0,plaything;toy
111 | 110,0.0005,66,3,1,swimming;pool;swimming;bath;natatorium
112 | 111,0.0005,465,49,0,stool
113 | 112,0.0005,50,4,0,barrel;cask
114 | 113,0.0005,622,75,0,basket;handbasket
115 | 114,0.0005,80,9,1,waterfall;falls
116 | 115,0.0005,59,3,0,tent;collapsible;shelter
117 | 116,0.0005,531,72,0,bag
118 | 117,0.0005,282,30,0,minibike;motorbike
119 | 118,0.0005,73,7,0,cradle
120 | 119,0.0005,435,44,0,oven
121 | 120,0.0005,136,25,0,ball
122 | 121,0.0005,116,24,0,food;solid;food
123 | 122,0.0004,266,31,0,step;stair
124 | 123,0.0004,58,12,0,tank;storage;tank
125 | 124,0.0004,418,83,0,trade;name;brand;name;brand;marque
126 | 125,0.0004,319,43,0,microwave;microwave;oven
127 | 126,0.0004,1193,139,0,pot;flowerpot
128 | 127,0.0004,97,23,0,animal;animate;being;beast;brute;creature;fauna
129 | 128,0.0004,347,36,0,bicycle;bike;wheel;cycle
130 | 129,0.0004,52,5,1,lake
131 | 130,0.0004,246,22,0,dishwasher;dish;washer;dishwashing;machine
132 | 131,0.0004,108,13,0,screen;silver;screen;projection;screen
133 | 132,0.0004,201,30,0,blanket;cover
134 | 133,0.0004,285,21,0,sculpture
135 | 134,0.0004,268,27,0,hood;exhaust;hood
136 | 135,0.0003,1020,108,0,sconce
137 | 136,0.0003,1282,122,0,vase
138 | 137,0.0003,528,65,0,traffic;light;traffic;signal;stoplight
139 | 138,0.0003,453,57,0,tray
140 | 139,0.0003,671,100,0,ashcan;trash;can;garbage;can;wastebin;ash;bin;ash-bin;ashbin;dustbin;trash;barrel;trash;bin
141 | 140,0.0003,397,44,0,fan
142 | 141,0.0003,92,8,1,pier;wharf;wharfage;dock
143 | 142,0.0003,228,18,0,crt;screen
144 | 143,0.0003,570,59,0,plate
145 | 144,0.0003,217,22,0,monitor;monitoring;device
146 | 145,0.0003,206,19,0,bulletin;board;notice;board
147 | 146,0.0003,130,14,0,shower
148 | 147,0.0003,178,28,0,radiator
149 | 148,0.0002,504,57,0,glass;drinking;glass
150 | 149,0.0002,775,96,0,clock
151 | 150,0.0002,421,56,0,flag
152 |
--------------------------------------------------------------------------------
/Segmentation/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import torch
4 | import lib.utils.data as torchdata
5 | import cv2
6 | from torchvision import transforms
7 | import numpy as np
8 |
9 |
10 | class BaseDataset(torchdata.Dataset):
11 | def __init__(self, odgt, opt, **kwargs):
12 | # parse options
13 | self.imgSize = opt.imgSize
14 | self.imgMaxSize = opt.imgMaxSize
15 |
16 | # max down sampling rate of network to avoid rounding during conv or pooling
17 | self.padding_constant = opt.padding_constant
18 |
19 | # parse the input list
20 | self.parse_input_list(odgt, **kwargs)
21 |
22 | # mean and std
23 | self.normalize = transforms.Normalize(
24 | mean=[102.9801, 115.9465, 122.7717],
25 | std=[1., 1., 1.])
26 |
27 | def parse_input_list(self, odgt, max_sample=-1, start_idx=-1, end_idx=-1):
28 | if isinstance(odgt, list):
29 | self.list_sample = odgt
30 | elif isinstance(odgt, str):
31 | self.list_sample = [json.loads(x.rstrip()) for x in open(odgt, 'r')]
32 |
33 | if max_sample > 0:
34 | self.list_sample = self.list_sample[0:max_sample]
35 | if start_idx >= 0 and end_idx >= 0: # divide file list
36 | self.list_sample = self.list_sample[start_idx:end_idx]
37 |
38 | self.num_sample = len(self.list_sample)
39 | assert self.num_sample > 0
40 | print('# samples: {}'.format(self.num_sample))
41 |
42 | def img_transform(self, img):
43 | # image to float
44 | img = img.astype(np.float32)
45 | img = img.transpose((2, 0, 1))
46 | img = self.normalize(torch.from_numpy(img.copy()))
47 | return img
48 |
49 | # Round x to the nearest multiple of p and x' >= x
50 | def round2nearest_multiple(self, x, p):
51 | return ((x - 1) // p + 1) * p
52 |
53 |
54 | class TrainDataset(BaseDataset):
55 | def __init__(self, odgt, opt, batch_per_gpu=1, **kwargs):
56 | super(TrainDataset, self).__init__(odgt, opt, **kwargs)
57 | self.root_dataset = opt.root_dataset
58 | self.random_flip = opt.random_flip
59 | # down sampling rate of segm labe
60 | self.segm_downsampling_rate = opt.segm_downsampling_rate
61 | self.batch_per_gpu = batch_per_gpu
62 |
63 | # classify images into two classes: 1. h > w and 2. h <= w
64 | self.batch_record_list = [[], []]
65 |
66 | # override dataset length when trainig with batch_per_gpu > 1
67 | self.cur_idx = 0
68 | self.if_shuffled = False
69 |
70 | def _get_sub_batch(self):
71 | while True:
72 | # get a sample record
73 | this_sample = self.list_sample[self.cur_idx]
74 | if this_sample['height'] > this_sample['width']:
75 | self.batch_record_list[0].append(this_sample) # h > w, go to 1st class
76 | else:
77 | self.batch_record_list[1].append(this_sample) # h <= w, go to 2nd class
78 |
79 | # update current sample pointer
80 | self.cur_idx += 1
81 | if self.cur_idx >= self.num_sample:
82 | self.cur_idx = 0
83 | np.random.shuffle(self.list_sample)
84 |
85 | if len(self.batch_record_list[0]) == self.batch_per_gpu:
86 | batch_records = self.batch_record_list[0]
87 | self.batch_record_list[0] = []
88 | break
89 | elif len(self.batch_record_list[1]) == self.batch_per_gpu:
90 | batch_records = self.batch_record_list[1]
91 | self.batch_record_list[1] = []
92 | break
93 | return batch_records
94 |
95 | def __getitem__(self, index):
96 | # NOTE: random shuffle for the first time. shuffle in __init__ is useless
97 | if not self.if_shuffled:
98 | np.random.shuffle(self.list_sample)
99 | self.if_shuffled = True
100 |
101 | # get sub-batch candidates
102 | batch_records = self._get_sub_batch()
103 |
104 | # resize all images' short edges to the chosen size
105 | if isinstance(self.imgSize, list):
106 | this_short_size = np.random.choice(self.imgSize)
107 | else:
108 | this_short_size = self.imgSize
109 |
110 | # calculate the BATCH's height and width
111 | # since we concat more than one samples, the batch's h and w shall be larger than EACH sample
112 | batch_resized_size = np.zeros((self.batch_per_gpu, 2), np.int32)
113 | for i in range(self.batch_per_gpu):
114 | img_height, img_width = batch_records[i]['height'], batch_records[i]['width']
115 | this_scale = min(
116 | this_short_size / min(img_height, img_width), \
117 | self.imgMaxSize / max(img_height, img_width))
118 | img_resized_height, img_resized_width = img_height * this_scale, img_width * this_scale
119 | batch_resized_size[i, :] = img_resized_height, img_resized_width
120 | batch_resized_height = np.max(batch_resized_size[:, 0])
121 | batch_resized_width = np.max(batch_resized_size[:, 1])
122 |
123 | # Here we must pad both input image and segmentation map to size h' and w' so that p | h' and p | w'
124 | batch_resized_height = int(self.round2nearest_multiple(batch_resized_height, self.padding_constant))
125 | batch_resized_width = int(self.round2nearest_multiple(batch_resized_width, self.padding_constant))
126 |
127 | assert self.padding_constant >= self.segm_downsampling_rate,\
128 | 'padding constant must be equal or large than segm downsamping rate'
129 | batch_images = torch.zeros(self.batch_per_gpu, 3, batch_resized_height, batch_resized_width)
130 | batch_segms = torch.zeros(
131 | self.batch_per_gpu, batch_resized_height // self.segm_downsampling_rate, \
132 | batch_resized_width // self.segm_downsampling_rate).long()
133 |
134 | for i in range(self.batch_per_gpu):
135 | this_record = batch_records[i]
136 |
137 | # load image and label
138 | image_path = os.path.join(self.root_dataset, this_record['fpath_img'])
139 | segm_path = os.path.join(self.root_dataset, this_record['fpath_segm'])
140 | img = cv2.imread(image_path, cv2.IMREAD_COLOR)
141 | segm = cv2.imread(segm_path, cv2.IMREAD_GRAYSCALE)
142 |
143 | assert(img.ndim == 3)
144 | assert(segm.ndim == 2)
145 | assert(img.shape[0] == segm.shape[0])
146 | assert(img.shape[1] == segm.shape[1])
147 |
148 | if self.random_flip is True:
149 | random_flip = np.random.choice([0, 1])
150 | if random_flip == 1:
151 | img = cv2.flip(img, 1)
152 | segm = cv2.flip(segm, 1)
153 |
154 | # note that each sample within a mini batch has different scale param
155 | img = cv2.resize(img, (batch_resized_size[i, 1], batch_resized_size[i, 0]), interpolation=cv2.INTER_LINEAR)
156 | segm = cv2.resize(segm, (batch_resized_size[i, 1], batch_resized_size[i, 0]), interpolation=cv2.INTER_NEAREST)
157 |
158 | # to avoid seg label misalignment
159 | segm_rounded_height = self.round2nearest_multiple(segm.shape[0], self.segm_downsampling_rate)
160 | segm_rounded_width = self.round2nearest_multiple(segm.shape[1], self.segm_downsampling_rate)
161 | segm_rounded = np.zeros((segm_rounded_height, segm_rounded_width), dtype='uint8')
162 | segm_rounded[:segm.shape[0], :segm.shape[1]] = segm
163 |
164 | segm = cv2.resize(
165 | segm_rounded,
166 | (segm_rounded.shape[1] // self.segm_downsampling_rate, \
167 | segm_rounded.shape[0] // self.segm_downsampling_rate), \
168 | interpolation=cv2.INTER_NEAREST)
169 |
170 | # image transform
171 | img = self.img_transform(img)
172 |
173 | batch_images[i][:, :img.shape[1], :img.shape[2]] = img
174 | batch_segms[i][:segm.shape[0], :segm.shape[1]] = torch.from_numpy(segm.astype(np.int)).long()
175 |
176 | batch_segms = batch_segms - 1 # label from -1 to 149
177 | output = dict()
178 | output['img_data'] = batch_images
179 | output['seg_label'] = batch_segms
180 | return output
181 |
182 | def __len__(self):
183 | return int(1e10) # It's a fake length due to the trick that every loader maintains its own list
184 | #return self.num_sampleclass
185 |
186 |
187 | class ValDataset(BaseDataset):
188 | def __init__(self, odgt, opt, **kwargs):
189 | super(ValDataset, self).__init__(odgt, opt, **kwargs)
190 | self.root_dataset = opt.root_dataset
191 |
192 | def __getitem__(self, index):
193 | this_record = self.list_sample[index]
194 | # load image and label
195 | image_path = os.path.join(self.root_dataset, this_record['fpath_img'])
196 | segm_path = os.path.join(self.root_dataset, this_record['fpath_segm'])
197 | img = cv2.imread(image_path, cv2.IMREAD_COLOR)
198 | segm = cv2.imread(segm_path, cv2.IMREAD_GRAYSCALE)
199 |
200 | ori_height, ori_width, _ = img.shape
201 |
202 | img_resized_list = []
203 | for this_short_size in self.imgSize:
204 | # calculate target height and width
205 | scale = min(this_short_size / float(min(ori_height, ori_width)),
206 | self.imgMaxSize / float(max(ori_height, ori_width)))
207 | target_height, target_width = int(ori_height * scale), int(ori_width * scale)
208 |
209 | # to avoid rounding in network
210 | target_height = self.round2nearest_multiple(target_height, self.padding_constant)
211 | target_width = self.round2nearest_multiple(target_width, self.padding_constant)
212 |
213 | # resize
214 | img_resized = cv2.resize(img.copy(), (target_width, target_height))
215 |
216 | # image transform
217 | img_resized = self.img_transform(img_resized)
218 |
219 | img_resized = torch.unsqueeze(img_resized, 0)
220 | img_resized_list.append(img_resized)
221 |
222 | segm = torch.from_numpy(segm.astype(np.int)).long()
223 | batch_segms = torch.unsqueeze(segm, 0)
224 |
225 | batch_segms = batch_segms - 1 # label from -1 to 149
226 | output = dict()
227 | output['img_ori'] = img.copy()
228 | output['img_data'] = [x.contiguous() for x in img_resized_list]
229 | output['seg_label'] = batch_segms.contiguous()
230 | output['info'] = this_record['fpath_img']
231 | return output
232 |
233 | def __len__(self):
234 | return self.num_sample
235 |
236 |
237 | class TestDataset(BaseDataset):
238 | def __init__(self, odgt, opt, **kwargs):
239 | super(TestDataset, self).__init__(odgt, opt, **kwargs)
240 |
241 | def __getitem__(self, index):
242 | this_record = self.list_sample[index]
243 | # load image and label
244 | image_path = this_record['fpath_img']
245 | img = cv2.imread(image_path, cv2.IMREAD_COLOR)
246 |
247 | ori_height, ori_width, _ = img.shape
248 |
249 | img_resized_list = []
250 | for this_short_size in self.imgSize:
251 | # calculate target height and width
252 | scale = min(this_short_size / float(min(ori_height, ori_width)),
253 | self.imgMaxSize / float(max(ori_height, ori_width)))
254 | target_height, target_width = int(ori_height * scale), int(ori_width * scale)
255 |
256 | # to avoid rounding in network
257 | target_height = self.round2nearest_multiple(target_height, self.padding_constant)
258 | target_width = self.round2nearest_multiple(target_width, self.padding_constant)
259 |
260 | # resize
261 | img_resized = cv2.resize(img.copy(), (target_width, target_height))
262 |
263 | # image transform
264 | img_resized = self.img_transform(img_resized)
265 | img_resized = torch.unsqueeze(img_resized, 0)
266 | img_resized_list.append(img_resized)
267 |
268 | output = dict()
269 | output['img_ori'] = img.copy()
270 | output['img_data'] = [x.contiguous() for x in img_resized_list]
271 | output['info'] = this_record['fpath_img']
272 | return output
273 |
274 | def __len__(self):
275 | return self.num_sample
276 |
--------------------------------------------------------------------------------
/Segmentation/lib/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/smartcameras/ColorFool/a2ccae4db821e304fb54090c332545b71046ea22/Segmentation/lib/__init__.py
--------------------------------------------------------------------------------
/Segmentation/lib/nn/__init__.py:
--------------------------------------------------------------------------------
1 | from .modules import *
2 | from .parallel import UserScatteredDataParallel, user_scattered_collate, async_copy_to
3 |
--------------------------------------------------------------------------------
/Segmentation/lib/nn/modules/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : __init__.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
12 | from .replicate import DataParallelWithCallback, patch_replication_callback
13 |
--------------------------------------------------------------------------------
/Segmentation/lib/nn/modules/batchnorm.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : batchnorm.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import collections
12 |
13 | import torch
14 | import torch.nn.functional as F
15 |
16 | from torch.nn.modules.batchnorm import _BatchNorm
17 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
18 |
19 | from .comm import SyncMaster
20 |
21 | __all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d']
22 |
23 |
24 | def _sum_ft(tensor):
25 | """sum over the first and last dimention"""
26 | return tensor.sum(dim=0).sum(dim=-1)
27 |
28 |
29 | def _unsqueeze_ft(tensor):
30 | """add new dementions at the front and the tail"""
31 | return tensor.unsqueeze(0).unsqueeze(-1)
32 |
33 |
34 | _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
35 | _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
36 |
37 |
38 | class _SynchronizedBatchNorm(_BatchNorm):
39 | def __init__(self, num_features, eps=1e-5, momentum=0.001, affine=True):
40 | super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)
41 |
42 | self._sync_master = SyncMaster(self._data_parallel_master)
43 |
44 | self._is_parallel = False
45 | self._parallel_id = None
46 | self._slave_pipe = None
47 |
48 | # customed batch norm statistics
49 | self._moving_average_fraction = 1. - momentum
50 | self.register_buffer('_tmp_running_mean', torch.zeros(self.num_features))
51 | self.register_buffer('_tmp_running_var', torch.ones(self.num_features))
52 | self.register_buffer('_running_iter', torch.ones(1))
53 | self._tmp_running_mean = self.running_mean.clone() * self._running_iter
54 | self._tmp_running_var = self.running_var.clone() * self._running_iter
55 |
56 | def forward(self, input):
57 | # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
58 | if not (self._is_parallel and self.training):
59 | return F.batch_norm(
60 | input, self.running_mean, self.running_var, self.weight, self.bias,
61 | self.training, self.momentum, self.eps)
62 |
63 | # Resize the input to (B, C, -1).
64 | input_shape = input.size()
65 | input = input.view(input.size(0), self.num_features, -1)
66 |
67 | # Compute the sum and square-sum.
68 | sum_size = input.size(0) * input.size(2)
69 | input_sum = _sum_ft(input)
70 | input_ssum = _sum_ft(input ** 2)
71 |
72 | # Reduce-and-broadcast the statistics.
73 | if self._parallel_id == 0:
74 | mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
75 | else:
76 | mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
77 |
78 | # Compute the output.
79 | if self.affine:
80 | # MJY:: Fuse the multiplication for speed.
81 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
82 | else:
83 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
84 |
85 | # Reshape it.
86 | return output.view(input_shape)
87 |
88 | def __data_parallel_replicate__(self, ctx, copy_id):
89 | self._is_parallel = True
90 | self._parallel_id = copy_id
91 |
92 | # parallel_id == 0 means master device.
93 | if self._parallel_id == 0:
94 | ctx.sync_master = self._sync_master
95 | else:
96 | self._slave_pipe = ctx.sync_master.register_slave(copy_id)
97 |
98 | def _data_parallel_master(self, intermediates):
99 | """Reduce the sum and square-sum, compute the statistics, and broadcast it."""
100 | intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
101 |
102 | to_reduce = [i[1][:2] for i in intermediates]
103 | to_reduce = [j for i in to_reduce for j in i] # flatten
104 | target_gpus = [i[1].sum.get_device() for i in intermediates]
105 |
106 | sum_size = sum([i[1].sum_size for i in intermediates])
107 | sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
108 |
109 | mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
110 |
111 | broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
112 |
113 | outputs = []
114 | for i, rec in enumerate(intermediates):
115 | outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))
116 |
117 | return outputs
118 |
119 | def _add_weighted(self, dest, delta, alpha=1, beta=1, bias=0):
120 | """return *dest* by `dest := dest*alpha + delta*beta + bias`"""
121 | return dest * alpha + delta * beta + bias
122 |
123 | def _compute_mean_std(self, sum_, ssum, size):
124 | """Compute the mean and standard-deviation with sum and square-sum. This method
125 | also maintains the moving average on the master device."""
126 | assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
127 | mean = sum_ / size
128 | sumvar = ssum - sum_ * mean
129 | unbias_var = sumvar / (size - 1)
130 | bias_var = sumvar / size
131 |
132 | self._tmp_running_mean = self._add_weighted(self._tmp_running_mean, mean.data, alpha=self._moving_average_fraction)
133 | self._tmp_running_var = self._add_weighted(self._tmp_running_var, unbias_var.data, alpha=self._moving_average_fraction)
134 | self._running_iter = self._add_weighted(self._running_iter, 1, alpha=self._moving_average_fraction)
135 |
136 | self.running_mean = self._tmp_running_mean / self._running_iter
137 | self.running_var = self._tmp_running_var / self._running_iter
138 |
139 | return mean, bias_var.clamp(self.eps) ** -0.5
140 |
141 |
142 | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
143 | r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
144 | mini-batch.
145 |
146 | .. math::
147 |
148 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
149 |
150 | This module differs from the built-in PyTorch BatchNorm1d as the mean and
151 | standard-deviation are reduced across all devices during training.
152 |
153 | For example, when one uses `nn.DataParallel` to wrap the network during
154 | training, PyTorch's implementation normalize the tensor on each device using
155 | the statistics only on that device, which accelerated the computation and
156 | is also easy to implement, but the statistics might be inaccurate.
157 | Instead, in this synchronized version, the statistics will be computed
158 | over all training samples distributed on multiple devices.
159 |
160 | Note that, for one-GPU or CPU-only case, this module behaves exactly same
161 | as the built-in PyTorch implementation.
162 |
163 | The mean and standard-deviation are calculated per-dimension over
164 | the mini-batches and gamma and beta are learnable parameter vectors
165 | of size C (where C is the input size).
166 |
167 | During training, this layer keeps a running estimate of its computed mean
168 | and variance. The running sum is kept with a default momentum of 0.1.
169 |
170 | During evaluation, this running mean/variance is used for normalization.
171 |
172 | Because the BatchNorm is done over the `C` dimension, computing statistics
173 | on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
174 |
175 | Args:
176 | num_features: num_features from an expected input of size
177 | `batch_size x num_features [x width]`
178 | eps: a value added to the denominator for numerical stability.
179 | Default: 1e-5
180 | momentum: the value used for the running_mean and running_var
181 | computation. Default: 0.1
182 | affine: a boolean value that when set to ``True``, gives the layer learnable
183 | affine parameters. Default: ``True``
184 |
185 | Shape:
186 | - Input: :math:`(N, C)` or :math:`(N, C, L)`
187 | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
188 |
189 | Examples:
190 | >>> # With Learnable Parameters
191 | >>> m = SynchronizedBatchNorm1d(100)
192 | >>> # Without Learnable Parameters
193 | >>> m = SynchronizedBatchNorm1d(100, affine=False)
194 | >>> input = torch.autograd.Variable(torch.randn(20, 100))
195 | >>> output = m(input)
196 | """
197 |
198 | def _check_input_dim(self, input):
199 | if input.dim() != 2 and input.dim() != 3:
200 | raise ValueError('expected 2D or 3D input (got {}D input)'
201 | .format(input.dim()))
202 | super(SynchronizedBatchNorm1d, self)._check_input_dim(input)
203 |
204 |
205 | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
206 | r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
207 | of 3d inputs
208 |
209 | .. math::
210 |
211 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
212 |
213 | This module differs from the built-in PyTorch BatchNorm2d as the mean and
214 | standard-deviation are reduced across all devices during training.
215 |
216 | For example, when one uses `nn.DataParallel` to wrap the network during
217 | training, PyTorch's implementation normalize the tensor on each device using
218 | the statistics only on that device, which accelerated the computation and
219 | is also easy to implement, but the statistics might be inaccurate.
220 | Instead, in this synchronized version, the statistics will be computed
221 | over all training samples distributed on multiple devices.
222 |
223 | Note that, for one-GPU or CPU-only case, this module behaves exactly same
224 | as the built-in PyTorch implementation.
225 |
226 | The mean and standard-deviation are calculated per-dimension over
227 | the mini-batches and gamma and beta are learnable parameter vectors
228 | of size C (where C is the input size).
229 |
230 | During training, this layer keeps a running estimate of its computed mean
231 | and variance. The running sum is kept with a default momentum of 0.1.
232 |
233 | During evaluation, this running mean/variance is used for normalization.
234 |
235 | Because the BatchNorm is done over the `C` dimension, computing statistics
236 | on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
237 |
238 | Args:
239 | num_features: num_features from an expected input of
240 | size batch_size x num_features x height x width
241 | eps: a value added to the denominator for numerical stability.
242 | Default: 1e-5
243 | momentum: the value used for the running_mean and running_var
244 | computation. Default: 0.1
245 | affine: a boolean value that when set to ``True``, gives the layer learnable
246 | affine parameters. Default: ``True``
247 |
248 | Shape:
249 | - Input: :math:`(N, C, H, W)`
250 | - Output: :math:`(N, C, H, W)` (same shape as input)
251 |
252 | Examples:
253 | >>> # With Learnable Parameters
254 | >>> m = SynchronizedBatchNorm2d(100)
255 | >>> # Without Learnable Parameters
256 | >>> m = SynchronizedBatchNorm2d(100, affine=False)
257 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
258 | >>> output = m(input)
259 | """
260 |
261 | def _check_input_dim(self, input):
262 | if input.dim() != 4:
263 | raise ValueError('expected 4D input (got {}D input)'
264 | .format(input.dim()))
265 | super(SynchronizedBatchNorm2d, self)._check_input_dim(input)
266 |
267 |
268 | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
269 | r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
270 | of 4d inputs
271 |
272 | .. math::
273 |
274 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
275 |
276 | This module differs from the built-in PyTorch BatchNorm3d as the mean and
277 | standard-deviation are reduced across all devices during training.
278 |
279 | For example, when one uses `nn.DataParallel` to wrap the network during
280 | training, PyTorch's implementation normalize the tensor on each device using
281 | the statistics only on that device, which accelerated the computation and
282 | is also easy to implement, but the statistics might be inaccurate.
283 | Instead, in this synchronized version, the statistics will be computed
284 | over all training samples distributed on multiple devices.
285 |
286 | Note that, for one-GPU or CPU-only case, this module behaves exactly same
287 | as the built-in PyTorch implementation.
288 |
289 | The mean and standard-deviation are calculated per-dimension over
290 | the mini-batches and gamma and beta are learnable parameter vectors
291 | of size C (where C is the input size).
292 |
293 | During training, this layer keeps a running estimate of its computed mean
294 | and variance. The running sum is kept with a default momentum of 0.1.
295 |
296 | During evaluation, this running mean/variance is used for normalization.
297 |
298 | Because the BatchNorm is done over the `C` dimension, computing statistics
299 | on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
300 | or Spatio-temporal BatchNorm
301 |
302 | Args:
303 | num_features: num_features from an expected input of
304 | size batch_size x num_features x depth x height x width
305 | eps: a value added to the denominator for numerical stability.
306 | Default: 1e-5
307 | momentum: the value used for the running_mean and running_var
308 | computation. Default: 0.1
309 | affine: a boolean value that when set to ``True``, gives the layer learnable
310 | affine parameters. Default: ``True``
311 |
312 | Shape:
313 | - Input: :math:`(N, C, D, H, W)`
314 | - Output: :math:`(N, C, D, H, W)` (same shape as input)
315 |
316 | Examples:
317 | >>> # With Learnable Parameters
318 | >>> m = SynchronizedBatchNorm3d(100)
319 | >>> # Without Learnable Parameters
320 | >>> m = SynchronizedBatchNorm3d(100, affine=False)
321 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
322 | >>> output = m(input)
323 | """
324 |
325 | def _check_input_dim(self, input):
326 | if input.dim() != 5:
327 | raise ValueError('expected 5D input (got {}D input)'
328 | .format(input.dim()))
329 | super(SynchronizedBatchNorm3d, self)._check_input_dim(input)
330 |
--------------------------------------------------------------------------------
/Segmentation/lib/nn/modules/comm.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : comm.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import queue
12 | import collections
13 | import threading
14 |
15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
16 |
17 |
18 | class FutureResult(object):
19 | """A thread-safe future implementation. Used only as one-to-one pipe."""
20 |
21 | def __init__(self):
22 | self._result = None
23 | self._lock = threading.Lock()
24 | self._cond = threading.Condition(self._lock)
25 |
26 | def put(self, result):
27 | with self._lock:
28 | assert self._result is None, 'Previous result has\'t been fetched.'
29 | self._result = result
30 | self._cond.notify()
31 |
32 | def get(self):
33 | with self._lock:
34 | if self._result is None:
35 | self._cond.wait()
36 |
37 | res = self._result
38 | self._result = None
39 | return res
40 |
41 |
42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
44 |
45 |
46 | class SlavePipe(_SlavePipeBase):
47 | """Pipe for master-slave communication."""
48 |
49 | def run_slave(self, msg):
50 | self.queue.put((self.identifier, msg))
51 | ret = self.result.get()
52 | self.queue.put(True)
53 | return ret
54 |
55 |
56 | class SyncMaster(object):
57 | """An abstract `SyncMaster` object.
58 |
59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should
60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master.
61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
62 | and passed to a registered callback.
63 | - After receiving the messages, the master device should gather the information and determine to message passed
64 | back to each slave devices.
65 | """
66 |
67 | def __init__(self, master_callback):
68 | """
69 |
70 | Args:
71 | master_callback: a callback to be invoked after having collected messages from slave devices.
72 | """
73 | self._master_callback = master_callback
74 | self._queue = queue.Queue()
75 | self._registry = collections.OrderedDict()
76 | self._activated = False
77 |
78 | def register_slave(self, identifier):
79 | """
80 | Register an slave device.
81 |
82 | Args:
83 | identifier: an identifier, usually is the device id.
84 |
85 | Returns: a `SlavePipe` object which can be used to communicate with the master device.
86 |
87 | """
88 | if self._activated:
89 | assert self._queue.empty(), 'Queue is not clean before next initialization.'
90 | self._activated = False
91 | self._registry.clear()
92 | future = FutureResult()
93 | self._registry[identifier] = _MasterRegistry(future)
94 | return SlavePipe(identifier, self._queue, future)
95 |
96 | def run_master(self, master_msg):
97 | """
98 | Main entry for the master device in each forward pass.
99 | The messages were first collected from each devices (including the master device), and then
100 | an callback will be invoked to compute the message to be sent back to each devices
101 | (including the master device).
102 |
103 | Args:
104 | master_msg: the message that the master want to send to itself. This will be placed as the first
105 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
106 |
107 | Returns: the message to be sent back to the master device.
108 |
109 | """
110 | self._activated = True
111 |
112 | intermediates = [(0, master_msg)]
113 | for i in range(self.nr_slaves):
114 | intermediates.append(self._queue.get())
115 |
116 | results = self._master_callback(intermediates)
117 | assert results[0][0] == 0, 'The first result should belongs to the master.'
118 |
119 | for i, res in results:
120 | if i == 0:
121 | continue
122 | self._registry[i].result.put(res)
123 |
124 | for i in range(self.nr_slaves):
125 | assert self._queue.get() is True
126 |
127 | return results[0][1]
128 |
129 | @property
130 | def nr_slaves(self):
131 | return len(self._registry)
132 |
--------------------------------------------------------------------------------
/Segmentation/lib/nn/modules/replicate.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : replicate.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import functools
12 |
13 | from torch.nn.parallel.data_parallel import DataParallel
14 |
15 | __all__ = [
16 | 'CallbackContext',
17 | 'execute_replication_callbacks',
18 | 'DataParallelWithCallback',
19 | 'patch_replication_callback'
20 | ]
21 |
22 |
23 | class CallbackContext(object):
24 | pass
25 |
26 |
27 | def execute_replication_callbacks(modules):
28 | """
29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
30 |
31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
32 |
33 | Note that, as all modules are isomorphism, we assign each sub-module with a context
34 | (shared among multiple copies of this module on different devices).
35 | Through this context, different copies can share some information.
36 |
37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
38 | of any slave copies.
39 | """
40 | master_copy = modules[0]
41 | nr_modules = len(list(master_copy.modules()))
42 | ctxs = [CallbackContext() for _ in range(nr_modules)]
43 |
44 | for i, module in enumerate(modules):
45 | for j, m in enumerate(module.modules()):
46 | if hasattr(m, '__data_parallel_replicate__'):
47 | m.__data_parallel_replicate__(ctxs[j], i)
48 |
49 |
50 | class DataParallelWithCallback(DataParallel):
51 | """
52 | Data Parallel with a replication callback.
53 |
54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
55 | original `replicate` function.
56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
57 |
58 | Examples:
59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
61 | # sync_bn.__data_parallel_replicate__ will be invoked.
62 | """
63 |
64 | def replicate(self, module, device_ids):
65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
66 | execute_replication_callbacks(modules)
67 | return modules
68 |
69 |
70 | def patch_replication_callback(data_parallel):
71 | """
72 | Monkey-patch an existing `DataParallel` object. Add the replication callback.
73 | Useful when you have customized `DataParallel` implementation.
74 |
75 | Examples:
76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
78 | > patch_replication_callback(sync_bn)
79 | # this is equivalent to
80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
82 | """
83 |
84 | assert isinstance(data_parallel, DataParallel)
85 |
86 | old_replicate = data_parallel.replicate
87 |
88 | @functools.wraps(old_replicate)
89 | def new_replicate(module, device_ids):
90 | modules = old_replicate(module, device_ids)
91 | execute_replication_callbacks(modules)
92 | return modules
93 |
94 | data_parallel.replicate = new_replicate
95 |
--------------------------------------------------------------------------------
/Segmentation/lib/nn/modules/tests/test_numeric_batchnorm.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : test_numeric_batchnorm.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 |
9 | import unittest
10 |
11 | import torch
12 | import torch.nn as nn
13 | from torch.autograd import Variable
14 |
15 | from sync_batchnorm.unittest import TorchTestCase
16 |
17 |
18 | def handy_var(a, unbias=True):
19 | n = a.size(0)
20 | asum = a.sum(dim=0)
21 | as_sum = (a ** 2).sum(dim=0) # a square sum
22 | sumvar = as_sum - asum * asum / n
23 | if unbias:
24 | return sumvar / (n - 1)
25 | else:
26 | return sumvar / n
27 |
28 |
29 | class NumericTestCase(TorchTestCase):
30 | def testNumericBatchNorm(self):
31 | a = torch.rand(16, 10)
32 | bn = nn.BatchNorm2d(10, momentum=1, eps=1e-5, affine=False)
33 | bn.train()
34 |
35 | a_var1 = Variable(a, requires_grad=True)
36 | b_var1 = bn(a_var1)
37 | loss1 = b_var1.sum()
38 | loss1.backward()
39 |
40 | a_var2 = Variable(a, requires_grad=True)
41 | a_mean2 = a_var2.mean(dim=0, keepdim=True)
42 | a_std2 = torch.sqrt(handy_var(a_var2, unbias=False).clamp(min=1e-5))
43 | # a_std2 = torch.sqrt(a_var2.var(dim=0, keepdim=True, unbiased=False) + 1e-5)
44 | b_var2 = (a_var2 - a_mean2) / a_std2
45 | loss2 = b_var2.sum()
46 | loss2.backward()
47 |
48 | self.assertTensorClose(bn.running_mean, a.mean(dim=0))
49 | self.assertTensorClose(bn.running_var, handy_var(a))
50 | self.assertTensorClose(a_var1.data, a_var2.data)
51 | self.assertTensorClose(b_var1.data, b_var2.data)
52 | self.assertTensorClose(a_var1.grad, a_var2.grad)
53 |
54 |
55 | if __name__ == '__main__':
56 | unittest.main()
57 |
--------------------------------------------------------------------------------
/Segmentation/lib/nn/modules/tests/test_sync_batchnorm.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : test_sync_batchnorm.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 |
9 | import unittest
10 |
11 | import torch
12 | import torch.nn as nn
13 | from torch.autograd import Variable
14 |
15 | from sync_batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, DataParallelWithCallback
16 | from sync_batchnorm.unittest import TorchTestCase
17 |
18 |
19 | def handy_var(a, unbias=True):
20 | n = a.size(0)
21 | asum = a.sum(dim=0)
22 | as_sum = (a ** 2).sum(dim=0) # a square sum
23 | sumvar = as_sum - asum * asum / n
24 | if unbias:
25 | return sumvar / (n - 1)
26 | else:
27 | return sumvar / n
28 |
29 |
30 | def _find_bn(module):
31 | for m in module.modules():
32 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, SynchronizedBatchNorm1d, SynchronizedBatchNorm2d)):
33 | return m
34 |
35 |
36 | class SyncTestCase(TorchTestCase):
37 | def _syncParameters(self, bn1, bn2):
38 | bn1.reset_parameters()
39 | bn2.reset_parameters()
40 | if bn1.affine and bn2.affine:
41 | bn2.weight.data.copy_(bn1.weight.data)
42 | bn2.bias.data.copy_(bn1.bias.data)
43 |
44 | def _checkBatchNormResult(self, bn1, bn2, input, is_train, cuda=False):
45 | """Check the forward and backward for the customized batch normalization."""
46 | bn1.train(mode=is_train)
47 | bn2.train(mode=is_train)
48 |
49 | if cuda:
50 | input = input.cuda()
51 |
52 | self._syncParameters(_find_bn(bn1), _find_bn(bn2))
53 |
54 | input1 = Variable(input, requires_grad=True)
55 | output1 = bn1(input1)
56 | output1.sum().backward()
57 | input2 = Variable(input, requires_grad=True)
58 | output2 = bn2(input2)
59 | output2.sum().backward()
60 |
61 | self.assertTensorClose(input1.data, input2.data)
62 | self.assertTensorClose(output1.data, output2.data)
63 | self.assertTensorClose(input1.grad, input2.grad)
64 | self.assertTensorClose(_find_bn(bn1).running_mean, _find_bn(bn2).running_mean)
65 | self.assertTensorClose(_find_bn(bn1).running_var, _find_bn(bn2).running_var)
66 |
67 | def testSyncBatchNormNormalTrain(self):
68 | bn = nn.BatchNorm1d(10)
69 | sync_bn = SynchronizedBatchNorm1d(10)
70 |
71 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True)
72 |
73 | def testSyncBatchNormNormalEval(self):
74 | bn = nn.BatchNorm1d(10)
75 | sync_bn = SynchronizedBatchNorm1d(10)
76 |
77 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False)
78 |
79 | def testSyncBatchNormSyncTrain(self):
80 | bn = nn.BatchNorm1d(10, eps=1e-5, affine=False)
81 | sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
82 | sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
83 |
84 | bn.cuda()
85 | sync_bn.cuda()
86 |
87 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True, cuda=True)
88 |
89 | def testSyncBatchNormSyncEval(self):
90 | bn = nn.BatchNorm1d(10, eps=1e-5, affine=False)
91 | sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
92 | sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
93 |
94 | bn.cuda()
95 | sync_bn.cuda()
96 |
97 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False, cuda=True)
98 |
99 | def testSyncBatchNorm2DSyncTrain(self):
100 | bn = nn.BatchNorm2d(10)
101 | sync_bn = SynchronizedBatchNorm2d(10)
102 | sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
103 |
104 | bn.cuda()
105 | sync_bn.cuda()
106 |
107 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10, 16, 16), True, cuda=True)
108 |
109 |
110 | if __name__ == '__main__':
111 | unittest.main()
112 |
--------------------------------------------------------------------------------
/Segmentation/lib/nn/modules/unittest.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : unittest.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import unittest
12 |
13 | import numpy as np
14 | from torch.autograd import Variable
15 |
16 |
17 | def as_numpy(v):
18 | if isinstance(v, Variable):
19 | v = v.data
20 | return v.cpu().numpy()
21 |
22 |
23 | class TorchTestCase(unittest.TestCase):
24 | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3):
25 | npa, npb = as_numpy(a), as_numpy(b)
26 | self.assertTrue(
27 | np.allclose(npa, npb, atol=atol),
28 | 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max())
29 | )
30 |
--------------------------------------------------------------------------------
/Segmentation/lib/nn/parallel/__init__.py:
--------------------------------------------------------------------------------
1 | from .data_parallel import UserScatteredDataParallel, user_scattered_collate, async_copy_to
2 |
--------------------------------------------------------------------------------
/Segmentation/lib/nn/parallel/data_parallel.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf8 -*-
2 |
3 | import torch.cuda as cuda
4 | import torch.nn as nn
5 | import torch
6 | import collections
7 | from torch.nn.parallel._functions import Gather
8 |
9 |
10 | __all__ = ['UserScatteredDataParallel', 'user_scattered_collate', 'async_copy_to']
11 |
12 |
13 | def async_copy_to(obj, dev, main_stream=None):
14 | if torch.is_tensor(obj):
15 | v = obj.cuda(dev, non_blocking=True)
16 | if main_stream is not None:
17 | v.data.record_stream(main_stream)
18 | return v
19 | elif isinstance(obj, collections.Mapping):
20 | return {k: async_copy_to(o, dev, main_stream) for k, o in obj.items()}
21 | elif isinstance(obj, collections.Sequence):
22 | return [async_copy_to(o, dev, main_stream) for o in obj]
23 | else:
24 | return obj
25 |
26 |
27 | def dict_gather(outputs, target_device, dim=0):
28 | """
29 | Gathers variables from different GPUs on a specified device
30 | (-1 means the CPU), with dictionary support.
31 | """
32 | def gather_map(outputs):
33 | out = outputs[0]
34 | if torch.is_tensor(out):
35 | # MJY(20180330) HACK:: force nr_dims > 0
36 | if out.dim() == 0:
37 | outputs = [o.unsqueeze(0) for o in outputs]
38 | return Gather.apply(target_device, dim, *outputs)
39 | elif out is None:
40 | return None
41 | elif isinstance(out, collections.Mapping):
42 | return {k: gather_map([o[k] for o in outputs]) for k in out}
43 | elif isinstance(out, collections.Sequence):
44 | return type(out)(map(gather_map, zip(*outputs)))
45 | return gather_map(outputs)
46 |
47 |
48 | class DictGatherDataParallel(nn.DataParallel):
49 | def gather(self, outputs, output_device):
50 | return dict_gather(outputs, output_device, dim=self.dim)
51 |
52 |
53 | class UserScatteredDataParallel(DictGatherDataParallel):
54 | def scatter(self, inputs, kwargs, device_ids):
55 | assert len(inputs) == 1
56 | inputs = inputs[0]
57 | inputs = _async_copy_stream(inputs, device_ids)
58 | inputs = [[i] for i in inputs]
59 | assert len(kwargs) == 0
60 | kwargs = [{} for _ in range(len(inputs))]
61 |
62 | return inputs, kwargs
63 |
64 |
65 | def user_scattered_collate(batch):
66 | return batch
67 |
68 |
69 | def _async_copy(inputs, device_ids):
70 | nr_devs = len(device_ids)
71 | assert type(inputs) in (tuple, list)
72 | assert len(inputs) == nr_devs
73 |
74 | outputs = []
75 | for i, dev in zip(inputs, device_ids):
76 | with cuda.device(dev):
77 | outputs.append(async_copy_to(i, dev))
78 |
79 | return tuple(outputs)
80 |
81 |
82 | def _async_copy_stream(inputs, device_ids):
83 | nr_devs = len(device_ids)
84 | assert type(inputs) in (tuple, list)
85 | assert len(inputs) == nr_devs
86 |
87 | outputs = []
88 | streams = [_get_stream(d) for d in device_ids]
89 | for i, dev, stream in zip(inputs, device_ids, streams):
90 | with cuda.device(dev):
91 | main_stream = cuda.current_stream()
92 | with cuda.stream(stream):
93 | outputs.append(async_copy_to(i, dev, main_stream=main_stream))
94 | main_stream.wait_stream(stream)
95 |
96 | return outputs
97 |
98 |
99 | """Adapted from: torch/nn/parallel/_functions.py"""
100 | # background streams used for copying
101 | _streams = None
102 |
103 |
104 | def _get_stream(device):
105 | """Gets a background stream for copying between CPU and GPU"""
106 | global _streams
107 | if device == -1:
108 | return None
109 | if _streams is None:
110 | _streams = [None] * cuda.device_count()
111 | if _streams[device] is None: _streams[device] = cuda.Stream(device)
112 | return _streams[device]
113 |
--------------------------------------------------------------------------------
/Segmentation/lib/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .th import *
2 |
--------------------------------------------------------------------------------
/Segmentation/lib/utils/data/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from .dataset import Dataset, TensorDataset, ConcatDataset
3 | from .dataloader import DataLoader
4 |
--------------------------------------------------------------------------------
/Segmentation/lib/utils/data/dataloader.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.multiprocessing as multiprocessing
3 | from torch._C import _set_worker_signal_handlers, _set_worker_pids, \
4 | _remove_worker_pids, _error_if_any_worker_fails
5 | from .sampler import SequentialSampler, RandomSampler, BatchSampler
6 | import signal
7 | import functools
8 | import collections
9 | import re
10 | import sys
11 | import threading
12 | import traceback
13 | from torch._six import string_classes, int_classes
14 | import numpy as np
15 |
16 | if sys.version_info[0] == 2:
17 | import Queue as queue
18 | else:
19 | import queue
20 |
21 |
22 | class ExceptionWrapper(object):
23 | r"Wraps an exception plus traceback to communicate across threads"
24 |
25 | def __init__(self, exc_info):
26 | self.exc_type = exc_info[0]
27 | self.exc_msg = "".join(traceback.format_exception(*exc_info))
28 |
29 |
30 | _use_shared_memory = False
31 | """Whether to use shared memory in default_collate"""
32 |
33 |
34 | def _worker_loop(dataset, index_queue, data_queue, collate_fn, seed, init_fn, worker_id):
35 | global _use_shared_memory
36 | _use_shared_memory = True
37 |
38 | # Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal
39 | # module's handlers are executed after Python returns from C low-level
40 | # handlers, likely when the same fatal signal happened again already.
41 | # https://docs.python.org/3/library/signal.html Sec. 18.8.1.1
42 | _set_worker_signal_handlers()
43 |
44 | torch.set_num_threads(1)
45 | torch.manual_seed(seed)
46 | np.random.seed(seed)
47 |
48 | if init_fn is not None:
49 | init_fn(worker_id)
50 |
51 | while True:
52 | r = index_queue.get()
53 | if r is None:
54 | break
55 | idx, batch_indices = r
56 | try:
57 | samples = collate_fn([dataset[i] for i in batch_indices])
58 | except Exception:
59 | data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
60 | else:
61 | data_queue.put((idx, samples))
62 |
63 |
64 | def _worker_manager_loop(in_queue, out_queue, done_event, pin_memory, device_id):
65 | if pin_memory:
66 | torch.cuda.set_device(device_id)
67 |
68 | while True:
69 | try:
70 | r = in_queue.get()
71 | except Exception:
72 | if done_event.is_set():
73 | return
74 | raise
75 | if r is None:
76 | break
77 | if isinstance(r[1], ExceptionWrapper):
78 | out_queue.put(r)
79 | continue
80 | idx, batch = r
81 | try:
82 | if pin_memory:
83 | batch = pin_memory_batch(batch)
84 | except Exception:
85 | out_queue.put((idx, ExceptionWrapper(sys.exc_info())))
86 | else:
87 | out_queue.put((idx, batch))
88 |
89 | numpy_type_map = {
90 | 'float64': torch.DoubleTensor,
91 | 'float32': torch.FloatTensor,
92 | 'float16': torch.HalfTensor,
93 | 'int64': torch.LongTensor,
94 | 'int32': torch.IntTensor,
95 | 'int16': torch.ShortTensor,
96 | 'int8': torch.CharTensor,
97 | 'uint8': torch.ByteTensor,
98 | }
99 |
100 |
101 | def default_collate(batch):
102 | "Puts each data field into a tensor with outer dimension batch size"
103 |
104 | error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"
105 | elem_type = type(batch[0])
106 | if torch.is_tensor(batch[0]):
107 | out = None
108 | if _use_shared_memory:
109 | # If we're in a background process, concatenate directly into a
110 | # shared memory tensor to avoid an extra copy
111 | numel = sum([x.numel() for x in batch])
112 | storage = batch[0].storage()._new_shared(numel)
113 | out = batch[0].new(storage)
114 | return torch.stack(batch, 0, out=out)
115 | elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
116 | and elem_type.__name__ != 'string_':
117 | elem = batch[0]
118 | if elem_type.__name__ == 'ndarray':
119 | # array of string classes and object
120 | if re.search('[SaUO]', elem.dtype.str) is not None:
121 | raise TypeError(error_msg.format(elem.dtype))
122 |
123 | return torch.stack([torch.from_numpy(b) for b in batch], 0)
124 | if elem.shape == (): # scalars
125 | py_type = float if elem.dtype.name.startswith('float') else int
126 | return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
127 | elif isinstance(batch[0], int_classes):
128 | return torch.LongTensor(batch)
129 | elif isinstance(batch[0], float):
130 | return torch.DoubleTensor(batch)
131 | elif isinstance(batch[0], string_classes):
132 | return batch
133 | elif isinstance(batch[0], collections.Mapping):
134 | return {key: default_collate([d[key] for d in batch]) for key in batch[0]}
135 | elif isinstance(batch[0], collections.Sequence):
136 | transposed = zip(*batch)
137 | return [default_collate(samples) for samples in transposed]
138 |
139 | raise TypeError((error_msg.format(type(batch[0]))))
140 |
141 |
142 | def pin_memory_batch(batch):
143 | if torch.is_tensor(batch):
144 | return batch.pin_memory()
145 | elif isinstance(batch, string_classes):
146 | return batch
147 | elif isinstance(batch, collections.Mapping):
148 | return {k: pin_memory_batch(sample) for k, sample in batch.items()}
149 | elif isinstance(batch, collections.Sequence):
150 | return [pin_memory_batch(sample) for sample in batch]
151 | else:
152 | return batch
153 |
154 |
155 | _SIGCHLD_handler_set = False
156 | """Whether SIGCHLD handler is set for DataLoader worker failures. Only one
157 | handler needs to be set for all DataLoaders in a process."""
158 |
159 |
160 | def _set_SIGCHLD_handler():
161 | # Windows doesn't support SIGCHLD handler
162 | if sys.platform == 'win32':
163 | return
164 | # can't set signal in child threads
165 | if not isinstance(threading.current_thread(), threading._MainThread):
166 | return
167 | global _SIGCHLD_handler_set
168 | if _SIGCHLD_handler_set:
169 | return
170 | previous_handler = signal.getsignal(signal.SIGCHLD)
171 | if not callable(previous_handler):
172 | previous_handler = None
173 |
174 | def handler(signum, frame):
175 | # This following call uses `waitid` with WNOHANG from C side. Therefore,
176 | # Python can still get and update the process status successfully.
177 | _error_if_any_worker_fails()
178 | if previous_handler is not None:
179 | previous_handler(signum, frame)
180 |
181 | signal.signal(signal.SIGCHLD, handler)
182 | _SIGCHLD_handler_set = True
183 |
184 |
185 | class DataLoaderIter(object):
186 | "Iterates once over the DataLoader's dataset, as specified by the sampler"
187 |
188 | def __init__(self, loader):
189 | self.dataset = loader.dataset
190 | self.collate_fn = loader.collate_fn
191 | self.batch_sampler = loader.batch_sampler
192 | self.num_workers = loader.num_workers
193 | self.pin_memory = loader.pin_memory and torch.cuda.is_available()
194 | self.timeout = loader.timeout
195 | self.done_event = threading.Event()
196 |
197 | self.sample_iter = iter(self.batch_sampler)
198 |
199 | if self.num_workers > 0:
200 | self.worker_init_fn = loader.worker_init_fn
201 | self.index_queue = multiprocessing.SimpleQueue()
202 | self.worker_result_queue = multiprocessing.SimpleQueue()
203 | self.batches_outstanding = 0
204 | self.worker_pids_set = False
205 | self.shutdown = False
206 | self.send_idx = 0
207 | self.rcvd_idx = 0
208 | self.reorder_dict = {}
209 |
210 | base_seed = torch.LongTensor(1).random_(0, 2**31-1)[0]
211 | self.workers = [
212 | multiprocessing.Process(
213 | target=_worker_loop,
214 | args=(self.dataset, self.index_queue, self.worker_result_queue, self.collate_fn,
215 | base_seed + i, self.worker_init_fn, i))
216 | for i in range(self.num_workers)]
217 |
218 | if self.pin_memory or self.timeout > 0:
219 | self.data_queue = queue.Queue()
220 | if self.pin_memory:
221 | maybe_device_id = torch.cuda.current_device()
222 | else:
223 | # do not initialize cuda context if not necessary
224 | maybe_device_id = None
225 | self.worker_manager_thread = threading.Thread(
226 | target=_worker_manager_loop,
227 | args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory,
228 | maybe_device_id))
229 | self.worker_manager_thread.daemon = True
230 | self.worker_manager_thread.start()
231 | else:
232 | self.data_queue = self.worker_result_queue
233 |
234 | for w in self.workers:
235 | w.daemon = True # ensure that the worker exits on process exit
236 | w.start()
237 |
238 | _set_worker_pids(id(self), tuple(w.pid for w in self.workers))
239 | _set_SIGCHLD_handler()
240 | self.worker_pids_set = True
241 |
242 | # prime the prefetch loop
243 | for _ in range(2 * self.num_workers):
244 | self._put_indices()
245 |
246 | def __len__(self):
247 | return len(self.batch_sampler)
248 |
249 | def _get_batch(self):
250 | if self.timeout > 0:
251 | try:
252 | return self.data_queue.get(timeout=self.timeout)
253 | except queue.Empty:
254 | raise RuntimeError('DataLoader timed out after {} seconds'.format(self.timeout))
255 | else:
256 | return self.data_queue.get()
257 |
258 | def __next__(self):
259 | if self.num_workers == 0: # same-process loading
260 | indices = next(self.sample_iter) # may raise StopIteration
261 | batch = self.collate_fn([self.dataset[i] for i in indices])
262 | if self.pin_memory:
263 | batch = pin_memory_batch(batch)
264 | return batch
265 |
266 | # check if the next sample has already been generated
267 | if self.rcvd_idx in self.reorder_dict:
268 | batch = self.reorder_dict.pop(self.rcvd_idx)
269 | return self._process_next_batch(batch)
270 |
271 | if self.batches_outstanding == 0:
272 | self._shutdown_workers()
273 | raise StopIteration
274 |
275 | while True:
276 | assert (not self.shutdown and self.batches_outstanding > 0)
277 | idx, batch = self._get_batch()
278 | self.batches_outstanding -= 1
279 | if idx != self.rcvd_idx:
280 | # store out-of-order samples
281 | self.reorder_dict[idx] = batch
282 | continue
283 | return self._process_next_batch(batch)
284 |
285 | next = __next__ # Python 2 compatibility
286 |
287 | def __iter__(self):
288 | return self
289 |
290 | def _put_indices(self):
291 | assert self.batches_outstanding < 2 * self.num_workers
292 | indices = next(self.sample_iter, None)
293 | if indices is None:
294 | return
295 | self.index_queue.put((self.send_idx, indices))
296 | self.batches_outstanding += 1
297 | self.send_idx += 1
298 |
299 | def _process_next_batch(self, batch):
300 | self.rcvd_idx += 1
301 | self._put_indices()
302 | if isinstance(batch, ExceptionWrapper):
303 | raise batch.exc_type(batch.exc_msg)
304 | return batch
305 |
306 | def __getstate__(self):
307 | # TODO: add limited pickling support for sharing an iterator
308 | # across multiple threads for HOGWILD.
309 | # Probably the best way to do this is by moving the sample pushing
310 | # to a separate thread and then just sharing the data queue
311 | # but signalling the end is tricky without a non-blocking API
312 | raise NotImplementedError("DataLoaderIterator cannot be pickled")
313 |
314 | def _shutdown_workers(self):
315 | try:
316 | if not self.shutdown:
317 | self.shutdown = True
318 | self.done_event.set()
319 | # if worker_manager_thread is waiting to put
320 | while not self.data_queue.empty():
321 | self.data_queue.get()
322 | for _ in self.workers:
323 | self.index_queue.put(None)
324 | # done_event should be sufficient to exit worker_manager_thread,
325 | # but be safe here and put another None
326 | self.worker_result_queue.put(None)
327 | finally:
328 | # removes pids no matter what
329 | if self.worker_pids_set:
330 | _remove_worker_pids(id(self))
331 | self.worker_pids_set = False
332 |
333 | def __del__(self):
334 | if self.num_workers > 0:
335 | self._shutdown_workers()
336 |
337 |
338 | class DataLoader(object):
339 | """
340 | Data loader. Combines a dataset and a sampler, and provides
341 | single- or multi-process iterators over the dataset.
342 |
343 | Arguments:
344 | dataset (Dataset): dataset from which to load the data.
345 | batch_size (int, optional): how many samples per batch to load
346 | (default: 1).
347 | shuffle (bool, optional): set to ``True`` to have the data reshuffled
348 | at every epoch (default: False).
349 | sampler (Sampler, optional): defines the strategy to draw samples from
350 | the dataset. If specified, ``shuffle`` must be False.
351 | batch_sampler (Sampler, optional): like sampler, but returns a batch of
352 | indices at a time. Mutually exclusive with batch_size, shuffle,
353 | sampler, and drop_last.
354 | num_workers (int, optional): how many subprocesses to use for data
355 | loading. 0 means that the data will be loaded in the main process.
356 | (default: 0)
357 | collate_fn (callable, optional): merges a list of samples to form a mini-batch.
358 | pin_memory (bool, optional): If ``True``, the data loader will copy tensors
359 | into CUDA pinned memory before returning them.
360 | drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
361 | if the dataset size is not divisible by the batch size. If ``False`` and
362 | the size of dataset is not divisible by the batch size, then the last batch
363 | will be smaller. (default: False)
364 | timeout (numeric, optional): if positive, the timeout value for collecting a batch
365 | from workers. Should always be non-negative. (default: 0)
366 | worker_init_fn (callable, optional): If not None, this will be called on each
367 | worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
368 | input, after seeding and before data loading. (default: None)
369 |
370 | .. note:: By default, each worker will have its PyTorch seed set to
371 | ``base_seed + worker_id``, where ``base_seed`` is a long generated
372 | by main process using its RNG. You may use ``torch.initial_seed()`` to access
373 | this value in :attr:`worker_init_fn`, which can be used to set other seeds
374 | (e.g. NumPy) before data loading.
375 |
376 | .. warning:: If ``spawn'' start method is used, :attr:`worker_init_fn` cannot be an
377 | unpicklable object, e.g., a lambda function.
378 | """
379 |
380 | def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
381 | num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False,
382 | timeout=0, worker_init_fn=None):
383 | self.dataset = dataset
384 | self.batch_size = batch_size
385 | self.num_workers = num_workers
386 | self.collate_fn = collate_fn
387 | self.pin_memory = pin_memory
388 | self.drop_last = drop_last
389 | self.timeout = timeout
390 | self.worker_init_fn = worker_init_fn
391 |
392 | if timeout < 0:
393 | raise ValueError('timeout option should be non-negative')
394 |
395 | if batch_sampler is not None:
396 | if batch_size > 1 or shuffle or sampler is not None or drop_last:
397 | raise ValueError('batch_sampler is mutually exclusive with '
398 | 'batch_size, shuffle, sampler, and drop_last')
399 |
400 | if sampler is not None and shuffle:
401 | raise ValueError('sampler is mutually exclusive with shuffle')
402 |
403 | if self.num_workers < 0:
404 | raise ValueError('num_workers cannot be negative; '
405 | 'use num_workers=0 to disable multiprocessing.')
406 |
407 | if batch_sampler is None:
408 | if sampler is None:
409 | if shuffle:
410 | sampler = RandomSampler(dataset)
411 | else:
412 | sampler = SequentialSampler(dataset)
413 | batch_sampler = BatchSampler(sampler, batch_size, drop_last)
414 |
415 | self.sampler = sampler
416 | self.batch_sampler = batch_sampler
417 |
418 | def __iter__(self):
419 | return DataLoaderIter(self)
420 |
421 | def __len__(self):
422 | return len(self.batch_sampler)
423 |
--------------------------------------------------------------------------------
/Segmentation/lib/utils/data/dataset.py:
--------------------------------------------------------------------------------
1 | import bisect
2 | import warnings
3 |
4 | from torch._utils import _accumulate
5 | from torch import randperm
6 |
7 |
8 | class Dataset(object):
9 | """An abstract class representing a Dataset.
10 |
11 | All other datasets should subclass it. All subclasses should override
12 | ``__len__``, that provides the size of the dataset, and ``__getitem__``,
13 | supporting integer indexing in range from 0 to len(self) exclusive.
14 | """
15 |
16 | def __getitem__(self, index):
17 | raise NotImplementedError
18 |
19 | def __len__(self):
20 | raise NotImplementedError
21 |
22 | def __add__(self, other):
23 | return ConcatDataset([self, other])
24 |
25 |
26 | class TensorDataset(Dataset):
27 | """Dataset wrapping data and target tensors.
28 |
29 | Each sample will be retrieved by indexing both tensors along the first
30 | dimension.
31 |
32 | Arguments:
33 | data_tensor (Tensor): contains sample data.
34 | target_tensor (Tensor): contains sample targets (labels).
35 | """
36 |
37 | def __init__(self, data_tensor, target_tensor):
38 | assert data_tensor.size(0) == target_tensor.size(0)
39 | self.data_tensor = data_tensor
40 | self.target_tensor = target_tensor
41 |
42 | def __getitem__(self, index):
43 | return self.data_tensor[index], self.target_tensor[index]
44 |
45 | def __len__(self):
46 | return self.data_tensor.size(0)
47 |
48 |
49 | class ConcatDataset(Dataset):
50 | """
51 | Dataset to concatenate multiple datasets.
52 | Purpose: useful to assemble different existing datasets, possibly
53 | large-scale datasets as the concatenation operation is done in an
54 | on-the-fly manner.
55 |
56 | Arguments:
57 | datasets (iterable): List of datasets to be concatenated
58 | """
59 |
60 | @staticmethod
61 | def cumsum(sequence):
62 | r, s = [], 0
63 | for e in sequence:
64 | l = len(e)
65 | r.append(l + s)
66 | s += l
67 | return r
68 |
69 | def __init__(self, datasets):
70 | super(ConcatDataset, self).__init__()
71 | assert len(datasets) > 0, 'datasets should not be an empty iterable'
72 | self.datasets = list(datasets)
73 | self.cumulative_sizes = self.cumsum(self.datasets)
74 |
75 | def __len__(self):
76 | return self.cumulative_sizes[-1]
77 |
78 | def __getitem__(self, idx):
79 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
80 | if dataset_idx == 0:
81 | sample_idx = idx
82 | else:
83 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
84 | return self.datasets[dataset_idx][sample_idx]
85 |
86 | @property
87 | def cummulative_sizes(self):
88 | warnings.warn("cummulative_sizes attribute is renamed to "
89 | "cumulative_sizes", DeprecationWarning, stacklevel=2)
90 | return self.cumulative_sizes
91 |
92 |
93 | class Subset(Dataset):
94 | def __init__(self, dataset, indices):
95 | self.dataset = dataset
96 | self.indices = indices
97 |
98 | def __getitem__(self, idx):
99 | return self.dataset[self.indices[idx]]
100 |
101 | def __len__(self):
102 | return len(self.indices)
103 |
104 |
105 | def random_split(dataset, lengths):
106 | """
107 | Randomly split a dataset into non-overlapping new datasets of given lengths
108 | ds
109 |
110 | Arguments:
111 | dataset (Dataset): Dataset to be split
112 | lengths (iterable): lengths of splits to be produced
113 | """
114 | if sum(lengths) != len(dataset):
115 | raise ValueError("Sum of input lengths does not equal the length of the input dataset!")
116 |
117 | indices = randperm(sum(lengths))
118 | return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)]
119 |
--------------------------------------------------------------------------------
/Segmentation/lib/utils/data/distributed.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from .sampler import Sampler
4 | from torch.distributed import get_world_size, get_rank
5 |
6 |
7 | class DistributedSampler(Sampler):
8 | """Sampler that restricts data loading to a subset of the dataset.
9 |
10 | It is especially useful in conjunction with
11 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
12 | process can pass a DistributedSampler instance as a DataLoader sampler,
13 | and load a subset of the original dataset that is exclusive to it.
14 |
15 | .. note::
16 | Dataset is assumed to be of constant size.
17 |
18 | Arguments:
19 | dataset: Dataset used for sampling.
20 | num_replicas (optional): Number of processes participating in
21 | distributed training.
22 | rank (optional): Rank of the current process within num_replicas.
23 | """
24 |
25 | def __init__(self, dataset, num_replicas=None, rank=None):
26 | if num_replicas is None:
27 | num_replicas = get_world_size()
28 | if rank is None:
29 | rank = get_rank()
30 | self.dataset = dataset
31 | self.num_replicas = num_replicas
32 | self.rank = rank
33 | self.epoch = 0
34 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
35 | self.total_size = self.num_samples * self.num_replicas
36 |
37 | def __iter__(self):
38 | # deterministically shuffle based on epoch
39 | g = torch.Generator()
40 | g.manual_seed(self.epoch)
41 | indices = list(torch.randperm(len(self.dataset), generator=g))
42 |
43 | # add extra samples to make it evenly divisible
44 | indices += indices[:(self.total_size - len(indices))]
45 | assert len(indices) == self.total_size
46 |
47 | # subsample
48 | offset = self.num_samples * self.rank
49 | indices = indices[offset:offset + self.num_samples]
50 | assert len(indices) == self.num_samples
51 |
52 | return iter(indices)
53 |
54 | def __len__(self):
55 | return self.num_samples
56 |
57 | def set_epoch(self, epoch):
58 | self.epoch = epoch
59 |
--------------------------------------------------------------------------------
/Segmentation/lib/utils/data/sampler.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class Sampler(object):
5 | """Base class for all Samplers.
6 |
7 | Every Sampler subclass has to provide an __iter__ method, providing a way
8 | to iterate over indices of dataset elements, and a __len__ method that
9 | returns the length of the returned iterators.
10 | """
11 |
12 | def __init__(self, data_source):
13 | pass
14 |
15 | def __iter__(self):
16 | raise NotImplementedError
17 |
18 | def __len__(self):
19 | raise NotImplementedError
20 |
21 |
22 | class SequentialSampler(Sampler):
23 | """Samples elements sequentially, always in the same order.
24 |
25 | Arguments:
26 | data_source (Dataset): dataset to sample from
27 | """
28 |
29 | def __init__(self, data_source):
30 | self.data_source = data_source
31 |
32 | def __iter__(self):
33 | return iter(range(len(self.data_source)))
34 |
35 | def __len__(self):
36 | return len(self.data_source)
37 |
38 |
39 | class RandomSampler(Sampler):
40 | """Samples elements randomly, without replacement.
41 |
42 | Arguments:
43 | data_source (Dataset): dataset to sample from
44 | """
45 |
46 | def __init__(self, data_source):
47 | self.data_source = data_source
48 |
49 | def __iter__(self):
50 | return iter(torch.randperm(len(self.data_source)).long())
51 |
52 | def __len__(self):
53 | return len(self.data_source)
54 |
55 |
56 | class SubsetRandomSampler(Sampler):
57 | """Samples elements randomly from a given list of indices, without replacement.
58 |
59 | Arguments:
60 | indices (list): a list of indices
61 | """
62 |
63 | def __init__(self, indices):
64 | self.indices = indices
65 |
66 | def __iter__(self):
67 | return (self.indices[i] for i in torch.randperm(len(self.indices)))
68 |
69 | def __len__(self):
70 | return len(self.indices)
71 |
72 |
73 | class WeightedRandomSampler(Sampler):
74 | """Samples elements from [0,..,len(weights)-1] with given probabilities (weights).
75 |
76 | Arguments:
77 | weights (list) : a list of weights, not necessary summing up to one
78 | num_samples (int): number of samples to draw
79 | replacement (bool): if ``True``, samples are drawn with replacement.
80 | If not, they are drawn without replacement, which means that when a
81 | sample index is drawn for a row, it cannot be drawn again for that row.
82 | """
83 |
84 | def __init__(self, weights, num_samples, replacement=True):
85 | self.weights = torch.DoubleTensor(weights)
86 | self.num_samples = num_samples
87 | self.replacement = replacement
88 |
89 | def __iter__(self):
90 | return iter(torch.multinomial(self.weights, self.num_samples, self.replacement))
91 |
92 | def __len__(self):
93 | return self.num_samples
94 |
95 |
96 | class BatchSampler(object):
97 | """Wraps another sampler to yield a mini-batch of indices.
98 |
99 | Args:
100 | sampler (Sampler): Base sampler.
101 | batch_size (int): Size of mini-batch.
102 | drop_last (bool): If ``True``, the sampler will drop the last batch if
103 | its size would be less than ``batch_size``
104 |
105 | Example:
106 | >>> list(BatchSampler(range(10), batch_size=3, drop_last=False))
107 | [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
108 | >>> list(BatchSampler(range(10), batch_size=3, drop_last=True))
109 | [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
110 | """
111 |
112 | def __init__(self, sampler, batch_size, drop_last):
113 | self.sampler = sampler
114 | self.batch_size = batch_size
115 | self.drop_last = drop_last
116 |
117 | def __iter__(self):
118 | batch = []
119 | for idx in self.sampler:
120 | batch.append(idx)
121 | if len(batch) == self.batch_size:
122 | yield batch
123 | batch = []
124 | if len(batch) > 0 and not self.drop_last:
125 | yield batch
126 |
127 | def __len__(self):
128 | if self.drop_last:
129 | return len(self.sampler) // self.batch_size
130 | else:
131 | return (len(self.sampler) + self.batch_size - 1) // self.batch_size
132 |
--------------------------------------------------------------------------------
/Segmentation/lib/utils/th.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Variable
3 | import numpy as np
4 | import collections
5 |
6 | __all__ = ['as_variable', 'as_numpy', 'mark_volatile']
7 |
8 | def as_variable(obj):
9 | if isinstance(obj, Variable):
10 | return obj
11 | if isinstance(obj, collections.Sequence):
12 | return [as_variable(v) for v in obj]
13 | elif isinstance(obj, collections.Mapping):
14 | return {k: as_variable(v) for k, v in obj.items()}
15 | else:
16 | return Variable(obj)
17 |
18 | def as_numpy(obj):
19 | if isinstance(obj, collections.Sequence):
20 | return [as_numpy(v) for v in obj]
21 | elif isinstance(obj, collections.Mapping):
22 | return {k: as_numpy(v) for k, v in obj.items()}
23 | elif isinstance(obj, Variable):
24 | return obj.data.cpu().numpy()
25 | elif torch.is_tensor(obj):
26 | return obj.cpu().numpy()
27 | else:
28 | return np.array(obj)
29 |
30 | def mark_volatile(obj):
31 | if torch.is_tensor(obj):
32 | obj = Variable(obj)
33 | if isinstance(obj, Variable):
34 | obj.no_grad = True
35 | return obj
36 | elif isinstance(obj, collections.Mapping):
37 | return {k: mark_volatile(o) for k, o in obj.items()}
38 | elif isinstance(obj, collections.Sequence):
39 | return [mark_volatile(o) for o in obj]
40 | else:
41 | return obj
42 |
--------------------------------------------------------------------------------
/Segmentation/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .models import ModelBuilder, SegmentationModule
2 |
--------------------------------------------------------------------------------
/Segmentation/models/mobilenet.py:
--------------------------------------------------------------------------------
1 | """
2 | This MobileNetV2 implementation is modified from the following repository:
3 | https://github.com/tonylins/pytorch-mobilenet-v2
4 | """
5 |
6 | import os
7 | import sys
8 | import torch
9 | import torch.nn as nn
10 | import math
11 | from lib.nn import SynchronizedBatchNorm2d
12 |
13 | try:
14 | from urllib import urlretrieve
15 | except ImportError:
16 | from urllib.request import urlretrieve
17 |
18 |
19 | __all__ = ['mobilenetv2']
20 |
21 |
22 | model_urls = {
23 | 'mobilenetv2': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/mobilenet_v2.pth.tar',
24 | }
25 |
26 |
27 | def conv_bn(inp, oup, stride):
28 | return nn.Sequential(
29 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
30 | SynchronizedBatchNorm2d(oup),
31 | nn.ReLU6(inplace=True)
32 | )
33 |
34 |
35 | def conv_1x1_bn(inp, oup):
36 | return nn.Sequential(
37 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
38 | SynchronizedBatchNorm2d(oup),
39 | nn.ReLU6(inplace=True)
40 | )
41 |
42 |
43 | class InvertedResidual(nn.Module):
44 | def __init__(self, inp, oup, stride, expand_ratio):
45 | super(InvertedResidual, self).__init__()
46 | self.stride = stride
47 | assert stride in [1, 2]
48 |
49 | hidden_dim = round(inp * expand_ratio)
50 | self.use_res_connect = self.stride == 1 and inp == oup
51 |
52 | if expand_ratio == 1:
53 | self.conv = nn.Sequential(
54 | # dw
55 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
56 | SynchronizedBatchNorm2d(hidden_dim),
57 | nn.ReLU6(inplace=True),
58 | # pw-linear
59 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
60 | SynchronizedBatchNorm2d(oup),
61 | )
62 | else:
63 | self.conv = nn.Sequential(
64 | # pw
65 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
66 | SynchronizedBatchNorm2d(hidden_dim),
67 | nn.ReLU6(inplace=True),
68 | # dw
69 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
70 | SynchronizedBatchNorm2d(hidden_dim),
71 | nn.ReLU6(inplace=True),
72 | # pw-linear
73 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
74 | SynchronizedBatchNorm2d(oup),
75 | )
76 |
77 | def forward(self, x):
78 | if self.use_res_connect:
79 | return x + self.conv(x)
80 | else:
81 | return self.conv(x)
82 |
83 |
84 | class MobileNetV2(nn.Module):
85 | def __init__(self, n_class=1000, input_size=224, width_mult=1.):
86 | super(MobileNetV2, self).__init__()
87 | block = InvertedResidual
88 | input_channel = 32
89 | last_channel = 1280
90 | interverted_residual_setting = [
91 | # t, c, n, s
92 | [1, 16, 1, 1],
93 | [6, 24, 2, 2],
94 | [6, 32, 3, 2],
95 | [6, 64, 4, 2],
96 | [6, 96, 3, 1],
97 | [6, 160, 3, 2],
98 | [6, 320, 1, 1],
99 | ]
100 |
101 | # building first layer
102 | assert input_size % 32 == 0
103 | input_channel = int(input_channel * width_mult)
104 | self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel
105 | self.features = [conv_bn(3, input_channel, 2)]
106 | # building inverted residual blocks
107 | for t, c, n, s in interverted_residual_setting:
108 | output_channel = int(c * width_mult)
109 | for i in range(n):
110 | if i == 0:
111 | self.features.append(block(input_channel, output_channel, s, expand_ratio=t))
112 | else:
113 | self.features.append(block(input_channel, output_channel, 1, expand_ratio=t))
114 | input_channel = output_channel
115 | # building last several layers
116 | self.features.append(conv_1x1_bn(input_channel, self.last_channel))
117 | # make it nn.Sequential
118 | self.features = nn.Sequential(*self.features)
119 |
120 | # building classifier
121 | self.classifier = nn.Sequential(
122 | nn.Dropout(0.2),
123 | nn.Linear(self.last_channel, n_class),
124 | )
125 |
126 | self._initialize_weights()
127 |
128 | def forward(self, x):
129 | x = self.features(x)
130 | x = x.mean(3).mean(2)
131 | x = self.classifier(x)
132 | return x
133 |
134 | def _initialize_weights(self):
135 | for m in self.modules():
136 | if isinstance(m, nn.Conv2d):
137 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
138 | m.weight.data.normal_(0, math.sqrt(2. / n))
139 | if m.bias is not None:
140 | m.bias.data.zero_()
141 | elif isinstance(m, SynchronizedBatchNorm2d):
142 | m.weight.data.fill_(1)
143 | m.bias.data.zero_()
144 | elif isinstance(m, nn.Linear):
145 | n = m.weight.size(1)
146 | m.weight.data.normal_(0, 0.01)
147 | m.bias.data.zero_()
148 |
149 |
150 | def mobilenetv2(pretrained=False, **kwargs):
151 | """Constructs a MobileNet_V2 model.
152 |
153 | Args:
154 | pretrained (bool): If True, returns a model pre-trained on ImageNet
155 | """
156 | model = MobileNetV2(n_class=1000, **kwargs)
157 | if pretrained:
158 | model.load_state_dict(load_url(model_urls['mobilenetv2']), strict=False)
159 | return model
160 |
161 |
162 | def load_url(url, model_dir='./pretrained', map_location=None):
163 | if not os.path.exists(model_dir):
164 | os.makedirs(model_dir)
165 | filename = url.split('/')[-1]
166 | cached_file = os.path.join(model_dir, filename)
167 | if not os.path.exists(cached_file):
168 | sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
169 | urlretrieve(url, cached_file)
170 | return torch.load(cached_file, map_location=map_location)
171 |
172 |
--------------------------------------------------------------------------------
/Segmentation/models/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torchvision
4 | from . import resnet, resnext, mobilenet
5 | from lib.nn import SynchronizedBatchNorm2d
6 |
7 |
8 | class SegmentationModuleBase(nn.Module):
9 | def __init__(self):
10 | super(SegmentationModuleBase, self).__init__()
11 |
12 | def pixel_acc(self, pred, label):
13 | _, preds = torch.max(pred, dim=1)
14 | valid = (label >= 0).long()
15 | acc_sum = torch.sum(valid * (preds == label).long())
16 | pixel_sum = torch.sum(valid)
17 | acc = acc_sum.float() / (pixel_sum.float() + 1e-10)
18 | return acc
19 |
20 |
21 | class SegmentationModule(SegmentationModuleBase):
22 | def __init__(self, net_enc, net_dec, crit, deep_sup_scale=None):
23 | super(SegmentationModule, self).__init__()
24 | self.encoder = net_enc
25 | self.decoder = net_dec
26 | self.crit = crit
27 | self.deep_sup_scale = deep_sup_scale
28 |
29 | def forward(self, feed_dict, segSize=None):
30 | # training
31 | if segSize is None:
32 | if self.deep_sup_scale is not None: # use deep supervision technique
33 | (pred, pred_deepsup) = self.decoder(self.encoder(feed_dict['img_data'], return_feature_maps=True))
34 | else:
35 | pred = self.decoder(self.encoder(feed_dict['img_data'], return_feature_maps=True))
36 |
37 | loss = self.crit(pred, feed_dict['seg_label'])
38 | if self.deep_sup_scale is not None:
39 | loss_deepsup = self.crit(pred_deepsup, feed_dict['seg_label'])
40 | loss = loss + loss_deepsup * self.deep_sup_scale
41 |
42 | acc = self.pixel_acc(pred, feed_dict['seg_label'])
43 | return loss, acc
44 | # inference
45 | else:
46 | pred = self.decoder(self.encoder(feed_dict['img_data'], return_feature_maps=True), segSize=segSize)
47 | return pred
48 |
49 |
50 | def conv3x3(in_planes, out_planes, stride=1, has_bias=False):
51 | "3x3 convolution with padding"
52 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
53 | padding=1, bias=has_bias)
54 |
55 |
56 | def conv3x3_bn_relu(in_planes, out_planes, stride=1):
57 | return nn.Sequential(
58 | conv3x3(in_planes, out_planes, stride),
59 | SynchronizedBatchNorm2d(out_planes),
60 | nn.ReLU(inplace=True),
61 | )
62 |
63 |
64 | class ModelBuilder():
65 | # custom weights initialization
66 | def weights_init(self, m):
67 | classname = m.__class__.__name__
68 | if classname.find('Conv') != -1:
69 | nn.init.kaiming_normal_(m.weight.data)
70 | elif classname.find('BatchNorm') != -1:
71 | m.weight.data.fill_(1.)
72 | m.bias.data.fill_(1e-4)
73 | #elif classname.find('Linear') != -1:
74 | # m.weight.data.normal_(0.0, 0.0001)
75 |
76 | def build_encoder(self, arch='resnet50dilated', fc_dim=512, weights=''):
77 | pretrained = True if len(weights) == 0 else False
78 | arch = arch.lower()
79 | if arch == 'mobilenetv2dilated':
80 | orig_mobilenet = mobilenet.__dict__['mobilenetv2'](pretrained=pretrained)
81 | net_encoder = MobileNetV2Dilated(orig_mobilenet, dilate_scale=8)
82 | elif arch == 'resnet18':
83 | orig_resnet = resnet.__dict__['resnet18'](pretrained=pretrained)
84 | net_encoder = Resnet(orig_resnet)
85 | elif arch == 'resnet18dilated':
86 | orig_resnet = resnet.__dict__['resnet18'](pretrained=pretrained)
87 | net_encoder = ResnetDilated(orig_resnet, dilate_scale=8)
88 | elif arch == 'resnet34':
89 | raise NotImplementedError
90 | orig_resnet = resnet.__dict__['resnet34'](pretrained=pretrained)
91 | net_encoder = Resnet(orig_resnet)
92 | elif arch == 'resnet34dilated':
93 | raise NotImplementedError
94 | orig_resnet = resnet.__dict__['resnet34'](pretrained=pretrained)
95 | net_encoder = ResnetDilated(orig_resnet, dilate_scale=8)
96 | elif arch == 'resnet50':
97 | orig_resnet = resnet.__dict__['resnet50'](pretrained=pretrained)
98 | net_encoder = Resnet(orig_resnet)
99 | elif arch == 'resnet50dilated':
100 | orig_resnet = resnet.__dict__['resnet50'](pretrained=pretrained)
101 | net_encoder = ResnetDilated(orig_resnet, dilate_scale=8)
102 | elif arch == 'resnet101':
103 | orig_resnet = resnet.__dict__['resnet101'](pretrained=pretrained)
104 | net_encoder = Resnet(orig_resnet)
105 | elif arch == 'resnet101dilated':
106 | orig_resnet = resnet.__dict__['resnet101'](pretrained=pretrained)
107 | net_encoder = ResnetDilated(orig_resnet, dilate_scale=8)
108 | elif arch == 'resnext101':
109 | orig_resnext = resnext.__dict__['resnext101'](pretrained=pretrained)
110 | net_encoder = Resnet(orig_resnext) # we can still use class Resnet
111 | else:
112 | raise Exception('Architecture undefined!')
113 |
114 | # net_encoder.apply(self.weights_init)
115 | if len(weights) > 0:
116 | print('Loading weights for net_encoder')
117 | net_encoder.load_state_dict(
118 | torch.load(weights, map_location=lambda storage, loc: storage), strict=False)
119 | return net_encoder
120 |
121 | def build_decoder(self, arch='ppm_deepsup',
122 | fc_dim=512, num_class=150,
123 | weights='', use_softmax=False):
124 | arch = arch.lower()
125 | if arch == 'c1_deepsup':
126 | net_decoder = C1DeepSup(
127 | num_class=num_class,
128 | fc_dim=fc_dim,
129 | use_softmax=use_softmax)
130 | elif arch == 'c1':
131 | net_decoder = C1(
132 | num_class=num_class,
133 | fc_dim=fc_dim,
134 | use_softmax=use_softmax)
135 | elif arch == 'ppm':
136 | net_decoder = PPM(
137 | num_class=num_class,
138 | fc_dim=fc_dim,
139 | use_softmax=use_softmax)
140 | elif arch == 'ppm_deepsup':
141 | net_decoder = PPMDeepsup(
142 | num_class=num_class,
143 | fc_dim=fc_dim,
144 | use_softmax=use_softmax)
145 | elif arch == 'upernet_lite':
146 | net_decoder = UPerNet(
147 | num_class=num_class,
148 | fc_dim=fc_dim,
149 | use_softmax=use_softmax,
150 | fpn_dim=256)
151 | elif arch == 'upernet':
152 | net_decoder = UPerNet(
153 | num_class=num_class,
154 | fc_dim=fc_dim,
155 | use_softmax=use_softmax,
156 | fpn_dim=512)
157 | else:
158 | raise Exception('Architecture undefined!')
159 |
160 | net_decoder.apply(self.weights_init)
161 | if len(weights) > 0:
162 | print('Loading weights for net_decoder')
163 | net_decoder.load_state_dict(
164 | torch.load(weights, map_location=lambda storage, loc: storage), strict=False)
165 | return net_decoder
166 |
167 |
168 | class Resnet(nn.Module):
169 | def __init__(self, orig_resnet):
170 | super(Resnet, self).__init__()
171 |
172 | # take pretrained resnet, except AvgPool and FC
173 | self.conv1 = orig_resnet.conv1
174 | self.bn1 = orig_resnet.bn1
175 | self.relu1 = orig_resnet.relu1
176 | self.conv2 = orig_resnet.conv2
177 | self.bn2 = orig_resnet.bn2
178 | self.relu2 = orig_resnet.relu2
179 | self.conv3 = orig_resnet.conv3
180 | self.bn3 = orig_resnet.bn3
181 | self.relu3 = orig_resnet.relu3
182 | self.maxpool = orig_resnet.maxpool
183 | self.layer1 = orig_resnet.layer1
184 | self.layer2 = orig_resnet.layer2
185 | self.layer3 = orig_resnet.layer3
186 | self.layer4 = orig_resnet.layer4
187 |
188 | def forward(self, x, return_feature_maps=False):
189 | conv_out = []
190 |
191 | x = self.relu1(self.bn1(self.conv1(x)))
192 | x = self.relu2(self.bn2(self.conv2(x)))
193 | x = self.relu3(self.bn3(self.conv3(x)))
194 | x = self.maxpool(x)
195 |
196 | x = self.layer1(x); conv_out.append(x);
197 | x = self.layer2(x); conv_out.append(x);
198 | x = self.layer3(x); conv_out.append(x);
199 | x = self.layer4(x); conv_out.append(x);
200 |
201 | if return_feature_maps:
202 | return conv_out
203 | return [x]
204 |
205 |
206 | class ResnetDilated(nn.Module):
207 | def __init__(self, orig_resnet, dilate_scale=8):
208 | super(ResnetDilated, self).__init__()
209 | from functools import partial
210 |
211 | if dilate_scale == 8:
212 | orig_resnet.layer3.apply(
213 | partial(self._nostride_dilate, dilate=2))
214 | orig_resnet.layer4.apply(
215 | partial(self._nostride_dilate, dilate=4))
216 | elif dilate_scale == 16:
217 | orig_resnet.layer4.apply(
218 | partial(self._nostride_dilate, dilate=2))
219 |
220 | # take pretrained resnet, except AvgPool and FC
221 | self.conv1 = orig_resnet.conv1
222 | self.bn1 = orig_resnet.bn1
223 | self.relu1 = orig_resnet.relu1
224 | self.conv2 = orig_resnet.conv2
225 | self.bn2 = orig_resnet.bn2
226 | self.relu2 = orig_resnet.relu2
227 | self.conv3 = orig_resnet.conv3
228 | self.bn3 = orig_resnet.bn3
229 | self.relu3 = orig_resnet.relu3
230 | self.maxpool = orig_resnet.maxpool
231 | self.layer1 = orig_resnet.layer1
232 | self.layer2 = orig_resnet.layer2
233 | self.layer3 = orig_resnet.layer3
234 | self.layer4 = orig_resnet.layer4
235 |
236 | def _nostride_dilate(self, m, dilate):
237 | classname = m.__class__.__name__
238 | if classname.find('Conv') != -1:
239 | # the convolution with stride
240 | if m.stride == (2, 2):
241 | m.stride = (1, 1)
242 | if m.kernel_size == (3, 3):
243 | m.dilation = (dilate//2, dilate//2)
244 | m.padding = (dilate//2, dilate//2)
245 | # other convoluions
246 | else:
247 | if m.kernel_size == (3, 3):
248 | m.dilation = (dilate, dilate)
249 | m.padding = (dilate, dilate)
250 |
251 | def forward(self, x, return_feature_maps=False):
252 | conv_out = []
253 |
254 | x = self.relu1(self.bn1(self.conv1(x)))
255 | x = self.relu2(self.bn2(self.conv2(x)))
256 | x = self.relu3(self.bn3(self.conv3(x)))
257 | x = self.maxpool(x)
258 |
259 | x = self.layer1(x); conv_out.append(x);
260 | x = self.layer2(x); conv_out.append(x);
261 | x = self.layer3(x); conv_out.append(x);
262 | x = self.layer4(x); conv_out.append(x);
263 |
264 | if return_feature_maps:
265 | return conv_out
266 | return [x]
267 |
268 |
269 | class MobileNetV2Dilated(nn.Module):
270 | def __init__(self, orig_net, dilate_scale=8):
271 | super(MobileNetV2Dilated, self).__init__()
272 | from functools import partial
273 |
274 | # take pretrained mobilenet features
275 | self.features = orig_net.features[:-1]
276 |
277 | self.total_idx = len(self.features)
278 | self.down_idx = [2, 4, 7, 14]
279 |
280 | if dilate_scale == 8:
281 | for i in range(self.down_idx[-2], self.down_idx[-1]):
282 | self.features[i].apply(
283 | partial(self._nostride_dilate, dilate=2)
284 | )
285 | for i in range(self.down_idx[-1], self.total_idx):
286 | self.features[i].apply(
287 | partial(self._nostride_dilate, dilate=4)
288 | )
289 | elif dilate_scale == 16:
290 | for i in range(self.down_idx[-1], self.total_idx):
291 | self.features[i].apply(
292 | partial(self._nostride_dilate, dilate=2)
293 | )
294 |
295 | def _nostride_dilate(self, m, dilate):
296 | classname = m.__class__.__name__
297 | if classname.find('Conv') != -1:
298 | # the convolution with stride
299 | if m.stride == (2, 2):
300 | m.stride = (1, 1)
301 | if m.kernel_size == (3, 3):
302 | m.dilation = (dilate//2, dilate//2)
303 | m.padding = (dilate//2, dilate//2)
304 | # other convoluions
305 | else:
306 | if m.kernel_size == (3, 3):
307 | m.dilation = (dilate, dilate)
308 | m.padding = (dilate, dilate)
309 |
310 | def forward(self, x, return_feature_maps=False):
311 | if return_feature_maps:
312 | conv_out = []
313 | for i in range(self.total_idx):
314 | x = self.features[i](x)
315 | if i in self.down_idx:
316 | conv_out.append(x)
317 | conv_out.append(x)
318 | return conv_out
319 |
320 | else:
321 | return [self.features(x)]
322 |
323 |
324 | # last conv, deep supervision
325 | class C1DeepSup(nn.Module):
326 | def __init__(self, num_class=150, fc_dim=2048, use_softmax=False):
327 | super(C1DeepSup, self).__init__()
328 | self.use_softmax = use_softmax
329 |
330 | self.cbr = conv3x3_bn_relu(fc_dim, fc_dim // 4, 1)
331 | self.cbr_deepsup = conv3x3_bn_relu(fc_dim // 2, fc_dim // 4, 1)
332 |
333 | # last conv
334 | self.conv_last = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0)
335 | self.conv_last_deepsup = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0)
336 |
337 | def forward(self, conv_out, segSize=None):
338 | conv5 = conv_out[-1]
339 |
340 | x = self.cbr(conv5)
341 | x = self.conv_last(x)
342 |
343 | if self.use_softmax: # is True during inference
344 | x = nn.functional.interpolate(
345 | x, size=segSize, mode='bilinear', align_corners=False)
346 | x = nn.functional.softmax(x, dim=1)
347 | return x
348 |
349 | # deep sup
350 | conv4 = conv_out[-2]
351 | _ = self.cbr_deepsup(conv4)
352 | _ = self.conv_last_deepsup(_)
353 |
354 | x = nn.functional.log_softmax(x, dim=1)
355 | _ = nn.functional.log_softmax(_, dim=1)
356 |
357 | return (x, _)
358 |
359 |
360 | # last conv
361 | class C1(nn.Module):
362 | def __init__(self, num_class=150, fc_dim=2048, use_softmax=False):
363 | super(C1, self).__init__()
364 | self.use_softmax = use_softmax
365 |
366 | self.cbr = conv3x3_bn_relu(fc_dim, fc_dim // 4, 1)
367 |
368 | # last conv
369 | self.conv_last = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0)
370 |
371 | def forward(self, conv_out, segSize=None):
372 | conv5 = conv_out[-1]
373 | x = self.cbr(conv5)
374 | x = self.conv_last(x)
375 |
376 | if self.use_softmax: # is True during inference
377 | x = nn.functional.interpolate(
378 | x, size=segSize, mode='bilinear', align_corners=False)
379 | x = nn.functional.softmax(x, dim=1)
380 | else:
381 | x = nn.functional.log_softmax(x, dim=1)
382 |
383 | return x
384 |
385 |
386 | # pyramid pooling
387 | class PPM(nn.Module):
388 | def __init__(self, num_class=150, fc_dim=4096,
389 | use_softmax=False, pool_scales=(1, 2, 3, 6)):
390 | super(PPM, self).__init__()
391 | self.use_softmax = use_softmax
392 |
393 | self.ppm = []
394 | for scale in pool_scales:
395 | self.ppm.append(nn.Sequential(
396 | nn.AdaptiveAvgPool2d(scale),
397 | nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False),
398 | SynchronizedBatchNorm2d(512),
399 | nn.ReLU(inplace=True)
400 | ))
401 | self.ppm = nn.ModuleList(self.ppm)
402 |
403 | self.conv_last = nn.Sequential(
404 | nn.Conv2d(fc_dim+len(pool_scales)*512, 512,
405 | kernel_size=3, padding=1, bias=False),
406 | SynchronizedBatchNorm2d(512),
407 | nn.ReLU(inplace=True),
408 | nn.Dropout2d(0.1),
409 | nn.Conv2d(512, num_class, kernel_size=1)
410 | )
411 |
412 | def forward(self, conv_out, segSize=None):
413 | conv5 = conv_out[-1]
414 |
415 | input_size = conv5.size()
416 | ppm_out = [conv5]
417 | for pool_scale in self.ppm:
418 | ppm_out.append(nn.functional.interpolate(
419 | pool_scale(conv5),
420 | (input_size[2], input_size[3]),
421 | mode='bilinear', align_corners=False))
422 | ppm_out = torch.cat(ppm_out, 1)
423 |
424 | x = self.conv_last(ppm_out)
425 |
426 | if self.use_softmax: # is True during inference
427 | x = nn.functional.interpolate(
428 | x, size=segSize, mode='bilinear', align_corners=False)
429 | x = nn.functional.softmax(x, dim=1)
430 | else:
431 | x = nn.functional.log_softmax(x, dim=1)
432 | return x
433 |
434 |
435 | # pyramid pooling, deep supervision
436 | class PPMDeepsup(nn.Module):
437 | def __init__(self, num_class=150, fc_dim=4096,
438 | use_softmax=False, pool_scales=(1, 2, 3, 6)):
439 | super(PPMDeepsup, self).__init__()
440 | self.use_softmax = use_softmax
441 |
442 | self.ppm = []
443 | for scale in pool_scales:
444 | self.ppm.append(nn.Sequential(
445 | nn.AdaptiveAvgPool2d(scale),
446 | nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False),
447 | SynchronizedBatchNorm2d(512),
448 | nn.ReLU(inplace=True)
449 | ))
450 | self.ppm = nn.ModuleList(self.ppm)
451 | self.cbr_deepsup = conv3x3_bn_relu(fc_dim // 2, fc_dim // 4, 1)
452 |
453 | self.conv_last = nn.Sequential(
454 | nn.Conv2d(fc_dim+len(pool_scales)*512, 512,
455 | kernel_size=3, padding=1, bias=False),
456 | SynchronizedBatchNorm2d(512),
457 | nn.ReLU(inplace=True),
458 | nn.Dropout2d(0.1),
459 | nn.Conv2d(512, num_class, kernel_size=1)
460 | )
461 | self.conv_last_deepsup = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0)
462 | self.dropout_deepsup = nn.Dropout2d(0.1)
463 |
464 | def forward(self, conv_out, segSize=None):
465 | conv5 = conv_out[-1]
466 |
467 | input_size = conv5.size()
468 | ppm_out = [conv5]
469 | for pool_scale in self.ppm:
470 | ppm_out.append(nn.functional.interpolate(
471 | pool_scale(conv5),
472 | (input_size[2], input_size[3]),
473 | mode='bilinear', align_corners=False))
474 | ppm_out = torch.cat(ppm_out, 1)
475 |
476 | x = self.conv_last(ppm_out)
477 |
478 | if self.use_softmax: # is True during inference
479 | x = nn.functional.interpolate(
480 | x, size=segSize, mode='bilinear', align_corners=False)
481 | x = nn.functional.softmax(x, dim=1)
482 | return x
483 |
484 | # deep sup
485 | conv4 = conv_out[-2]
486 | _ = self.cbr_deepsup(conv4)
487 | _ = self.dropout_deepsup(_)
488 | _ = self.conv_last_deepsup(_)
489 |
490 | x = nn.functional.log_softmax(x, dim=1)
491 | _ = nn.functional.log_softmax(_, dim=1)
492 |
493 | return (x, _)
494 |
495 |
496 | # upernet
497 | class UPerNet(nn.Module):
498 | def __init__(self, num_class=150, fc_dim=4096,
499 | use_softmax=False, pool_scales=(1, 2, 3, 6),
500 | fpn_inplanes=(256, 512, 1024, 2048), fpn_dim=256):
501 | super(UPerNet, self).__init__()
502 | self.use_softmax = use_softmax
503 |
504 | # PPM Module
505 | self.ppm_pooling = []
506 | self.ppm_conv = []
507 |
508 | for scale in pool_scales:
509 | self.ppm_pooling.append(nn.AdaptiveAvgPool2d(scale))
510 | self.ppm_conv.append(nn.Sequential(
511 | nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False),
512 | SynchronizedBatchNorm2d(512),
513 | nn.ReLU(inplace=True)
514 | ))
515 | self.ppm_pooling = nn.ModuleList(self.ppm_pooling)
516 | self.ppm_conv = nn.ModuleList(self.ppm_conv)
517 | self.ppm_last_conv = conv3x3_bn_relu(fc_dim + len(pool_scales)*512, fpn_dim, 1)
518 |
519 | # FPN Module
520 | self.fpn_in = []
521 | for fpn_inplane in fpn_inplanes[:-1]: # skip the top layer
522 | self.fpn_in.append(nn.Sequential(
523 | nn.Conv2d(fpn_inplane, fpn_dim, kernel_size=1, bias=False),
524 | SynchronizedBatchNorm2d(fpn_dim),
525 | nn.ReLU(inplace=True)
526 | ))
527 | self.fpn_in = nn.ModuleList(self.fpn_in)
528 |
529 | self.fpn_out = []
530 | for i in range(len(fpn_inplanes) - 1): # skip the top layer
531 | self.fpn_out.append(nn.Sequential(
532 | conv3x3_bn_relu(fpn_dim, fpn_dim, 1),
533 | ))
534 | self.fpn_out = nn.ModuleList(self.fpn_out)
535 |
536 | self.conv_last = nn.Sequential(
537 | conv3x3_bn_relu(len(fpn_inplanes) * fpn_dim, fpn_dim, 1),
538 | nn.Conv2d(fpn_dim, num_class, kernel_size=1)
539 | )
540 |
541 | def forward(self, conv_out, segSize=None):
542 | conv5 = conv_out[-1]
543 |
544 | input_size = conv5.size()
545 | ppm_out = [conv5]
546 | for pool_scale, pool_conv in zip(self.ppm_pooling, self.ppm_conv):
547 | ppm_out.append(pool_conv(nn.functional.interpolate(
548 | pool_scale(conv5),
549 | (input_size[2], input_size[3]),
550 | mode='bilinear', align_corners=False)))
551 | ppm_out = torch.cat(ppm_out, 1)
552 | f = self.ppm_last_conv(ppm_out)
553 |
554 | fpn_feature_list = [f]
555 | for i in reversed(range(len(conv_out) - 1)):
556 | conv_x = conv_out[i]
557 | conv_x = self.fpn_in[i](conv_x) # lateral branch
558 |
559 | f = nn.functional.interpolate(
560 | f, size=conv_x.size()[2:], mode='bilinear', align_corners=False) # top-down branch
561 | f = conv_x + f
562 |
563 | fpn_feature_list.append(self.fpn_out[i](f))
564 |
565 | fpn_feature_list.reverse() # [P2 - P5]
566 | output_size = fpn_feature_list[0].size()[2:]
567 | fusion_list = [fpn_feature_list[0]]
568 | for i in range(1, len(fpn_feature_list)):
569 | fusion_list.append(nn.functional.interpolate(
570 | fpn_feature_list[i],
571 | output_size,
572 | mode='bilinear', align_corners=False))
573 | fusion_out = torch.cat(fusion_list, 1)
574 | x = self.conv_last(fusion_out)
575 |
576 | if self.use_softmax: # is True during inference
577 | x = nn.functional.interpolate(
578 | x, size=segSize, mode='bilinear', align_corners=False)
579 | x = nn.functional.softmax(x, dim=1)
580 | return x
581 |
582 | x = nn.functional.log_softmax(x, dim=1)
583 |
584 | return x
585 |
--------------------------------------------------------------------------------
/Segmentation/models/resnet.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import torch
4 | import torch.nn as nn
5 | import math
6 | from lib.nn import SynchronizedBatchNorm2d
7 |
8 | try:
9 | from urllib import urlretrieve
10 | except ImportError:
11 | from urllib.request import urlretrieve
12 |
13 |
14 | __all__ = ['ResNet', 'resnet18', 'resnet50', 'resnet101'] # resnet101 is coming soon!
15 |
16 |
17 | model_urls = {
18 | 'resnet18': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet18-imagenet.pth',
19 | 'resnet50': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet50-imagenet.pth',
20 | 'resnet101': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet101-imagenet.pth'
21 | }
22 |
23 |
24 | def conv3x3(in_planes, out_planes, stride=1):
25 | "3x3 convolution with padding"
26 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
27 | padding=1, bias=False)
28 |
29 |
30 | class BasicBlock(nn.Module):
31 | expansion = 1
32 |
33 | def __init__(self, inplanes, planes, stride=1, downsample=None):
34 | super(BasicBlock, self).__init__()
35 | self.conv1 = conv3x3(inplanes, planes, stride)
36 | self.bn1 = SynchronizedBatchNorm2d(planes)
37 | self.relu = nn.ReLU(inplace=True)
38 | self.conv2 = conv3x3(planes, planes)
39 | self.bn2 = SynchronizedBatchNorm2d(planes)
40 | self.downsample = downsample
41 | self.stride = stride
42 |
43 | def forward(self, x):
44 | residual = x
45 |
46 | out = self.conv1(x)
47 | out = self.bn1(out)
48 | out = self.relu(out)
49 |
50 | out = self.conv2(out)
51 | out = self.bn2(out)
52 |
53 | if self.downsample is not None:
54 | residual = self.downsample(x)
55 |
56 | out += residual
57 | out = self.relu(out)
58 |
59 | return out
60 |
61 |
62 | class Bottleneck(nn.Module):
63 | expansion = 4
64 |
65 | def __init__(self, inplanes, planes, stride=1, downsample=None):
66 | super(Bottleneck, self).__init__()
67 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
68 | self.bn1 = SynchronizedBatchNorm2d(planes)
69 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
70 | padding=1, bias=False)
71 | self.bn2 = SynchronizedBatchNorm2d(planes)
72 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
73 | self.bn3 = SynchronizedBatchNorm2d(planes * 4)
74 | self.relu = nn.ReLU(inplace=True)
75 | self.downsample = downsample
76 | self.stride = stride
77 |
78 | def forward(self, x):
79 | residual = x
80 |
81 | out = self.conv1(x)
82 | out = self.bn1(out)
83 | out = self.relu(out)
84 |
85 | out = self.conv2(out)
86 | out = self.bn2(out)
87 | out = self.relu(out)
88 |
89 | out = self.conv3(out)
90 | out = self.bn3(out)
91 |
92 | if self.downsample is not None:
93 | residual = self.downsample(x)
94 |
95 | out += residual
96 | out = self.relu(out)
97 |
98 | return out
99 |
100 |
101 | class ResNet(nn.Module):
102 |
103 | def __init__(self, block, layers, num_classes=1000):
104 | self.inplanes = 128
105 | super(ResNet, self).__init__()
106 | self.conv1 = conv3x3(3, 64, stride=2)
107 | self.bn1 = SynchronizedBatchNorm2d(64)
108 | self.relu1 = nn.ReLU(inplace=True)
109 | self.conv2 = conv3x3(64, 64)
110 | self.bn2 = SynchronizedBatchNorm2d(64)
111 | self.relu2 = nn.ReLU(inplace=True)
112 | self.conv3 = conv3x3(64, 128)
113 | self.bn3 = SynchronizedBatchNorm2d(128)
114 | self.relu3 = nn.ReLU(inplace=True)
115 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
116 |
117 | self.layer1 = self._make_layer(block, 64, layers[0])
118 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
119 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
120 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
121 | self.avgpool = nn.AvgPool2d(7, stride=1)
122 | self.fc = nn.Linear(512 * block.expansion, num_classes)
123 |
124 | for m in self.modules():
125 | if isinstance(m, nn.Conv2d):
126 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
127 | m.weight.data.normal_(0, math.sqrt(2. / n))
128 | elif isinstance(m, SynchronizedBatchNorm2d):
129 | m.weight.data.fill_(1)
130 | m.bias.data.zero_()
131 |
132 | def _make_layer(self, block, planes, blocks, stride=1):
133 | downsample = None
134 | if stride != 1 or self.inplanes != planes * block.expansion:
135 | downsample = nn.Sequential(
136 | nn.Conv2d(self.inplanes, planes * block.expansion,
137 | kernel_size=1, stride=stride, bias=False),
138 | SynchronizedBatchNorm2d(planes * block.expansion),
139 | )
140 |
141 | layers = []
142 | layers.append(block(self.inplanes, planes, stride, downsample))
143 | self.inplanes = planes * block.expansion
144 | for i in range(1, blocks):
145 | layers.append(block(self.inplanes, planes))
146 |
147 | return nn.Sequential(*layers)
148 |
149 | def forward(self, x):
150 | x = self.relu1(self.bn1(self.conv1(x)))
151 | x = self.relu2(self.bn2(self.conv2(x)))
152 | x = self.relu3(self.bn3(self.conv3(x)))
153 | x = self.maxpool(x)
154 |
155 | x = self.layer1(x)
156 | x = self.layer2(x)
157 | x = self.layer3(x)
158 | x = self.layer4(x)
159 |
160 | x = self.avgpool(x)
161 | x = x.view(x.size(0), -1)
162 | x = self.fc(x)
163 |
164 | return x
165 |
166 | def resnet18(pretrained=False, **kwargs):
167 | """Constructs a ResNet-18 model.
168 |
169 | Args:
170 | pretrained (bool): If True, returns a model pre-trained on ImageNet
171 | """
172 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
173 | if pretrained:
174 | model.load_state_dict(load_url(model_urls['resnet18']))
175 | return model
176 |
177 | '''
178 | def resnet34(pretrained=False, **kwargs):
179 | """Constructs a ResNet-34 model.
180 |
181 | Args:
182 | pretrained (bool): If True, returns a model pre-trained on ImageNet
183 | """
184 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
185 | if pretrained:
186 | model.load_state_dict(load_url(model_urls['resnet34']))
187 | return model
188 | '''
189 |
190 | def resnet50(pretrained=False, **kwargs):
191 | """Constructs a ResNet-50 model.
192 |
193 | Args:
194 | pretrained (bool): If True, returns a model pre-trained on ImageNet
195 | """
196 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
197 | if pretrained:
198 | model.load_state_dict(load_url(model_urls['resnet50']), strict=False)
199 | return model
200 |
201 |
202 | def resnet101(pretrained=False, **kwargs):
203 | """Constructs a ResNet-101 model.
204 |
205 | Args:
206 | pretrained (bool): If True, returns a model pre-trained on ImageNet
207 | """
208 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
209 | if pretrained:
210 | model.load_state_dict(load_url(model_urls['resnet101']), strict=False)
211 | return model
212 |
213 | # def resnet152(pretrained=False, **kwargs):
214 | # """Constructs a ResNet-152 model.
215 | #
216 | # Args:
217 | # pretrained (bool): If True, returns a model pre-trained on ImageNet
218 | # """
219 | # model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
220 | # if pretrained:
221 | # model.load_state_dict(load_url(model_urls['resnet152']))
222 | # return model
223 |
224 | def load_url(url, model_dir='./pretrained', map_location=None):
225 | if not os.path.exists(model_dir):
226 | os.makedirs(model_dir)
227 | filename = url.split('/')[-1]
228 | cached_file = os.path.join(model_dir, filename)
229 | if not os.path.exists(cached_file):
230 | sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
231 | urlretrieve(url, cached_file)
232 | return torch.load(cached_file, map_location=map_location)
233 |
--------------------------------------------------------------------------------
/Segmentation/models/resnext.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import torch
4 | import torch.nn as nn
5 | import math
6 | from lib.nn import SynchronizedBatchNorm2d
7 |
8 | try:
9 | from urllib import urlretrieve
10 | except ImportError:
11 | from urllib.request import urlretrieve
12 |
13 |
14 | __all__ = ['ResNeXt', 'resnext101'] # support resnext 101
15 |
16 |
17 | model_urls = {
18 | #'resnext50': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnext50-imagenet.pth',
19 | 'resnext101': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnext101-imagenet.pth'
20 | }
21 |
22 |
23 | def conv3x3(in_planes, out_planes, stride=1):
24 | "3x3 convolution with padding"
25 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
26 | padding=1, bias=False)
27 |
28 |
29 | class GroupBottleneck(nn.Module):
30 | expansion = 2
31 |
32 | def __init__(self, inplanes, planes, stride=1, groups=1, downsample=None):
33 | super(GroupBottleneck, self).__init__()
34 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
35 | self.bn1 = SynchronizedBatchNorm2d(planes)
36 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
37 | padding=1, groups=groups, bias=False)
38 | self.bn2 = SynchronizedBatchNorm2d(planes)
39 | self.conv3 = nn.Conv2d(planes, planes * 2, kernel_size=1, bias=False)
40 | self.bn3 = SynchronizedBatchNorm2d(planes * 2)
41 | self.relu = nn.ReLU(inplace=True)
42 | self.downsample = downsample
43 | self.stride = stride
44 |
45 | def forward(self, x):
46 | residual = x
47 |
48 | out = self.conv1(x)
49 | out = self.bn1(out)
50 | out = self.relu(out)
51 |
52 | out = self.conv2(out)
53 | out = self.bn2(out)
54 | out = self.relu(out)
55 |
56 | out = self.conv3(out)
57 | out = self.bn3(out)
58 |
59 | if self.downsample is not None:
60 | residual = self.downsample(x)
61 |
62 | out += residual
63 | out = self.relu(out)
64 |
65 | return out
66 |
67 |
68 | class ResNeXt(nn.Module):
69 |
70 | def __init__(self, block, layers, groups=32, num_classes=1000):
71 | self.inplanes = 128
72 | super(ResNeXt, self).__init__()
73 | self.conv1 = conv3x3(3, 64, stride=2)
74 | self.bn1 = SynchronizedBatchNorm2d(64)
75 | self.relu1 = nn.ReLU(inplace=True)
76 | self.conv2 = conv3x3(64, 64)
77 | self.bn2 = SynchronizedBatchNorm2d(64)
78 | self.relu2 = nn.ReLU(inplace=True)
79 | self.conv3 = conv3x3(64, 128)
80 | self.bn3 = SynchronizedBatchNorm2d(128)
81 | self.relu3 = nn.ReLU(inplace=True)
82 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
83 |
84 | self.layer1 = self._make_layer(block, 128, layers[0], groups=groups)
85 | self.layer2 = self._make_layer(block, 256, layers[1], stride=2, groups=groups)
86 | self.layer3 = self._make_layer(block, 512, layers[2], stride=2, groups=groups)
87 | self.layer4 = self._make_layer(block, 1024, layers[3], stride=2, groups=groups)
88 | self.avgpool = nn.AvgPool2d(7, stride=1)
89 | self.fc = nn.Linear(1024 * block.expansion, num_classes)
90 |
91 | for m in self.modules():
92 | if isinstance(m, nn.Conv2d):
93 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels // m.groups
94 | m.weight.data.normal_(0, math.sqrt(2. / n))
95 | elif isinstance(m, SynchronizedBatchNorm2d):
96 | m.weight.data.fill_(1)
97 | m.bias.data.zero_()
98 |
99 | def _make_layer(self, block, planes, blocks, stride=1, groups=1):
100 | downsample = None
101 | if stride != 1 or self.inplanes != planes * block.expansion:
102 | downsample = nn.Sequential(
103 | nn.Conv2d(self.inplanes, planes * block.expansion,
104 | kernel_size=1, stride=stride, bias=False),
105 | SynchronizedBatchNorm2d(planes * block.expansion),
106 | )
107 |
108 | layers = []
109 | layers.append(block(self.inplanes, planes, stride, groups, downsample))
110 | self.inplanes = planes * block.expansion
111 | for i in range(1, blocks):
112 | layers.append(block(self.inplanes, planes, groups=groups))
113 |
114 | return nn.Sequential(*layers)
115 |
116 | def forward(self, x):
117 | x = self.relu1(self.bn1(self.conv1(x)))
118 | x = self.relu2(self.bn2(self.conv2(x)))
119 | x = self.relu3(self.bn3(self.conv3(x)))
120 | x = self.maxpool(x)
121 |
122 | x = self.layer1(x)
123 | x = self.layer2(x)
124 | x = self.layer3(x)
125 | x = self.layer4(x)
126 |
127 | x = self.avgpool(x)
128 | x = x.view(x.size(0), -1)
129 | x = self.fc(x)
130 |
131 | return x
132 |
133 |
134 | '''
135 | def resnext50(pretrained=False, **kwargs):
136 | """Constructs a ResNet-50 model.
137 |
138 | Args:
139 | pretrained (bool): If True, returns a model pre-trained on Places
140 | """
141 | model = ResNeXt(GroupBottleneck, [3, 4, 6, 3], **kwargs)
142 | if pretrained:
143 | model.load_state_dict(load_url(model_urls['resnext50']), strict=False)
144 | return model
145 | '''
146 |
147 |
148 | def resnext101(pretrained=False, **kwargs):
149 | """Constructs a ResNet-101 model.
150 |
151 | Args:
152 | pretrained (bool): If True, returns a model pre-trained on Places
153 | """
154 | model = ResNeXt(GroupBottleneck, [3, 4, 23, 3], **kwargs)
155 | if pretrained:
156 | model.load_state_dict(load_url(model_urls['resnext101']), strict=False)
157 | return model
158 |
159 |
160 | # def resnext152(pretrained=False, **kwargs):
161 | # """Constructs a ResNeXt-152 model.
162 | #
163 | # Args:
164 | # pretrained (bool): If True, returns a model pre-trained on Places
165 | # """
166 | # model = ResNeXt(GroupBottleneck, [3, 8, 36, 3], **kwargs)
167 | # if pretrained:
168 | # model.load_state_dict(load_url(model_urls['resnext152']))
169 | # return model
170 |
171 |
172 | def load_url(url, model_dir='./pretrained', map_location=None):
173 | if not os.path.exists(model_dir):
174 | os.makedirs(model_dir)
175 | filename = url.split('/')[-1]
176 | cached_file = os.path.join(model_dir, filename)
177 | if not os.path.exists(cached_file):
178 | sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
179 | urlretrieve(url, cached_file)
180 | return torch.load(cached_file, map_location=map_location)
181 |
--------------------------------------------------------------------------------
/Segmentation/script.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | clear
4 |
5 | # Path to images and results
6 | DATASET=../Dataset/
7 | RESULT_PATH=./SegmentationResults/
8 |
9 | # Segmentation model
10 | MODEL_PATH=models
11 | MASKTYPE=smooth
12 |
13 | # Inference
14 | python -u SemanticMasks.py \
15 | --model_path $MODEL_PATH \
16 | --dataset $DATASET \
17 | --arch_encoder resnet50dilated \
18 | --arch_decoder ppm_deepsup \
19 | --fc_dim 2048 \
20 | --result $RESULT_PATH \
21 | --mask_type $MASKTYPE \
22 | --gpu 0
23 |
--------------------------------------------------------------------------------
/Segmentation/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import functools
4 | import fnmatch
5 | import numpy as np
6 |
7 |
8 | def find_recursive(root_dir, ext='.jpg'):
9 | files = []
10 | for root, dirnames, filenames in os.walk(root_dir):
11 | for filename in fnmatch.filter(filenames, '*' + ext):
12 | files.append(os.path.join(root, filename))
13 | return files
14 |
15 |
16 | class AverageMeter(object):
17 | """Computes and stores the average and current value"""
18 | def __init__(self):
19 | self.initialized = False
20 | self.val = None
21 | self.avg = None
22 | self.sum = None
23 | self.count = None
24 |
25 | def initialize(self, val, weight):
26 | self.val = val
27 | self.avg = val
28 | self.sum = val * weight
29 | self.count = weight
30 | self.initialized = True
31 |
32 | def update(self, val, weight=1):
33 | if not self.initialized:
34 | self.initialize(val, weight)
35 | else:
36 | self.add(val, weight)
37 |
38 | def add(self, val, weight):
39 | self.val = val
40 | self.sum += val * weight
41 | self.count += weight
42 | self.avg = self.sum / self.count
43 |
44 | def value(self):
45 | return self.val
46 |
47 | def average(self):
48 | return self.avg
49 |
50 |
51 | def unique(ar, return_index=False, return_inverse=False, return_counts=False):
52 | ar = np.asanyarray(ar).flatten()
53 |
54 | optional_indices = return_index or return_inverse
55 | optional_returns = optional_indices or return_counts
56 |
57 | if ar.size == 0:
58 | if not optional_returns:
59 | ret = ar
60 | else:
61 | ret = (ar,)
62 | if return_index:
63 | ret += (np.empty(0, np.bool),)
64 | if return_inverse:
65 | ret += (np.empty(0, np.bool),)
66 | if return_counts:
67 | ret += (np.empty(0, np.intp),)
68 | return ret
69 | if optional_indices:
70 | perm = ar.argsort(kind='mergesort' if return_index else 'quicksort')
71 | aux = ar[perm]
72 | else:
73 | ar.sort()
74 | aux = ar
75 | flag = np.concatenate(([True], aux[1:] != aux[:-1]))
76 |
77 | if not optional_returns:
78 | ret = aux[flag]
79 | else:
80 | ret = (aux[flag],)
81 | if return_index:
82 | ret += (perm[flag],)
83 | if return_inverse:
84 | iflag = np.cumsum(flag) - 1
85 | inv_idx = np.empty(ar.shape, dtype=np.intp)
86 | inv_idx[perm] = iflag
87 | ret += (inv_idx,)
88 | if return_counts:
89 | idx = np.concatenate(np.nonzero(flag) + ([ar.size],))
90 | ret += (np.diff(idx),)
91 | return ret
92 |
93 |
94 | def colorEncode(labelmap, colors, mode='BGR'):
95 | labelmap = labelmap.astype('int')
96 | labelmap_rgb = np.zeros((labelmap.shape[0], labelmap.shape[1], 3),
97 | dtype=np.uint8)
98 | for label in unique(labelmap):
99 | if label < 0:
100 | continue
101 | labelmap_rgb += (labelmap == label)[:, :, np.newaxis] * \
102 | np.tile(colors[label],
103 | (labelmap.shape[0], labelmap.shape[1], 1))
104 |
105 | if mode == 'BGR':
106 | return labelmap_rgb[:, :, ::-1]
107 | else:
108 | return labelmap_rgb
109 |
110 |
111 | def accuracy(preds, label):
112 | valid = (label >= 0)
113 | acc_sum = (valid * (preds == label)).sum()
114 | valid_sum = valid.sum()
115 | acc = float(acc_sum) / (valid_sum + 1e-10)
116 | return acc, valid_sum
117 |
118 |
119 | def intersectionAndUnion(imPred, imLab, numClass):
120 | imPred = np.asarray(imPred).copy()
121 | imLab = np.asarray(imLab).copy()
122 |
123 | imPred += 1
124 | imLab += 1
125 | # Remove classes from unlabeled pixels in gt image.
126 | # We should not penalize detections in unlabeled portions of the image.
127 | imPred = imPred * (imLab > 0)
128 |
129 | # Compute area intersection:
130 | intersection = imPred * (imPred == imLab)
131 | (area_intersection, _) = np.histogram(
132 | intersection, bins=numClass, range=(1, numClass))
133 |
134 | # Compute area union:
135 | (area_pred, _) = np.histogram(imPred, bins=numClass, range=(1, numClass))
136 | (area_lab, _) = np.histogram(imLab, bins=numClass, range=(1, numClass))
137 | area_union = area_pred + area_lab - area_intersection
138 |
139 | return (area_intersection, area_union)
140 |
141 |
142 | class NotSupportedCliException(Exception):
143 | pass
144 |
145 |
146 | def process_range(xpu, inp):
147 | start, end = map(int, inp)
148 | if start > end:
149 | end, start = start, end
150 | return map(lambda x: '{}{}'.format(xpu, x), range(start, end+1))
151 |
152 |
153 | REGEX = [
154 | (re.compile(r'^gpu(\d+)$'), lambda x: ['gpu%s' % x[0]]),
155 | (re.compile(r'^(\d+)$'), lambda x: ['gpu%s' % x[0]]),
156 | (re.compile(r'^gpu(\d+)-(?:gpu)?(\d+)$'),
157 | functools.partial(process_range, 'gpu')),
158 | (re.compile(r'^(\d+)-(\d+)$'),
159 | functools.partial(process_range, 'gpu')),
160 | ]
161 |
162 |
163 | def parse_devices(input_devices):
164 |
165 | """Parse user's devices input str to standard format.
166 | e.g. [gpu0, gpu1, ...]
167 |
168 | """
169 | ret = []
170 | for d in input_devices.split(','):
171 | for regex, func in REGEX:
172 | m = regex.match(d.lower().strip())
173 | if m:
174 | tmp = func(m.groups())
175 | # prevent duplicate
176 | for x in tmp:
177 | if x not in ret:
178 | ret.append(x)
179 | break
180 | else:
181 | raise NotSupportedCliException(
182 | 'Can not recognize device: "{}"'.format(d))
183 | return ret
184 |
--------------------------------------------------------------------------------
/TutorialDemoColorFool/Image/ILSVRC2012_val_00003533.JPEG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/smartcameras/ColorFool/a2ccae4db821e304fb54090c332545b71046ea22/TutorialDemoColorFool/Image/ILSVRC2012_val_00003533.JPEG
--------------------------------------------------------------------------------
/TutorialDemoColorFool/Masks/Person/ILSVRC2012_val_00003533.JPEG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/smartcameras/ColorFool/a2ccae4db821e304fb54090c332545b71046ea22/TutorialDemoColorFool/Masks/Person/ILSVRC2012_val_00003533.JPEG
--------------------------------------------------------------------------------
/TutorialDemoColorFool/Masks/Sky/ILSVRC2012_val_00003533.JPEG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/smartcameras/ColorFool/a2ccae4db821e304fb54090c332545b71046ea22/TutorialDemoColorFool/Masks/Sky/ILSVRC2012_val_00003533.JPEG
--------------------------------------------------------------------------------
/TutorialDemoColorFool/Masks/Vegetation/ILSVRC2012_val_00003533.JPEG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/smartcameras/ColorFool/a2ccae4db821e304fb54090c332545b71046ea22/TutorialDemoColorFool/Masks/Vegetation/ILSVRC2012_val_00003533.JPEG
--------------------------------------------------------------------------------
/TutorialDemoColorFool/Masks/Water/ILSVRC2012_val_00003533.JPEG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/smartcameras/ColorFool/a2ccae4db821e304fb54090c332545b71046ea22/TutorialDemoColorFool/Masks/Water/ILSVRC2012_val_00003533.JPEG
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | torch
3 | torchvision
4 | opencv-python
5 | tqdm
6 | future
7 | scikit-image
8 | tensorboardX
9 |
--------------------------------------------------------------------------------