├── LICENSE ├── README.md ├── model_definitions ├── __init__.py ├── ozan_min_norm_solvers.py ├── ozan_rep_fun.py ├── resnet_taskonomy.py ├── xception_taskonomy_joined_decoder.py ├── xception_taskonomy_new.py └── xception_taskonomy_small.py ├── network_selection ├── .vscode │ ├── launch.json │ ├── settings.json │ └── tasks.json ├── Makefile ├── a.out ├── main ├── main.cpp ├── make_plots.py ├── plots │ ├── setting_1.pdf │ ├── setting_2.pdf │ ├── setting_3.pdf │ └── setting_4.pdf ├── results.txt ├── results_20.txt ├── results_alt.txt ├── results_alt_20.txt ├── results_alt_test.txt ├── results_large.txt ├── results_large_20.txt ├── results_large_test.txt ├── results_mean.txt ├── results_small_data.txt ├── results_small_data_at4.txt ├── results_small_data_test.txt └── results_test.txt ├── read_training_history.py ├── saved_models └── placeholder ├── sync_batchnorm ├── __init__.py ├── batchnorm.py ├── batchnorm_reimpl.py ├── comm.py ├── replicate.py └── unittest.py ├── taskonomy_loader.py ├── taskonomy_losses.py ├── train_models.txt ├── train_taskonomy.py └── val_models.txt /LICENSE: -------------------------------------------------------------------------------- 1 | COPYRIGHT 2 | 3 | All contributions by Trevor Standley: 4 | Copyright (c) 2018, Trevor Standley 5 | All rights reserved. 6 | 7 | All other contributions: 8 | Copyright (c) 2018-, the respective contributors. 9 | All rights reserved. 10 | 11 | 12 | LICENSE 13 | 14 | MIT License 15 | 16 | Each contributor holds copyright over their respective contributions. 17 | The project versioning (Git) records all such contribution source information. 18 | 19 | Permission is hereby granted, free of charge, to any person obtaining a copy 20 | of this software and associated documentation files (the "Software"), to deal 21 | in the Software without restriction, including without limitation the rights 22 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 23 | copies of the Software, and to permit persons to whom the Software is 24 | furnished to do so, subject to the following conditions: 25 | 26 | The above copyright notice and this permission notice shall be included in all 27 | copies or substantial portions of the Software. 28 | 29 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 30 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 31 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 32 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 33 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 34 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 35 | SOFTWARE. 36 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Code for Which Tasks to Train Together in Multi-Task Learning 2 | 3 | Trevor Standley, Amir R. Zamir, Dawn Chen, Leonidas Guibas, Jitendra Malik, Silvio Savarese 4 | 5 | ICML 2020 6 | 7 | http://taskgrouping.stanford.edu/ 8 | 9 | 1. Install pytorch,torchvision 10 | 2. Install apex 11 | ``` 12 | conda install -c conda-forge nvidia-apex 13 | ``` 14 | 3. (optional) install data loading speedups: 15 | ``` 16 | conda install -c thomasbrandon -c defaults -c conda-forge pillow-accel-avx2 17 | conda install -c conda-forge libjpeg-turbo 18 | ``` 19 | 4. Get training data 20 | https://github.com/StanfordVL/taskonomy/tree/master/data 21 | The data must be aranged in 22 | 23 | ``` 24 | inputs: 25 | root/rgb/building/point_x_view_x.png 26 | labels: 27 | root/$task$/$building$/point_x_view_x.png 28 | ``` 29 | order. 30 | 31 | usage example 32 | ``` 33 | python3 train_taskonomy.py -d=/taskonomy_data/ -a=xception_taskonomy_new -j 4 -b 96 -lr=.1 --fp16 -sbn --tasks=sdnerac -r 34 | ``` 35 | 36 | Pretrained models from setting 2: 37 | 38 | 39 | https://drive.google.com/drive/folders/1XQVpv6Yyz5CRGNxetO0LTXuTvMS_w5R5?usp=sharing 40 | 41 | to test these models on the test set: 42 | 43 | ``` 44 | python3 train_taskonomy.py -d=/taskonomy_data/ -a=xception_taskonomy_new -j 4 -b 256 -lr=.1 --fp16 -sbn --tasks=[task letters] --resume=setting2_models/xception_taskonomy_new_[task letters].pth.tar -t -r 45 | ``` 46 | 47 | (contact for models from other settings) 48 | 49 | -------------------------------------------------------------------------------- /model_definitions/__init__.py: -------------------------------------------------------------------------------- 1 | from .xception_taskonomy_new import * 2 | from .xception_taskonomy_joined_decoder import * 3 | from .xception_taskonomy_small import * 4 | from .resnet_taskonomy import * 5 | 6 | -------------------------------------------------------------------------------- /model_definitions/ozan_min_norm_solvers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import math 4 | 5 | 6 | class MinNormSolver: 7 | MAX_ITER = 250 8 | STOP_CRIT = 1e-5 9 | 10 | def _min_norm_element_from2(v1v1, v1v2, v2v2): 11 | """ 12 | Analytical solution for min_{c} |cx_1 + (1-c)x_2|_2^2 13 | d is the distance (objective) optimzed 14 | v1v1 = 15 | v1v2 = 16 | v2v2 = 17 | """ 18 | if v1v2 >= v1v1: 19 | # Case: Fig 1, third column 20 | gamma = 0.999 21 | cost = v1v1 22 | return gamma, cost 23 | if v1v2 >= v2v2: 24 | # Case: Fig 1, first column 25 | gamma = 0.001 26 | cost = v2v2 27 | return gamma, cost 28 | # Case: Fig 1, second column 29 | gamma = -1.0 * ( (v1v2 - v2v2) / (v1v1+v2v2 - 2*v1v2) ) 30 | cost = v2v2 + gamma*(v1v2 - v2v2) 31 | return gamma, cost 32 | 33 | def _min_norm_2d(vecs, dps): 34 | """ 35 | Find the minimum norm solution as combination of two points 36 | This is correct only in 2D 37 | ie. min_c |\sum c_i x_i|_2^2 st. \sum c_i = 1 , 1 >= c_1 >= 0 for all i, c_i + c_j = 1.0 for some i, j 38 | """ 39 | dmin = 1e99 40 | sol=None 41 | for i in range(len(vecs)): 42 | for j in range(i+1,len(vecs)): 43 | if (i,j) not in dps: 44 | dps[(i, j)] = 0.0 45 | for k in range(len(vecs[i])): 46 | dps[(i,j)] += torch.dot(vecs[i][k], vecs[j][k]).item()#.data[0] 47 | dps[(j, i)] = dps[(i, j)] 48 | if (i,i) not in dps: 49 | dps[(i, i)] = 0.0 50 | for k in range(len(vecs[i])): 51 | dps[(i,i)] += torch.dot(vecs[i][k], vecs[i][k]).item()#.data[0] 52 | if (j,j) not in dps: 53 | dps[(j, j)] = 0.0 54 | for k in range(len(vecs[i])): 55 | dps[(j, j)] += torch.dot(vecs[j][k], vecs[j][k]).item()#.data[0] 56 | c,d = MinNormSolver._min_norm_element_from2(dps[(i,i)], dps[(i,j)], dps[(j,j)]) 57 | #print('c,d',c,d) 58 | if d < dmin: 59 | dmin = d 60 | sol = [(i,j),c,d] 61 | 62 | if sol is None or math.isnan(c): 63 | raise ValueError('A numeric instability occured in ozan_min_norm_solvers.') 64 | return sol, dps 65 | 66 | def _projection2simplex(y): 67 | """ 68 | Given y, it solves argmin_z |y-z|_2 st \sum z = 1 , 1 >= z_i >= 0 for all i 69 | """ 70 | m = len(y) 71 | sorted_y = np.flip(np.sort(y), axis=0) 72 | tmpsum = 0.0 73 | tmax_f = (np.sum(y) - 1.0)/m 74 | for i in range(m-1): 75 | tmpsum+= sorted_y[i] 76 | tmax = (tmpsum - 1)/ (i+1.0) 77 | if tmax > sorted_y[i+1]: 78 | tmax_f = tmax 79 | break 80 | return np.maximum(y - tmax_f, np.zeros(y.shape)) 81 | 82 | def _next_point(cur_val, grad, n): 83 | proj_grad = grad - ( np.sum(grad) / n ) 84 | tm1 = -1.0*cur_val[proj_grad<0]/proj_grad[proj_grad<0] 85 | tm2 = (1.0 - cur_val[proj_grad>0])/(proj_grad[proj_grad>0]) 86 | 87 | skippers = np.sum(tm1<1e-7) + np.sum(tm2<1e-7) 88 | t = 1 89 | if len(tm1[tm1>1e-7]) > 0: 90 | t = np.min(tm1[tm1>1e-7]) 91 | if len(tm2[tm2>1e-7]) > 0: 92 | t = min(t, np.min(tm2[tm2>1e-7])) 93 | 94 | next_point = proj_grad*t + cur_val 95 | next_point = MinNormSolver._projection2simplex(next_point) 96 | return next_point 97 | 98 | def find_min_norm_element(vecs): 99 | """ 100 | Given a list of vectors (vecs), this method finds the minimum norm element in the convex hull 101 | as min |u|_2 st. u = \sum c_i vecs[i] and \sum c_i = 1. 102 | It is quite geometric, and the main idea is the fact that if d_{ij} = min |u|_2 st u = c x_i + (1-c) x_j; the solution lies in (0, d_{i,j}) 103 | Hence, we find the best 2-task solution, and then run the projected gradient descent until convergence 104 | """ 105 | # Solution lying at the combination of two points 106 | dps = {} 107 | init_sol, dps = MinNormSolver._min_norm_2d(vecs, dps) 108 | 109 | n=len(vecs) 110 | sol_vec = np.zeros(n) 111 | sol_vec[init_sol[0][0]] = init_sol[1] 112 | sol_vec[init_sol[0][1]] = 1 - init_sol[1] 113 | 114 | if n < 3: 115 | # This is optimal for n=2, so return the solution 116 | return sol_vec , init_sol[2] 117 | 118 | iter_count = 0 119 | 120 | grad_mat = np.zeros((n,n)) 121 | for i in range(n): 122 | for j in range(n): 123 | grad_mat[i,j] = dps[(i, j)] 124 | 125 | 126 | while iter_count < MinNormSolver.MAX_ITER: 127 | grad_dir = -1.0*np.dot(grad_mat, sol_vec) 128 | new_point = MinNormSolver._next_point(sol_vec, grad_dir, n) 129 | # Re-compute the inner products for line search 130 | v1v1 = 0.0 131 | v1v2 = 0.0 132 | v2v2 = 0.0 133 | for i in range(n): 134 | for j in range(n): 135 | v1v1 += sol_vec[i]*sol_vec[j]*dps[(i,j)] 136 | v1v2 += sol_vec[i]*new_point[j]*dps[(i,j)] 137 | v2v2 += new_point[i]*new_point[j]*dps[(i,j)] 138 | nc, nd = MinNormSolver._min_norm_element_from2(v1v1, v1v2, v2v2) 139 | new_sol_vec = nc*sol_vec + (1-nc)*new_point 140 | change = new_sol_vec - sol_vec 141 | if np.sum(np.abs(change)) < MinNormSolver.STOP_CRIT: 142 | return sol_vec, nd 143 | sol_vec = new_sol_vec 144 | 145 | def find_min_norm_element_FW(vecs): 146 | """ 147 | Given a list of vectors (vecs), this method finds the minimum norm element in the convex hull 148 | as min |u|_2 st. u = \sum c_i vecs[i] and \sum c_i = 1. 149 | It is quite geometric, and the main idea is the fact that if d_{ij} = min |u|_2 st u = c x_i + (1-c) x_j; the solution lies in (0, d_{i,j}) 150 | Hence, we find the best 2-task solution, and then run the Frank Wolfe until convergence 151 | """ 152 | # Solution lying at the combination of two points 153 | dps = {} 154 | init_sol, dps = MinNormSolver._min_norm_2d(vecs, dps) 155 | 156 | n=len(vecs) 157 | sol_vec = np.zeros(n) 158 | sol_vec[init_sol[0][0]] = init_sol[1] 159 | sol_vec[init_sol[0][1]] = 1 - init_sol[1] 160 | 161 | if n < 3: 162 | # This is optimal for n=2, so return the solution 163 | return sol_vec , init_sol[2] 164 | 165 | iter_count = 0 166 | 167 | grad_mat = np.zeros((n,n)) 168 | for i in range(n): 169 | for j in range(n): 170 | grad_mat[i,j] = dps[(i, j)] 171 | 172 | while iter_count < MinNormSolver.MAX_ITER: 173 | t_iter = np.argmin(np.dot(grad_mat, sol_vec)) 174 | 175 | v1v1 = np.dot(sol_vec, np.dot(grad_mat, sol_vec)) 176 | v1v2 = np.dot(sol_vec, grad_mat[:, t_iter]) 177 | v2v2 = grad_mat[t_iter, t_iter] 178 | 179 | nc, nd = MinNormSolver._min_norm_element_from2(v1v1, v1v2, v2v2) 180 | new_sol_vec = nc*sol_vec 181 | new_sol_vec[t_iter] += 1 - nc 182 | 183 | change = new_sol_vec - sol_vec 184 | if np.sum(np.abs(change)) < MinNormSolver.STOP_CRIT: 185 | return sol_vec, nd 186 | sol_vec = new_sol_vec 187 | 188 | 189 | def gradient_normalizers(grads, losses, normalization_type): 190 | gn = {} 191 | if normalization_type == 'l2': 192 | for t in grads: 193 | gn[t] = np.sqrt(np.sum([gr.pow(2).sum().data[0] for gr in grads[t]])) 194 | elif normalization_type == 'loss': 195 | for t in grads: 196 | gn[t] = losses[t] 197 | elif normalization_type == 'loss+': 198 | for t in grads: 199 | gn[t] = losses[t] * np.sqrt(np.sum([gr.pow(2).sum().data[0] for gr in grads[t]])) 200 | elif normalization_type == 'none': 201 | for t in grads: 202 | gn[t] = 1.0 203 | else: 204 | print('ERROR: Invalid Normalization Type') 205 | return gn -------------------------------------------------------------------------------- /model_definitions/ozan_rep_fun.py: -------------------------------------------------------------------------------- 1 | import torch.autograd 2 | import sys 3 | import math 4 | from .ozan_min_norm_solvers import MinNormSolver 5 | import statistics 6 | class OzanRepFunction(torch.autograd.Function): 7 | # def __init__(self,copies,noop=False): 8 | # super(OzanRepFunction,self).__init__() 9 | # self.copies=copies 10 | # self.noop=noop 11 | n=5 12 | def __init__(self): 13 | super(OzanRepFunction, self).__init__() 14 | 15 | @staticmethod 16 | def forward(ctx, input): 17 | 18 | shape = input.shape 19 | ret = input.expand(OzanRepFunction.n,*shape) 20 | return ret.clone() # REASON FOR ERROR: forgot to .clone() here 21 | 22 | #@staticmethod 23 | # def backward(ctx, grad_output): 24 | # # print("backward",grad_output.shape) 25 | # # print() 26 | # # print() 27 | # if grad_output.shape[0]==2: 28 | # theta0,theta1=grad_output[0].view(-1).float(), grad_output[1].view(-1).float() 29 | # diff = theta0-theta1 30 | # num = diff.dot(theta0) 31 | # denom = (diff.dot(diff)+.00000001) 32 | # a = num/denom 33 | # a1=float(a) 34 | # a = a.clamp(0,1) 35 | # a = float(a) 36 | # # print(float(a),a1,float(num),float(denom)) 37 | # # print() 38 | # # print() 39 | # def get_out_for_a(a): 40 | # return grad_output[0]*(1-a)+grad_output[1]*a 41 | # def get_score_for_a(a): 42 | # out = get_out_for_a(a) 43 | # vec = out.view(-1) 44 | # score = vec.dot(vec) 45 | # return float(score) 46 | # # print(0,get_score_for_a(0), 47 | # # .1,get_score_for_a(0.1), 48 | # # .2,get_score_for_a(0.2), 49 | # # .3,get_score_for_a(0.3), 50 | # # .4,get_score_for_a(0.4), 51 | # # .5,get_score_for_a(0.5), 52 | # # .6,get_score_for_a(0.6), 53 | # # .7,get_score_for_a(0.7), 54 | # # .8,get_score_for_a(0.8), 55 | # # .9,get_score_for_a(0.9), 56 | # # 1,get_score_for_a(1)) 57 | # # print(a,get_score_for_a(a)) 58 | # # print() 59 | # # print() 60 | # out = get_out_for_a(a) 61 | # #out=out*2 62 | # elif grad_output.shape[0]==1: 63 | # grad_input=grad_output.clone() 64 | # out = grad_input.sum(dim=0) 65 | # else: 66 | # pass 67 | # return out 68 | 69 | @staticmethod 70 | def backward(ctx, grad_output): 71 | num_grads = grad_output.shape[0] 72 | batch_size = grad_output.shape[1] 73 | # print(num_grads) 74 | # print(num_grads) 75 | # print(num_grads) 76 | #print(grad_output.shape) 77 | # print(grad_output.shape) 78 | # print(grad_output.shape) 79 | # print(num_grads) 80 | # print(num_grads) 81 | if num_grads>=2: 82 | #print ('shape in = ',grad_output[0].view(batch_size,-1).float().shape) 83 | try: 84 | alphas, score = MinNormSolver.find_min_norm_element([grad_output[i].view(batch_size,-1).float() for i in range(num_grads)]) 85 | #print(alphas) 86 | except ValueError as error: 87 | alphas = [1/num_grads for i in range(num_grads)] 88 | #print('outs shape',out.shape) 89 | #print('alphas shape',alphas.shape) 90 | 91 | #out = out.view() 92 | #out = torch.zeros_like(grad_output[0]) 93 | # print(alphas) 94 | # print() 95 | # print() 96 | grad_outputs = [grad_output[i]*alphas[i]*math.sqrt(num_grads) for i in range(num_grads)] 97 | output = grad_outputs[0] 98 | for i in range(1,num_grads): 99 | output+=grad_outputs[i] 100 | return output 101 | 102 | 103 | elif num_grads==1: 104 | grad_input=grad_output.clone() 105 | out = grad_input.sum(dim=0) 106 | else: 107 | pass 108 | return out 109 | 110 | ozan_rep_function = OzanRepFunction.apply 111 | 112 | 113 | class TrevorRepFunction(torch.autograd.Function): 114 | n=5 115 | def __init__(self): 116 | super(TrevorRepFunction, self).__init__() 117 | 118 | @staticmethod 119 | def forward(ctx, input): 120 | return input.clone() 121 | 122 | @staticmethod 123 | def backward(ctx, grad_output): 124 | #num_grads = grad_output.shape[0] 125 | #print(num_grads) 126 | grad_input=grad_output.clone() 127 | mul = 1.0/math.sqrt(TrevorRepFunction.n) 128 | out = grad_input * mul 129 | return out 130 | 131 | trevor_rep_function = TrevorRepFunction.apply 132 | 133 | count = 0 134 | 135 | class GradNormRepFunction(torch.autograd.Function): 136 | n=5 137 | inital_task_losses=None 138 | current_task_losses=None 139 | current_weights=None 140 | def __init__(self): 141 | super(GradNormRepFunction, self).__init__() 142 | 143 | @staticmethod 144 | def forward(ctx, input): 145 | shape = input.shape 146 | ret = input.expand(GradNormRepFunction.n,*shape) 147 | return ret.clone() 148 | 149 | @staticmethod 150 | def backward(ctx, grad_output): 151 | global count 152 | num_grads = grad_output.shape[0] 153 | batch_size = grad_output.shape[1] 154 | grad_output=grad_output.float() 155 | if num_grads>=2: 156 | 157 | GiW = [torch.sqrt(grad_output[i].reshape(-1).dot(grad_output[i].reshape(-1)))*GradNormRepFunction.current_weights[i] for i in range(num_grads)] 158 | GW_bar = torch.mean(torch.stack(GiW)) 159 | 160 | try: 161 | Li_ratio=[c/max(i,.0000001) for c,i in zip(GradNormRepFunction.current_task_losses,GradNormRepFunction.inital_task_losses)] 162 | mean_ratio = statistics.mean(Li_ratio) 163 | ri = [lir/max(mean_ratio,.00000001) for lir in Li_ratio] 164 | target_grad=[float(GW_bar * (max(r_i,.00000001)**1.5)) for r_i in ri] 165 | 166 | target_weight= [float(target_grad[i]/float(GiW[i])) for i in range(num_grads)] 167 | total_weight = sum(target_weight) 168 | total_weight = max(.0000001,total_weight) 169 | target_weight=[i*num_grads/total_weight for i in target_weight] 170 | 171 | for i in range(len(GradNormRepFunction.current_weights)): 172 | wi = GradNormRepFunction.current_weights[i] 173 | GradNormRepFunction.current_weights[i]+=(.0001*wi if (wi 1: 38 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 39 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 40 | self.conv1 = conv3x3(inplanes, planes, stride) 41 | self.bn1 = norm_layer(planes) 42 | self.relu = nn.ReLU(inplace=True) 43 | self.conv2 = conv3x3(planes, planes) 44 | self.bn2 = norm_layer(planes) 45 | self.downsample = downsample 46 | self.stride = stride 47 | 48 | def forward(self, x): 49 | identity = x 50 | 51 | out = self.conv1(x) 52 | out = self.bn1(out) 53 | out = self.relu(out) 54 | 55 | out = self.conv2(out) 56 | out = self.bn2(out) 57 | 58 | if self.downsample is not None: 59 | identity = self.downsample(x) 60 | 61 | out += identity 62 | out = self.relu(out) 63 | 64 | return out 65 | 66 | 67 | class Bottleneck(nn.Module): 68 | expansion = 4 69 | 70 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 71 | base_width=64, dilation=1, norm_layer=None): 72 | super(Bottleneck, self).__init__() 73 | if norm_layer is None: 74 | norm_layer = nn.BatchNorm2d 75 | width = int(planes * (base_width / 64.)) * groups 76 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 77 | self.conv1 = conv1x1(inplanes, width) 78 | self.bn1 = norm_layer(width) 79 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 80 | self.bn2 = norm_layer(width) 81 | self.conv3 = conv1x1(width, planes * self.expansion) 82 | self.bn3 = norm_layer(planes * self.expansion) 83 | self.relu = nn.ReLU(inplace=True) 84 | self.downsample = downsample 85 | self.stride = stride 86 | 87 | def forward(self, x): 88 | identity = x 89 | 90 | out = self.conv1(x) 91 | out = self.bn1(out) 92 | out = self.relu(out) 93 | 94 | out = self.conv2(out) 95 | out = self.bn2(out) 96 | out = self.relu(out) 97 | 98 | out = self.conv3(out) 99 | out = self.bn3(out) 100 | 101 | if self.downsample is not None: 102 | identity = self.downsample(x) 103 | 104 | out += identity 105 | out = self.relu(out) 106 | 107 | return out 108 | 109 | 110 | class ResNetEncoder(nn.Module): 111 | 112 | def __init__(self, block, layers,widths=[64,128,256,512], num_classes=1000, zero_init_residual=False, 113 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 114 | norm_layer=None): 115 | super(ResNetEncoder, self).__init__() 116 | if norm_layer is None: 117 | norm_layer = nn.BatchNorm2d 118 | self._norm_layer = norm_layer 119 | 120 | self.inplanes = 64 121 | self.dilation = 1 122 | if replace_stride_with_dilation is None: 123 | # each element in the tuple indicates if we should replace 124 | # the 2x2 stride with a dilated convolution instead 125 | replace_stride_with_dilation = [False, False, False] 126 | if len(replace_stride_with_dilation) != 3: 127 | raise ValueError("replace_stride_with_dilation should be None " 128 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 129 | self.groups = groups 130 | self.base_width = width_per_group 131 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 132 | bias=False) 133 | self.bn1 = norm_layer(self.inplanes) 134 | self.relu = nn.ReLU(inplace=True) 135 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 136 | self.layer1 = self._make_layer(block, widths[0], layers[0]) 137 | self.layer2 = self._make_layer(block, widths[1], layers[1], stride=2, 138 | dilate=replace_stride_with_dilation[0]) 139 | self.layer3 = self._make_layer(block, widths[2], layers[2], stride=2, 140 | dilate=replace_stride_with_dilation[1]) 141 | self.layer4 = self._make_layer(block, widths[3], layers[3], stride=2, 142 | dilate=replace_stride_with_dilation[2]) 143 | 144 | for m in self.modules(): 145 | if isinstance(m, nn.Conv2d): 146 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 147 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 148 | nn.init.constant_(m.weight, 1) 149 | nn.init.constant_(m.bias, 0) 150 | 151 | # Zero-initialize the last BN in each residual branch, 152 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 153 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 154 | if zero_init_residual: 155 | for m in self.modules(): 156 | if isinstance(m, Bottleneck): 157 | nn.init.constant_(m.bn3.weight, 0) 158 | elif isinstance(m, BasicBlock): 159 | nn.init.constant_(m.bn2.weight, 0) 160 | 161 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 162 | norm_layer = self._norm_layer 163 | downsample = None 164 | previous_dilation = self.dilation 165 | if dilate: 166 | self.dilation *= stride 167 | stride = 1 168 | if stride != 1 or self.inplanes != planes * block.expansion: 169 | downsample = nn.Sequential( 170 | conv1x1(self.inplanes, planes * block.expansion, stride), 171 | norm_layer(planes * block.expansion), 172 | ) 173 | 174 | layers = [] 175 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 176 | self.base_width, previous_dilation, norm_layer)) 177 | self.inplanes = planes * block.expansion 178 | for _ in range(1, blocks): 179 | layers.append(block(self.inplanes, planes, groups=self.groups, 180 | base_width=self.base_width, dilation=self.dilation, 181 | norm_layer=norm_layer)) 182 | 183 | return nn.Sequential(*layers) 184 | 185 | def forward(self, x): 186 | x = self.conv1(x) 187 | x = self.bn1(x) 188 | x = self.relu(x) 189 | x = self.maxpool(x) 190 | 191 | x = self.layer1(x) 192 | x = self.layer2(x) 193 | x = self.layer3(x) 194 | x = self.layer4(x) 195 | 196 | return x 197 | 198 | 199 | 200 | class Decoder(nn.Module): 201 | def __init__(self, output_channels=32,num_classes=None,base_match=512): 202 | super(Decoder, self).__init__() 203 | 204 | self.output_channels = output_channels 205 | self.num_classes = num_classes 206 | 207 | self.relu = nn.ReLU(inplace=True) 208 | if num_classes is not None: 209 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 210 | self.fc = nn.Linear(512 * block.expansion, num_classes) 211 | else: 212 | self.upconv0 = nn.ConvTranspose2d(base_match,256,2,2) 213 | self.bn_upconv0 = nn.BatchNorm2d(256) 214 | self.conv_decode0 = nn.Conv2d(256, 256, 3,padding=1) 215 | self.bn_decode0 = nn.BatchNorm2d(256) 216 | self.upconv1 = nn.ConvTranspose2d(256,128,2,2) 217 | self.bn_upconv1 = nn.BatchNorm2d(128) 218 | self.conv_decode1 = nn.Conv2d(128, 128, 3,padding=1) 219 | self.bn_decode1 = nn.BatchNorm2d(128) 220 | self.upconv2 = nn.ConvTranspose2d(128,64,2,2) 221 | self.bn_upconv2 = nn.BatchNorm2d(64) 222 | self.conv_decode2 = nn.Conv2d(64, 64, 3,padding=1) 223 | self.bn_decode2 = nn.BatchNorm2d(64) 224 | self.upconv3 = nn.ConvTranspose2d(64,48,2,2) 225 | self.bn_upconv3 = nn.BatchNorm2d(48) 226 | self.conv_decode3 = nn.Conv2d(48, 48, 3,padding=1) 227 | self.bn_decode3 = nn.BatchNorm2d(48) 228 | self.upconv4 = nn.ConvTranspose2d(48,32,2,2) 229 | self.bn_upconv4 = nn.BatchNorm2d(32) 230 | self.conv_decode4 = nn.Conv2d(32, output_channels, 3,padding=1) 231 | 232 | 233 | 234 | def forward(self,representation): 235 | #batch_size=representation.shape[0] 236 | if self.num_classes is None: 237 | #x2 = self.conv_decode_res(representation) 238 | #x2 = self.bn_conv_decode_res(x2) 239 | #x2 = interpolate(x2,size=(256,256)) 240 | 241 | x = self.upconv0(representation) 242 | x = self.bn_upconv0(x) 243 | x = self.relu(x) 244 | x = self.conv_decode0(x) 245 | x = self.bn_decode0(x) 246 | x = self.relu(x) 247 | 248 | x = self.upconv1(x) 249 | x = self.bn_upconv1(x) 250 | x = self.relu(x) 251 | x = self.conv_decode1(x) 252 | x = self.bn_decode1(x) 253 | x = self.relu(x) 254 | x = self.upconv2(x) 255 | x = self.bn_upconv2(x) 256 | x = self.relu(x) 257 | x = self.conv_decode2(x) 258 | 259 | x = self.bn_decode2(x) 260 | x = self.relu(x) 261 | x = self.upconv3(x) 262 | x = self.bn_upconv3(x) 263 | x = self.relu(x) 264 | x = self.conv_decode3(x) 265 | x = self.bn_decode3(x) 266 | x = self.relu(x) 267 | x = self.upconv4(x) 268 | x = self.bn_upconv4(x) 269 | #x = torch.cat([x,x2],1) 270 | #print(x.shape,self.static.shape) 271 | #x = torch.cat([x,x2,input,self.static.expand(batch_size,-1,-1,-1)],1) 272 | x = self.relu(x) 273 | x = self.conv_decode4(x) 274 | 275 | #z = x[:,19:22,:,:].clone() 276 | #y = (z).norm(2,1,True).clamp(min=1e-12) 277 | #print(y.shape,x[:,21:24,:,:].shape) 278 | #x[:,19:22,:,:]=z/y 279 | 280 | else: 281 | 282 | x = F.adaptive_avg_pool2d(x, (1, 1)) 283 | x = x.view(x.size(0), -1) 284 | x = self.fc(x) 285 | return x 286 | 287 | 288 | 289 | class ResNet(nn.Module): 290 | def __init__(self,block,layers,tasks=None,num_classes=None, ozan=False,size=1,**kwargs): 291 | super(ResNet, self).__init__() 292 | if size==1: 293 | self.encoder=ResNetEncoder(block,layers,**kwargs) 294 | elif size==2: 295 | self.encoder=ResNetEncoder(block,layers,[96,192,384,720],**kwargs) 296 | elif size==3: 297 | self.encoder=ResNetEncoder(block,layers,[112,224,448,880],**kwargs) 298 | elif size==0.5: 299 | self.encoder=ResNetEncoder(block,layers,[48,96,192,360],**kwargs) 300 | self.tasks=tasks 301 | self.ozan=ozan 302 | self.task_to_decoder = {} 303 | 304 | if tasks is not None: 305 | #self.final_conv = nn.Conv2d(728,512,3,1,1) 306 | #self.final_conv_bn = nn.BatchNorm2d(512) 307 | for task in tasks: 308 | if task == 'segment_semantic': 309 | output_channels = 18 310 | if task == 'depth_zbuffer': 311 | output_channels = 1 312 | if task == 'normal': 313 | output_channels = 3 314 | if task == 'edge_occlusion': 315 | output_channels = 1 316 | if task == 'reshading': 317 | output_channels = 3 318 | if task == 'keypoints2d': 319 | output_channels = 1 320 | if task == 'edge_texture': 321 | output_channels = 1 322 | if size==1: 323 | decoder=Decoder(output_channels) 324 | elif size==2: 325 | decoder=Decoder(output_channels,base_match=720) 326 | elif size==3: 327 | decoder=Decoder(output_channels,base_match=880) 328 | elif size==0.5: 329 | decoder=Decoder(output_channels,base_match=360) 330 | self.task_to_decoder[task]=decoder 331 | else: 332 | self.task_to_decoder['classification']=Decoder(output_channels=0,num_classes=1000) 333 | 334 | self.decoders = nn.ModuleList(self.task_to_decoder.values()) 335 | 336 | #------- init weights -------- 337 | for m in self.modules(): 338 | if isinstance(m, nn.Conv2d): 339 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 340 | m.weight.data.normal_(0, math.sqrt(2. / n)) 341 | elif isinstance(m, nn.BatchNorm2d): 342 | m.weight.data.fill_(1) 343 | m.bias.data.zero_() 344 | #----------------------------- 345 | def forward(self, input): 346 | rep = self.encoder(input) 347 | 348 | 349 | if self.tasks is None: 350 | return self.decoders[0](rep) 351 | 352 | #rep = self.final_conv(rep) 353 | #rep = self.final_conv_bn(rep) 354 | 355 | outputs={'rep':rep} 356 | if self.ozan: 357 | OzanRepFunction.n=len(self.decoders) 358 | rep = ozan_rep_function(rep) 359 | for i,(task,decoder) in enumerate(zip(self.task_to_decoder.keys(),self.decoders)): 360 | outputs[task]=decoder(rep[i]) 361 | else: 362 | TrevorRepFunction.n=len(self.decoders) 363 | rep = trevor_rep_function(rep) 364 | for i,(task,decoder) in enumerate(zip(self.task_to_decoder.keys(),self.decoders)): 365 | outputs[task]=decoder(rep) 366 | 367 | return outputs 368 | 369 | def _resnet(arch, block, layers, pretrained, **kwargs): 370 | model = ResNet(block=block, layers=layers, **kwargs) 371 | # if pretrained: 372 | # state_dict = load_state_dict_from_url(model_urls[arch], 373 | # progress=progress) 374 | # model.load_state_dict(state_dict) 375 | return model 376 | 377 | 378 | def resnet18_taskonomy(pretrained=False, **kwargs): 379 | """Constructs a ResNet-18 model. 380 | 381 | Args: 382 | pretrained (bool): If True, returns a model pre-trained on ImageNet 383 | progress (bool): If True, displays a progress bar of the download to stderr 384 | """ 385 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, 386 | **kwargs) 387 | 388 | def resnet18_taskonomy_tripple(pretrained=False, **kwargs): 389 | """Constructs a ResNet-18 model. 390 | 391 | Args: 392 | pretrained (bool): If True, returns a model pre-trained on ImageNet 393 | progress (bool): If True, displays a progress bar of the download to stderr 394 | """ 395 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained,size=3, 396 | **kwargs) 397 | 398 | def resnet18_taskonomy_half(pretrained=False, **kwargs): 399 | """Constructs a ResNet-18 model. 400 | 401 | Args: 402 | pretrained (bool): If True, returns a model pre-trained on ImageNet 403 | progress (bool): If True, displays a progress bar of the download to stderr 404 | """ 405 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained,size=0.5, 406 | **kwargs) 407 | 408 | 409 | def resnet34_taskonomy(pretrained=False, **kwargs): 410 | """Constructs a ResNet-34 model. 411 | 412 | Args: 413 | pretrained (bool): If True, returns a model pre-trained on ImageNet 414 | progress (bool): If True, displays a progress bar of the download to stderr 415 | """ 416 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, 417 | **kwargs) 418 | 419 | 420 | def resnet50_taskonomy(pretrained=False, **kwargs): 421 | """Constructs a ResNet-50 model. 422 | 423 | Args: 424 | pretrained (bool): If True, returns a model pre-trained on ImageNet 425 | progress (bool): If True, displays a progress bar of the download to stderr 426 | """ 427 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, 428 | **kwargs) 429 | 430 | 431 | def resnet101_taskonomy(pretrained=False, **kwargs): 432 | """Constructs a ResNet-101 model. 433 | 434 | Args: 435 | pretrained (bool): If True, returns a model pre-trained on ImageNet 436 | progress (bool): If True, displays a progress bar of the download to stderr 437 | """ 438 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, 439 | **kwargs) 440 | 441 | 442 | def resnet152_taskonomy(pretrained=False, **kwargs): 443 | """Constructs a ResNet-152 model. 444 | 445 | Args: 446 | pretrained (bool): If True, returns a model pre-trained on ImageNet 447 | progress (bool): If True, displays a progress bar of the download to stderr 448 | """ 449 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, 450 | **kwargs) 451 | -------------------------------------------------------------------------------- /model_definitions/xception_taskonomy_joined_decoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Creates an Xception Model as defined in: 3 | 4 | Francois Chollet 5 | Xception: Deep Learning with Depthwise Separable Convolutions 6 | https://arxiv.org/pdf/1610.02357.pdf 7 | 8 | This weights ported from the Keras implementation. Achieves the following performance on the validation set: 9 | 10 | Loss:0.9173 Prec@1:78.892 Prec@5:94.292 11 | 12 | REMEMBER to set your image size to 3x299x299 for both test and validation 13 | 14 | normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], 15 | std=[0.5, 0.5, 0.5]) 16 | 17 | The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299 18 | """ 19 | import math 20 | import torch.nn as nn 21 | import torch.nn.functional as F 22 | import torch.utils.model_zoo as model_zoo 23 | from torch.nn import init 24 | import torch 25 | from .ozan_rep_fun import ozan_rep_function,trevor_rep_function,OzanRepFunction,TrevorRepFunction 26 | 27 | __all__ = ['xception_taskonomy_joined_decoder','xception_taskonomy_joined_decoder_fifth','xception_taskonomy_joined_decoder_quad','xception_taskonomy_joined_decoder_half','xception_taskonomy_joined_decoder_80','xception_taskonomy_joined_decoder_ozan'] 28 | 29 | # model_urls = { 30 | # 'xception_taskonomy':'file:///home/tstand/Dropbox/taskonomy/xception_taskonomy-a4b32ef7.pth.tar' 31 | # } 32 | 33 | 34 | class SeparableConv2d(nn.Module): 35 | def __init__(self,in_channels,out_channels,kernel_size=1,stride=1,padding=0,dilation=1,bias=False,groupsize=1): 36 | super(SeparableConv2d,self).__init__() 37 | 38 | self.conv1 = nn.Conv2d(in_channels,in_channels,kernel_size,stride,padding,dilation,groups=max(1,in_channels//groupsize),bias=bias) 39 | self.pointwise = nn.Conv2d(in_channels,out_channels,1,1,0,1,1,bias=bias) 40 | #self.conv1=nn.Conv2d(in_channels,out_channels,kernel_size,stride,padding,dilation,bias=bias) 41 | #self.pointwise=lambda x:x 42 | 43 | def forward(self,x): 44 | x = self.conv1(x) 45 | x = self.pointwise(x) 46 | return x 47 | 48 | 49 | class Block(nn.Module): 50 | def __init__(self,in_filters,out_filters,reps,strides=1,start_with_relu=True,grow_first=True): 51 | super(Block, self).__init__() 52 | 53 | if out_filters != in_filters or strides!=1: 54 | self.skip = nn.Conv2d(in_filters,out_filters,1,stride=strides, bias=False) 55 | self.skipbn = nn.BatchNorm2d(out_filters) 56 | else: 57 | self.skip=None 58 | 59 | self.relu = nn.ReLU(inplace=True) 60 | rep=[] 61 | 62 | filters=in_filters 63 | if grow_first: 64 | rep.append(self.relu) 65 | rep.append(SeparableConv2d(in_filters,out_filters,3,stride=1,padding=1,bias=False)) 66 | rep.append(nn.BatchNorm2d(out_filters)) 67 | filters = out_filters 68 | 69 | for i in range(reps-1): 70 | rep.append(self.relu) 71 | rep.append(SeparableConv2d(filters,filters,3,stride=1,padding=1,bias=False)) 72 | rep.append(nn.BatchNorm2d(filters)) 73 | 74 | if not grow_first: 75 | rep.append(self.relu) 76 | rep.append(SeparableConv2d(in_filters,out_filters,3,stride=1,padding=1,bias=False)) 77 | rep.append(nn.BatchNorm2d(out_filters)) 78 | filters=out_filters 79 | 80 | if not start_with_relu: 81 | rep = rep[1:] 82 | else: 83 | rep[0] = nn.ReLU(inplace=False) 84 | 85 | if strides != 1: 86 | #rep.append(nn.AvgPool2d(3,strides,1)) 87 | rep.append(nn.Conv2d(filters,filters,2,2)) 88 | self.rep = nn.Sequential(*rep) 89 | 90 | def forward(self,inp): 91 | x = self.rep(inp) 92 | 93 | if self.skip is not None: 94 | skip = self.skip(inp) 95 | skip = self.skipbn(skip) 96 | else: 97 | skip = inp 98 | x+=skip 99 | return x 100 | 101 | class Encoder(nn.Module): 102 | def __init__(self, sizes=[32,64,128,256,728,728,728,728,728,728,728,728,728]): 103 | super(Encoder, self).__init__() 104 | self.conv1 = nn.Conv2d(3, sizes[0], 3,2, 1, bias=False) 105 | self.bn1 = nn.BatchNorm2d(sizes[0]) 106 | self.relu = nn.ReLU(inplace=True) 107 | self.relu2 = nn.ReLU(inplace=False) 108 | 109 | self.conv2 = nn.Conv2d(sizes[0],sizes[1],3,1,1,bias=False) 110 | self.bn2 = nn.BatchNorm2d(sizes[1]) 111 | #do relu here 112 | 113 | self.block1=Block(sizes[1],sizes[2],2,2,start_with_relu=False,grow_first=True) 114 | self.block2=Block(sizes[2],sizes[3],2,2,start_with_relu=True,grow_first=True) 115 | self.block3=Block(sizes[3],sizes[4],2,2,start_with_relu=True,grow_first=True) 116 | 117 | self.block4=Block(sizes[4],sizes[5],3,1,start_with_relu=True,grow_first=True) 118 | self.block5=Block(sizes[5],sizes[6],3,1,start_with_relu=True,grow_first=True) 119 | self.block6=Block(sizes[6],sizes[7],3,1,start_with_relu=True,grow_first=True) 120 | self.block7=Block(sizes[7],sizes[8],3,1,start_with_relu=True,grow_first=True) 121 | 122 | self.block8=Block(sizes[8],sizes[9],3,1,start_with_relu=True,grow_first=True) 123 | self.block9=Block(sizes[9],sizes[10],3,1,start_with_relu=True,grow_first=True) 124 | self.block10=Block(sizes[10],sizes[11],3,1,start_with_relu=True,grow_first=True) 125 | self.block11=Block(sizes[11],sizes[12],3,1,start_with_relu=True,grow_first=True) 126 | 127 | #self.block12=Block(728,1024,2,2,start_with_relu=True,grow_first=False) 128 | 129 | #self.conv3 = SeparableConv2d(768,512,3,1,1) 130 | #self.bn3 = nn.BatchNorm2d(512) 131 | #self.conv3 = SeparableConv2d(1024,1536,3,1,1) 132 | #self.bn3 = nn.BatchNorm2d(1536) 133 | 134 | #do relu here 135 | #self.conv4 = SeparableConv2d(1536,2048,3,1,1) 136 | #self.bn4 = nn.BatchNorm2d(2048) 137 | def forward(self,input): 138 | 139 | x = self.conv1(input) 140 | x = self.bn1(x) 141 | x = self.relu(x) 142 | 143 | x = self.conv2(x) 144 | x = self.bn2(x) 145 | x = self.relu(x) 146 | 147 | x = self.block1(x) 148 | x = self.block2(x) 149 | x = self.block3(x) 150 | x = self.block4(x) 151 | x = self.block5(x) 152 | x = self.block6(x) 153 | x = self.block7(x) 154 | x = self.block8(x) 155 | x = self.block9(x) 156 | x = self.block10(x) 157 | x = self.block11(x) 158 | #x = self.block12(x) 159 | 160 | #x = self.conv3(x) 161 | #x = self.bn3(x) 162 | #x = self.relu(x) 163 | 164 | 165 | #x = self.conv4(x) 166 | #x = self.bn4(x) 167 | 168 | representation = self.relu2(x) 169 | 170 | return representation 171 | 172 | 173 | 174 | def interpolate(inp,size): 175 | t = inp.type() 176 | inp = inp.float() 177 | out = nn.functional.interpolate(inp,size=size,mode='bilinear',align_corners=False) 178 | if out.type()!=t: 179 | out = out.half() 180 | return out 181 | 182 | 183 | 184 | class Decoder(nn.Module): 185 | def __init__(self, output_channels=32,num_classes=None): 186 | super(Decoder, self).__init__() 187 | 188 | self.output_channels = output_channels 189 | self.num_classes = num_classes 190 | 191 | self.relu = nn.ReLU(inplace=True) 192 | if num_classes is not None: 193 | self.block12=Block(728,1024,2,2,start_with_relu=True,grow_first=False) 194 | 195 | self.conv3 = SeparableConv2d(1024,1536,3,1,1) 196 | self.bn3 = nn.BatchNorm2d(1536) 197 | 198 | #do relu here 199 | self.conv4 = SeparableConv2d(1536,2048,3,1,1) 200 | self.bn4 = nn.BatchNorm2d(2048) 201 | 202 | self.fc = nn.Linear(2048, num_classes) 203 | else: 204 | self.upconv1 = nn.ConvTranspose2d(512,128,2,2) 205 | self.bn_upconv1 = nn.BatchNorm2d(128) 206 | self.conv_decode1 = nn.Conv2d(128, 128, 3,padding=1) 207 | self.bn_decode1 = nn.BatchNorm2d(128) 208 | self.upconv2 = nn.ConvTranspose2d(128,64,2,2) 209 | self.bn_upconv2 = nn.BatchNorm2d(64) 210 | self.conv_decode2 = nn.Conv2d(64, 64, 3,padding=1) 211 | self.bn_decode2 = nn.BatchNorm2d(64) 212 | self.upconv3 = nn.ConvTranspose2d(64,48,2,2) 213 | self.bn_upconv3 = nn.BatchNorm2d(48) 214 | self.conv_decode3 = nn.Conv2d(48, 48, 3,padding=1) 215 | self.bn_decode3 = nn.BatchNorm2d(48) 216 | self.upconv4 = nn.ConvTranspose2d(48,32,2,2) 217 | self.bn_upconv4 = nn.BatchNorm2d(32) 218 | self.conv_decode4 = nn.Conv2d(32, output_channels, 3,padding=1) 219 | 220 | 221 | 222 | def forward(self,representation): 223 | if self.num_classes is None: 224 | x = self.upconv1(representation) 225 | x = self.bn_upconv1(x) 226 | x = self.relu(x) 227 | x = self.conv_decode1(x) 228 | x = self.bn_decode1(x) 229 | x = self.relu(x) 230 | x = self.upconv2(x) 231 | x = self.bn_upconv2(x) 232 | x = self.relu(x) 233 | x = self.conv_decode2(x) 234 | 235 | x = self.bn_decode2(x) 236 | x = self.relu(x) 237 | x = self.upconv3(x) 238 | x = self.bn_upconv3(x) 239 | x = self.relu(x) 240 | x = self.conv_decode3(x) 241 | x = self.bn_decode3(x) 242 | x = self.relu(x) 243 | x = self.upconv4(x) 244 | x = self.bn_upconv4(x) 245 | x = self.relu(x) 246 | x = self.conv_decode4(x) 247 | 248 | else: 249 | x = self.block12(representation) 250 | 251 | x = self.conv3(x) 252 | x = self.bn3(x) 253 | x = self.relu(x) 254 | 255 | x = self.conv4(x) 256 | x = self.bn4(x) 257 | x = self.relu(x) 258 | 259 | x = F.adaptive_avg_pool2d(x, (1, 1)) 260 | x = x.view(x.size(0), -1) 261 | x = self.fc(x) 262 | return x 263 | 264 | 265 | 266 | 267 | class XceptionTaskonomy(nn.Module): 268 | """ 269 | Xception optimized for the ImageNet dataset, as specified in 270 | https://arxiv.org/pdf/1610.02357.pdf 271 | """ 272 | def __init__(self,size=1, tasks=None,num_classes=None, ozan=False): 273 | """ Constructor 274 | Args: 275 | num_classes: number of classes 276 | """ 277 | super(XceptionTaskonomy, self).__init__() 278 | pre_rep_size=728 279 | sizes=[32,64,128,256,728,728,728,728,728,728,728,728,728] 280 | if size == 1: 281 | sizes=[32,64,128,256,728,728,728,728,728,728,728,728,728] 282 | elif size==.2: 283 | sizes=[32,64,128,256,728,728,728,728,728,728,728,728,728] 284 | elif size==.3: 285 | sizes=[32,64,128,256,728,728,728,728,728,728,728,728,728] 286 | elif size==.4: 287 | sizes=[32,64,128,256,728,728,728,728,728,728,728,728,728] 288 | elif size==.5: 289 | sizes=[24,48,96,192,512,512,512,512,512,512,512,512,512] 290 | elif size==.8: 291 | sizes=[32,64,128,248,648,648,648,648,648,648,648,648,648] 292 | elif size==2: 293 | sizes=[32,64, 128,256, 728, 728, 728, 728, 728, 728, 728, 728, 728] 294 | elif size==4: 295 | sizes=[64,128,256,512,1456,1456,1456,1456,1456,1456,1456,1456,1456] 296 | 297 | 298 | self.encoder=Encoder(sizes=sizes) 299 | pre_rep_size=sizes[-1] 300 | 301 | self.tasks=tasks 302 | self.ozan=ozan 303 | self.task_to_decoder = {} 304 | 305 | 306 | 307 | if tasks is not None: 308 | 309 | self.final_conv = SeparableConv2d(pre_rep_size,512,3,1,1) 310 | self.final_conv_bn = nn.BatchNorm2d(512) 311 | output_channels=0 312 | self.channels_per_task = {'segment_semantic':18, 313 | 'depth_zbuffer':1, 314 | 'normal':3, 315 | 'edge_occlusion':1, 316 | 'reshading':3, 317 | 'keypoints2d':1, 318 | 'edge_texture':1, 319 | } 320 | for task in tasks: 321 | output_channels+=self.channels_per_task[task] 322 | self.decoder=Decoder(output_channels) 323 | 324 | else: 325 | self.decoder=Decoder(output_channels=0,num_classes=1000) 326 | 327 | 328 | #------- init weights -------- 329 | for m in self.modules(): 330 | if isinstance(m, nn.Conv2d): 331 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 332 | m.weight.data.normal_(0, math.sqrt(2. / n)) 333 | elif isinstance(m, nn.BatchNorm2d): 334 | m.weight.data.fill_(1) 335 | m.bias.data.zero_() 336 | #----------------------------- 337 | 338 | 339 | def forward(self, input): 340 | rep = self.encoder(input) 341 | 342 | 343 | if self.tasks is None: 344 | return self.decoder(rep) 345 | 346 | rep = self.final_conv(rep) 347 | rep = self.final_conv_bn(rep) 348 | 349 | outputs = {} 350 | raw_output=self.decoder(rep) 351 | 352 | range_start = 0 353 | #print(raw_output.shape) 354 | for task in self.tasks: 355 | outputs[task]=raw_output[:,range_start:range_start+self.channels_per_task[task],:,:] 356 | range_start+=self.channels_per_task[task] 357 | 358 | return outputs 359 | 360 | 361 | 362 | def xception_taskonomy_joined_decoder(**kwargs): 363 | """ 364 | Construct Xception. 365 | """ 366 | 367 | model = XceptionTaskonomy(**kwargs,size=1) 368 | 369 | return model 370 | 371 | def xception_taskonomy_joined_decoder_fifth(**kwargs): 372 | """ 373 | Construct Xception. 374 | """ 375 | 376 | model = XceptionTaskonomy(**kwargs,size=.2) 377 | 378 | return model 379 | 380 | def xception_taskonomy_joined_decoder_quad(**kwargs): 381 | """ 382 | Construct Xception. 383 | """ 384 | 385 | model = XceptionTaskonomy(**kwargs,size=4) 386 | 387 | return model 388 | 389 | def xception_taskonomy_joined_decoder_half(**kwargs): 390 | """ 391 | Construct Xception. 392 | """ 393 | 394 | model = XceptionTaskonomy(**kwargs,size=.5) 395 | 396 | return model 397 | 398 | def xception_taskonomy_joined_decoder_80(**kwargs): 399 | """ 400 | Construct Xception. 401 | """ 402 | 403 | model = XceptionTaskonomy(**kwargs,size=.8) 404 | 405 | return model 406 | 407 | def xception_taskonomy_joined_decoder_ozan(**kwargs): 408 | """ 409 | Construct Xception. 410 | """ 411 | 412 | model = XceptionTaskonomy(ozan=True,**kwargs) 413 | 414 | return model 415 | -------------------------------------------------------------------------------- /model_definitions/xception_taskonomy_new.py: -------------------------------------------------------------------------------- 1 | """ 2 | """ 3 | import math 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.utils.model_zoo as model_zoo 7 | from torch.nn import init 8 | import torch 9 | from .ozan_rep_fun import ozan_rep_function,trevor_rep_function,OzanRepFunction,TrevorRepFunction 10 | 11 | __all__ = ['xception_taskonomy_new','xception_taskonomy_new_fifth','xception_taskonomy_new_quad','xception_taskonomy_new_half','xception_taskonomy_new_80','xception_taskonomy_ozan'] 12 | 13 | # model_urls = { 14 | # 'xception_taskonomy':'file:///home/tstand/Dropbox/taskonomy/xception_taskonomy-a4b32ef7.pth.tar' 15 | # } 16 | 17 | 18 | class SeparableConv2d(nn.Module): 19 | def __init__(self,in_channels,out_channels,kernel_size=1,stride=1,padding=0,dilation=1,bias=False,groupsize=1): 20 | super(SeparableConv2d,self).__init__() 21 | 22 | self.conv1 = nn.Conv2d(in_channels,in_channels,kernel_size,stride,padding,dilation,groups=max(1,in_channels//groupsize),bias=bias) 23 | self.pointwise = nn.Conv2d(in_channels,out_channels,1,1,0,1,1,bias=bias) 24 | #self.conv1=nn.Conv2d(in_channels,out_channels,kernel_size,stride,padding,dilation,bias=bias) 25 | #self.pointwise=lambda x:x 26 | 27 | def forward(self,x): 28 | x = self.conv1(x) 29 | x = self.pointwise(x) 30 | return x 31 | 32 | 33 | class Block(nn.Module): 34 | def __init__(self,in_filters,out_filters,reps,strides=1,start_with_relu=True,grow_first=True): 35 | super(Block, self).__init__() 36 | 37 | if out_filters != in_filters or strides!=1: 38 | self.skip = nn.Conv2d(in_filters,out_filters,1,stride=strides, bias=False) 39 | self.skipbn = nn.BatchNorm2d(out_filters) 40 | else: 41 | self.skip=None 42 | 43 | self.relu = nn.ReLU(inplace=True) 44 | rep=[] 45 | 46 | filters=in_filters 47 | if grow_first: 48 | rep.append(self.relu) 49 | rep.append(SeparableConv2d(in_filters,out_filters,3,stride=1,padding=1,bias=False)) 50 | rep.append(nn.BatchNorm2d(out_filters)) 51 | filters = out_filters 52 | 53 | for i in range(reps-1): 54 | rep.append(self.relu) 55 | rep.append(SeparableConv2d(filters,filters,3,stride=1,padding=1,bias=False)) 56 | rep.append(nn.BatchNorm2d(filters)) 57 | 58 | if not grow_first: 59 | rep.append(self.relu) 60 | rep.append(SeparableConv2d(in_filters,out_filters,3,stride=1,padding=1,bias=False)) 61 | rep.append(nn.BatchNorm2d(out_filters)) 62 | filters=out_filters 63 | 64 | if not start_with_relu: 65 | rep = rep[1:] 66 | else: 67 | rep[0] = nn.ReLU(inplace=False) 68 | 69 | if strides != 1: 70 | #rep.append(nn.AvgPool2d(3,strides,1)) 71 | rep.append(nn.Conv2d(filters,filters,2,2)) 72 | self.rep = nn.Sequential(*rep) 73 | 74 | def forward(self,inp): 75 | x = self.rep(inp) 76 | 77 | if self.skip is not None: 78 | skip = self.skip(inp) 79 | skip = self.skipbn(skip) 80 | else: 81 | skip = inp 82 | x+=skip 83 | return x 84 | 85 | class Encoder(nn.Module): 86 | def __init__(self, sizes=[32,64,128,256,728,728,728,728,728,728,728,728,728]): 87 | super(Encoder, self).__init__() 88 | self.conv1 = nn.Conv2d(3, sizes[0], 3,2, 1, bias=False) 89 | self.bn1 = nn.BatchNorm2d(sizes[0]) 90 | self.relu = nn.ReLU(inplace=True) 91 | self.relu2 = nn.ReLU(inplace=False) 92 | 93 | self.conv2 = nn.Conv2d(sizes[0],sizes[1],3,1,1,bias=False) 94 | self.bn2 = nn.BatchNorm2d(sizes[1]) 95 | #do relu here 96 | 97 | self.block1=Block(sizes[1],sizes[2],2,2,start_with_relu=False,grow_first=True) 98 | self.block2=Block(sizes[2],sizes[3],2,2,start_with_relu=True,grow_first=True) 99 | self.block3=Block(sizes[3],sizes[4],2,2,start_with_relu=True,grow_first=True) 100 | 101 | self.block4=Block(sizes[4],sizes[5],3,1,start_with_relu=True,grow_first=True) 102 | self.block5=Block(sizes[5],sizes[6],3,1,start_with_relu=True,grow_first=True) 103 | self.block6=Block(sizes[6],sizes[7],3,1,start_with_relu=True,grow_first=True) 104 | self.block7=Block(sizes[7],sizes[8],3,1,start_with_relu=True,grow_first=True) 105 | 106 | self.block8=Block(sizes[8],sizes[9],3,1,start_with_relu=True,grow_first=True) 107 | self.block9=Block(sizes[9],sizes[10],3,1,start_with_relu=True,grow_first=True) 108 | self.block10=Block(sizes[10],sizes[11],3,1,start_with_relu=True,grow_first=True) 109 | self.block11=Block(sizes[11],sizes[12],3,1,start_with_relu=True,grow_first=True) 110 | 111 | #self.block12=Block(728,1024,2,2,start_with_relu=True,grow_first=False) 112 | 113 | #self.conv3 = SeparableConv2d(768,512,3,1,1) 114 | #self.bn3 = nn.BatchNorm2d(512) 115 | #self.conv3 = SeparableConv2d(1024,1536,3,1,1) 116 | #self.bn3 = nn.BatchNorm2d(1536) 117 | 118 | #do relu here 119 | #self.conv4 = SeparableConv2d(1536,2048,3,1,1) 120 | #self.bn4 = nn.BatchNorm2d(2048) 121 | def forward(self,input): 122 | 123 | x = self.conv1(input) 124 | x = self.bn1(x) 125 | x = self.relu(x) 126 | 127 | x = self.conv2(x) 128 | x = self.bn2(x) 129 | x = self.relu(x) 130 | 131 | x = self.block1(x) 132 | x = self.block2(x) 133 | x = self.block3(x) 134 | x = self.block4(x) 135 | x = self.block5(x) 136 | x = self.block6(x) 137 | x = self.block7(x) 138 | x = self.block8(x) 139 | x = self.block9(x) 140 | x = self.block10(x) 141 | x = self.block11(x) 142 | #x = self.block12(x) 143 | 144 | #x = self.conv3(x) 145 | #x = self.bn3(x) 146 | #x = self.relu(x) 147 | 148 | 149 | #x = self.conv4(x) 150 | #x = self.bn4(x) 151 | 152 | representation = self.relu2(x) 153 | 154 | return representation 155 | 156 | 157 | 158 | def interpolate(inp,size): 159 | t = inp.type() 160 | inp = inp.float() 161 | out = nn.functional.interpolate(inp,size=size,mode='bilinear',align_corners=False) 162 | if out.type()!=t: 163 | out = out.half() 164 | return out 165 | 166 | 167 | 168 | class Decoder(nn.Module): 169 | def __init__(self, output_channels=32,num_classes=None,half_sized_output=False,small_decoder=True): 170 | super(Decoder, self).__init__() 171 | 172 | self.output_channels = output_channels 173 | self.num_classes = num_classes 174 | self.half_sized_output=half_sized_output 175 | self.relu = nn.ReLU(inplace=True) 176 | if num_classes is not None: 177 | self.block12=Block(728,1024,2,2,start_with_relu=True,grow_first=False) 178 | 179 | self.conv3 = SeparableConv2d(1024,1536,3,1,1) 180 | self.bn3 = nn.BatchNorm2d(1536) 181 | 182 | #do relu here 183 | self.conv4 = SeparableConv2d(1536,2048,3,1,1) 184 | self.bn4 = nn.BatchNorm2d(2048) 185 | 186 | self.fc = nn.Linear(2048, num_classes) 187 | else: 188 | if small_decoder: 189 | self.upconv1 = nn.ConvTranspose2d(512,128,2,2) 190 | self.bn_upconv1 = nn.BatchNorm2d(128) 191 | self.conv_decode1 = nn.Conv2d(128, 128, 3,padding=1) 192 | self.bn_decode1 = nn.BatchNorm2d(128) 193 | self.upconv2 = nn.ConvTranspose2d(128,64,2,2) 194 | self.bn_upconv2 = nn.BatchNorm2d(64) 195 | self.conv_decode2 = nn.Conv2d(64, 64, 3,padding=1) 196 | self.bn_decode2 = nn.BatchNorm2d(64) 197 | self.upconv3 = nn.ConvTranspose2d(64,48,2,2) 198 | self.bn_upconv3 = nn.BatchNorm2d(48) 199 | self.conv_decode3 = nn.Conv2d(48, 48, 3,padding=1) 200 | self.bn_decode3 = nn.BatchNorm2d(48) 201 | if half_sized_output: 202 | self.upconv4 = nn.Identity() 203 | self.bn_upconv4 = nn.Identity() 204 | self.conv_decode4 = nn.Conv2d(48, output_channels, 3,padding=1) 205 | else: 206 | self.upconv4 = nn.ConvTranspose2d(48,32,2,2) 207 | self.bn_upconv4 = nn.BatchNorm2d(32) 208 | self.conv_decode4 = nn.Conv2d(32, output_channels, 3,padding=1) 209 | else: 210 | self.upconv1 = nn.ConvTranspose2d(512,256,2,2) 211 | self.bn_upconv1 = nn.BatchNorm2d(256) 212 | self.conv_decode1 = nn.Conv2d(256, 256, 3,padding=1) 213 | self.bn_decode1 = nn.BatchNorm2d(256) 214 | self.upconv2 = nn.ConvTranspose2d(256,128,2,2) 215 | self.bn_upconv2 = nn.BatchNorm2d(128) 216 | self.conv_decode2 = nn.Conv2d(128, 128, 3,padding=1) 217 | self.bn_decode2 = nn.BatchNorm2d(128) 218 | self.upconv3 = nn.ConvTranspose2d(128,96,2,2) 219 | self.bn_upconv3 = nn.BatchNorm2d(96) 220 | self.conv_decode3 = nn.Conv2d(96, 96, 3,padding=1) 221 | self.bn_decode3 = nn.BatchNorm2d(96) 222 | if half_sized_output: 223 | self.upconv4 = nn.Identity() 224 | self.bn_upconv4 = nn.Identity() 225 | self.conv_decode4 = nn.Conv2d(96, output_channels, 3,padding=1) 226 | else: 227 | self.upconv4 = nn.ConvTranspose2d(96,64,2,2) 228 | self.bn_upconv4 = nn.BatchNorm2d(64) 229 | self.conv_decode4 = nn.Conv2d(64, output_channels, 3,padding=1) 230 | 231 | 232 | 233 | 234 | def forward(self,representation): 235 | if self.num_classes is None: 236 | x = self.upconv1(representation) 237 | x = self.bn_upconv1(x) 238 | x = self.relu(x) 239 | x = self.conv_decode1(x) 240 | x = self.bn_decode1(x) 241 | x = self.relu(x) 242 | x = self.upconv2(x) 243 | x = self.bn_upconv2(x) 244 | x = self.relu(x) 245 | x = self.conv_decode2(x) 246 | 247 | x = self.bn_decode2(x) 248 | x = self.relu(x) 249 | x = self.upconv3(x) 250 | x = self.bn_upconv3(x) 251 | x = self.relu(x) 252 | x = self.conv_decode3(x) 253 | x = self.bn_decode3(x) 254 | x = self.relu(x) 255 | if not self.half_sized_output: 256 | x = self.upconv4(x) 257 | x = self.bn_upconv4(x) 258 | x = self.relu(x) 259 | x = self.conv_decode4(x) 260 | 261 | else: 262 | x = self.block12(representation) 263 | 264 | x = self.conv3(x) 265 | x = self.bn3(x) 266 | x = self.relu(x) 267 | 268 | x = self.conv4(x) 269 | x = self.bn4(x) 270 | x = self.relu(x) 271 | 272 | x = F.adaptive_avg_pool2d(x, (1, 1)) 273 | x = x.view(x.size(0), -1) 274 | x = self.fc(x) 275 | return x 276 | 277 | 278 | 279 | 280 | class XceptionTaskonomy(nn.Module): 281 | """ 282 | Xception optimized for the ImageNet dataset, as specified in 283 | https://arxiv.org/pdf/1610.02357.pdf 284 | """ 285 | def __init__(self,size=1, tasks=None,num_classes=None, ozan=False,half_sized_output=False): 286 | """ Constructor 287 | Args: 288 | num_classes: number of classes 289 | """ 290 | super(XceptionTaskonomy, self).__init__() 291 | pre_rep_size=728 292 | sizes=[32,64,128,256,728,728,728,728,728,728,728,728,728] 293 | if size == 1: 294 | sizes=[32,64,128,256,728,728,728,728,728,728,728,728,728] 295 | elif size==.2: 296 | sizes=[16,32,64,256,320,320,320,320,320,320,320,320,320] 297 | elif size==.3: 298 | sizes=[32,64,128,256,728,728,728,728,728,728,728,728,728] 299 | elif size==.4: 300 | sizes=[32,64,128,256,728,728,728,728,728,728,728,728,728] 301 | elif size==.5: 302 | sizes=[24,48,96,192,512,512,512,512,512,512,512,512,512] 303 | elif size==.8: 304 | sizes=[32,64,128,248,648,648,648,648,648,648,648,648,648] 305 | elif size==2: 306 | sizes=[32,64, 128,256, 728, 728, 728, 728, 728, 728, 728, 728, 728] 307 | elif size==4: 308 | sizes=[64,128,256,512,1456,1456,1456,1456,1456,1456,1456,1456,1456] 309 | 310 | 311 | self.encoder=Encoder(sizes=sizes) 312 | pre_rep_size=sizes[-1] 313 | 314 | self.tasks=tasks 315 | self.ozan=ozan 316 | self.task_to_decoder = {} 317 | 318 | 319 | 320 | if tasks is not None: 321 | 322 | self.final_conv = SeparableConv2d(pre_rep_size,512,3,1,1) 323 | self.final_conv_bn = nn.BatchNorm2d(512) 324 | for task in tasks: 325 | if task == 'segment_semantic': 326 | output_channels = 18 327 | if task == 'depth_zbuffer': 328 | output_channels = 1 329 | if task == 'normal': 330 | output_channels = 3 331 | if task == 'edge_occlusion': 332 | output_channels = 1 333 | if task == 'keypoints2d': 334 | output_channels = 1 335 | if task == 'edge_texture': 336 | output_channels = 1 337 | if task == 'reshading': 338 | output_channels = 1 339 | if task == 'rgb': 340 | output_channels = 3 341 | if task == 'principal_curvature': 342 | output_channels = 2 343 | decoder=Decoder(output_channels,half_sized_output=half_sized_output) 344 | self.task_to_decoder[task]=decoder 345 | else: 346 | self.task_to_decoder['classification']=Decoder(output_channels=0,num_classes=1000) 347 | 348 | self.decoders = nn.ModuleList(self.task_to_decoder.values()) 349 | 350 | #------- init weights -------- 351 | for m in self.modules(): 352 | if isinstance(m, nn.Conv2d): 353 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 354 | m.weight.data.normal_(0, math.sqrt(2. / n)) 355 | elif isinstance(m, nn.BatchNorm2d): 356 | m.weight.data.fill_(1) 357 | m.bias.data.zero_() 358 | #----------------------------- 359 | 360 | 361 | def forward(self, input): 362 | rep = self.encoder(input) 363 | 364 | 365 | if self.tasks is None: 366 | return self.decoders[0](rep) 367 | 368 | rep = self.final_conv(rep) 369 | rep = self.final_conv_bn(rep) 370 | 371 | outputs={'rep':rep} 372 | if self.ozan: 373 | OzanRepFunction.n=len(self.decoders) 374 | rep = ozan_rep_function(rep) 375 | for i,(task,decoder) in enumerate(zip(self.task_to_decoder.keys(),self.decoders)): 376 | outputs[task]=decoder(rep[i]) 377 | else: 378 | TrevorRepFunction.n=len(self.decoders) 379 | rep = trevor_rep_function(rep) 380 | for i,(task,decoder) in enumerate(zip(self.task_to_decoder.keys(),self.decoders)): 381 | outputs[task]=decoder(rep) 382 | 383 | return outputs 384 | 385 | 386 | 387 | def xception_taskonomy_new(**kwargs): 388 | """ 389 | Construct Xception. 390 | """ 391 | 392 | model = XceptionTaskonomy(**kwargs,size=1) 393 | 394 | return model 395 | 396 | def xception_taskonomy_new_fifth(**kwargs): 397 | """ 398 | Construct Xception. 399 | """ 400 | 401 | model = XceptionTaskonomy(**kwargs,size=.2) 402 | 403 | return model 404 | 405 | def xception_taskonomy_new_quad(**kwargs): 406 | """ 407 | Construct Xception. 408 | """ 409 | 410 | model = XceptionTaskonomy(**kwargs,size=4) 411 | 412 | return model 413 | 414 | def xception_taskonomy_new_half(**kwargs): 415 | """ 416 | Construct Xception. 417 | """ 418 | 419 | model = XceptionTaskonomy(**kwargs,size=.5) 420 | 421 | return model 422 | 423 | def xception_taskonomy_new_80(**kwargs): 424 | """ 425 | Construct Xception. 426 | """ 427 | 428 | model = XceptionTaskonomy(**kwargs,size=.8) 429 | 430 | return model 431 | 432 | def xception_taskonomy_ozan(**kwargs): 433 | """ 434 | Construct Xception. 435 | """ 436 | 437 | model = XceptionTaskonomy(ozan=True,**kwargs) 438 | 439 | return model 440 | -------------------------------------------------------------------------------- /network_selection/.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "g++ build and debug active file", 9 | "type": "cppdbg", 10 | "request": "launch", 11 | "program": "${fileDirname}/${fileBasenameNoExtension}", 12 | "args": [], 13 | "stopAtEntry": false, 14 | "cwd": "${workspaceFolder}", 15 | "environment": [], 16 | "externalConsole": false, 17 | "MIMode": "gdb", 18 | "setupCommands": [ 19 | { 20 | "description": "Enable pretty-printing for gdb", 21 | "text": "-enable-pretty-printing", 22 | "ignoreFailures": true 23 | } 24 | ], 25 | "preLaunchTask": "g++ build active file", 26 | "miDebuggerPath": "/usr/bin/gdb" 27 | } 28 | ] 29 | } -------------------------------------------------------------------------------- /network_selection/.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "files.associations": { 3 | "iosfwd": "cpp" 4 | } 5 | } -------------------------------------------------------------------------------- /network_selection/.vscode/tasks.json: -------------------------------------------------------------------------------- 1 | { 2 | "tasks": [ 3 | { 4 | "type": "shell", 5 | "label": "g++ build active file", 6 | "command": "/usr/bin/g++", 7 | "args": [ 8 | "-g", 9 | "${file}", 10 | "-o", 11 | "${fileDirname}/${fileBasenameNoExtension}" 12 | ], 13 | "options": { 14 | "cwd": "/usr/bin" 15 | } 16 | } 17 | ], 18 | "version": "2.0.0" 19 | } -------------------------------------------------------------------------------- /network_selection/Makefile: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tstandley/taskgrouping/bb1496e42ff442b7ac69e6f227060b8023325d07/network_selection/Makefile -------------------------------------------------------------------------------- /network_selection/a.out: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tstandley/taskgrouping/bb1496e42ff442b7ac69e6f227060b8023325d07/network_selection/a.out -------------------------------------------------------------------------------- /network_selection/main: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tstandley/taskgrouping/bb1496e42ff442b7ac69e6f227060b8023325d07/network_selection/main -------------------------------------------------------------------------------- /network_selection/main.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | using namespace std; 12 | const double inf=99; 13 | 14 | vector read_line(stringstream &ssline) { 15 | 16 | vector nline; 17 | while(!ssline.eof()) { 18 | char next = ssline.peek(); 19 | if(next == '\t') 20 | { 21 | nline.push_back(inf); 22 | ssline.get(); 23 | } 24 | else 25 | { 26 | double num; 27 | if(ssline >> num) { 28 | ssline.get(); 29 | nline.push_back(num); 30 | } 31 | else 32 | { 33 | nline.push_back(inf); 34 | } 35 | 36 | } 37 | } 38 | return nline; 39 | } 40 | class Model 41 | { 42 | public: 43 | Model(vector perf, double cost):performance(perf), cost(cost){} 44 | Model():cost(-1){} 45 | vector performance; 46 | double cost; 47 | bool match(const Model &other) const { 48 | for (int i = 0; i < performance.size();i++) 49 | if((performance[i] == inf && other.performance[i]!=inf)|| 50 | (performance[i] != inf && other.performance[i]==inf)) 51 | return false; 52 | return cost==other.cost; 53 | } 54 | void remove_dim(int which_dim) { 55 | performance.erase(performance.begin()+which_dim); 56 | } 57 | double total_score(){ 58 | double sum=0; 59 | for(auto i : performance) 60 | sum+=i; 61 | return sum; 62 | } 63 | 64 | int rank() const { 65 | int r=0; 66 | for (double i : performance) 67 | if (i < inf) 68 | r++; 69 | return r; 70 | } 71 | 72 | 73 | }; 74 | 75 | ostream & operator << (ostream & out, Model mod) { 76 | double sum=0; 77 | for(auto i : mod.performance) { 78 | out << i << ", "; 79 | sum+=i; 80 | } 81 | if(mod.cost==0) 82 | { 83 | out << "=" << sum; 84 | } 85 | else 86 | { 87 | out << mod.cost; 88 | } 89 | return out; 90 | } 91 | 92 | vector get_model_performances(string filename) { 93 | ifstream file(filename); 94 | 95 | string line; 96 | 97 | vector ret; 98 | while(std::getline(file,line)) 99 | { 100 | stringstream ssline(line); 101 | char next; 102 | if (ret.size()==0) { 103 | auto nline = read_line(ssline); 104 | 105 | for(int i = 0; i < nline.size();i++) 106 | { 107 | vector infs({inf,inf,inf,inf,inf}); 108 | infs[i]=nline[i]; 109 | ret.push_back(Model(infs,0.5)); 110 | } 111 | } else { 112 | auto nline = read_line(ssline); 113 | ret.push_back(Model(nline,1)); 114 | } 115 | } 116 | 117 | return ret; 118 | } 119 | 120 | vector score_solution (const vector &to_score,int size=-1) { 121 | 122 | if(size==-1) size = to_score[0].performance.size(); 123 | vector score(size,inf); 124 | 125 | for (auto mod : to_score) 126 | { 127 | for(int i = 0; i < mod.performance.size();i++) 128 | score[i]=min(score[i],mod.performance[i]); 129 | } 130 | return score; 131 | } 132 | 133 | bool dominates(const vector& first, const vector & second) { 134 | for (int i = 0; i < first.size();i++) 135 | if (first[i] > second[i]) 136 | return false; 137 | return true; 138 | } 139 | 140 | vector filter(const vector &candidates,const vector & best_score, double remaining_budget) { 141 | vector to_return; 142 | for(auto candidate : candidates) { 143 | if (candidate.cost <= remaining_budget && !dominates(best_score,candidate.performance) ) 144 | to_return.push_back(candidate); 145 | } 146 | return to_return; 147 | } 148 | 149 | vector better(const vector & a, const vector & b) { 150 | if (a.size()==0) 151 | return b; 152 | auto a_score = score_solution(a); 153 | auto b_score = score_solution(b); 154 | double a_value = 0; 155 | double b_value = 0; 156 | for(auto i : a_score)a_value+=i; 157 | for(auto i : b_score)b_value+=i; 158 | if(a_value < b_value) 159 | return a; 160 | return b; 161 | } 162 | 163 | double get_sorting_score(const vector& running_score,const Model & a) { 164 | double amax = -inf; 165 | for(int i = 0; i < running_score.size();i++) 166 | amax = max(amax,running_score[i]-a.performance[i]); 167 | return amax; 168 | } 169 | 170 | // double get_sorting_score(const vector& running_score,const Model & a) { 171 | // double amax = 0; 172 | // for(int i = 0; i < running_score.size();i++) 173 | // amax += min(running_score[i],a.performance[i]); 174 | // return amax; 175 | // } 176 | 177 | vector get_best_networks(const vector &candidates_in, vector running_solution, double remaining_budget) { 178 | 179 | if (candidates_in.size()==0) 180 | return running_solution; 181 | 182 | vector running_score = score_solution(running_solution,candidates_in[0].performance.size()); 183 | 184 | 185 | auto candidates = filter(candidates_in,running_score,remaining_budget); 186 | if (candidates.size()==0) 187 | return running_solution; 188 | std::sort(candidates.begin(),candidates.end(), 189 | [&running_score](const Model & a, const Model & b)-> bool{ 190 | return get_sorting_score(running_score,a) < get_sorting_score(running_score,b); 191 | } 192 | ); 193 | 194 | 195 | vector best_solution = running_solution; 196 | 197 | 198 | while(!candidates.empty()) { 199 | 200 | running_solution.push_back(candidates.back()); 201 | candidates.pop_back(); 202 | auto best_below = get_best_networks(candidates,running_solution,remaining_budget-running_solution.back().cost); 203 | 204 | running_solution.pop_back(); 205 | best_solution = better(best_solution,best_below); 206 | } 207 | return best_solution; 208 | 209 | } 210 | 211 | vector synthetic_performances(vectorinput) { 212 | 213 | int num_tasks=input[0].performance.size(); 214 | 215 | int random_dimension = rand()%num_tasks; 216 | vector valid_values; 217 | for(int i =0; i < input.size();i++) 218 | { 219 | input[i].performance.push_back(inf); 220 | double val = input[i].performance[random_dimension]; 221 | if(val < inf) 222 | valid_values.push_back(val); 223 | } 224 | int fixed_input_size=input.size(); 225 | for(int i =5; i < fixed_input_size;i++) 226 | { 227 | input.push_back(input[i]); 228 | input.back().performance.back()=valid_values[rand()%valid_values.size()]; 229 | for(int ii =0;ii < input.back().performance.size();ii++){ 230 | input.back().performance[ii]*=1.05; 231 | } 232 | } 233 | return input;} 234 | 235 | 236 | vector translate_scores_to_test_set(const vector &solution,const vector &test_performances){ 237 | vector ret; 238 | for (Model mod : solution) { 239 | for(Model mod2 : test_performances) { 240 | if (mod.match(mod2)){ 241 | ret.push_back(mod2); 242 | //break; 243 | } 244 | } 245 | } 246 | return ret; 247 | } 248 | 249 | vector remove_task(vector performances,int task) { 250 | vector ret; 251 | 252 | for(auto i : performances) { 253 | if (i.performance[task]==inf) { 254 | i.remove_dim(task); 255 | ret.push_back(i); 256 | } 257 | } 258 | 259 | return ret; 260 | } 261 | 262 | bool subset(Model &a, Model&b) { 263 | if (a.performance.size() != b.performance.size()) 264 | return false; 265 | if (a.rank() > b.rank()) 266 | return false; 267 | for (int i=0;i higher_order_approximation(vector performances) { 276 | 277 | 278 | vector pairs_models; 279 | for (auto a : performances) 280 | if(a.rank() ==2) 281 | pairs_models.push_back(a); 282 | 283 | vector new_models; 284 | for (auto a : performances) 285 | if(a.rank() <=2) 286 | new_models.push_back(a); 287 | else { 288 | Model new_model({0,0,0,0,0},1); 289 | vector count({0,0,0,0,0}); 290 | for (auto pair_model:pairs_models) 291 | if (subset(pair_model,a)) { 292 | for(int i = 0; i < pair_model.performance.size();i++) 293 | if (pair_model.performance[i]!=inf){ 294 | new_model.performance[i]+=pair_model.performance[i]; 295 | count[i]++; 296 | } 297 | } 298 | for(int i=0;i just_pairs(vector performances) { 313 | vector new_models; 314 | for (auto a : performances) 315 | if(a.rank() <=2 || a.rank()==5) 316 | new_models.push_back(a); 317 | 318 | 319 | return new_models; 320 | } 321 | 322 | vector get_mins(vector performances) { 323 | vector mins({inf,inf,inf,inf,inf}); 324 | for (auto a : performances){ 325 | for(int i =0;i a.performance[i]) mins[i]= a.performance[i]; 328 | } 329 | 330 | } 331 | } 332 | return mins; 333 | } 334 | 335 | vector get_maxes(vector performances) { 336 | vector maxes({0,0,0,0,0}); 337 | for (auto a : performances){ 338 | for(int i =0;i scale_values(vector performances) { 350 | 351 | auto mins = get_mins(performances); 352 | auto maxes = get_maxes(performances); 353 | 354 | for(int i = 0; i < maxes.size();i++) { 355 | cout << maxes[i] << ' '; 356 | } 357 | cout << endl; 358 | 359 | for(int i = 0; i < maxes.size();i++) { 360 | cout << mins[i] << ' '; 361 | } 362 | cout << endl; 363 | 364 | vector scaled_models; 365 | for (auto a : performances){ 366 | for(int i =0;i compute_and_print(vector &performances, vector & performances_test_set, bool print=false) { 376 | if(print) 377 | cout << "number_of_models= "<< performances.size()<< endl; 378 | 379 | vector perfs; 380 | for(double budget = 1; budget <= performances[0].performance.size();budget+=.5){ 381 | auto solutiona = get_best_networks(performances,vector(),budget); 382 | 383 | auto solution = translate_scores_to_test_set(solutiona,performances_test_set); 384 | //auto solution=solutiona; 385 | 386 | Model sol(score_solution(solution,solution[0].performance.size()),0); 387 | if (print){ 388 | for(auto mod:solution) { 389 | cout << mod << endl; 390 | } 391 | cout << "budget=" << budget << " " << sol << endl; 392 | } 393 | perfs.push_back(sol.total_score()); 394 | } 395 | if (print) 396 | for(auto p:perfs) 397 | cout << p << endl; 398 | return perfs; 399 | } 400 | 401 | 402 | double do_one_random(vector test_perfs, double budget){ 403 | 404 | while(true) { 405 | double remaining_budget=budget; 406 | 407 | vector random_sol; 408 | vector used(test_perfs.size(), false); 409 | int used_count=0; 410 | while(true) 411 | { 412 | int index; 413 | do { 414 | index=rand()%test_perfs.size(); 415 | } while(used[index]); 416 | used[index]=true; 417 | used_count++; 418 | 419 | if(used_count == test_perfs.size()) { 420 | used_count=0; 421 | for(int i =0;i 0.50279) 441 | // cout << total_score << ' ' << random_sol.size() << endl; 442 | // } 443 | 444 | return total_score; 445 | } 446 | 447 | 448 | } 449 | 450 | } 451 | 452 | 453 | 454 | // vector do_random(vector test_perfs) { 455 | 456 | // vector totals; 457 | // int num_times=3000; 458 | // for(int i=0;i do_random(vector test_perfs) { 474 | 475 | vector totals({0,0,0,0,0,0,0,0,0}); 476 | int num_times=1000000; 477 | for(int i=0;i do_worst(vector test_perfs) { 494 | 495 | vector totals({0,0,0,0,0,0,0,0,0}); 496 | int num_times=10000000; 497 | for(int i=0;iMulti-task Network', 19 | 'random':'Random gropings', 20 | 'independent':'Five Independent Networks', 21 | 'esa':'ESA (ours) 5.3.1', 22 | 'hoa':'HOA (ours) 5.3.2', 23 | 'optimal':'Optimal Network
Choice (ours)' } 24 | 25 | name_to_color={'sener_et_al':7, 26 | 'gradnorm':8, 27 | 'worst':0, 28 | 'all_in_one':1, 29 | 'random':2, 30 | 'independent':3, 31 | 'esa':4, 32 | 'hoa':5, 33 | 'optimal':6} 34 | 35 | fig = go.Figure() 36 | 37 | symbols=['circle','square','diamond','star','hexagram','star-triangle-up','asterisk','y-up','cross'] 38 | 39 | for i,(key,val) in enumerate(curves.items()): 40 | fig.add_trace(go.Scatter(x=budget, y=val, name=name_to_name[key],connectgaps=True ,marker_symbol=name_to_color[key],marker_size=10,line=dict(color=px.colors.qualitative.G10[name_to_color[key]]))) 41 | 42 | #line=dict(color=) 43 | 44 | # Create and style traces 45 | # if 'sener_et_al' in curves: 46 | # fig.add_trace(go.Scatter(x=budget, y=curves['sener_et_al'], name='Sener et al.',connectgaps=True ,)) 47 | # if 'gradnorm' in curves: 48 | # fig.add_trace(go.Scatter(x=budget, y=curves['gradnorm'], name='GradNorm',connectgaps=True )) 49 | # fig.add_trace(go.Scatter(x=budget, y=curves['worst'], name='Worst Network
Choice',connectgaps=True )) 50 | # fig.add_trace(go.Scatter(x=budget, y=curves['all_in_one'], name='Single Traditional
Multi-task Network',connectgaps=True )) 51 | # fig.add_trace(go.Scatter(x=budget, y=curves['random'], name='Random Groupings',connectgaps=True )) 52 | # fig.add_trace(go.Scatter(x=budget, y=curves['independent'], name='Five Independent
Networks',connectgaps=True )) 53 | # fig.add_trace(go.Scatter(x=budget, y=curves['esa'], name='ESA (ours) 3.3.1',connectgaps=True )) 54 | # fig.add_trace(go.Scatter(x=budget, y=curves['hoa'], name='HOA (ours) 3.3.2',connectgaps=True )) 55 | # fig.add_trace(go.Scatter(x=budget, y=curves['optimal'], name='Optimal Network
Choice (ours)',connectgaps=True )) 56 | 57 | 58 | # Edit the layout 59 | fig.update_layout(title=dict(text='Performance vs Compute', font=dict(size=22,color='black')), 60 | xaxis_title=dict(text='Inference Time Cost',font=dict(size=18,color='black')), 61 | yaxis_title=dict(text='Total Loss (lower is better)',font=dict(size=18,color='black')), 62 | legend=dict(font=dict(color='black',size=16)), 63 | #colorway=px.colors.qualitative.G10, 64 | xaxis=dict( 65 | showline=True, 66 | showgrid=False, 67 | showticklabels=True, 68 | linecolor='rgb(0, 0, 0)', 69 | linewidth=1, 70 | ticks='outside', 71 | tickfont=dict( 72 | family='Arial', 73 | size=15, 74 | color='rgb(0, 0, 0)', 75 | ), 76 | ), 77 | yaxis=dict( 78 | showgrid=True, 79 | #zeroline=False, 80 | ticks='outside', 81 | showline=True, 82 | showticklabels=True, 83 | linecolor='rgb(0, 0, 0)', 84 | linewidth=1, 85 | tickfont=dict( 86 | family='Arial', 87 | size=15, 88 | color='rgb(0, 0, 0)', 89 | ), 90 | ), 91 | autosize=False, 92 | margin=dict( 93 | autoexpand=False, 94 | l=58, 95 | r=240, 96 | t=32, 97 | b=47 98 | ), 99 | width=600, 100 | height=100+27*len(curves), 101 | #showlegend=False, 102 | plot_bgcolor='white' 103 | ) 104 | 105 | fig.write_image('plots/'+name+'.pdf') 106 | #fig.show() 107 | 108 | 109 | curves_1=dict( 110 | sener_et_al = [0.5621, None, 0.5556, None, None, None, 0.5471], 111 | gradnorm = [0.5148, None, None, None, None, None, 0.5001], 112 | worst = [0.50278, 0.50278, 0.50278, 0.50278, 0.50278, 0.50278, 0.50179, 0.50179, 0.49941], 113 | all_in_one = [0.50273, None, 0.4916, 0.48873, None, None, 0.4883], 114 | random = [0.50278, 0.485347, 0.473641, 0.469079, 0.465265, 0.46271, 0.460238, 0.458358, 0.456486], 115 | independent = [0.51456, 0.50139, 0.47704, 0.46515, None, None, 0.45456, None, 0.44774], 116 | esa = [0.50273, 0.48732, 0.46727, 0.46063, 0.45722, 0.45058, 0.45058, 0.44742, 0.44742], 117 | hoa = [0.50278, 0.46132, 0.45474, 0.4505, 0.44875, 0.44489, 0.44112, 0.44552, 0.44196], 118 | optimal = [0.50273, 0.46132, 0.45224, 0.44612, 0.44235, 0.43932, 0.43555, 0.43555, 0.43481], 119 | ) 120 | 121 | curves_2=dict( 122 | worst = [0.35989, 0.36554, 0.36926, 0.36936, 0.36956, 0.36956, 0.36956, 0.36956, 0.36956], 123 | independent = [0.37276,0.35715,0.35926,0.36188,None,None,0.35384,None,0.35216], 124 | all_in_one = [0.35989,None,0.35408,None,0.35431,None,0.35295,None,None], 125 | random = [0.35989, 0.360109, 0.357285, 0.355924, 0.353176, 0.351664, 0.349508, 0.348102, 0.346303] , 126 | esa = [0.35989, 0.35989, 0.34696, 0.34696, 0.34483, 0.34483, 0.34483, 0.34483, 0.34483], 127 | hoa = [0.35989, 0.35758, 0.31733, 0.31562, 0.31177, 0.30525, 0.3019, 0.3019, 0.30187], 128 | optimal = [0.35989, 0.35478, 0.31733, 0.3145, 0.30606, 0.3049, 0.3019, 0.3019, 0.30167], 129 | ) 130 | 131 | curves_3=dict( 132 | worst = [0.42998, 0.47544, 0.47182, 0.47205, 0.4717, 0.47066, 0.46857, 0.46702, 0.46495] , 133 | all_in_one = [0.42998, None, None, None, None, None, 0.44391 ], 134 | random = [0.42998, 0.439361, 0.439917, 0.435501, 0.431542, 0.427947, 0.424582, 0.421834, 0.419124], 135 | independent = [0.41805, None, None, 0.4262, None, None, None, None, 0.40643], 136 | 137 | esa = [0.42998, 0.44778, 0.43055, 0.39507, 0.40381, 0.39404, 0.40278, 0.39404, 0.40278] , 138 | hoa = [0.42998, 0.44778, 0.40887, 0.38776, 0.38352, 0.38682, 0.38574, 0.38574, 0.38471] , 139 | optimal = [0.42998, 0.42275, 0.40857, 0.38776, 0.38352, 0.38352, 0.38249, 0.38249, 0.38249], 140 | 141 | ) 142 | 143 | curves_4=dict( 144 | worst = [0.684042, 0.689178, 0.696036, 0.698235, 0.700446, 0.701056, 0.701056, 0.701056, 0.701056], 145 | independent= [0.698867,None,None,0.692437,None,None,None,None,0.685578], 146 | random = [0.684042, 0.683817, 0.681984, 0.681949, 0.680581, 0.6801, 0.679037, 0.678471, 0.677633] , 147 | 148 | 149 | esa = [0.684042, 0.684042, 0.677567, 0.680649, 0.677349, 0.676049, 0.676049, 0.675976, 0.675976] , 150 | all_in_one = [0.684042,None,None,None,None,None,0.672991], 151 | 152 | hoa = [0.684042, 0.678697, 0.674597, 0.671067, 0.669696, 0.671867, 0.670496, 0.670496, 0.670496] , 153 | optimal = [0.684042, 0.678697, 0.674597, 0.671067, 0.669696, 0.669696, 0.668986, 0.668986, 0.668986], 154 | ) 155 | 156 | 157 | make_plot(curves_1,'setting_1') 158 | make_plot(curves_2,'setting_2') 159 | make_plot(curves_3,'setting_3') 160 | make_plot(curves_4,'setting_4') -------------------------------------------------------------------------------- /network_selection/plots/setting_1.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tstandley/taskgrouping/bb1496e42ff442b7ac69e6f227060b8023325d07/network_selection/plots/setting_1.pdf -------------------------------------------------------------------------------- /network_selection/plots/setting_2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tstandley/taskgrouping/bb1496e42ff442b7ac69e6f227060b8023325d07/network_selection/plots/setting_2.pdf -------------------------------------------------------------------------------- /network_selection/plots/setting_3.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tstandley/taskgrouping/bb1496e42ff442b7ac69e6f227060b8023325d07/network_selection/plots/setting_3.pdf -------------------------------------------------------------------------------- /network_selection/plots/setting_4.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tstandley/taskgrouping/bb1496e42ff442b7ac69e6f227060b8023325d07/network_selection/plots/setting_4.pdf -------------------------------------------------------------------------------- /network_selection/results.txt: -------------------------------------------------------------------------------- 1 | 0.08021 0.1739 0.08965 0.09264 0.03313 2 | 0.08039 3 | 0.1695 4 | 0.08591 5 | 0.0895 6 | 0.02783 7 | 0.07858 0.1833 8 | 0.074 0.0997 9 | 0.07722 0.09718 10 | 0.07897 0.04462 11 | 0.1695 0.09275 12 | 0.1706 0.09318 13 | 0.1748 0.03192 14 | 0.08968 0.09181 15 | 0.09358 0.02908 16 | 0.09185 0.03488 17 | 0.07498 0.1698 0.09575 18 | 0.07699 0.1782 0.09704 19 | 0.07893 0.1863 0.04559 20 | 0.0722 0.09919 0.0961 21 | 0.07222 0.0982 0.03689 22 | 0.0766 0.09342 0.03508 23 | 0.1654 0.09358 0.09253 24 | 0.1708 0.09396 0.03286 25 | 0.1793 0.09073 0.02937 26 | 0.09626 0.09024 0.02609 27 | 0.07762 0.1822 0.09869 0.1015 28 | 0.07576 0.1735 0.09718 0.04513 29 | 0.0795 0.1797 0.09272 0.04141 30 | 0.07369 0.09944 0.09697 0.03312 31 | 0.1708 0.09392 0.09334 0.02803 32 | 0.07854 0.1864 0.1 0.09814 0.04453 33 | -------------------------------------------------------------------------------- /network_selection/results_20.txt: -------------------------------------------------------------------------------- 1 | 0.2743 0.385 0.1498 0.24 0.1529 2 | 0.3937 3 | 0.3345 4 | 0.1721 5 | 0.2337 6 | 0.1788 7 | 0.2581 0.366 8 | 0.2126 0.1716 9 | 0.2434 0.235 10 | 0.3073 0.4247 11 | 0.3665 0.155 12 | 0.4057 0.1634 13 | 0.3671 0.2621 14 | 0.157 0.2245 15 | 0.1625 0.3694 16 | 0.2343 0.1412 17 | 0.21 0.3373 0.1666 18 | 0.2297 0.3313 0.264 19 | 0.2504 0.3582 0.3321 20 | 0.2064 0.1723 0.2407 21 | 0.2172 0.176 0.3459 22 | 0.2381 0.2368 0.3472 23 | 0.3302 0.1639 0.2233 24 | 0.3229 0.1531 0.3651 25 | 0.3327 0.2455 0.3415 26 | 0.1547 0.2503 0.1064 27 | 0.2341 0.3375 0.1754 0.1987 28 | 0.211 0.3185 0.1707 0.1747 29 | 0.2294 0.3495 0.1459 0.3404 30 | 0.2147 0.1645 0.1796 0.3378 31 | 0.3263 0.1585 0.2328 0.3416 32 | 0.2319 0.3658 0.1727 0.2396 0.3485 -------------------------------------------------------------------------------- /network_selection/results_alt.txt: -------------------------------------------------------------------------------- 1 | 0.005267 0.07486 0.1192 0.1943 0.3024 2 | 0.004765 3 | 0.07434 4 | 0.1194 5 | 0.1906 6 | 0.3 7 | 0.004412 0.07738 8 | 0.003875 0.1225 9 | 0.006973 0.1936 10 | 0.004757 0.3065 11 | 0.07519 0.1156 12 | 0.07232 0.1854 13 | 0.07293 0.2966 14 | 0.1156 0.192 15 | 0.1164 0.2999 16 | 0.1885 0.297 17 | 0.004678 0.07717 0.1164 18 | 0.006961 0.07414 0.195 19 | 0.004981 0.07465 0.2986 20 | 0.006793 0.1161 0.1997 21 | 0.004752 0.1189 0.3038 22 | 0.006561 0.1906 0.2977 23 | 0.07386 0.1145 0.1931 24 | 0.07449 0.1154 0.2983 25 | 0.07227 0.1872 0.2966 26 | 0.1144 0.1911 0.2975 27 | 0.007305 0.07403 0.1155 0.1934 28 | 0.004973 0.07509 0.1155 0.2988 29 | 0.006663 0.07219 0.1925 0.2965 30 | 0.006843 0.115 0.1937 0.2979 31 | 0.07256 0.1143 0.1926 0.2969 32 | 0.007211 0.07327 0.1144 0.1952 0.2974 -------------------------------------------------------------------------------- /network_selection/results_alt_20.txt: -------------------------------------------------------------------------------- 1 | 0.04071 0.1162 0.1467 0.4029 0.3683 2 | 0.02443 3 | 0.1169 4 | 0.1467 5 | 0.3616 6 | 0.3546 7 | 0.02686 0.1243 8 | 0.05888 0.1499 9 | 0.0306 0.3822 10 | 0.02853 0.3721 11 | 0.128 0.4582 12 | 0.1228 0.3647 13 | 0.1445 0.4461 14 | 0.1561 0.3718 15 | 0.3969 0.3626 16 | 0.02406 0.1325 0.1436 17 | 0.0403 0.1375 0.4596 18 | 0.02199 0.1266 0.3714 19 | 0.03172 0.1536 0.3826 20 | 0.02113 0.1462 0.3601 21 | 0.0439 0.3882 0.3986 22 | 0.1523 0.153 0.3977 23 | 0.1351 0.1609 0.3488 24 | 0.1285 0.3601 0.3656 25 | 0.1845 0.4258 0.358 26 | 0.06426 0.1879 0.1513 0.4761 27 | 0.03202 0.123 0.1469 0.3503 28 | 0.03829 0.1475 0.3577 0.3871 29 | 0.04181 0.1703 0.4108 0.3707 30 | 0.133 0.185 0.3952 0.3782 31 | 0.03864 0.1317 0.1904 0.3904 0.3438 -------------------------------------------------------------------------------- /network_selection/results_alt_test.txt: -------------------------------------------------------------------------------- 1 | 0.005197 0.07464 0.1207 0.1934 0.2985 2 | 0.004698 3 | 0.07408 4 | 0.1209 5 | 0.1896 6 | 0.2963 7 | 0.004356 0.07713 8 | 0.003826 0.124 9 | 0.00688 0.1932 10 | 0.004695 0.3027 11 | 0.07483 0.117 12 | 0.07197 0.1849 13 | 0.07274 0.2928 14 | 0.117 0.1912 15 | 0.1178 0.296 16 | 0.1875 0.293 17 | 0.004616 0.07693 0.118 18 | 0.006865 0.07367 0.1942 19 | 0.004914 0.07446 0.2946 20 | 0.006696 0.1176 0.1986 21 | 0.004689 0.1204 0.3 22 | 0.006473 0.1899 0.2938 23 | 0.07349 0.1161 0.1923 24 | 0.07417 0.1168 0.2943 25 | 0.07349 0.1864 0.2927 26 | 0.1158 0.1904 0.2936 27 | 0.007208 0.07377 0.117 0.1927 28 | 0.004907 0.07484 0.1171 0.295 29 | 0.006572 0.07186 0.1911 0.2926 30 | 0.006749 0.1165 0.1931 0.2941 31 | 0.0724 0.1158 0.1921 0.2932 32 | 0.007112 0.07293 0.1158 0.1946 0.2936 -------------------------------------------------------------------------------- /network_selection/results_large.txt: -------------------------------------------------------------------------------- 1 | 0.03664 0.1478 0.08067 0.08043 0.02041 2 | 0.03576 3 | 0.1427 4 | 0.0797 5 | 0.08093 6 | 0.01692 7 | 0.03606 0.1434 8 | 0.03275 0.08292 9 | 0.03575 0.08488 10 | 0.03711 0.01601 11 | 0.1385 0.07971 12 | 0.1492 0.08343 13 | 0.1493 0.01624 14 | 0.0897 0.04267 15 | 0.08472 0.01184 16 | 0.08297 0.01275 17 | 0.03528 0.138 0.08183 18 | 0.03883 0.1488 0.08495 19 | 0.03748 0.1462 0.0188 20 | 0.03496 0.08498 0.0815 21 | 0.03535 0.08511 0.01615 22 | 0.03787 0.08445 0.01438 23 | 0.1413 0.08215 0.08067 24 | 0.1454 0.08238 0.01618 25 | 0.1518 0.08345 0.01667 26 | 0.0886 0.07956 0.01616 27 | 0.03534 0.1407 0.08259 0.08481 28 | 0.03661 0.1418 0.08246 0.01842 29 | 0.0382 0.1488 0.08472 0.01694 30 | 0.03677 0.08664 0.07945 0.01813 31 | 0.1452 0.08277 0.08481 0.01745 32 | 0.0367 0.1442 0.08296 0.08298 0.01641 -------------------------------------------------------------------------------- /network_selection/results_large_20.txt: -------------------------------------------------------------------------------- 1 | 0.1022 0.381 0.1245 0.1114 0.08506 2 | 0.1623 3 | 0.317 4 | 0.1164 5 | 0.116 6 | 0.215 7 | 0.1067 0.2633 8 | 0.09147 0.1304 9 | 0.09653 0.1832 10 | 0.09416 0.1221 11 | 0.3544 0.1328 12 | 0.3204 0.1351 13 | 0.3515 0.08215 14 | 0.1261 0.1205 15 | 0.1243 0.1256 16 | 0.1322 0.1243 17 | 0.08381 0.246 0.1326 18 | 0.09195 0.3071 0.1296 19 | 0.09837 0.3772 0.09825 20 | 0.1001 0.1268 0.1512 21 | 0.1133 0.1335 0.1359 22 | 0.1197 0.1207 0.143 23 | 0.2646 0.1245 0.1381 24 | 0.3473 0.1255 0.1021 25 | 0.2966 0.1307 0.1277 26 | 0.1264 0.118 0.07475 27 | 0.09258 0.4347 0.1253 0.1225 28 | 0.09358 0.2792 0.1295 0.09629 29 | 0.1054 0.3073 0.122 0.08293 30 | 0.09383 0.1316 0.1249 0.1336 31 | 0.2825 0.1272 0.1705 0.1131 32 | 0.1085 0.2447 0.1313 0.1325 0.1101 -------------------------------------------------------------------------------- /network_selection/results_large_test.txt: -------------------------------------------------------------------------------- 1 | 0.03435 0.1474 0.08036 0.07954 0.02023 2 | 0.03372 3 | 0.1424 4 | 0.07939 5 | 0.07988 6 | 0.01677 7 | 0.03377 0.1431 8 | 0.031 0.08267 9 | 0.03331 0.0839 10 | 0.03485 0.01583 11 | 0.1376 0.07942 12 | 0.148 0.08244 13 | 0.1495 0.01609 14 | 0.08941 0.04209 15 | 0.08439 0.01179 16 | 0.08204 0.01256 17 | 0.03326 0.1374 0.08152 18 | 0.03643 0.1483 0.08394 19 | 0.03579 0.1461 0.01857 20 | 0.03281 0.08467 0.08048 21 | 0.03303 0.08477 0.01604 22 | 0.03532 0.08351 0.01423 23 | 0.141 0.08175 0.07963 24 | 0.1441 0.08201 0.01604 25 | 0.1519 0.08246 0.0164 26 | 0.08833 0.07865 0.01613 27 | 0.03321 0.14 0.08237 0.08386 28 | 0.03419 0.1408 0.08207 0.01818 29 | 0.03645 0.1482 0.08373 0.0168 30 | 0.0347 0.08639 0.07851 0.01796 31 | 0.1443 0.08249 0.08386 0.01721 32 | 0.03515 0.1439 0.08267 0.08194 0.01623 -------------------------------------------------------------------------------- /network_selection/results_mean.txt: -------------------------------------------------------------------------------- 1 | 0.5486 0.7536 0.283 0.3626 0.2584 2 | 0.7037 3 | 0.6657 4 | 0.3013 5 | 0.3638 6 | 0.3334 7 | 0.4631 0.6891 8 | 0.3943 0.3125 9 | 0.4461 0.3839 10 | 0.506 0.5879 11 | 0.6966 0.2986 12 | 0.7747 0.2827 13 | 0.6628 0.3768 14 | 0.2944 0.426 15 | 0.3042 0.5029 16 | 0.3512 0.23202 17 | 0.3958 0.6373 0.3174 18 | 0.4043 0.6213 0.4 19 | 0.4365 0.6707 0.4768 20 | 0.3811 0.3186 0.3787 21 | 0.3933 0.3179 0.4946 22 | 0.4584 0.3773 0.4799 23 | 0.6179 0.3116 0.3427 24 | 0.6174 0.2886 0.4987 25 | 0.6654 0.3759 0.4585 26 | 0.2936 0.3811 0.18865 27 | 0.4158 0.668 0.3237 0.339 28 | 0.4186 0.6024 0.3114 0.3044 29 | 0.4272 0.6745 0.2683 0.483 30 | 0.4065 0.3118 0.305 0.4651 31 | 0.618 0.2949 0.3781 0.4583 32 | 0.3998 0.6593 0.3151 0.3778 0.5073 -------------------------------------------------------------------------------- /network_selection/results_small_data.txt: -------------------------------------------------------------------------------- 1 | 0.05777 0.1813 0.09048 0.07671 0.02365 2 | 0.05797 3 | 0.1651 4 | 0.08696 5 | 0.0802 6 | 0.01999 7 | 0.06659 0.1776 8 | 0.05328 0.0962 9 | 0.0622 0.08512 10 | 0.06259 0.03028 11 | 0.1567 0.08795 12 | 0.196 0.07568 13 | 0.1706 0.02618 14 | 0.09272 0.07777 15 | 0.09465 0.01553 16 | 0.08076 0.01517 17 | 0.07121 0.1674 0.09013 18 | 0.07143 0.1843 0.09027 19 | 0.08371 0.1905 0.03126 20 | 0.06334 0.09325 0.08412 21 | 0.0673 0.09374 0.02513 22 | 0.07009 0.08804 0.02587 23 | 0.1784 0.09214 0.08023 24 | 0.1767 0.0925 0.02856 25 | 0.1827 0.08171 0.02173 26 | 0.09641 0.07349 0.01508 27 | 0.06032 0.1795 0.09453 0.09382 28 | 0.07608 0.1823 0.09286 0.0306 29 | 0.07096 0.1996 0.08362 0.0365 30 | 0.07139 0.09415 0.07843 0.02027 31 | 0.1898 0.09507 0.08288 0.02674 32 | 0.06556 0.1665 0.09228 0.08441 0.02573 33 | -------------------------------------------------------------------------------- /network_selection/results_small_data_at4.txt: -------------------------------------------------------------------------------- 1 | 0.127 0.3419 0.1507 0.1446 0.1389 2 | 0.1315 3 | 0.359 4 | 0.1335 5 | 0.1376 6 | 0.1121 7 | 0.1396 0.3132 8 | 0.1281 0.1568 9 | 0.1872 0.1388 10 | 0.143 0.1621 11 | 0.2961 0.149 12 | 0.3393 0.1438 13 | 0.4161 0.1229 14 | 0.1409 0.1258 15 | 0.148 0.1608 16 | 0.1184 0.1201 17 | 0.1306 0.3168 0.1709 18 | 0.1354 0.3616 0.1752 19 | 0.1701 0.3707 0.1547 20 | 0.2377 0.1743 0.1449 21 | 0.1685 0.1831 0.2913 22 | 0.18 0.1289 0.1294 23 | 0.3292 0.1547 0.1518 24 | 0.3224 0.164 0.1675 25 | 0.3821 0.1428 0.1272 26 | 0.15 0.1546 0.1308 27 | 0.1329 0.3415 0.164 0.1676 28 | 0.1733 0.4203 0.204 0.1679 29 | 0.1322 0.3616 0.1343 0.112 30 | 0.1211 0.1522 0.1944 0.127 31 | 0.3158 0.1564 0.1405 0.1325 32 | 0.1306 0.3591 0.1565 0.1624 0.148 -------------------------------------------------------------------------------- /network_selection/results_small_data_test.txt: -------------------------------------------------------------------------------- 1 | 0.05521 0.1815 0.09032 0.07581 0.02336 2 | 0.05538 3 | 0.1653 4 | 0.0867 5 | 0.07929 6 | 0.01976 7 | 0.06319 0.1781 8 | 0.05097 0.09609 9 | 0.05864 0.08415 10 | 0.0585 0.02992 11 | 0.1573 0.08773 12 | 0.1956 0.07473 13 | 0.1712 0.02587 14 | 0.09241 0.07685 15 | 0.09424 0.01536 16 | 0.07982 0.01501 17 | 0.06832 0.1672 0.08987 18 | 0.06822 0.1846 0.08922 19 | 0.07829 0.1907 0.03089 20 | 0.05954 0.093 0.0832 21 | 0.06377 0.09354 0.02484 22 | 0.06715 0.0871 0.02558 23 | 0.179 0.09193 0.08001 24 | 0.1772 0.09227 0.02825 25 | 0.1829 0.08098 0.02148 26 | 0.09611 0.0726 0.01492 27 | 0.05782 0.1799 0.09439 0.09284 28 | 0.07096 0.1818 0.09258 0.03024 29 | 0.06668 0.1996 0.08268 0.03616 30 | 0.0663 0.0941 0.07747 0.02007 31 | 0.1891 0.09499 0.08204 0.02644 32 | 0.06244 0.1665 0.09211 0.08349 0.02544 -------------------------------------------------------------------------------- /network_selection/results_test.txt: -------------------------------------------------------------------------------- 1 | 0.07742 0.1737 0.08932 0.09203 0.03268 2 | 0.07662 3 | 0.1696 4 | 0.08555 5 | 0.08847 6 | 0.0275 7 | 0.07419 0.1831 8 | 0.07084 0.0994 9 | 0.07369 0.09601 10 | 0.07504 0.044 11 | 0.1694 0.09249 12 | 0.1713 0.08882 13 | 0.1753 0.03145 14 | 0.08934 0.09077 15 | 0.09327 0.02865 16 | 0.09077 0.0344 17 | 0.07193 0.17 0.09544 18 | 0.07311 0.1785 0.09591 19 | 0.07617 0.1865 0.04474 20 | 0.06933 0.09966 0.09302 21 | 0.06859 0.09796 0.03625 22 | 0.07323 0.09232 0.03463 23 | 0.1658 0.09318 0.09143 24 | 0.1706 0.09362 0.03239 25 | 0.1795 0.08968 0.02887 26 | 0.09596 0.08921 0.02566 27 | 0.07338 0.1826 0.09836 0.1003 28 | 0.07249 0.1739 0.09689 0.04441 29 | 0.07634 0.1801 0.09157 0.04097 30 | 0.07111 0.09941 0.09464 0.03328 31 | 0.1704 0.09356 0.09226 0.02768 32 | 0.07603 0.186 0.09976 0.09704 0.04395 -------------------------------------------------------------------------------- /read_training_history.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | from collections import defaultdict 5 | from train_taskonomy import print_table 6 | 7 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 8 | parser.add_argument('--model_file', '-m', default='', type=str, metavar='PATH', 9 | help='path to latest checkpoint (default: none)') 10 | 11 | parser.add_argument('--arch', '-a', metavar='ARCH', default='', 12 | help='model architecture: ' + 13 | ' (default: resnet18)') 14 | 15 | parser.add_argument('--save_raw',default='') 16 | parser.add_argument('--show_loss_plot','-s', action='store_true', 17 | help='show loss plot') 18 | 19 | args = parser.parse_args() 20 | 21 | 22 | 23 | 24 | # def print_table(table_list, go_back=True): 25 | 26 | # if go_back: 27 | # print("\033[F",end='') 28 | # print("\033[K",end='') 29 | # for i in range(len(table_list)): 30 | # print("\033[F",end='') 31 | # print("\033[K",end='') 32 | 33 | 34 | # lens = defaultdict(int) 35 | # for i in table_list: 36 | # for ii,to_print in enumerate(i): 37 | # for title,val in to_print.items(): 38 | # lens[(title,ii)]=max(lens[(title,ii)],max(len(title),len(val))) 39 | 40 | 41 | # # printed_table_list_header = [] 42 | # for ii,to_print in enumerate(table_list[0]): 43 | # for title,val in to_print.items(): 44 | 45 | # print('{0:^{1}}'.format(title,lens[(title,ii)]),end=" ") 46 | # for i in table_list: 47 | # print() 48 | # for ii,to_print in enumerate(i): 49 | # for title,val in to_print.items(): 50 | # print('{0:^{1}}'.format(val,lens[(title,ii)]),end=" ",flush=True) 51 | # print() 52 | 53 | 54 | def create_model(): 55 | import mymodels as models 56 | try: 57 | model = models.__dict__[args.arch](num_classification_classes=1000, 58 | num_segmentation_classes=21, 59 | num_segmentation_classes2=90, 60 | normalize=False) 61 | except: 62 | model = models.__dict__[args.arch]() 63 | 64 | 65 | return model 66 | 67 | 68 | if args.model_file: 69 | if os.path.isfile(args.model_file): 70 | print("=> loading checkpoint '{}'".format(args.model_file)) 71 | checkpoint = torch.load(args.model_file) 72 | 73 | progress_table = checkpoint['progress_table'] 74 | 75 | print_table(progress_table,False) 76 | 77 | if args.show_loss_plot: 78 | loss_history = checkpoint['loss_history'] 79 | print(len(loss_history)) 80 | print() 81 | import matplotlib.pyplot as plt 82 | loss_history2 = loss_history[200:] 83 | loss_history3 = [] 84 | cur = loss_history2[0] 85 | for i in loss_history2: 86 | cur = .99*cur+i*.01 87 | loss_history3.append(cur) 88 | plt.plot(range(len(loss_history3)),loss_history3) 89 | plt.show() -------------------------------------------------------------------------------- /saved_models/placeholder: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tstandley/taskgrouping/bb1496e42ff442b7ac69e6f227060b8023325d07/saved_models/placeholder -------------------------------------------------------------------------------- /sync_batchnorm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .batchnorm import patch_sync_batchnorm, convert_model 13 | from .replicate import DataParallelWithCallback, patch_replication_callback 14 | -------------------------------------------------------------------------------- /sync_batchnorm/batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import collections 12 | import contextlib 13 | 14 | import torch 15 | import torch.nn.functional as F 16 | 17 | from torch.nn.modules.batchnorm import _BatchNorm 18 | 19 | try: 20 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast 21 | except ImportError: 22 | ReduceAddCoalesced = Broadcast = None 23 | 24 | try: 25 | from jactorch.parallel.comm import SyncMaster 26 | from jactorch.parallel.data_parallel import JacDataParallel as DataParallelWithCallback 27 | except ImportError: 28 | from .comm import SyncMaster 29 | from .replicate import DataParallelWithCallback 30 | 31 | __all__ = [ 32 | 'SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d', 33 | 'patch_sync_batchnorm', 'convert_model' 34 | ] 35 | 36 | 37 | def _sum_ft(tensor): 38 | """sum over the first and last dimention""" 39 | return tensor.sum(dim=0).sum(dim=-1) 40 | 41 | 42 | def _unsqueeze_ft(tensor): 43 | """add new dimensions at the front and the tail""" 44 | return tensor.unsqueeze(0).unsqueeze(-1) 45 | 46 | 47 | _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) 48 | _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) 49 | 50 | 51 | class _SynchronizedBatchNorm(_BatchNorm): 52 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True): 53 | assert ReduceAddCoalesced is not None, 'Can not use Synchronized Batch Normalization without CUDA support.' 54 | 55 | super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) 56 | 57 | self._sync_master = SyncMaster(self._data_parallel_master) 58 | 59 | self._is_parallel = False 60 | self._parallel_id = None 61 | self._slave_pipe = None 62 | 63 | def forward(self, input): 64 | # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. 65 | if not (self._is_parallel and self.training): 66 | return F.batch_norm( 67 | input, self.running_mean, self.running_var, self.weight, self.bias, 68 | self.training, self.momentum, self.eps) 69 | 70 | # Resize the input to (B, C, -1). 71 | input_shape = input.size() 72 | input = input.view(input.size(0), self.num_features, -1) 73 | 74 | # Compute the sum and square-sum. 75 | sum_size = input.size(0) * input.size(2) 76 | input_sum = _sum_ft(input) 77 | input_ssum = _sum_ft(input ** 2) 78 | 79 | # Reduce-and-broadcast the statistics. 80 | if self._parallel_id == 0: 81 | mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) 82 | else: 83 | mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) 84 | 85 | # Compute the output. 86 | if self.affine: 87 | # MJY:: Fuse the multiplication for speed. 88 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) 89 | else: 90 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) 91 | 92 | # Reshape it. 93 | return output.view(input_shape) 94 | 95 | def __data_parallel_replicate__(self, ctx, copy_id): 96 | self._is_parallel = True 97 | self._parallel_id = copy_id 98 | 99 | # parallel_id == 0 means master device. 100 | if self._parallel_id == 0: 101 | ctx.sync_master = self._sync_master 102 | else: 103 | self._slave_pipe = ctx.sync_master.register_slave(copy_id) 104 | 105 | def _data_parallel_master(self, intermediates): 106 | """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" 107 | 108 | # Always using same "device order" makes the ReduceAdd operation faster. 109 | # Thanks to:: Tete Xiao (http://tetexiao.com/) 110 | intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) 111 | 112 | to_reduce = [i[1][:2] for i in intermediates] 113 | to_reduce = [j for i in to_reduce for j in i] # flatten 114 | target_gpus = [i[1].sum.get_device() for i in intermediates] 115 | 116 | sum_size = sum([i[1].sum_size for i in intermediates]) 117 | sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) 118 | mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) 119 | 120 | broadcasted = Broadcast.apply(target_gpus, mean, inv_std) 121 | 122 | outputs = [] 123 | for i, rec in enumerate(intermediates): 124 | outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2]))) 125 | 126 | return outputs 127 | 128 | def _compute_mean_std(self, sum_, ssum, size): 129 | """Compute the mean and standard-deviation with sum and square-sum. This method 130 | also maintains the moving average on the master device.""" 131 | assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' 132 | mean = sum_ / size 133 | sumvar = ssum - sum_ * mean 134 | unbias_var = sumvar / (size - 1) 135 | bias_var = sumvar / size 136 | 137 | if hasattr(torch, 'no_grad'): 138 | with torch.no_grad(): 139 | self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data 140 | self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data 141 | else: 142 | self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data 143 | self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data 144 | 145 | return mean, bias_var.clamp(self.eps) ** -0.5 146 | 147 | 148 | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): 149 | r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a 150 | mini-batch. 151 | 152 | .. math:: 153 | 154 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 155 | 156 | This module differs from the built-in PyTorch BatchNorm1d as the mean and 157 | standard-deviation are reduced across all devices during training. 158 | 159 | For example, when one uses `nn.DataParallel` to wrap the network during 160 | training, PyTorch's implementation normalize the tensor on each device using 161 | the statistics only on that device, which accelerated the computation and 162 | is also easy to implement, but the statistics might be inaccurate. 163 | Instead, in this synchronized version, the statistics will be computed 164 | over all training samples distributed on multiple devices. 165 | 166 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 167 | as the built-in PyTorch implementation. 168 | 169 | The mean and standard-deviation are calculated per-dimension over 170 | the mini-batches and gamma and beta are learnable parameter vectors 171 | of size C (where C is the input size). 172 | 173 | During training, this layer keeps a running estimate of its computed mean 174 | and variance. The running sum is kept with a default momentum of 0.1. 175 | 176 | During evaluation, this running mean/variance is used for normalization. 177 | 178 | Because the BatchNorm is done over the `C` dimension, computing statistics 179 | on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm 180 | 181 | Args: 182 | num_features: num_features from an expected input of size 183 | `batch_size x num_features [x width]` 184 | eps: a value added to the denominator for numerical stability. 185 | Default: 1e-5 186 | momentum: the value used for the running_mean and running_var 187 | computation. Default: 0.1 188 | affine: a boolean value that when set to ``True``, gives the layer learnable 189 | affine parameters. Default: ``True`` 190 | 191 | Shape:: 192 | - Input: :math:`(N, C)` or :math:`(N, C, L)` 193 | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) 194 | 195 | Examples: 196 | >>> # With Learnable Parameters 197 | >>> m = SynchronizedBatchNorm1d(100) 198 | >>> # Without Learnable Parameters 199 | >>> m = SynchronizedBatchNorm1d(100, affine=False) 200 | >>> input = torch.autograd.Variable(torch.randn(20, 100)) 201 | >>> output = m(input) 202 | """ 203 | 204 | def _check_input_dim(self, input): 205 | if input.dim() != 2 and input.dim() != 3: 206 | raise ValueError('expected 2D or 3D input (got {}D input)' 207 | .format(input.dim())) 208 | super(SynchronizedBatchNorm1d, self)._check_input_dim(input) 209 | 210 | 211 | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): 212 | r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch 213 | of 3d inputs 214 | 215 | .. math:: 216 | 217 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 218 | 219 | This module differs from the built-in PyTorch BatchNorm2d as the mean and 220 | standard-deviation are reduced across all devices during training. 221 | 222 | For example, when one uses `nn.DataParallel` to wrap the network during 223 | training, PyTorch's implementation normalize the tensor on each device using 224 | the statistics only on that device, which accelerated the computation and 225 | is also easy to implement, but the statistics might be inaccurate. 226 | Instead, in this synchronized version, the statistics will be computed 227 | over all training samples distributed on multiple devices. 228 | 229 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 230 | as the built-in PyTorch implementation. 231 | 232 | The mean and standard-deviation are calculated per-dimension over 233 | the mini-batches and gamma and beta are learnable parameter vectors 234 | of size C (where C is the input size). 235 | 236 | During training, this layer keeps a running estimate of its computed mean 237 | and variance. The running sum is kept with a default momentum of 0.1. 238 | 239 | During evaluation, this running mean/variance is used for normalization. 240 | 241 | Because the BatchNorm is done over the `C` dimension, computing statistics 242 | on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm 243 | 244 | Args: 245 | num_features: num_features from an expected input of 246 | size batch_size x num_features x height x width 247 | eps: a value added to the denominator for numerical stability. 248 | Default: 1e-5 249 | momentum: the value used for the running_mean and running_var 250 | computation. Default: 0.1 251 | affine: a boolean value that when set to ``True``, gives the layer learnable 252 | affine parameters. Default: ``True`` 253 | 254 | Shape:: 255 | - Input: :math:`(N, C, H, W)` 256 | - Output: :math:`(N, C, H, W)` (same shape as input) 257 | 258 | Examples: 259 | >>> # With Learnable Parameters 260 | >>> m = SynchronizedBatchNorm2d(100) 261 | >>> # Without Learnable Parameters 262 | >>> m = SynchronizedBatchNorm2d(100, affine=False) 263 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) 264 | >>> output = m(input) 265 | """ 266 | 267 | def _check_input_dim(self, input): 268 | if input.dim() != 4: 269 | raise ValueError('expected 4D input (got {}D input)' 270 | .format(input.dim())) 271 | super(SynchronizedBatchNorm2d, self)._check_input_dim(input) 272 | 273 | 274 | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): 275 | r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch 276 | of 4d inputs 277 | 278 | .. math:: 279 | 280 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 281 | 282 | This module differs from the built-in PyTorch BatchNorm3d as the mean and 283 | standard-deviation are reduced across all devices during training. 284 | 285 | For example, when one uses `nn.DataParallel` to wrap the network during 286 | training, PyTorch's implementation normalize the tensor on each device using 287 | the statistics only on that device, which accelerated the computation and 288 | is also easy to implement, but the statistics might be inaccurate. 289 | Instead, in this synchronized version, the statistics will be computed 290 | over all training samples distributed on multiple devices. 291 | 292 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 293 | as the built-in PyTorch implementation. 294 | 295 | The mean and standard-deviation are calculated per-dimension over 296 | the mini-batches and gamma and beta are learnable parameter vectors 297 | of size C (where C is the input size). 298 | 299 | During training, this layer keeps a running estimate of its computed mean 300 | and variance. The running sum is kept with a default momentum of 0.1. 301 | 302 | During evaluation, this running mean/variance is used for normalization. 303 | 304 | Because the BatchNorm is done over the `C` dimension, computing statistics 305 | on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm 306 | or Spatio-temporal BatchNorm 307 | 308 | Args: 309 | num_features: num_features from an expected input of 310 | size batch_size x num_features x depth x height x width 311 | eps: a value added to the denominator for numerical stability. 312 | Default: 1e-5 313 | momentum: the value used for the running_mean and running_var 314 | computation. Default: 0.1 315 | affine: a boolean value that when set to ``True``, gives the layer learnable 316 | affine parameters. Default: ``True`` 317 | 318 | Shape:: 319 | - Input: :math:`(N, C, D, H, W)` 320 | - Output: :math:`(N, C, D, H, W)` (same shape as input) 321 | 322 | Examples: 323 | >>> # With Learnable Parameters 324 | >>> m = SynchronizedBatchNorm3d(100) 325 | >>> # Without Learnable Parameters 326 | >>> m = SynchronizedBatchNorm3d(100, affine=False) 327 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) 328 | >>> output = m(input) 329 | """ 330 | 331 | def _check_input_dim(self, input): 332 | if input.dim() != 5: 333 | raise ValueError('expected 5D input (got {}D input)' 334 | .format(input.dim())) 335 | super(SynchronizedBatchNorm3d, self)._check_input_dim(input) 336 | 337 | 338 | @contextlib.contextmanager 339 | def patch_sync_batchnorm(): 340 | import torch.nn as nn 341 | 342 | backup = nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d 343 | 344 | nn.BatchNorm1d = SynchronizedBatchNorm1d 345 | nn.BatchNorm2d = SynchronizedBatchNorm2d 346 | nn.BatchNorm3d = SynchronizedBatchNorm3d 347 | 348 | yield 349 | 350 | nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d = backup 351 | 352 | 353 | def convert_model(module): 354 | """Traverse the input module and its child recursively 355 | and replace all instance of torch.nn.modules.batchnorm.BatchNorm*N*d 356 | to SynchronizedBatchNorm*N*d 357 | 358 | Args: 359 | module: the input module needs to be convert to SyncBN model 360 | 361 | Examples: 362 | >>> import torch.nn as nn 363 | >>> import torchvision 364 | >>> # m is a standard pytorch model 365 | >>> m = torchvision.models.resnet18(True) 366 | >>> m = nn.DataParallel(m) 367 | >>> # after convert, m is using SyncBN 368 | >>> m = convert_model(m) 369 | """ 370 | if isinstance(module, torch.nn.DataParallel): 371 | mod = module.module 372 | mod = convert_model(mod) 373 | mod = DataParallelWithCallback(mod) 374 | return mod 375 | 376 | mod = module 377 | for pth_module, sync_module in zip([torch.nn.modules.batchnorm.BatchNorm1d, 378 | torch.nn.modules.batchnorm.BatchNorm2d, 379 | torch.nn.modules.batchnorm.BatchNorm3d], 380 | [SynchronizedBatchNorm1d, 381 | SynchronizedBatchNorm2d, 382 | SynchronizedBatchNorm3d]): 383 | if isinstance(module, pth_module): 384 | mod = sync_module(module.num_features, module.eps, module.momentum, module.affine) 385 | mod.running_mean = module.running_mean 386 | mod.running_var = module.running_var 387 | if module.affine: 388 | mod.weight.data = module.weight.data.clone().detach() 389 | mod.bias.data = module.bias.data.clone().detach() 390 | 391 | for name, child in module.named_children(): 392 | mod.add_module(name, convert_model(child)) 393 | 394 | return mod 395 | -------------------------------------------------------------------------------- /sync_batchnorm/batchnorm_reimpl.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # File : batchnorm_reimpl.py 4 | # Author : acgtyrant 5 | # Date : 11/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.init as init 14 | 15 | __all__ = ['BatchNorm2dReimpl'] 16 | 17 | 18 | class BatchNorm2dReimpl(nn.Module): 19 | """ 20 | A re-implementation of batch normalization, used for testing the numerical 21 | stability. 22 | 23 | Author: acgtyrant 24 | See also: 25 | https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14 26 | """ 27 | def __init__(self, num_features, eps=1e-5, momentum=0.1): 28 | super().__init__() 29 | 30 | self.num_features = num_features 31 | self.eps = eps 32 | self.momentum = momentum 33 | self.weight = nn.Parameter(torch.empty(num_features)) 34 | self.bias = nn.Parameter(torch.empty(num_features)) 35 | self.register_buffer('running_mean', torch.zeros(num_features)) 36 | self.register_buffer('running_var', torch.ones(num_features)) 37 | self.reset_parameters() 38 | 39 | def reset_running_stats(self): 40 | self.running_mean.zero_() 41 | self.running_var.fill_(1) 42 | 43 | def reset_parameters(self): 44 | self.reset_running_stats() 45 | init.uniform_(self.weight) 46 | init.zeros_(self.bias) 47 | 48 | def forward(self, input_): 49 | batchsize, channels, height, width = input_.size() 50 | numel = batchsize * height * width 51 | input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel) 52 | sum_ = input_.sum(1) 53 | sum_of_square = input_.pow(2).sum(1) 54 | mean = sum_ / numel 55 | sumvar = sum_of_square - sum_ * mean 56 | 57 | self.running_mean = ( 58 | (1 - self.momentum) * self.running_mean 59 | + self.momentum * mean.detach() 60 | ) 61 | unbias_var = sumvar / (numel - 1) 62 | self.running_var = ( 63 | (1 - self.momentum) * self.running_var 64 | + self.momentum * unbias_var.detach() 65 | ) 66 | 67 | bias_var = sumvar / numel 68 | inv_std = 1 / (bias_var + self.eps).pow(0.5) 69 | output = ( 70 | (input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) * 71 | self.weight.unsqueeze(1) + self.bias.unsqueeze(1)) 72 | 73 | return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous() 74 | 75 | -------------------------------------------------------------------------------- /sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | 59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 62 | and passed to a registered callback. 63 | - After receiving the messages, the master device should gather the information and determine to message passed 64 | back to each slave devices. 65 | """ 66 | 67 | def __init__(self, master_callback): 68 | """ 69 | 70 | Args: 71 | master_callback: a callback to be invoked after having collected messages from slave devices. 72 | """ 73 | self._master_callback = master_callback 74 | self._queue = queue.Queue() 75 | self._registry = collections.OrderedDict() 76 | self._activated = False 77 | 78 | def __getstate__(self): 79 | return {'master_callback': self._master_callback} 80 | 81 | def __setstate__(self, state): 82 | self.__init__(state['master_callback']) 83 | 84 | def register_slave(self, identifier): 85 | """ 86 | Register an slave device. 87 | 88 | Args: 89 | identifier: an identifier, usually is the device id. 90 | 91 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 92 | 93 | """ 94 | if self._activated: 95 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 96 | self._activated = False 97 | self._registry.clear() 98 | future = FutureResult() 99 | self._registry[identifier] = _MasterRegistry(future) 100 | return SlavePipe(identifier, self._queue, future) 101 | 102 | def run_master(self, master_msg): 103 | """ 104 | Main entry for the master device in each forward pass. 105 | The messages were first collected from each devices (including the master device), and then 106 | an callback will be invoked to compute the message to be sent back to each devices 107 | (including the master device). 108 | 109 | Args: 110 | master_msg: the message that the master want to send to itself. This will be placed as the first 111 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 112 | 113 | Returns: the message to be sent back to the master device. 114 | 115 | """ 116 | self._activated = True 117 | 118 | intermediates = [(0, master_msg)] 119 | for i in range(self.nr_slaves): 120 | intermediates.append(self._queue.get()) 121 | 122 | results = self._master_callback(intermediates) 123 | assert results[0][0] == 0, 'The first result should belongs to the master.' 124 | 125 | for i, res in results: 126 | if i == 0: 127 | continue 128 | self._registry[i].result.put(res) 129 | 130 | for i in range(self.nr_slaves): 131 | assert self._queue.get() is True 132 | 133 | return results[0][1] 134 | 135 | @property 136 | def nr_slaves(self): 137 | return len(self._registry) 138 | -------------------------------------------------------------------------------- /sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | 31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 32 | 33 | Note that, as all modules are isomorphism, we assign each sub-module with a context 34 | (shared among multiple copies of this module on different devices). 35 | Through this context, different copies can share some information. 36 | 37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 38 | of any slave copies. 39 | """ 40 | master_copy = modules[0] 41 | nr_modules = len(list(master_copy.modules())) 42 | ctxs = [CallbackContext() for _ in range(nr_modules)] 43 | 44 | for i, module in enumerate(modules): 45 | for j, m in enumerate(module.modules()): 46 | if hasattr(m, '__data_parallel_replicate__'): 47 | m.__data_parallel_replicate__(ctxs[j], i) 48 | 49 | 50 | class DataParallelWithCallback(DataParallel): 51 | """ 52 | Data Parallel with a replication callback. 53 | 54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 55 | original `replicate` function. 56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 57 | 58 | Examples: 59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 61 | # sync_bn.__data_parallel_replicate__ will be invoked. 62 | """ 63 | 64 | def replicate(self, module, device_ids): 65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 66 | execute_replication_callbacks(modules) 67 | return modules 68 | 69 | 70 | def patch_replication_callback(data_parallel): 71 | """ 72 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 73 | Useful when you have customized `DataParallel` implementation. 74 | 75 | Examples: 76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 78 | > patch_replication_callback(sync_bn) 79 | # this is equivalent to 80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 82 | """ 83 | 84 | assert isinstance(data_parallel, DataParallel) 85 | 86 | old_replicate = data_parallel.replicate 87 | 88 | @functools.wraps(old_replicate) 89 | def new_replicate(module, device_ids): 90 | modules = old_replicate(module, device_ids) 91 | execute_replication_callbacks(modules) 92 | return modules 93 | 94 | data_parallel.replicate = new_replicate 95 | -------------------------------------------------------------------------------- /sync_batchnorm/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | import torch 13 | 14 | 15 | class TorchTestCase(unittest.TestCase): 16 | def assertTensorClose(self, x, y): 17 | adiff = float((x - y).abs().max()) 18 | if (y == 0).all(): 19 | rdiff = 'NaN' 20 | else: 21 | rdiff = float((adiff / y).abs().max()) 22 | 23 | message = ( 24 | 'Tensor close check failed\n' 25 | 'adiff={}\n' 26 | 'rdiff={}\n' 27 | ).format(adiff, rdiff) 28 | self.assertTrue(torch.allclose(x, y), message) 29 | 30 | -------------------------------------------------------------------------------- /taskonomy_loader.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | from PIL import Image, ImageOps 4 | import os 5 | import os.path 6 | import zipfile as zf 7 | import io 8 | import logging 9 | import random 10 | import copy 11 | import numpy as np 12 | import time 13 | import torch 14 | 15 | import multiprocessing 16 | import warnings 17 | import torchvision.transforms as transforms 18 | 19 | from multiprocessing import Manager 20 | 21 | 22 | class TaskonomyLoader(data.Dataset): 23 | def __init__(self, 24 | root, 25 | label_set=['depth_zbuffer','normal','segment_semantic','edge_occlusion','reshading','keypoints2d','edge_texture'], 26 | model_whitelist=None, 27 | model_limit=None, 28 | output_size=None, 29 | convert_to_tensor=True, 30 | return_filename=False, 31 | half_sized_output=False, 32 | augment=False): 33 | manager=Manager() 34 | self.root = root 35 | self.model_limit=model_limit 36 | self.records=[] 37 | if model_whitelist is None: 38 | self.model_whitelist=None 39 | else: 40 | self.model_whitelist = set() 41 | with open(model_whitelist) as f: 42 | for line in f: 43 | self.model_whitelist.add(line.strip()) 44 | 45 | for i,(where, subdirs, files) in enumerate(os.walk(os.path.join(root,'rgb'))): 46 | if subdirs!=[]: continue 47 | model = where.split('/')[-1] 48 | if self.model_whitelist is None or model in self.model_whitelist: 49 | full_paths = [os.path.join(where,f) for f in files] 50 | if isinstance(model_limit, tuple): 51 | full_paths.sort() 52 | full_paths = full_paths[model_limit[0]:model_limit[1]] 53 | elif model_limit is not None: 54 | full_paths.sort() 55 | full_paths = full_paths[:model_limit] 56 | self.records+=full_paths 57 | 58 | #self.records = manager.list(self.records) 59 | self.label_set = label_set 60 | self.output_size = output_size 61 | self.half_sized_output=half_sized_output 62 | self.convert_to_tensor = convert_to_tensor 63 | self.return_filename=return_filename 64 | self.to_tensor = transforms.ToTensor() 65 | self.augment = augment 66 | 67 | if augment == "aggressive": 68 | print('Data augmentation is on (aggressive).') 69 | elif augment: 70 | print('Data augmentation is on (flip).') 71 | else: 72 | print('no data augmentation') 73 | self.last = {} 74 | 75 | def process_image(self,im,input=False): 76 | output_size=self.output_size 77 | if self.half_sized_output and not input: 78 | if output_size is None: 79 | output_size=(128,128) 80 | else: 81 | output_size=output_size[0]//2,output_size[1]//2 82 | if output_size is not None and output_size!=im.size: 83 | im = im.resize(output_size,Image.BILINEAR) 84 | 85 | bands = im.getbands() 86 | if self.convert_to_tensor: 87 | if bands[0]=='L': 88 | im = np.array(im) 89 | im.setflags(write=1) 90 | im = torch.from_numpy(im).unsqueeze(0) 91 | else: 92 | with warnings.catch_warnings(): 93 | warnings.simplefilter("ignore") 94 | im = self.to_tensor(im) 95 | 96 | return im 97 | 98 | def __getitem__(self, index): 99 | """ 100 | Args: 101 | index (int): Index 102 | 103 | Returns: 104 | tuple: (image, target) where target is an uint8 matrix of integers with the same width and height. 105 | If there is an error loading an image or its labels, simply return the previous example. 106 | """ 107 | with torch.no_grad(): 108 | file_name=self.records[index] 109 | save_filename = file_name 110 | 111 | flip_lr = (random.randint(0,1) > .5 and self.augment) 112 | 113 | flip_ud = (random.randint(0,1) > .5 and (self.augment=="aggressive")) 114 | 115 | 116 | 117 | 118 | pil_im = Image.open(file_name) 119 | 120 | if flip_lr: 121 | pil_im = ImageOps.mirror(pil_im) 122 | if flip_ud: 123 | pil_im = ImageOps.flip(pil_im) 124 | 125 | im = self.process_image(pil_im,input=True) 126 | 127 | error=False 128 | 129 | ys = {} 130 | mask = None 131 | to_load = self.label_set 132 | if len(set(['edge_occlusion','normal','reshading','principal_curvature']).intersection(self.label_set))!=0: 133 | if os.path.isfile(file_name.replace('rgb','mask')): 134 | to_load.append('mask') 135 | elif 'depth_zbuffer' not in to_load: 136 | to_load.append('depth_zbuffer') 137 | 138 | for i in to_load: 139 | if i=='mask' and mask is not None: 140 | continue 141 | 142 | yfilename = file_name.replace('rgb',i) 143 | try: 144 | yim = Image.open(yfilename) 145 | except: 146 | yim = self.last[i].copy() 147 | error = True 148 | if (i in self.last and yim.getbands() != self.last[i].getbands()) or error: 149 | yim = self.last[i].copy() 150 | try: 151 | self.last[i]=yim.copy() 152 | except: 153 | pass 154 | if flip_lr: 155 | try: 156 | yim = ImageOps.mirror(yim) 157 | except: 158 | pass 159 | if flip_ud: 160 | try: 161 | yim = ImageOps.flip(yim) 162 | except: 163 | pass 164 | try: 165 | yim = self.process_image(yim) 166 | except: 167 | yim = self.last[i].copy() 168 | yim = self.process_image(yim) 169 | 170 | 171 | if i == 'depth_zbuffer': 172 | yim = yim.float() 173 | mask = yim < (2**13) 174 | yim-=1500.0 175 | yim/= 1000.0 176 | elif i == 'edge_occlusion': 177 | yim = yim.float() 178 | yim-=56.0248 179 | yim/=239.1265 180 | elif i == 'keypoints2d': 181 | yim = yim.float() 182 | yim-=50.0 183 | yim/=100.0 184 | elif i == 'edge_texture': 185 | yim = yim.float() 186 | yim-=718.0 187 | yim/=1070.0 188 | elif i == 'normal': 189 | yim = yim.float() 190 | yim -=.5 191 | yim *=2.0 192 | if flip_lr: 193 | yim[0]*=-1.0 194 | if flip_ud: 195 | yim[1]*=-1.0 196 | elif i == 'reshading': 197 | yim=yim.mean(dim=0,keepdim=True) 198 | yim-=.4962 199 | yim/=0.2846 200 | #print('reshading',yim.shape,yim.max(),yim.min()) 201 | elif i == 'principal_curvature': 202 | yim=yim[:2] 203 | yim-=torch.tensor([0.5175, 0.4987]).view(2,1,1) 204 | yim/=torch.tensor([0.1373, 0.0359]).view(2,1,1) 205 | #print('principal_curvature',yim.shape,yim.max(),yim.min()) 206 | elif i == 'mask': 207 | mask=yim.bool() 208 | yim=mask 209 | 210 | ys[i] = yim 211 | 212 | 213 | 214 | if mask is not None: 215 | ys['mask']=mask 216 | 217 | # print(self.label_set) 218 | # print('rgb' in self.label_set) 219 | if not 'rgb' in self.label_set: 220 | ys['rgb']=im 221 | 222 | if self.return_filename: 223 | return im, ys, file_name 224 | else: 225 | return im, ys 226 | 227 | 228 | def __len__(self): 229 | return (len(self.records)) 230 | 231 | def show(im, ys): 232 | from matplotlib import pyplot as plt 233 | plt.figure(figsize=(30,30)) 234 | plt.subplot(4,3,1).set_title('RGB') 235 | im = im.permute([1,2,0]) 236 | plt.imshow(im) 237 | #print(im) 238 | #print(ys) 239 | for i, y in enumerate(ys): 240 | yim=ys[y] 241 | plt.subplot(4,3,2+i).set_title(y) 242 | if y=='normal': 243 | yim+=1 244 | yim/=2 245 | if yim.shape[0]==2: 246 | yim = torch.cat([yim,torch.zeros((1,yim.shape[1],yim.shape[2]))],dim=0) 247 | yim = yim.permute([1,2,0]) 248 | yim = yim.squeeze() 249 | plt.imshow(np.array(yim)) 250 | 251 | 252 | plt.show() 253 | 254 | def test(): 255 | loader = TaskonomyLoader( 256 | '/home/tstand/Desktop/lite_taskonomy/', 257 | label_set=['normal','reshading','principal_curvature','edge_occlusion','depth_zbuffer'], 258 | augment='aggressive') 259 | 260 | totals= {} 261 | totals2 = {} 262 | count = {} 263 | indices= list(range(len(loader))) 264 | random.shuffle(indices) 265 | for data_count, index in enumerate(indices): 266 | im, ys=loader[index] 267 | show(im,ys) 268 | mask = ys['mask'] 269 | #mask = ~mask 270 | print(index) 271 | for i, y in enumerate(ys): 272 | yim=ys[y] 273 | yim = yim.float() 274 | if y not in totals: 275 | totals[y]=0 276 | totals2[y]=0 277 | count[y]=0 278 | 279 | totals[y]+=(yim*mask).sum(dim=[1,2]) 280 | totals2[y]+=((yim**2)*mask).sum(dim=[1,2]) 281 | count[y]+=(torch.ones_like(yim)*mask).sum(dim=[1,2]) 282 | 283 | #print(y,yim.shape) 284 | std = torch.sqrt((totals2[y]-(totals[y]**2)/count[y])/count[y]) 285 | print(data_count,'/',len(loader),y,'mean:',totals[y]/count[y],'std:',std) 286 | 287 | def output_mask(index,loader): 288 | from matplotlib import pyplot as plt 289 | filename=loader.records[index] 290 | filename=filename.replace('rgb','mask') 291 | filename=filename.replace('/intel_nvme/taskonomy_data/','/run/shm/') 292 | if os.path.isfile(filename): 293 | return 294 | 295 | 296 | print(filename) 297 | 298 | 299 | x,ys = loader[index] 300 | 301 | mask =ys['mask'] 302 | mask=mask.squeeze() 303 | mask_im=Image.fromarray(mask.numpy()) 304 | mask_im = mask_im.convert(mode='1') 305 | # plt.subplot(2,1,1) 306 | # plt.imshow(mask) 307 | # plt.subplot(2,1,2) 308 | # plt.imshow(mask_im) 309 | # plt.show() 310 | path, _ = os.path.split(filename) 311 | os.makedirs(path,exist_ok=True) 312 | mask_im.save(filename,bits=1,optimize=True) 313 | 314 | 315 | 316 | 317 | 318 | 319 | def get_masks(): 320 | import multiprocessing 321 | 322 | 323 | loader = TaskonomyLoader( 324 | '/intel_nvme/taskonomy_data/', 325 | label_set=['depth_zbuffer'], 326 | augment=False) 327 | 328 | indices= list(range(len(loader))) 329 | 330 | random.shuffle(indices) 331 | 332 | 333 | for count,index in enumerate(indices): 334 | print(count,len(indices)) 335 | output_mask(index,loader) 336 | 337 | 338 | 339 | 340 | 341 | if __name__ == "__main__": 342 | test() 343 | #get_masks() 344 | -------------------------------------------------------------------------------- /taskonomy_losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import collections 3 | 4 | sl=0 5 | nl=0 6 | nl2=0 7 | nl3=0 8 | dl=0 9 | el=0 10 | rl=0 11 | kl=0 12 | tl=0 13 | al=0 14 | cl=0 15 | popular_offsets=collections.defaultdict(int) 16 | batch_number=0 17 | 18 | def segment_semantic_loss(output,target,mask): 19 | global sl 20 | sl = torch.nn.functional.cross_entropy(output.float(),target.long().squeeze(dim=1),ignore_index=0,reduction='mean') 21 | return sl 22 | 23 | def normal_loss(output,target,mask): 24 | global nl 25 | nl= rotate_loss(output,target,mask,normal_loss_base) 26 | return nl 27 | 28 | def normal_loss_simple(output,target,mask): 29 | global nl 30 | out = torch.nn.functional.l1_loss(output,target,reduction='none') 31 | out *=mask.float() 32 | nl = out.mean() 33 | return nl 34 | 35 | def rotate_loss(output,target,mask,loss_name): 36 | global popular_offsets 37 | target=target[:,:,1:-1,1:-1].float() 38 | mask = mask[:,:,1:-1,1:-1].float() 39 | output=output.float() 40 | val1 = loss = loss_name(output[:,:,1:-1,1:-1],target,mask) 41 | 42 | val2 = loss_name(output[:,:,0:-2,1:-1],target,mask) 43 | loss = torch.min(loss,val2) 44 | val3 = loss_name(output[:,:,1:-1,0:-2],target,mask) 45 | loss = torch.min(loss,val3) 46 | val4 = loss_name(output[:,:,2:,1:-1],target,mask) 47 | loss = torch.min(loss,val4) 48 | val5 = loss_name(output[:,:,1:-1,2:],target,mask) 49 | loss = torch.min(loss,val5) 50 | val6 = loss_name(output[:,:,0:-2,0:-2],target,mask) 51 | loss = torch.min(loss,val6) 52 | val7 = loss_name(output[:,:,2:,2:],target,mask) 53 | loss = torch.min(loss,val7) 54 | val8 = loss_name(output[:,:,0:-2,2:],target,mask) 55 | loss = torch.min(loss,val8) 56 | val9 = loss_name(output[:,:,2:,0:-2],target,mask) 57 | loss = torch.min(loss,val9) 58 | 59 | #lst = [val1,val2,val3,val4,val5,val6,val7,val8,val9] 60 | 61 | #print(loss.size()) 62 | loss=loss.mean() 63 | #print(loss) 64 | return loss 65 | 66 | 67 | def normal_loss_base(output,target,mask): 68 | out = torch.nn.functional.l1_loss(output,target,reduction='none') 69 | out *=mask 70 | out = out.mean(dim=(1,2,3)) 71 | return out 72 | 73 | def normal2_loss(output,target,mask): 74 | global nl3 75 | diff = output.float() - target.float() 76 | out = torch.abs(diff) 77 | out = out*mask.float() 78 | nl3 = out.mean() 79 | return nl3 80 | 81 | def depth_loss_simple(output,target,mask): 82 | global dl 83 | out = torch.nn.functional.l1_loss(output,target,reduction='none') 84 | out *=mask.float() 85 | dl = out.mean() 86 | return dl 87 | 88 | def depth_loss(output,target,mask): 89 | global dl 90 | dl = rotate_loss(output,target,mask,depth_loss_base) 91 | return dl 92 | 93 | def depth_loss_base(output,target,mask): 94 | out = torch.nn.functional.l1_loss(output,target,reduction='none') 95 | out *=mask.float() 96 | out = out.mean(dim=(1,2,3)) 97 | return out 98 | 99 | def edge_loss_simple(output,target,mask): 100 | global el 101 | 102 | out = torch.nn.functional.l1_loss(output,target,reduction='none') 103 | out *=mask 104 | el = out.mean() 105 | return el 106 | 107 | def reshade_loss(output,target,mask): 108 | global rl 109 | out = torch.nn.functional.l1_loss(output,target,reduction='none') 110 | out *=mask 111 | rl = out.mean() 112 | return rl 113 | 114 | def keypoints2d_loss(output,target,mask): 115 | global kl 116 | kl = torch.nn.functional.l1_loss(output,target) 117 | return kl 118 | 119 | def edge2d_loss(output,target,mask): 120 | global tl 121 | tl = torch.nn.functional.l1_loss(output,target) 122 | return tl 123 | 124 | def auto_loss(output,target,mask): 125 | global al 126 | al = torch.nn.functional.l1_loss(output,target) 127 | return al 128 | 129 | def pc_loss(output,target,mask): 130 | global cl 131 | out = torch.nn.functional.l1_loss(output,target,reduction='none') 132 | out *=mask 133 | cl = out.mean() 134 | return cl 135 | 136 | def edge_loss(output,target,mask): 137 | global el 138 | out = torch.nn.functional.l1_loss(output,target,reduction='none') 139 | out *=mask 140 | el = out.mean() 141 | return el 142 | 143 | 144 | def get_taskonomy_loss(losses): 145 | def taskonomy_loss(output,target): 146 | if 'mask' in target: 147 | mask = target['mask'] 148 | else: 149 | mask=None 150 | 151 | sum_loss=None 152 | num=0 153 | for n,t in target.items(): 154 | if n in losses: 155 | o = output[n].float() 156 | this_loss = losses[n](o,t,mask) 157 | num+=1 158 | if sum_loss: 159 | sum_loss = sum_loss+ this_loss 160 | else: 161 | sum_loss = this_loss 162 | 163 | return sum_loss#/num # should not take average when using xception_taskonomy_new 164 | return taskonomy_loss 165 | 166 | 167 | def get_losses_and_tasks(args): 168 | task_str = args.tasks 169 | losses = {} 170 | criteria = {} 171 | taskonomy_tasks = [] 172 | 173 | if 's' in task_str: 174 | losses['segment_semantic'] = segment_semantic_loss 175 | criteria['ss_l']=lambda x,y : sl 176 | taskonomy_tasks.append('segment_semantic') 177 | if 'd' in task_str: 178 | if not args.rotate_loss: 179 | losses['depth_zbuffer'] = depth_loss_simple 180 | else: 181 | print('got rotate loss') 182 | losses['depth_zbuffer'] = depth_loss 183 | criteria['depth_l']=lambda x,y : dl 184 | taskonomy_tasks.append('depth_zbuffer') 185 | 186 | if 'n' in task_str: 187 | if not args.rotate_loss: 188 | losses['normal']=normal_loss_simple 189 | else: 190 | print('got rotate loss') 191 | losses['normal']=normal_loss 192 | criteria['norm_l']=lambda x,y : nl 193 | #criteria['norm_l2']=lambda x,y : nl2 194 | taskonomy_tasks.append('normal') 195 | if 'N' in task_str: 196 | losses['normal2']=normal2_loss 197 | criteria['norm2']=lambda x,y : nl3 198 | taskonomy_tasks.append('normal2') 199 | if 'k' in task_str: 200 | losses['keypoints2d']=keypoints2d_loss 201 | criteria['key_l']=lambda x,y : kl 202 | taskonomy_tasks.append('keypoints2d') 203 | if 'e' in task_str: 204 | if not args.rotate_loss: 205 | losses['edge_occlusion'] = edge_loss_simple 206 | else: 207 | print('got rotate loss') 208 | losses['edge_occlusion'] = edge_loss 209 | #losses['edge_occlusion']=edge_loss 210 | criteria['edge_l']=lambda x,y : el 211 | taskonomy_tasks.append('edge_occlusion') 212 | if 'r' in task_str: 213 | losses['reshading']=reshade_loss 214 | criteria['shade_l']=lambda x,y : rl 215 | taskonomy_tasks.append('reshading') 216 | if 't' in task_str: 217 | losses['edge_texture']=edge2d_loss 218 | criteria['edge2d_l']=lambda x,y : tl 219 | taskonomy_tasks.append('edge_texture') 220 | if 'a' in task_str: 221 | losses['rgb']=auto_loss 222 | criteria['rgb_l']=lambda x,y : al 223 | taskonomy_tasks.append('rgb') 224 | if 'c' in task_str: 225 | losses['principal_curvature']=pc_loss 226 | criteria['pc_l']=lambda x,y : cl 227 | taskonomy_tasks.append('principal_curvature') 228 | 229 | #"nacre" 230 | 231 | if args.task_weights: 232 | weights=[float(x) for x in args.task_weights.split(',')] 233 | losses2={} 234 | criteria2={} 235 | 236 | 237 | for l,w,c in zip(losses.items(),weights,criteria.items()): 238 | losses[l[0]]=lambda x,y,z,l=l[1],w=w:l(x,y,z)*w 239 | criteria[c[0]]=lambda x,y,c=c[1],w=w:c(x,y)*w 240 | 241 | taskonomy_loss = get_taskonomy_loss(losses) 242 | return taskonomy_loss,losses, criteria, taskonomy_tasks 243 | -------------------------------------------------------------------------------- /train_models.txt: -------------------------------------------------------------------------------- 1 | hallettsville 2 | kingfisher 3 | seeley 4 | martinville 5 | macksville 6 | pamelia 7 | browntown 8 | grace 9 | whiteriver 10 | mullica 11 | haaswood 12 | samuels 13 | eastville 14 | calavo 15 | hurley 16 | cullison 17 | lluveras 18 | brevort 19 | biltmore 20 | rabbit 21 | hiteman 22 | ribera 23 | byers 24 | waukeenah 25 | jenners 26 | quantico 27 | arona 28 | corozal 29 | cottonport 30 | murchison 31 | bellemeade 32 | cashel 33 | connellsville 34 | mcewen 35 | gloria 36 | experiment 37 | waldenburg 38 | umpqua 39 | goodview 40 | archer 41 | silva 42 | sisters 43 | roeville 44 | mentasta 45 | coeburn 46 | musicks 47 | aloha 48 | elton 49 | ballou 50 | gluck 51 | goodfield 52 | mogote 53 | adairsville 54 | clarkridge 55 | holcut 56 | mammoth 57 | gluek 58 | landing 59 | fishersville 60 | haymarket 61 | ohoopee 62 | crookston 63 | highspire 64 | montreal 65 | merlin 66 | cauthron 67 | rogue 68 | castor 69 | redbank 70 | sugarville 71 | ihlen 72 | rutherford 73 | sanctuary 74 | airport 75 | broadwell 76 | silerton 77 | eagan 78 | ackermanville 79 | maiden 80 | lindberg 81 | lindenwood 82 | eudora 83 | mckeesport 84 | creede 85 | pinesdale 86 | matoaca 87 | goodyear 88 | pettigrew 89 | funkstown 90 | stokes 91 | marland 92 | applewold 93 | uvalda 94 | athens 95 | lovilia 96 | mobridge 97 | ossipee 98 | torrington 99 | lathrup 100 | mesic 101 | castroville 102 | seatonville 103 | hatfield 104 | roxboro 105 | dunmor 106 | tilghmanton 107 | tyler 108 | codell 109 | model 110 | deemston 111 | hometown 112 | edgemere 113 | dansville 114 | helton 115 | foyil 116 | tysons 117 | pittsburg 118 | pleasant 119 | cayuse 120 | grassy 121 | wyldwood 122 | rosenberg 123 | frierson 124 | nicut 125 | mashulaville 126 | peacock 127 | carpendale 128 | scandinavia 129 | leilani 130 | broseley 131 | country 132 | divide 133 | wilkinsburg 134 | lessley 135 | muleshoe 136 | kettle 137 | frontenac 138 | kronborg 139 | kihei 140 | tomales 141 | pocasset 142 | baneberry 143 | siren 144 | beechwood 145 | timberon 146 | adrian 147 | fonda 148 | darden 149 | merom 150 | stockman 151 | milford 152 | culbertson 153 | collierville 154 | bellwood 155 | frankfort 156 | fitchburg 157 | random 158 | reyno 159 | auburn 160 | oriole 161 | poyen 162 | swisshome 163 | hordville 164 | wattsville 165 | mogadore 166 | shumway 167 | cisne 168 | elmira 169 | maunawili 170 | wappingers 171 | merchantville 172 | pablo 173 | swormville 174 | poipu 175 | kildare 176 | neshkoro 177 | benicia 178 | soldier 179 | cobalt 180 | grigston 181 | hercules 182 | eagerville 183 | #gough 184 | kevin 185 | grangeville 186 | akiak 187 | arkansaw 188 | beach 189 | mazomanie 190 | greigsville 191 | newcomb 192 | kendall 193 | yadkinville 194 | #wiconisco 195 | sumas 196 | helix 197 | markleeville 198 | maguayo 199 | tippecanoe 200 | pocopson 201 | islandton 202 | glenmoor 203 | oyens 204 | glassboro 205 | chireno 206 | maryhill 207 | readsboro 208 | macland 209 | waucousta 210 | grantsville 211 | kobuk 212 | seward 213 | albertville 214 | barahona 215 | mayesville 216 | bonfield 217 | maben 218 | milaca 219 | sunshine 220 | graceville 221 | melstone 222 | gasburg 223 | hammon 224 | superior 225 | stockwell 226 | hainesburg 227 | irvine 228 | shingler 229 | pearce 230 | spencerville 231 | okabena 232 | orangeburg 233 | touhy 234 | macarthur 235 | inkom 236 | calmar 237 | bautista 238 | lynxville 239 | darrtown 240 | reserve 241 | coffeen 242 | cousins 243 | cooperstown 244 | hendrix 245 | halfway 246 | gratz 247 | gastonia 248 | sodaville 249 | ooltewah 250 | sawpit 251 | pasatiempo 252 | orason 253 | ballantine 254 | kopperl 255 | chesterbrook 256 | fredericksburg 257 | losantville 258 | vails 259 | trail 260 | liddieville 261 | roane 262 | waimea 263 | noxapater 264 | neibert 265 | sontag 266 | weleetka 267 | churchton 268 | lineville 269 | brewton 270 | ellaville 271 | kirksville 272 | rockport 273 | duarte 274 | almena 275 | wakeman 276 | lynchburg 277 | dauberville 278 | natural 279 | marksville 280 | whitethorn 281 | goodwine 282 | macedon 283 | delton 284 | kinney 285 | michiana 286 | blenheim 287 | bremerton 288 | hartline 289 | colebrook 290 | winfield 291 | arbutus 292 | rough 293 | onaga 294 | sasakwa 295 | branford 296 | stilwell 297 | munsons 298 | ancor 299 | mosinee 300 | victorville 301 | caruthers 302 | fleming 303 | barboursville 304 | ewell 305 | wyatt 306 | annona 307 | woonsocket 308 | voorhees 309 | bowlus 310 | ladue 311 | bolton 312 | shauck 313 | checotah 314 | shelbiana 315 | edson 316 | vacherie 317 | bonnie 318 | yankeetown 319 | wells 320 | ewansville 321 | purple 322 | windhorst 323 | espanola 324 | anthoston 325 | gravelly 326 | hambleton 327 | westfield 328 | paige 329 | mifflinburg 330 | destin 331 | blackstone 332 | azusa 333 | tallmadge 334 | chilhowie 335 | globe 336 | avonia 337 | jacobus 338 | crandon 339 | wesley 340 | aldrich 341 | winooski 342 | mifflintown 343 | retsof 344 | tariffville 345 | chiloquin 346 | ovalo 347 | warrenville 348 | callicoon 349 | bountiful 350 | spotswood 351 | wilseyville 352 | sagerton 353 | monson 354 | manassas 355 | winthrop 356 | ashport 357 | hillsdale 358 | braxton 359 | cosmos 360 | belpre 361 | willow 362 | dalcour 363 | westerville 364 | euharlee 365 | maricopa 366 | cutlerville 367 | sarcoxie 368 | cohoes 369 | woodbine 370 | ludlowville 371 | kemblesville 372 | northgate 373 | dryville 374 | clairton 375 | mcdade 376 | artois 377 | kinde 378 | mcclure 379 | portal 380 | gladstone 381 | harrellsville 382 | nemacolin 383 | circleville 384 | andover 385 | moberly 386 | texasville 387 | smoketown 388 | peden 389 | norvelt 390 | wilbraham 391 | bowmore 392 | booth 393 | annawan 394 | alfred 395 | keiser 396 | hildebran 397 | clive 398 | hortense 399 | morris 400 | silas 401 | dedham 402 | schoolcraft 403 | bonesteel 404 | aldine 405 | american 406 | scioto 407 | cokeville 408 | kathryn 409 | hacienda 410 | herricks 411 | shelbyville 412 | benevolence 413 | parole 414 | assinippi 415 | idanha 416 | portola 417 | forkland 418 | bohemia 419 | mccloud 420 | imbery 421 | brentsville 422 | apache 423 | rosser 424 | potterville 425 | peconic 426 | cochranton 427 | anaheim 428 | uncertain 429 | nuevo 430 | cason 431 | angiola 432 | yscloskey 433 | corder 434 | almota 435 | lenoir 436 | newfields 437 | chrisney 438 | ranchester 439 | hitchland 440 | bertram 441 | carneiro 442 | haxtun 443 | laupahoehoe 444 | hominy 445 | ogilvie 446 | rancocas 447 | spread 448 | potosi 449 | ruckersville 450 | thrall 451 | gaylord 452 | mosquito 453 | burien 454 | monticello 455 | noonday 456 | waipahu 457 | germfask 458 | mentmore 459 | sussex 460 | german 461 | howie 462 | sundown 463 | hanson 464 | plumerville 465 | badger 466 | nimmons 467 | galatia 468 | terrell 469 | kremlin 470 | alstown 471 | laytonsville 472 | lindsborg 473 | seiling 474 | allensville 475 | wilkesboro 476 | sweatman 477 | cebolla 478 | marstons 479 | tradewinds 480 | micanopy 481 | barranquitas 482 | shellsburg -------------------------------------------------------------------------------- /train_taskonomy.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | import time 5 | import platform 6 | 7 | import torch 8 | from torch.autograd import Variable 9 | import torch.nn as nn 10 | import torch.backends.cudnn as cudnn 11 | import torchvision.datasets as datasets 12 | 13 | 14 | from taskonomy_losses import * 15 | from taskonomy_loader import TaskonomyLoader 16 | 17 | 18 | from apex.parallel import DistributedDataParallel as DDP 19 | from apex.fp16_utils import * 20 | from apex import amp, optimizers 21 | import copy 22 | import numpy as np 23 | import signal 24 | import sys 25 | import math 26 | from collections import defaultdict 27 | import scipy.stats 28 | 29 | #from ptflops import get_model_complexity_info 30 | 31 | import model_definitions as models 32 | 33 | model_names = sorted(name for name in models.__dict__ 34 | if name.islower() and not name.startswith("__") 35 | and callable(models.__dict__[name])) 36 | 37 | 38 | parser = argparse.ArgumentParser(description='PyTorch Taskonomy Training') 39 | parser.add_argument('--data_dir', '-d', dest='data_dir',required=True, 40 | help='path to training set') 41 | parser.add_argument('--arch', '-a', metavar='ARCH',required=True, 42 | choices=model_names, 43 | help='model architecture: ' + 44 | ' | '.join(model_names) + 45 | ' (required)') 46 | parser.add_argument('-b', '--batch-size', default=64, type=int, 47 | help='mini-batch size (default: 64)') 48 | parser.add_argument('--tasks', '-ts', default='sdnkt', dest='tasks', 49 | help='which tasks to train on') 50 | parser.add_argument('--model_dir', default='saved_models', dest='model_dir', 51 | help='where to save models') 52 | parser.add_argument('--image-size', default=256, type=int, 53 | help='size of image side (images are square)') 54 | parser.add_argument('-j', '--workers', default=4, type=int, 55 | help='number of data loading workers (default: 4)') 56 | parser.add_argument('-pf', '--print_frequency', default=1, type=int, 57 | help='how often to print output') 58 | parser.add_argument('--epochs', default=100, type=int, 59 | help='maximum number of epochs to run') 60 | parser.add_argument('-mlr', '--minimum_learning_rate', default=3e-5, type=float, 61 | metavar='LR', help='End trianing when learning rate falls below this value.') 62 | 63 | parser.add_argument('-lr', '--learning-rate',dest='lr', default=0.1, type=float, 64 | metavar='LR', help='initial learning rate') 65 | parser.add_argument('-ltw0', '--loss_tracking_window_initial', default=500000, type=int, 66 | help='inital loss tracking window (default: 500000)') 67 | parser.add_argument('-mltw', '--maximum_loss_tracking_window', default=2000000, type=int, 68 | help='maximum loss tracking window (default: 2000000)') 69 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 70 | help='momentum') 71 | parser.add_argument('--weight-decay', '-wd','--wd', default=1e-4, type=float, 72 | metavar='W', help='weight decay (default: 1e-4)') 73 | parser.add_argument('--resume','--restart', default='', type=str, metavar='PATH', 74 | help='path to latest checkpoint (default: none)') 75 | # parser.add_argument('--start-epoch', default=0, type=int, 76 | # help='manual epoch number (useful on restarts)') 77 | parser.add_argument('-n','--experiment_name', default='', type=str, 78 | help='name to prepend to experiment saves.') 79 | parser.add_argument('-v', '--validate', dest='validate', action='store_true', 80 | help='evaluate model on validation set') 81 | parser.add_argument('-t', '--test', dest='test', action='store_true', 82 | help='evaluate model on test set') 83 | 84 | parser.add_argument('-r', '--rotate_loss', dest='rotate_loss', action='store_true', 85 | help='should loss rotation occur') 86 | parser.add_argument('--pretrained', dest='pretrained', default='', 87 | help='use pre-trained model') 88 | parser.add_argument('-vb', '--virtual-batch-multiplier', default=1, type=int, 89 | metavar='N', help='number of forward/backward passes per parameter update') 90 | parser.add_argument('--fp16', action='store_true', 91 | help='Run model fp16 mode.') 92 | parser.add_argument('-sbn', '--sync_batch_norm', action='store_true', 93 | help='sync batch norm parameters accross gpus.') 94 | parser.add_argument('-hs', '--half_sized_output', action='store_true', 95 | help='output 128x128 rather than 256x256.') 96 | parser.add_argument('-na','--no_augment', action='store_true', 97 | help='Run model fp16 mode.') 98 | parser.add_argument('-ml', '--model-limit', default=None, type=int, 99 | help='Limit the number of training instances from a single 3d building model.') 100 | parser.add_argument('-tw', '--task-weights', default=None, type=str, 101 | help='a comma separated list of numbers one for each task to multiply the loss by.') 102 | 103 | cudnn.benchmark = False 104 | 105 | def main(args): 106 | print(args) 107 | print('starting on', platform.node()) 108 | if 'CUDA_VISIBLE_DEVICES' in os.environ: 109 | print('cuda gpus:',os.environ['CUDA_VISIBLE_DEVICES']) 110 | 111 | main_stream = torch.cuda.Stream() 112 | 113 | if args.fp16: 114 | assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled." 115 | print('Got fp16!') 116 | 117 | taskonomy_loss, losses, criteria, taskonomy_tasks = get_losses_and_tasks(args) 118 | 119 | print("including the following tasks:", list(losses.keys())) 120 | 121 | criteria2={'Loss':taskonomy_loss} 122 | for key,value in criteria.items(): 123 | criteria2[key]=value 124 | criteria = criteria2 125 | 126 | print('data_dir =',args.data_dir, len(args.data_dir)) 127 | 128 | if args.no_augment: 129 | augment = False 130 | else: 131 | augment = True 132 | train_dataset = TaskonomyLoader( 133 | args.data_dir, 134 | label_set=taskonomy_tasks, 135 | model_whitelist='train_models.txt', 136 | model_limit=args.model_limit, 137 | output_size = (args.image_size,args.image_size), 138 | half_sized_output=args.half_sized_output, 139 | augment=augment) 140 | 141 | print('Found',len(train_dataset),'training instances.') 142 | 143 | print("=> creating model '{}'".format(args.arch)) 144 | model = models.__dict__[args.arch](tasks=losses.keys(),half_sized_output=args.half_sized_output) 145 | 146 | def get_n_params(model): 147 | pp=0 148 | for p in list(model.parameters()): 149 | #print(p.size()) 150 | nn=1 151 | for s in list(p.size()): 152 | 153 | nn = nn*s 154 | pp += nn 155 | return pp 156 | 157 | print("Model has", get_n_params(model), "parameters") 158 | try: 159 | print("Encoder has", get_n_params(model.encoder), "parameters") 160 | #flops, params=get_model_complexity_info(model.encoder,(3,256,256), as_strings=False, print_per_layer_stat=False) 161 | #print("Encoder has", flops, "Flops and", params, "parameters,") 162 | except: 163 | print("Each encoder has", get_n_params(model.encoders[0]), "parameters") 164 | for decoder in model.task_to_decoder.values(): 165 | print("Decoder has", get_n_params(decoder), "parameters") 166 | 167 | model = model.cuda() 168 | 169 | 170 | optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 171 | 172 | #tested with adamW. Poor results observed 173 | #optimizer = adamW.AdamW(model.parameters(),lr= args.lr,weight_decay=args.weight_decay,eps=1e-3) 174 | 175 | 176 | # Initialize Amp. Amp accepts either values or strings for the optional override arguments, 177 | # for convenient interoperation with argparse. 178 | if args.fp16: 179 | model, optimizer = amp.initialize(model, optimizer, 180 | opt_level='O1', 181 | loss_scale="dynamic", 182 | verbosity=0 183 | ) 184 | print('Got fp16!') 185 | 186 | #args.lr = args.lr*float(args.batch_size*args.virtual_batch_multiplier)/256. 187 | 188 | # optionally resume from a checkpoint 189 | checkpoint=None 190 | if args.resume: 191 | if os.path.isfile(args.resume): 192 | print("=> loading checkpoint '{}'".format(args.resume)) 193 | checkpoint = torch.load(args.resume, map_location = lambda storage, loc: storage.cuda()) 194 | model.load_state_dict(checkpoint['state_dict']) 195 | print("=> loaded checkpoint '{}' (epoch {})" 196 | .format(args.resume, checkpoint['epoch'])) 197 | else: 198 | print("=> no checkpoint found at '{}'".format(args.resume)) 199 | 200 | 201 | 202 | 203 | 204 | if args.pretrained != '': 205 | print('loading pretrained weights for '+args.arch+' ('+args.pretrained+')') 206 | model.encoder.load_state_dict(torch.load(args.pretrained)) 207 | 208 | 209 | if torch.cuda.device_count() >1: 210 | model = torch.nn.DataParallel(model).cuda() 211 | if args.sync_batch_norm: 212 | from sync_batchnorm import patch_replication_callback 213 | patch_replication_callback(model) 214 | 215 | print('Virtual batch size =', args.batch_size*args.virtual_batch_multiplier) 216 | 217 | if args.resume: 218 | if os.path.isfile(args.resume) and 'optimizer' in checkpoint: 219 | optimizer.load_state_dict(checkpoint['optimizer']) 220 | 221 | train_loader = torch.utils.data.DataLoader( 222 | train_dataset, batch_size=args.batch_size, shuffle=True, 223 | num_workers=args.workers, pin_memory=True, sampler=None) 224 | 225 | val_loader = get_eval_loader(args.data_dir, taskonomy_tasks, args) 226 | 227 | trainer=Trainer(train_loader,val_loader,model,optimizer,criteria,args,checkpoint) 228 | if args.validate: 229 | trainer.progress_table=[] 230 | trainer.validate([{}]) 231 | print() 232 | return 233 | 234 | 235 | if args.test: 236 | trainer.progress_table=[] 237 | # replace val loader with a loader that loads test data 238 | trainer.val_loader=get_eval_loader(args.data_dir, taskonomy_tasks, args,model_limit=(1000,2000)) 239 | trainer.validate([{}]) 240 | return 241 | 242 | trainer.train() 243 | 244 | 245 | def get_eval_loader(datadir, label_set, args,model_limit=1000): 246 | print(datadir) 247 | 248 | val_dataset = TaskonomyLoader(datadir, 249 | label_set=label_set, 250 | model_whitelist='val_models.txt', 251 | model_limit=model_limit, 252 | output_size = (args.image_size,args.image_size), 253 | half_sized_output=args.half_sized_output, 254 | augment=False) 255 | print('Found',len(val_dataset),'validation instances.') 256 | 257 | val_loader = torch.utils.data.DataLoader( 258 | val_dataset, 259 | batch_size=max(args.batch_size//2,1), shuffle=False, 260 | num_workers=args.workers, pin_memory=True,sampler=None) 261 | return val_loader 262 | 263 | program_start_time = time.time() 264 | 265 | def on_keyboared_interrupt(x,y): 266 | #print() 267 | sys.exit(1) 268 | signal.signal(signal.SIGINT, on_keyboared_interrupt) 269 | 270 | def get_average_learning_rate(optimizer): 271 | try: 272 | return optimizer.learning_rate 273 | except: 274 | s = 0 275 | for param_group in optimizer.param_groups: 276 | s+=param_group['lr'] 277 | return s/len(optimizer.param_groups) 278 | 279 | class data_prefetcher(): 280 | def __init__(self, loader): 281 | self.inital_loader = loader 282 | self.loader = iter(loader) 283 | self.stream = torch.cuda.Stream() 284 | self.preload() 285 | 286 | def preload(self): 287 | try: 288 | self.next_input, self.next_target = next(self.loader) 289 | except StopIteration: 290 | # self.next_input = None 291 | # self.next_target = None 292 | self.loader = iter(self.inital_loader) 293 | self.preload() 294 | return 295 | with torch.cuda.stream(self.stream): 296 | self.next_input = self.next_input.cuda(non_blocking=True) 297 | #self.next_target = self.next_target.cuda(async=True) 298 | self.next_target = {key: val.cuda(non_blocking=True) for (key,val) in self.next_target.items()} 299 | 300 | def next(self): 301 | torch.cuda.current_stream().wait_stream(self.stream) 302 | input = self.next_input 303 | target = self.next_target 304 | self.preload() 305 | return input, target 306 | 307 | class color: 308 | PURPLE = '\033[95m' 309 | CYAN = '\033[96m' 310 | DARKCYAN = '\033[36m' 311 | BLUE = '\033[94m' 312 | GREEN = '\033[92m' 313 | YELLOW = '\033[93m' 314 | RED = '\033[91m' 315 | BOLD = '\033[1m' 316 | UNDERLINE = '\033[4m' 317 | END = '\033[0m' 318 | 319 | 320 | def print_table(table_list, go_back=True): 321 | if len(table_list)==0: 322 | print() 323 | print() 324 | return 325 | if go_back: 326 | print("\033[F",end='') 327 | print("\033[K",end='') 328 | for i in range(len(table_list)): 329 | print("\033[F",end='') 330 | print("\033[K",end='') 331 | 332 | 333 | lens = defaultdict(int) 334 | for i in table_list: 335 | for ii,to_print in enumerate(i): 336 | for title,val in to_print.items(): 337 | lens[(title,ii)]=max(lens[(title,ii)],max(len(title),len(val))) 338 | 339 | 340 | # printed_table_list_header = [] 341 | for ii,to_print in enumerate(table_list[0]): 342 | for title,val in to_print.items(): 343 | 344 | print('{0:^{1}}'.format(title,lens[(title,ii)]),end=" ") 345 | for i in table_list: 346 | print() 347 | for ii,to_print in enumerate(i): 348 | for title,val in to_print.items(): 349 | print('{0:^{1}}'.format(val,lens[(title,ii)]),end=" ",flush=True) 350 | print() 351 | 352 | 353 | class AverageMeter(object): 354 | """Computes and stores the average and current value""" 355 | def __init__(self): 356 | self.reset() 357 | 358 | def reset(self): 359 | self.val = 0 360 | self.avg = 0 361 | self.std= 0 362 | self.sum = 0 363 | self.sumsq = 0 364 | self.count = 0 365 | self.lst = [] 366 | 367 | def update(self, val, n=1): 368 | self.val = float(val) 369 | self.sum += float(val) * n 370 | #self.sumsq += float(val)**2 371 | self.count += n 372 | self.avg = self.sum / self.count 373 | self.lst.append(self.val) 374 | self.std=np.std(self.lst) 375 | 376 | 377 | class Trainer: 378 | def __init__(self,train_loader,val_loader,model,optimizer,criteria,args,checkpoint=None): 379 | self.train_loader=train_loader 380 | self.val_loader=val_loader 381 | self.train_prefetcher=data_prefetcher(self.train_loader) 382 | self.model=model 383 | self.optimizer=optimizer 384 | self.criteria=criteria 385 | self.args = args 386 | self.fp16=args.fp16 387 | self.code_archive=self.get_code_archive() 388 | if checkpoint: 389 | if 'progress_table' in checkpoint: 390 | self.progress_table = checkpoint['progress_table'] 391 | else: 392 | self.progress_table=[] 393 | if 'epoch' in checkpoint: 394 | self.start_epoch = checkpoint['epoch']+1 395 | else: 396 | self.start_epoch = 0 397 | if 'best_loss' in checkpoint: 398 | self.best_loss = checkpoint['best_loss'] 399 | else: 400 | self.best_loss = 9e9 401 | if 'stats' in checkpoint: 402 | self.stats = checkpoint['stats'] 403 | else: 404 | self.stats=[] 405 | if 'loss_history' in checkpoint: 406 | self.loss_history = checkpoint['loss_history'] 407 | else: 408 | self.loss_history=[] 409 | else: 410 | self.progress_table=[] 411 | self.best_loss = 9e9 412 | self.stats = [] 413 | self.start_epoch = 0 414 | self.loss_history=[] 415 | 416 | self.lr0 = get_average_learning_rate(optimizer) 417 | 418 | print_table(self.progress_table,False) 419 | self.ticks=0 420 | self.last_tick=0 421 | self.loss_tracking_window = args.loss_tracking_window_initial 422 | 423 | def get_code_archive(self): 424 | file_contents={} 425 | for i in os.listdir('.'): 426 | if i[-3:]=='.py': 427 | with open(i,'r') as file: 428 | file_contents[i]=file.read() 429 | return file_contents 430 | 431 | def train(self): 432 | for self.epoch in range(self.start_epoch,self.args.epochs): 433 | current_learning_rate = get_average_learning_rate(self.optimizer) 434 | if current_learning_rate < self.args.minimum_learning_rate: 435 | break 436 | # train for one epoch 437 | train_string, train_stats = self.train_epoch() 438 | 439 | # evaluate on validation set 440 | progress_string=train_string 441 | loss, progress_string, val_stats = self.validate(progress_string) 442 | print() 443 | 444 | self.progress_table.append(progress_string) 445 | 446 | self.stats.append((train_stats,val_stats)) 447 | self.checkpoint(loss) 448 | 449 | def checkpoint(self, loss): 450 | is_best = loss < self.best_loss 451 | self.best_loss = min(loss, self.best_loss) 452 | save_filename = self.args.experiment_name+'_'+self.args.arch+'_'+('p' if self.args.pretrained != '' else 'np')+'_'+self.args.tasks+'_checkpoint.pth.tar' 453 | 454 | try: 455 | to_save = self.model 456 | if torch.cuda.device_count() >1: 457 | to_save=to_save.module 458 | gpus='all' 459 | if 'CUDA_VISIBLE_DEVICES' in os.environ: 460 | gpus=os.environ['CUDA_VISIBLE_DEVICES'] 461 | self.save_checkpoint({ 462 | 'epoch': self.epoch, 463 | 'info':{'machine':platform.node(), 'GPUS':gpus}, 464 | 'args': self.args, 465 | 'arch': self.args.arch, 466 | 'state_dict': to_save.state_dict(), 467 | 'best_loss': self.best_loss, 468 | 'optimizer' : self.optimizer.state_dict(), 469 | 'progress_table' : self.progress_table, 470 | 'stats': self.stats, 471 | 'loss_history': self.loss_history, 472 | 'code_archive':self.code_archive 473 | }, False, self.args.model_dir, save_filename) 474 | 475 | if is_best: 476 | self.save_checkpoint(None, True,self.args.model_dir, save_filename) 477 | except: 478 | print('save checkpoint failed...') 479 | 480 | 481 | 482 | def save_checkpoint(self,state, is_best,directory='', filename='checkpoint.pth.tar'): 483 | path = os.path.join(directory,filename) 484 | if is_best: 485 | best_path = os.path.join(directory,'best_'+filename) 486 | shutil.copyfile(path, best_path) 487 | else: 488 | torch.save(state, path) 489 | 490 | def learning_rate_schedule(self): 491 | ttest_p=0 492 | z_diff=0 493 | 494 | #don't reduce learning rate until the second epoch has ended 495 | if self.epoch < 2: 496 | return 0,0 497 | 498 | wind=self.loss_tracking_window//(self.args.batch_size*args.virtual_batch_multiplier) 499 | if len(self.loss_history)-self.last_tick > wind: 500 | a = self.loss_history[-wind:-wind*5//8] 501 | b = self.loss_history[-wind*3//8:] 502 | #remove outliers 503 | a = sorted(a) 504 | b = sorted(b) 505 | a = a[int(len(a)*.05):int(len(a)*.95)] 506 | b = b[int(len(b)*.05):int(len(b)*.95)] 507 | length_=min(len(a),len(b)) 508 | a=a[:length_] 509 | b=b[:length_] 510 | z_diff,ttest_p = scipy.stats.ttest_rel(a,b,nan_policy='omit') 511 | 512 | if z_diff < 0 or ttest_p > .99: 513 | self.ticks+=1 514 | self.last_tick=len(self.loss_history) 515 | self.adjust_learning_rate() 516 | self.loss_tracking_window = min(self.args.maximum_loss_tracking_window,self.loss_tracking_window*2) 517 | return ttest_p, z_diff 518 | 519 | def train_epoch(self): 520 | global program_start_time 521 | average_meters = defaultdict(AverageMeter) 522 | display_values = [] 523 | for name,func in self.criteria.items(): 524 | display_values.append(name) 525 | 526 | # switch to train mode 527 | self.model.train() 528 | 529 | end = time.time() 530 | epoch_start_time = time.time() 531 | epoch_start_time2=time.time() 532 | 533 | batch_num = 0 534 | num_data_points=len(self.train_loader)//self.args.virtual_batch_multiplier 535 | if num_data_points > 10000: 536 | num_data_points = num_data_points//5 537 | 538 | starting_learning_rate=get_average_learning_rate(self.optimizer) 539 | while True: 540 | if batch_num ==0: 541 | end=time.time() 542 | epoch_start_time2=time.time() 543 | if num_data_points==batch_num: 544 | break 545 | self.percent = batch_num/num_data_points 546 | loss_dict=None 547 | loss=0 548 | 549 | # accumulate gradients over multiple runs of input 550 | for _ in range(self.args.virtual_batch_multiplier): 551 | data_start = time.time() 552 | input, target = self.train_prefetcher.next() 553 | average_meters['data_time'].update(time.time() - data_start) 554 | loss_dict2,loss2 = self.train_batch(input,target) 555 | loss+=loss2 556 | if loss_dict is None: 557 | loss_dict=loss_dict2 558 | else: 559 | for key,value in loss_dict2.items(): 560 | loss_dict[key]+=value 561 | 562 | # divide by the number of accumulations 563 | loss/=self.args.virtual_batch_multiplier 564 | for key,value in loss_dict.items(): 565 | loss_dict[key]=value/self.args.virtual_batch_multiplier 566 | 567 | # do the weight updates and set gradients back to zero 568 | self.update() 569 | 570 | self.loss_history.append(float(loss)) 571 | ttest_p, z_diff = self.learning_rate_schedule() 572 | 573 | 574 | for name,value in loss_dict.items(): 575 | try: 576 | average_meters[name].update(value.data) 577 | except: 578 | average_meters[name].update(value) 579 | 580 | 581 | 582 | elapsed_time_for_epoch = (time.time()-epoch_start_time2) 583 | eta = (elapsed_time_for_epoch/(batch_num+.2))*(num_data_points-batch_num) 584 | if eta >= 24*3600: 585 | eta = 24*3600-1 586 | 587 | 588 | batch_num+=1 589 | current_learning_rate= get_average_learning_rate(self.optimizer) 590 | if True: 591 | 592 | to_print = {} 593 | to_print['ep']= ('{0}:').format(self.epoch) 594 | to_print['#/{0}'.format(num_data_points)]= ('{0}').format(batch_num) 595 | to_print['lr']= ('{0:0.3g}-{1:0.3g}').format(starting_learning_rate,current_learning_rate) 596 | to_print['eta']= ('{0}').format(time.strftime("%H:%M:%S", time.gmtime(int(eta)))) 597 | 598 | to_print['d%']=('{0:0.2g}').format(100*average_meters['data_time'].sum/elapsed_time_for_epoch) 599 | for name in display_values: 600 | meter = average_meters[name] 601 | to_print[name]= ('{meter.avg:.4g}').format(meter=meter) 602 | if batch_num < num_data_points-1: 603 | to_print['ETA']= ('{0}').format(time.strftime("%H:%M:%S", time.gmtime(int(eta+elapsed_time_for_epoch)))) 604 | to_print['ttest']= ('{0:0.3g},{1:0.3g}').format(z_diff,ttest_p) 605 | if batch_num % self.args.print_frequency == 0: 606 | print_table(self.progress_table+[[to_print]]) 607 | 608 | 609 | 610 | epoch_time = time.time()-epoch_start_time 611 | stats={'batches':num_data_points, 612 | 'learning_rate':current_learning_rate, 613 | 'Epoch time':epoch_time, 614 | } 615 | for name in display_values: 616 | meter = average_meters[name] 617 | stats[name] = meter.avg 618 | 619 | data_time = average_meters['data_time'].sum 620 | 621 | to_print['eta']= ('{0}').format(time.strftime("%H:%M:%S", time.gmtime(int(epoch_time)))) 622 | 623 | return [to_print], stats 624 | 625 | 626 | 627 | def train_batch(self, input, target): 628 | 629 | loss_dict = {} 630 | 631 | input = input.float() 632 | output = self.model(input) 633 | first_loss=None 634 | for c_name,criterion_fun in self.criteria.items(): 635 | if first_loss is None:first_loss=c_name 636 | loss_dict[c_name]=criterion_fun(output, target) 637 | 638 | loss = loss_dict[first_loss].clone() 639 | loss = loss / self.args.virtual_batch_multiplier 640 | 641 | if self.args.fp16: 642 | with amp.scale_loss(loss, self.optimizer) as scaled_loss: 643 | scaled_loss.backward() 644 | else: 645 | loss.backward() 646 | 647 | return loss_dict, loss 648 | 649 | 650 | def update(self): 651 | self.optimizer.step() 652 | self.optimizer.zero_grad() 653 | 654 | 655 | def validate(self, train_table): 656 | average_meters = defaultdict(AverageMeter) 657 | self.model.eval() 658 | epoch_start_time = time.time() 659 | batch_num=0 660 | num_data_points=len(self.val_loader) 661 | 662 | prefetcher = data_prefetcher(self.val_loader) 663 | torch.cuda.empty_cache() 664 | with torch.no_grad(): 665 | for i in range(len(self.val_loader)): 666 | input, target = prefetcher.next() 667 | 668 | 669 | if batch_num ==0: 670 | epoch_start_time2=time.time() 671 | 672 | output = self.model(input) 673 | 674 | 675 | loss_dict = {} 676 | 677 | for c_name,criterion_fun in self.criteria.items(): 678 | loss_dict[c_name]=criterion_fun(output, target) 679 | 680 | batch_num=i+1 681 | 682 | for name,value in loss_dict.items(): 683 | try: 684 | average_meters[name].update(value.data) 685 | except: 686 | average_meters[name].update(value) 687 | eta = ((time.time()-epoch_start_time2)/(batch_num+.2))*(len(self.val_loader)-batch_num) 688 | 689 | to_print = {} 690 | to_print['#/{0}'.format(num_data_points)]= ('{0}').format(batch_num) 691 | to_print['eta']= ('{0}').format(time.strftime("%H:%M:%S", time.gmtime(int(eta)))) 692 | for name in self.criteria.keys(): 693 | meter = average_meters[name] 694 | to_print[name]= ('{meter.avg:.4g}').format(meter=meter) 695 | progress=train_table+[to_print] 696 | if batch_num % self.args.print_frequency == 0: 697 | print_table(self.progress_table+[progress]) 698 | 699 | epoch_time = time.time()-epoch_start_time 700 | 701 | stats={'batches':len(self.val_loader), 702 | 'Epoch time':epoch_time, 703 | } 704 | ultimate_loss = None 705 | for name in self.criteria.keys(): 706 | meter = average_meters[name] 707 | stats[name]=meter.avg 708 | ultimate_loss = stats['Loss'] 709 | to_print['eta']= ('{0}').format(time.strftime("%H:%M:%S", time.gmtime(int(epoch_time)))) 710 | torch.cuda.empty_cache() 711 | return float(ultimate_loss), progress , stats 712 | 713 | def adjust_learning_rate(self): 714 | self.lr = self.lr0 * (0.50 ** (self.ticks)) 715 | self.set_learning_rate(self.lr) 716 | 717 | def set_learning_rate(self,lr): 718 | for param_group in self.optimizer.param_groups: 719 | param_group['lr'] = lr 720 | 721 | if __name__ == '__main__': 722 | #mp.set_start_method('forkserver') 723 | args = parser.parse_args() 724 | main(args) 725 | -------------------------------------------------------------------------------- /val_models.txt: -------------------------------------------------------------------------------- 1 | wainscott 2 | hobson 3 | convoy 4 | bettendorf 5 | tokeland 6 | klickitat 7 | kirwin 8 | channel 9 | emmaus 10 | lajas 11 | plessis 12 | lakeville 13 | kerrtown 14 | maugansville 15 | springerville 16 | stanleyville 17 | southfield 18 | wando 19 | jennie 20 | starks 21 | mcnary 22 | aulander 23 | denmark 24 | leavittsburg 25 | mahtomedi 26 | goffs 27 | maida 28 | tolstoy 29 | pomaria 30 | kankakee 31 | hornsby 32 | kangley 33 | bethlehem 34 | carpio 35 | everton 36 | copemish 37 | freedom 38 | tansboro 39 | lucan 40 | cantwell 41 | sands 42 | springhill 43 | donaldson 44 | kingdom 45 | deatsville 46 | capistrano 47 | leonardo 48 | gilbert 49 | cabin 50 | placida 51 | millbury 52 | cornville --------------------------------------------------------------------------------