├── Fig_2 ├── Fig_2a_to_f.m ├── Fig_2g_h.m ├── Fig_2i_j.ipynb ├── Fig_2k_l.m ├── Fig_2m_n.ipynb └── task │ ├── TSN.py │ └── __init__.py ├── Fig_3 ├── Fig_3a_to_d_f.m └── Fig_3e.m ├── Fig_4 └── Fig_4d_to_g.m ├── Fig_5 └── Fig_5_e_f.m ├── README.md └── Supplements ├── Fig_S1_KNN.m ├── Fig_S2_SNR_alpha.m ├── Fig_S3_correlated_biased.ipynb ├── Fig_S4 ├── Fig_MLE_Speed_VS_GroundT.m ├── N_100.mat └── N_10000.mat ├── Fig_S5 ├── CIFAR-10.ipynb ├── MNIST.ipynb ├── Tiny_imagenet.ipynb ├── models │ ├── __init__.py │ ├── densenet.py │ ├── dla.py │ ├── dla_simple.py │ ├── dpn.py │ ├── efficientnet.py │ ├── googlenet.py │ ├── lenet.py │ ├── mobilenet.py │ ├── mobilenetv2.py │ ├── pnasnet.py │ ├── preact_resnet.py │ ├── regnet.py │ ├── resnet.py │ ├── resnext.py │ ├── senet.py │ ├── shufflenet.py │ ├── shufflenetv2.py │ └── vgg.py └── pytorchtools.py ├── Fig_S6i_j.m └── Fig_S6l.m /Fig_2/Fig_2a_to_f.m: -------------------------------------------------------------------------------- 1 | % Main code for Student-Teacher-Notebook framework 2 | % Weinan Sun 10-10-2021 3 | 4 | tic 5 | close all 6 | clear all 7 | 8 | r_n = 20; % number of repeats 9 | nepoch = 2000; 10 | learnrate = 0.015; 11 | N_x_t = 100; % teacher input dimension 12 | N_y_t = 1; % teacher output dimension 13 | P=100; % number of training examples 14 | P_test = 1000; % number of testing examples 15 | 16 | % According to SNR, set variances for teacher's weights (variance_w) and output noise (variance_e) that sum to 1 17 | SNR = inf; 18 | 19 | if SNR == inf 20 | variance_w = 1; 21 | variance_e = 0; 22 | else 23 | variance_w = SNR/(SNR + 1); 24 | variance_e = 1/(SNR + 1); 25 | end 26 | 27 | % Student and teacher share the same dimensions 28 | N_x_s = N_x_t; 29 | N_y_s = N_y_t; 30 | 31 | % Notebook parameters 32 | % see Buhmann, Divko, and Schulten, 1989 for details regarding gamma and U terms 33 | 34 | M = 2000; % num of units in notebook 35 | a = 0.05; % notebook sparseness 36 | gamma = 0.6; % inhibtion parameter 37 | U = -0.15; % threshold for unit activation 38 | ncycle = 9; % number of recurrent cycles 39 | 40 | 41 | % Matrices for storing train error, test error, reactivation error (driven by notebook) 42 | % Without early stopping 43 | train_error_all = zeros(r_n,nepoch); % student train error 44 | test_error_all = zeros(r_n,nepoch); % student test error 45 | N_train_error_all = zeros(r_n,nepoch); % notebook train error 46 | N_test_error_all = zeros(r_n,nepoch); % notebook test error 47 | 48 | % With early stopping 49 | train_error_early_stop_all = zeros(r_n,nepoch); 50 | test_error_early_stop_all = zeros(r_n,nepoch); 51 | 52 | %Run simulation for r_n times 53 | for r = 1:r_n 54 | disp(r) 55 | rng(r); %set random seed for reproducibility 56 | 57 | %Errors 58 | error_train_vector = zeros(nepoch,1); 59 | error_test_vector = zeros(nepoch,1); 60 | error_react_vector = zeros(nepoch,1); 61 | 62 | %% Teacher Network 63 | W_t = normrnd(0,variance_w^0.5,[N_x_t,N_y_t]); % set teacher's weights with variance_w 64 | noise_train = normrnd(0,variance_e^0.5,[P,N_y_t]); % set the variance for label noise 65 | % Training data 66 | x_t_input = normrnd(0,(1/N_x_t)^0.5,[P,N_x_t]); % inputs 67 | y_t_output = x_t_input*W_t + noise_train; % outputs 68 | 69 | % Testing data 70 | noise_test = normrnd(0,variance_e^0.5,[P_test,N_y_t]); 71 | x_t_input_test = normrnd(0,(1/N_x_t)^0.5,[P_test,N_x_t]); 72 | y_t_output_test = x_t_input_test*W_t + noise_test; 73 | 74 | %% Notebook Network 75 | % Generate P random binary indices (0 or 1) with sparseness a 76 | N_patterns = zeros(P,M); 77 | for n=1:P 78 | N_patterns(n,randperm(M,M*a))=1; 79 | end 80 | 81 | %Hebbian learning for notebook recurrent weights 82 | W_N = (N_patterns - a)'*(N_patterns - a)/(M*a*(1-a)); 83 | W_N = W_N - gamma/(a*M);% add global inhibiton term, see Buhmann, Divko, and Schulten, 1989 84 | W_N = W_N.*~eye(size(W_N)); % no self connection 85 | 86 | % Hebbian learning for Notebook-Student weights (bidirectional) 87 | 88 | % Notebook to student weights, for reactivating student 89 | W_N_S_Lin = (N_patterns-a)'*x_t_input/(M*a*(1-a)); 90 | W_N_S_Lout = (N_patterns-a)'*y_t_output/(M*a*(1-a)); 91 | % Student to notebook weights, for providing partial cues 92 | W_S_N_Lin = x_t_input'*(N_patterns-a); 93 | W_S_N_Lout = y_t_output'*(N_patterns-a); 94 | 95 | %% Student Network 96 | W_s = normrnd(0,0^0.5,[N_x_s,N_y_s]); % set students weights, with zero weight initialization 97 | 98 | %% Generate offline training data from notebook reactivations 99 | N_patterns_reactivated = zeros(P,M,nepoch,'logical'); % array for storing retrieved notebook patterns, pre-calculating all epochs for speed considerations 100 | 101 | parfor m = 1:nepoch 102 | %for m = 1:nepoch % regular for-loop 103 | 104 | %% Notebook pattern completion through recurrent dynamis 105 | % Code below simulates hippocampal offline spontanenous 106 | % reactivations by seeding the initial notebook state with a random 107 | % binary pattern, then notebook goes through a two-step retrieval 108 | % process: (1) Retrieving a pattern using dynamic threshold to 109 | % ensure a pattern with sparseness a is retrieved (otherwise a silent 110 | % attractor will dominate retrieval). (2) Using the 111 | % retrieved pattern from (1) to seed a second round of pattern 112 | % completion using a fixed-threshold method (along with a global 113 | % inhibition term during encoding), so the retrieved patterns are 114 | % not forced to have a fixed sparseness, in addition, there is a 115 | % "silent state" attractor when the seeding pattern lies far away 116 | % from any of the encoded patterns. 117 | 118 | % Start recurrent cycles with dynamic threshold 119 | Activity_dyn_t = zeros(P, M); 120 | 121 | % First round of pattern completion through recurrent activtion cycles given 122 | % random initial input. 123 | for cycle = 1:ncycle 124 | if cycle <=1 125 | clamp = 1; 126 | else 127 | clamp = 0; 128 | end 129 | rand_patt = (rand(P,M)<=a); % random seeding activity 130 | % Seeding notebook with random patterns 131 | M_input = Activity_dyn_t + (rand_patt*clamp); 132 | % Seeding notebook with original patterns 133 | % M_input = Activity_dyn_t + (N_patterns*clamp); 134 | M_current = M_input*W_N; 135 | % scale currents between 0 and 1 136 | scale = 1.0 ./ (max(M_current,[],2) - min(M_current,[],2)); 137 | M_current = (M_current - min(M_current,[],2)) .* scale; 138 | % find threshold based on desired sparseness 139 | sorted_M_current = sort(M_current,2,'descend'); 140 | t_ind = floor(size(Activity_dyn_t,2) * a); 141 | t_ind(t_ind<1) = 1; 142 | t = sorted_M_current(:,t_ind); % threshold for unit activations 143 | Activity_dyn_t = (M_current >=t); 144 | end 145 | 146 | % Second round of pattern completion, with fix threshold 147 | Activity_fix_t = zeros(P, M); 148 | for cycle = 1:ncycle 149 | if cycle <=1 150 | clamp = 1; 151 | else 152 | clamp = 0; 153 | end 154 | M_input = Activity_fix_t + Activity_dyn_t*clamp; 155 | M_current = M_input*W_N; 156 | Activity_fix_t = (M_current >= U); % U is the fixed threshold 157 | end 158 | N_patterns_reactivated(:,:,m)=Activity_fix_t; 159 | end 160 | 161 | % Seeding notebook with original notebook patterns for calculating 162 | % training error mediated by notebook (seeding notebook with student 163 | % input via Student's input to Notebook weights, once pattern completion 164 | % finishes, use the retrieved pattern to activate Student's output unit 165 | % via Notebook to Student's output weights. 166 | 167 | Activity_notebook_train = zeros(P, M); 168 | for cycle = 1:ncycle 169 | if cycle <=1 170 | clamp = 1; 171 | else 172 | clamp = 0; 173 | end 174 | seed_patt = x_t_input*W_S_N_Lin; 175 | M_input = Activity_notebook_train + (seed_patt*clamp); 176 | M_current = M_input*W_N; 177 | scale = 1.0 ./ (max(M_current,[],2) - min(M_current,[],2)); 178 | M_current = (M_current - min(M_current,[],2)) .* scale; 179 | sorted_M_current = sort(M_current,2,'descend'); 180 | t_ind = floor(size(Activity_notebook_train,2) * a); 181 | t_ind(t_ind<1) = 1; 182 | t = sorted_M_current(:,t_ind); 183 | Activity_notebook_train = (M_current >=t); 184 | end 185 | N_S_output_train = Activity_notebook_train*W_N_S_Lout; 186 | % Notebook training error 187 | delta_N_train = y_t_output - N_S_output_train; 188 | error_N_train = sum(delta_N_train.^2)/P; 189 | % Since notebook errors stay constant throughout training, 190 | % populating each epoch with the same error value 191 | error_N_train_vector = ones(nepoch,1)*error_N_train; 192 | N_train_error_all(r,:) = error_N_train_vector; 193 | 194 | % Notebook generalization error 195 | Activity_notebook_test = zeros(P_test, M); 196 | for cycle = 1:ncycle 197 | if cycle <=1 198 | clamp = 1; 199 | else 200 | clamp = 0; 201 | end 202 | seed_patt = x_t_input_test*W_S_N_Lin; 203 | M_input = Activity_notebook_test + (seed_patt*clamp); 204 | M_current = M_input*W_N; 205 | scale = 1.0 ./ (max(M_current,[],2) - min(M_current,[],2)); 206 | M_current = (M_current - min(M_current,[],2)) .* scale; 207 | sorted_M_current = sort(M_current,2,'descend'); 208 | t_ind = floor(size(Activity_notebook_test,2) * a); 209 | t_ind(t_ind<1) = 1; 210 | t = sorted_M_current(:,t_ind); 211 | Activity_notebook_test = (M_current >=t); 212 | end 213 | N_S_output_test = Activity_notebook_test*W_N_S_Lout; 214 | % Notebook test error 215 | delta_N_test = y_t_output_test - N_S_output_test; 216 | error_N_test = sum(delta_N_test.^2)/P_test; 217 | % populating each epoch with the same error value 218 | error_N_test_vector = ones(nepoch,1)*error_N_test; 219 | N_test_error_all(r,:) = error_N_test_vector; 220 | 221 | 222 | % N_patterns_reactivated_test = zeros(P_test,M,'logical'); 223 | %% Student training through offline notebook reactivations at each epoch 224 | for m = 1:nepoch 225 | 226 | N_S_input = N_patterns_reactivated(:,:,m)*W_N_S_Lin; % notebook reactivated student input activity 227 | N_S_output = N_patterns_reactivated(:,:,m)*W_N_S_Lout; % notebook reactivated student output activity 228 | N_S_prediction = N_S_input*W_s; % student output prediction calculated by notebook reactivated input and student weights 229 | S_prediction = x_t_input*W_s; % student output prediction calculated by true training inputs and student weights 230 | S_prediction_test = x_t_input_test*W_s; % student output prediction calculated by true testing inputs and student weights 231 | 232 | % Train error 233 | delta_train = y_t_output - S_prediction; 234 | error_train = sum(delta_train.^2)/P; 235 | error_train_vector(m) = error_train; 236 | 237 | % Generalization error 238 | delta_test = y_t_output_test - S_prediction_test; 239 | error_test = sum(delta_test.^2)/P_test; 240 | error_test_vector(m) = error_test; 241 | 242 | % Gradient descent 243 | w_delta = N_S_input'*N_S_output - N_S_input'*N_S_input*W_s; 244 | W_s = W_s + learnrate*w_delta; 245 | end 246 | 247 | train_error_all(r,:) = error_train_vector; 248 | test_error_all(r,:) = error_test_vector; 249 | 250 | % Early stopping 251 | [min_v, min_p] = min(error_test_vector); 252 | train__error_early_stop = error_train_vector; 253 | train__error_early_stop (min_p+1:end) = error_train_vector (min_p); 254 | test_error_early_stop = error_test_vector; 255 | test_error_early_stop (min_p+1:end) = error_test_vector (min_p); 256 | train_error_early_stop_all(r,:) = train__error_early_stop; 257 | test_error_early_stop_all(r,:) = test_error_early_stop; 258 | end 259 | 260 | toc 261 | 262 | 263 | %% Plotting 264 | color_scheme = [137 152 193; 245 143 136]/255; 265 | line_w = 2; 266 | font_s = 12; 267 | 268 | % Without early stopping 269 | figure(1) 270 | hold on 271 | plot(1:nepoch,mean(train_error_all),'color',color_scheme(1,:),'LineWidth',line_w) 272 | plot(1:nepoch,mean(test_error_all),'color',color_scheme(2,:),'LineWidth',line_w) 273 | plot(1:nepoch,mean(N_train_error_all),'b--','LineWidth',2) 274 | plot(1:nepoch,mean(N_test_error_all),'r--','LineWidth',2) 275 | 276 | set(gca, 'FontSize', font_s) 277 | set(gca, 'FontSize', font_s) 278 | xlabel('Epoch','Color','k') 279 | ylabel('Error','Color','k') 280 | xlim([0 nepoch]) 281 | ylim([0 2]) 282 | set(gca,'linewidth',1) 283 | set(gcf,'position',[100,100,350,300]) 284 | 285 | % Save plot 286 | % saveas(gcf,strcat('Errors_No_ES_','SNR_',num2str(SNR),'_',date,'.pdf')); 287 | 288 | % With early stopping 289 | figure(2) 290 | hold on 291 | plot(1:nepoch,mean(train_error_early_stop_all),'color',color_scheme(1,:),'LineWidth',line_w) 292 | plot(1:nepoch,mean(test_error_early_stop_all),'color',color_scheme(2,:),'LineWidth',line_w) 293 | plot(1:nepoch,mean(N_train_error_all),'b--','LineWidth',line_w) 294 | plot(1:nepoch,mean(N_test_error_all),'r--','LineWidth',line_w) 295 | 296 | set(gca, 'FontSize', font_s) 297 | set(gca, 'FontSize', font_s) 298 | xlabel('Epoch','Color','k') 299 | ylabel('Error','Color','k') 300 | xlim([0 nepoch]) 301 | ylim([0 2]) 302 | set(gca,'linewidth',1) 303 | set(gcf,'position',[600,100,350,300]) 304 | 305 | % Save plot 306 | % saveas(gcf,strcat('Errors_ES_','SNR_',num2str(SNR),'_',date,'.pdf')); 307 | -------------------------------------------------------------------------------- /Fig_2/Fig_2g_h.m: -------------------------------------------------------------------------------- 1 | % code for Student-Teacher-Notebook framework,analytical curves 2 | % Weinan Sun 10-10-2021 3 | close all 4 | clear all 5 | 6 | nepoch = 2000; 7 | learnrate = 0.005; 8 | N_x_t = 100; 9 | N_y_t = 1; 10 | P=100; 11 | M = 5000; %num of units in notebook 12 | 13 | Et_big_early_stop = []; 14 | Eg_big_early_stop = []; 15 | Et_big = []; 16 | Eg_big = []; 17 | 18 | %SNR values sampled from log2 space 19 | SNR_log_interval = -4:0.5:4; 20 | SNR_vec =2.^SNR_log_interval; 21 | 22 | 23 | Notebook_Train = (P-1)/(M-1); % Analytial solution for notebook training error, see supplementary material for derivations. 24 | 25 | for count = 1:size(SNR_vec,2) 26 | disp(count) 27 | SNR = SNR_vec(count); 28 | Eg = []; 29 | Et = []; 30 | lr = learnrate; 31 | 32 | %use normalized variances 33 | if SNR == inf 34 | variance_w = 1; 35 | variance_e = 0; 36 | else 37 | variance_w = SNR/(SNR + 1); 38 | variance_e = 1/(SNR + 1); 39 | end 40 | 41 | alpha = P/N_x_t; % number of examples divided by input dimension 42 | 43 | % Analytical curves for training and testing errors, see supplementary 44 | % material for derivations 45 | 46 | for t = 1:1:nepoch 47 | 48 | train = @(lam) ( ( ( ((alpha^0.5+1).^2 - lam) .* (lam - (alpha^0.5-1).^2) ).^0.5) ./ (lam*2*pi) ).* ( lam.*variance_w + variance_e ).*exp(-2*lam.*t./(1./lr)) ; 49 | Et = [Et (1/alpha)*(integral(train,(alpha^0.5-1)^2,(alpha^0.5+1)^2) + (alpha<1)*(1 - alpha)* variance_e ) + (1-1/alpha)*variance_e]; 50 | 51 | test = @(lam) ( ( ( ((alpha^0.5+1).^2 - lam) .* (lam - (alpha^0.5-1).^2) ).^0.5) ./ (lam*2*pi) ).*(exp(-2*lam*t/(1/lr)) + ((1-exp(-lam*t/(1/lr))).^2)./(lam*SNR)); 52 | Eg = [Eg variance_w*(integral(test,(alpha^0.5-1)^2,(alpha^0.5+1)^2) + (alpha<1)* (1 - alpha) + 1/SNR)]; 53 | 54 | end 55 | 56 | % Early stopping curves 57 | [mm, pp] = min(Eg); 58 | Eg_early_stop = Eg; 59 | Eg_early_stop(pp+1:end) = Eg(pp); 60 | 61 | Et_early_stop = Et; 62 | Et_early_stop(pp+1:end) = Et(pp); 63 | 64 | % Pick better training error value and convert to memory generalization 65 | % scores 66 | better_train_no_early_stop = min(Et,ones(1,nepoch)*Notebook_Train); 67 | control_curve = (Et(1) - better_train_no_early_stop)/Et(1); 68 | lesion_curve = (Et(1) - Et)/Et(1); 69 | 70 | 71 | better_train_yes_early_stop = min(Et_early_stop,ones(1,nepoch)*Notebook_Train); 72 | control_curve_early_stop = (Et_early_stop(1) - better_train_yes_early_stop)/Et_early_stop(1); 73 | lesion_curve_early_stop = (Et_early_stop(1) - Et_early_stop)/Et_early_stop(1); 74 | 75 | 76 | control_Eg_curve = (Eg(1) - Eg)/Eg(1); 77 | control_Eg_curve_early_stop = (Eg_early_stop(1) - Eg_early_stop)/Eg_early_stop(1); 78 | 79 | 80 | 81 | figure(1) 82 | 83 | hold on; 84 | plot(1:1:nepoch,Et,'-','color',[0 0 1 count/20],'LineWidth',2) 85 | plot(1:1:nepoch,Eg,'-','color',[1 0 0 count/20],'LineWidth',2) 86 | 87 | set(gca, 'FontSize', 12) 88 | set(gca, 'FontSize', 12) 89 | xlabel('Epoch') 90 | ylabel('Error') 91 | set(gca,'linewidth',1.5) 92 | ylim([0 2.5]) 93 | 94 | figure(2) 95 | 96 | hold on; 97 | plot(1:1:nepoch,Et_early_stop,'-','color',[0 0 1 count/20],'LineWidth',2) 98 | plot(1:1:nepoch,Eg_early_stop,'-','color',[1 0 0 count/20],'LineWidth',2) 99 | 100 | set(gca, 'FontSize',12) 101 | set(gca, 'FontSize', 12) 102 | xlabel('Epoch') 103 | ylabel('Error') 104 | set(gca,'linewidth',1.5) 105 | ylim([0 2.5]) 106 | 107 | 108 | end 109 | 110 | %% save figures 111 | % figure(1) 112 | % set(gcf,'position',[100,100,350,290]) 113 | % saveas(gcf,strcat('Fig_2g','.pdf')); 114 | % figure(2) 115 | % set(gcf,'position',[100,100,350,290]) 116 | % saveas(gcf,strcat('Fig_2h','.pdf')); 117 | 118 | 119 | 120 | -------------------------------------------------------------------------------- /Fig_2/Fig_2k_l.m: -------------------------------------------------------------------------------- 1 | %% Clear workspace and initialize random number generator 2 | 3 | clear all; 4 | RandStream.setGlobalStream(RandStream('mt19937ar','seed',sum(100*clock))); 5 | 6 | %% Construct an ensemble of teacher-generated student data. 7 | 8 | num_teachers = 1; 9 | N = 1000; 10 | P = 1000; 11 | std_X = sqrt(1/N); 12 | 13 | SNR_vec = -4:0.02:4; 14 | SNRs = 2.^SNR_vec; 15 | 16 | num_SNRs = length(SNRs); 17 | 18 | wbars = NaN(N,num_SNRs,num_teachers); 19 | Xs = NaN(P,N,num_SNRs,num_teachers); 20 | etas = NaN(P,num_SNRs,num_teachers); 21 | ys = NaN(P,num_SNRs,num_teachers); 22 | for s = 1:num_SNRs; 23 | S = SNRs(s); 24 | std_noise = sqrt(1/(1+S)); 25 | std_weights = sqrt(S/(1+S)); 26 | for t = 1:num_teachers; 27 | wbars(:,s,t) = std_weights*randn(N,1); 28 | 29 | Xs(:,:,s,t) = std_X*randn(P,N); 30 | etas(:,s,t) = std_noise*randn(P,1); 31 | ys(:,s,t) = Xs(:,:,s,t)*wbars(:,s,t)+etas(:,s,t); 32 | end; 33 | end; 34 | 35 | var_ys = squeeze(var(ys,0,1)); 36 | var_ws = squeeze(var(wbars,0,1)); 37 | var_noises = squeeze(var(etas,0,1)); 38 | empirical_SNRs = (var_ys-var_noises)./var_noises; 39 | 40 | mean_empirical_SNRs = mean(empirical_SNRs,2); 41 | std_empirical_SNRs = std(empirical_SNRs,0,2); 42 | 43 | %% Plot basic stuff 44 | 45 | %close all; 46 | 47 | % figure(1) 48 | % hold on; 49 | % errorbar(SNRs,mean_empirical_SNRs,std_empirical_SNRs,'LineWidth',3) 50 | % plot(SNRs, empirical_SNRs) 51 | % hold off; 52 | %set(gca,'XScale','log') 53 | %set(gca,'YScale','log') 54 | 55 | %% Evaluate the log-likelihood function for each dataset 56 | % % DON'T USE THIS VERSION FOR LARGE N OR P. 57 | % 58 | % tic; 59 | % min_SNR = 0.02; 60 | % max_SNR = 15; 61 | % step_SNR = 0.02; 62 | % SNR_vector = min_SNR:step_SNR:max_SNR; 63 | % num_step_SNR = length(SNR_vector); 64 | % 65 | % log_likelihood_function = NaN(num_step_SNR,num_SNRs,num_teachers); 66 | % max_log_likelihood = NaN(num_SNRs,num_teachers); 67 | % estimated_SNRs = NaN(num_SNRs,num_teachers); 68 | % for s = 1:num_SNRs; 69 | % strcat('On SNR:',num2str(s)) 70 | % toc 71 | % for t = 1:num_teachers; 72 | % for i = 1:num_step_SNR; 73 | % %inv_C_matrix = (1+SNR_vector(i))*(eye(P)-Xs(:,:,s,t)*inv(eye(N)/SNR_vector(i)+Xs(:,:,s,t)'*Xs(:,:,s,t))*Xs(:,:,s,t)'); 74 | % inv_C_matrix = (1+SNR_vector(i))*inv(eye(P)+SNR_vector(i)*Xs(:,:,s,t)*Xs(:,:,s,t)'); 75 | % log_likelihood_function(i,s,t) = 1/2*log(det(inv_C_matrix)) - 1/2*ys(:,s,t)'*inv_C_matrix*ys(:,s,t); 76 | % %log_likelihood_function(i,s,t) = -1/2*ys(:,s,t)'*inv_C_matrix*ys(:,s,t); 77 | % %log_likelihood_function(i,s,t) = log(det(inv_C_matrix)); 78 | % end; 79 | % [max_log_likelihood(s,t) tmp_idx] = max(log_likelihood_function(:,s,t)); 80 | % estimated_SNRs(s,t) = SNR_vector(tmp_idx); 81 | % end; 82 | % end; 83 | 84 | 85 | %% Plot the log_likelihood functions 86 | 87 | %close all; 88 | % 89 | % figure(2) 90 | % for s = 1:num_SNRs; 91 | % for t = 1:2; 92 | % subplot(num_SNRs,2,(s-1)*2+t) 93 | % hold on; 94 | % plot(SNR_vector,log_likelihood_function(:,s,t)); 95 | % plot(empirical_SNRs(s,t),max_log_likelihood(s,t),'ob') 96 | % plot(estimated_SNRs(s,t),max_log_likelihood(s,t),'or','LineWidth',2) 97 | % end; 98 | % end; 99 | % 100 | % figure(10) 101 | % plot(empirical_SNRs(:),estimated_SNRs(:),'ok') 102 | 103 | %% Do an SVD decomposition on the data matrix. 104 | 105 | 106 | Us = NaN(P,P,num_SNRs,num_teachers); 107 | Ss = NaN(P,N,num_SNRs,num_teachers); 108 | Vs = NaN(N,N,num_SNRs,num_teachers); 109 | Lambdas = NaN(P,P,num_SNRs,num_teachers); 110 | inv_Lambdas = NaN(P,P,num_SNRs,num_teachers); 111 | parfor s = 1:num_SNRs; 112 | strcat('On SNR:',num2str(s)) 113 | 114 | for t = 1:num_teachers; 115 | [Us(:,:,s,t) Ss(:,:,s,t) Vs(:,:,s,t)] = svd(Xs(:,:,s,t)); 116 | Lambdas(:,:,s,t) = Ss(:,:,s,t)*Ss(:,:,s,t)'; 117 | inv_Lambdas(:,:,s,t) = pinv(Lambdas(:,:,s,t)); 118 | end; 119 | end; 120 | 121 | %% Evaluate the log-likelihood function for each dataset using the SVD decomposition 122 | % More numerically efficient version. 123 | 124 | tic; 125 | num_components = min(N,P); 126 | 127 | 128 | SNR_vector = 2.^(-6:0.02:6); 129 | num_step_SNR = length(SNR_vector); 130 | 131 | log_likelihood_function = NaN(num_step_SNR,num_SNRs,num_teachers); 132 | max_log_likelihood = NaN(num_SNRs,num_teachers); 133 | estimated_SNRs = NaN(num_SNRs,num_teachers); 134 | for s = 1:num_SNRs; 135 | strcat('On SNR:',num2str(s)) 136 | toc 137 | for t = 1:num_teachers; 138 | y_modes = Us(:,:,s,t)'*ys(:,s,t); 139 | for i = 1:num_step_SNR; 140 | inv_C_matrix_modes = (1+SNR_vector(i))*eye(P); 141 | A1 = (P-num_components)/2*log(1+SNR_vector(i)); 142 | A2 = 0; 143 | for j = 1:num_components; 144 | inv_C_matrix_modes(j,j) = (1+SNR_vector(i))/(1+SNR_vector(i)*Lambdas(j,j,s,t)); 145 | A1 = A1 + 1/2*log(inv_C_matrix_modes(j,j)); 146 | A2 = A2 - 1/2*y_modes(j)^2*inv_C_matrix_modes(j,j); 147 | end; 148 | if gt(P,num_components) 149 | for j = 1:P; 150 | A2 = A2 - 1/2*y_modes(j)^2*inv_C_matrix_modes(j,j); 151 | end; 152 | end; 153 | %inv_C_matrix_modes = (1+SNR_vector(i))*inv(eye(P)+SNR_vector(i)*Lambdas(:,:,s,t)); 154 | %log_likelihood_function(i,s,t) = 1/2*log(det(inv_C_matrix_modes)) - 1/2*y_modes'*inv_C_matrix_modes*y_modes; 155 | log_likelihood_function(i,s,t) = A1 + A2; 156 | %log_likelihood_function(i,s,t) = -1/2*ys(:,s,t)'*inv_C_matrix*ys(:,s,t); 157 | %log_likelihood_function(i,s,t) = log(det(inv_C_matrix)); 158 | end; 159 | [max_log_likelihood(s,t) tmp_idx] = max(log_likelihood_function(:,s,t)); 160 | estimated_SNRs(s,t) = SNR_vector(tmp_idx); 161 | end; 162 | end; 163 | 164 | %% Plot the log_likelihood functions 165 | 166 | %close all; 167 | 168 | % figure(3) 169 | % for s = 1:num_SNRs; 170 | % for t = 1:2; 171 | % subplot(num_SNRs,2,(s-1)*2+t) 172 | % hold on; 173 | % 174 | % plot(SNR_vector,log_likelihood_function(:,s,t)); 175 | % 176 | % plot(SNR_vector(s),max_log_likelihood(s,t),'ob') 177 | % plot(estimated_SNRs(s,t),max_log_likelihood(s,t),'or','LineWidth',2) 178 | % end; 179 | % end; 180 | 181 | figure(400) 182 | plot(log2(SNRs),log2(estimated_SNRs(:,1)),'ok') 183 | 184 | 185 | set(gcf,'position',[500 500 300 300]) 186 | xlim([-6,6]) 187 | ylim([-6,6]) 188 | 189 | figure(500) 190 | hold on 191 | plot(SNR_vector,log_likelihood_function(:,end)) 192 | plot(SNRs(end),max_log_likelihood(end),'ob') 193 | plot(estimated_SNRs(end),max_log_likelihood(end),'or','LineWidth',2) 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | -------------------------------------------------------------------------------- /Fig_2/task/TSN.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import torch 4 | from torch import nn, optim 5 | import matplotlib.pyplot as plt 6 | 7 | 8 | class TSN(): 9 | 10 | def __init__(self, input_dim, output_dim, P, P_test, SNR, lr = 0.05, rand_seed = 101): 11 | self.N = input_dim 12 | self.output_dim = output_dim 13 | self.P = P 14 | self.P_test = P_test 15 | self.variance_w = snr_to_var(SNR)[0] 16 | self.variance_e = snr_to_var(SNR)[1] 17 | self.lr = lr 18 | torch.manual_seed(rand_seed) 19 | noise_train = torch.normal(0, self.variance_e**0.5, size = [self.P,1]) 20 | noise_test = torch.normal(0, self.variance_e**0.5, size = [self.P_test,1]) 21 | W_t = torch.normal(0, self.variance_w**0.5, size=(self.N, 1)) 22 | self.train_x = torch.normal(0, (1/self.N)**0.5, size=(self.P, self.N)) 23 | self.train_y = torch.matmul(self.train_x, W_t) + noise_train 24 | self.test_x = torch.normal(0, (1/self.N)**0.5, size=(self.P_test, self.N)) 25 | self.test_y = torch.matmul(self.test_x, W_t) + noise_test 26 | 27 | def training_loop(self, nepoch = 1000, reg_strength = 0): 28 | 29 | model = nn.Sequential(nn.Linear(self.N,self.output_dim,bias=False)) 30 | optimizer = optim.SGD(model.parameters(), lr=self.lr*(self.N/2), weight_decay = reg_strength) # to match the learnrate in the matlab implementation. 31 | optimizer.zero_grad() 32 | 33 | with torch.no_grad(): 34 | list(model.parameters())[0].zero_() 35 | 36 | Et = np.zeros((nepoch,1)) 37 | Eg = np.zeros((nepoch,1)) 38 | print_iter = 200 39 | 40 | for i in range(nepoch): 41 | optimizer.zero_grad() 42 | train_error = criterion(self.train_y, model(self.train_x)) 43 | train_error.backward() 44 | optimizer.step() 45 | Et[i] = train_error.detach().numpy() 46 | with torch.no_grad(): 47 | test_error = criterion(self.test_y, model(self.test_x)) 48 | Eg[i] = test_error.numpy() 49 | # if i%print_iter == 0: 50 | # print('Et '+ str(Et[i].item()),'Eg '+ str(Eg[i].item()),str(i*100/nepoch) + '% Finished') 51 | 52 | return Et, Eg 53 | 54 | class TSN_validation(): 55 | 56 | def __init__(self, input_dim, output_dim, P, P_test, SNR, lr = 0.05, rand_seed = 101): 57 | self.N = input_dim 58 | self.output_dim = output_dim 59 | self.P = P 60 | self.P_test = P_test 61 | self.variance_w = snr_to_var(SNR)[0] 62 | self.variance_e = snr_to_var(SNR)[1] 63 | self.lr = 0.05 64 | torch.manual_seed(rand_seed) 65 | noise_train = torch.normal(0, self.variance_e**0.5, size = [self.P,1]) 66 | noise_test = torch.normal(0, self.variance_e**0.5, size = [self.P_test,1]) 67 | W_t = torch.normal(0, self.variance_w**0.5, size=(self.N, 1)) 68 | self.train_x = torch.normal(0, (1/self.N)**0.5, size=(self.P, self.N)) 69 | self.train_y = torch.matmul(self.train_x, W_t) + noise_train 70 | self.test_x = torch.normal(0, (1/self.N)**0.5, size=(self.P_test, self.N)) 71 | self.test_y = torch.matmul(self.test_x, W_t) + noise_test 72 | 73 | def training_loop(self, nepoch = 2000, reg_strength = 1e-2): 74 | 75 | model = nn.Sequential(nn.Linear(self.N,self.output_dim,bias=False)) 76 | optimizer = optim.SGD(model.parameters(), lr=self.lr*(self.N/2), weight_decay= reg_strength) # to match the learnrate in the matlab implementation. 77 | optimizer.zero_grad() 78 | 79 | with torch.no_grad(): 80 | list(model.parameters())[0].zero_() 81 | 82 | Et = np.zeros((nepoch,1)) 83 | Eg = np.zeros((nepoch,1)) 84 | print_iter = 200 85 | 86 | for i in range(nepoch): 87 | optimizer.zero_grad() 88 | train_error = criterion(self.train_y, model(self.train_x)) 89 | train_error.backward() 90 | optimizer.step() 91 | Et[i] = train_error.detach().numpy() 92 | with torch.no_grad(): 93 | test_error = criterion(self.test_y, model(self.test_x)) 94 | Eg[i] = test_error.numpy() 95 | # if i%print_iter == 0: 96 | # print('Et '+ str(Et[i].item()),'Eg '+ str(Eg[i].item()),str(i*100/nepoch) + '% Finished') 97 | 98 | return Et, Eg 99 | 100 | def snr_to_var(SNR): 101 | if SNR == np.inf: 102 | variance_w = 1 103 | variance_e = 0 104 | else: 105 | variance_w = SNR/(SNR + 1) 106 | variance_e = 1/(SNR + 1) 107 | return variance_w, variance_e 108 | 109 | def criterion(y, y_hat): 110 | return (y - y_hat).pow(2).mean() 111 | -------------------------------------------------------------------------------- /Fig_2/task/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuroai/Go-CLS_v2/a5f90b73c61088f0a154414af05ea1e54122c06d/Fig_2/task/__init__.py -------------------------------------------------------------------------------- /Fig_3/Fig_3a_to_d_f.m: -------------------------------------------------------------------------------- 1 | % code for Student-Teacher-Notebook framework, analytical curves 2 | % Weinan Sun 10-10-2021 3 | close all 4 | clear all 5 | 6 | nepoch = 2000; 7 | learnrate = 0.005; 8 | N_x_t = 100; 9 | N_y_t = 1; 10 | P=100; 11 | M = 5000; %num of units in notebook 12 | 13 | Et_big_early_stop = []; 14 | Eg_big_early_stop = []; 15 | Et_big = []; 16 | Eg_big = []; 17 | 18 | Et_lesion_remote = []; 19 | Eg_lesion_remote = []; 20 | 21 | %SNR values sampled from log2 space 22 | SNR_log_interval = -4:0.5:4; 23 | SNR_vec =2.^SNR_log_interval; 24 | 25 | 26 | Notebook_Train = (P-1)/(M-1); % Analytial solution for notebook training error, see supplementary material for derivations. 27 | 28 | for count = 1:size(SNR_vec,2) 29 | disp(count) 30 | SNR = SNR_vec(count); 31 | Eg = []; 32 | Et = []; 33 | lr = learnrate; 34 | 35 | %use normalized variances 36 | if SNR == inf 37 | variance_w = 1; 38 | variance_e = 0; 39 | else 40 | variance_w = SNR/(SNR + 1); 41 | variance_e = 1/(SNR + 1); 42 | end 43 | 44 | alpha = P/N_x_t; % number of examples divided by input dimension 45 | 46 | % Analytical curves for training and testing errors, see supplementary 47 | % material for derivations 48 | 49 | for t = 1:1:nepoch 50 | 51 | train = @(lam) ( ( ( ((alpha^0.5+1).^2 - lam) .* (lam - (alpha^0.5-1).^2) ).^0.5) ./ (lam*2*pi) ).* ( lam.*variance_w + variance_e ).*exp(-2*lam.*t./(1./lr)) ; 52 | Et = [Et (1/alpha)*(integral(train,(alpha^0.5-1)^2,(alpha^0.5+1)^2) + (alpha<1)*(1 - alpha)* variance_e ) + (1-1/alpha)*variance_e]; 53 | 54 | test = @(lam) ( ( ( ((alpha^0.5+1).^2 - lam) .* (lam - (alpha^0.5-1).^2) ).^0.5) ./ (lam*2*pi) ).*(exp(-2*lam*t/(1/lr)) + ((1-exp(-lam*t/(1/lr))).^2)./(lam*SNR)); 55 | Eg = [Eg variance_w*(integral(test,(alpha^0.5-1)^2,(alpha^0.5+1)^2) + (alpha<1)* (1 - alpha) + 1/SNR)]; 56 | 57 | end 58 | 59 | % Early stopping curves 60 | [mm, pp] = min(Eg); 61 | Eg_early_stop = Eg; 62 | Eg_early_stop(pp+1:end) = Eg(pp); 63 | 64 | Et_early_stop = Et; 65 | Et_early_stop(pp+1:end) = Et(pp); 66 | 67 | % Pick better training error value and convert to memory generalization 68 | % scores 69 | better_train_no_early_stop = min(Et,ones(1,nepoch)*Notebook_Train); 70 | control_curve = (Et(1) - better_train_no_early_stop)/Et(1); 71 | lesion_curve = (Et(1) - Et)/Et(1); 72 | 73 | 74 | better_train_yes_early_stop = min(Et_early_stop,ones(1,nepoch)*Notebook_Train); 75 | control_curve_early_stop = (Et_early_stop(1) - better_train_yes_early_stop)/Et_early_stop(1); 76 | lesion_curve_early_stop = (Et_early_stop(1) - Et_early_stop)/Et_early_stop(1); 77 | 78 | 79 | control_Eg_curve = (Eg(1) - Eg)/Eg(1); 80 | control_Eg_curve_early_stop = (Eg_early_stop(1) - Eg_early_stop)/Eg_early_stop(1); 81 | 82 | 83 | Et_lesion_remote = [Et_lesion_remote lesion_curve_early_stop(end)]; 84 | Eg_lesion_remote = [Eg_lesion_remote control_Eg_curve_early_stop(end)]; 85 | 86 | figure(3) 87 | 88 | hold on; 89 | plot(1:1:nepoch,control_curve,'k-','LineWidth',3) 90 | plot(1:1:nepoch,lesion_curve,'c-','LineWidth',3) 91 | % plot([5 1800],[control_curve(1) control_curve(1800)],'o-','color',[0 0 1]) 92 | plot([5 1800],[lesion_curve(1) lesion_curve(1800)],'o-','color',[1 0 0]) % pick a early and a late point as recent and remote memory 93 | 94 | set(gca, 'FontSize', 12) 95 | set(gca, 'FontSize', 12) 96 | xlabel('Epoch', 'FontSize',12) 97 | ylabel('Memory Score', 'FontSize',12) 98 | xlim([0 nepoch]) 99 | ylim([0 1]) 100 | set(gca,'linewidth',1.5) 101 | 102 | 103 | figure(4) 104 | 105 | hold on; 106 | plot(1:1:nepoch,control_curve_early_stop,'k-','LineWidth',3) 107 | plot(1:1:nepoch,lesion_curve_early_stop,'c-','LineWidth',3) 108 | % plot([5 1800],[control_curve_early_stop(1) control_curve_early_stop(1800)],'o-','color',[0 0 1]) 109 | plot([5 1800],[lesion_curve_early_stop(1) lesion_curve_early_stop(1800)],'o-','color',[1 0 0]) 110 | ylim([0 1]) 111 | 112 | set(gca, 'FontSize', 12) 113 | set(gca, 'FontSize', 12) 114 | xlabel('Epoch', 'FontSize',12) 115 | ylabel('Memory Score', 'FontSize',12) 116 | xlim([0 nepoch]) 117 | ylim([0 1]) 118 | set(gca,'linewidth',1.5) 119 | 120 | figure(5) 121 | 122 | hold on; 123 | plot(1:1:nepoch,control_Eg_curve,'g-','LineWidth',3) 124 | % plot([5 1800],[control_Eg_curve(1) control_Eg_curve(1800)],'o-','color',[0 0 1]) 125 | ylim([0 1]) 126 | 127 | set(gca, 'FontSize', 12) 128 | set(gca, 'FontSize', 12) 129 | xlabel('Epoch', 'FontSize',12) 130 | ylabel('Generalization Score', 'FontSize',12) 131 | xlim([0 nepoch]) 132 | ylim([-1 1]) 133 | set(gca,'linewidth',1.5) 134 | 135 | figure(6) 136 | hold on; 137 | plot(1:1:nepoch,control_Eg_curve_early_stop,'g-','LineWidth',3) 138 | % plot([5 1800],[control_Eg_curve_early_stop(1) control_Eg_curve_early_stop(1800)],'o-','color',[0 0 1]) 139 | ylim([-1 1]) 140 | set(gca, 'FontSize', 12) 141 | set(gca, 'FontSize', 12) 142 | xlabel('Epoch', 'FontSize',12) 143 | ylabel('Generalization Score', 'FontSize',12) 144 | xlim([0 nepoch]) 145 | ylim([-1 1]) 146 | set(gca,'linewidth',1.5) 147 | 148 | end 149 | 150 | figure(7) 151 | plot(Et_lesion_remote,Eg_lesion_remote,'ko') 152 | xlabel('Memory Score (Notebook lesioned)', 'FontSize',12) 153 | ylabel('Generalization Score', 'FontSize',12) 154 | 155 | %% save figures 156 | 157 | % figure(3) 158 | % set(gcf,'position',[100,100,350,290]) 159 | % saveas(gcf,strcat('Fig_3a','.pdf')); 160 | % figure(4) 161 | % set(gcf,'position',[100,100,350,290]) 162 | % saveas(gcf,strcat('Fig_3b','.pdf')); 163 | % figure(5) 164 | % set(gcf,'position',[100,100,350,290]) 165 | % saveas(gcf,strcat('Fig_3c','.pdf')); 166 | % figure(6) 167 | % set(gcf,'position',[100,100,350,290]) 168 | % saveas(gcf,strcat('Fig_3d','.pdf')); 169 | % figure(7) 170 | % set(gcf,'position',[100,100,350,290]) 171 | % saveas(gcf,strcat('Fig_3f','.pdf')); 172 | 173 | 174 | -------------------------------------------------------------------------------- /Fig_3/Fig_3e.m: -------------------------------------------------------------------------------- 1 | % code for Student-Teacher-Notebook framework 2 | % diversity of amnesia curves 3 | 4 | close all 5 | clear all 6 | 7 | saveplot = false; 8 | nepoch = 2000; 9 | learnrate = 0.005; 10 | N_x_t = 100; 11 | N_y_t = 1; 12 | P=100; 13 | M = 5000; %num of units in notebook 14 | 15 | Et_big_early_stop = []; 16 | Eg_big_early_stop = []; 17 | Et_big = []; 18 | Eg_big = []; 19 | 20 | SNR_vec = [0.01 0.1 0.3 1 8 inf]; 21 | 22 | Notebook_Train = (P-1)/(M-1); % Analytial solution for notebook training error 23 | 24 | % Simulation 1, varying SNR 25 | 26 | for count = 1:size(SNR_vec,2) 27 | 28 | SNR = SNR_vec(count); 29 | Eg = []; 30 | Et = []; 31 | lr = learnrate; 32 | %use normalized variances 33 | if SNR == inf 34 | variance_w = 1; 35 | variance_e = 0; 36 | else 37 | variance_w = SNR/(SNR + 1); 38 | variance_e = 1/(SNR + 1); 39 | end 40 | 41 | alpha = P/N_x_t; % number of examples divided by input dimension 42 | 43 | % Analytical curves for training and testing errors, see supplementary 44 | % material for derivations 45 | for t = 1:1:nepoch 46 | 47 | train = @(lam) ( ( ( ((alpha^0.5+1).^2 - lam) .* (lam - (alpha^0.5-1).^2) ).^0.5) ./ (lam*2*pi) ).* ( lam.*variance_w + variance_e ).*exp(-2*lam.*t./(1./lr)) ; 48 | Et = [Et (1/alpha)*(integral(train,(alpha^0.5-1)^2,(alpha^0.5+1)^2) + (alpha<1)*(1 - alpha)* variance_e ) + (1-1/alpha)*variance_e]; 49 | 50 | test = @(lam) ( ( ( ((alpha^0.5+1).^2 - lam) .* (lam - (alpha^0.5-1).^2) ).^0.5) ./ (lam*2*pi) ).*(exp(-2*lam*t/(1/lr)) + ((1-exp(-lam*t/(1/lr))).^2)./(lam*SNR)); 51 | Eg = [Eg variance_w*(integral(test,(alpha^0.5-1)^2,(alpha^0.5+1)^2) + (alpha<1)* (1 - alpha) + 1/SNR)]; 52 | 53 | end 54 | 55 | % Early stopping curves 56 | [min_Eg, pos] = min(Eg); 57 | Eg_ES = Eg; 58 | Eg_ES(pos+1:end) = Eg(pos); 59 | 60 | Et_ES = Et; 61 | Et_ES(pos+1:end) = Et(pos); 62 | 63 | % Pick better training error and create memory and generalization 64 | % scores 65 | better_Et = min(Et,ones(1,nepoch)*Notebook_Train); 66 | control_memory_score = (Et(1) - better_Et)/Et(1); 67 | lesion_memory_score = (Et(1) - Et)/Et(1); 68 | 69 | %with early stopping 70 | better_Et_ES = min(Et_ES,ones(1,nepoch)*Notebook_Train); 71 | control_memory_score_ES = (Et_ES(1) - better_Et_ES)/Et_ES(1); 72 | lesion_memory_score_ES = (Et_ES(1) - Et_ES)/Et_ES(1); 73 | 74 | 75 | control_generalization_score = (Eg(1) - Eg)/Eg(1); 76 | control_generalization_score_ES = (Eg_ES(1) - Eg_ES)/Eg_ES(1); 77 | 78 | figure(1) 79 | 80 | hold on; 81 | if SNR == inf 82 | % picking an early and late epoch as "recent" and 'remote' 83 | plot([1 1800],[control_memory_score_ES(1) control_memory_score_ES(1800)],'o-','color',[0 0 0]) 84 | end 85 | if SNR ~= inf 86 | plot([1 1800],[lesion_memory_score_ES(1) lesion_memory_score_ES(1800)],'o-','color',[0 0 0]) 87 | end 88 | set(gca, 'FontSize', 12) 89 | set(gca, 'FontSize', 12) 90 | xlabel('Epoch', 'FontSize',12) 91 | ylabel('Memory Score', 'FontSize',12) 92 | xlim([-300 nepoch + 100]) 93 | ylim([-0.1 1.15]) 94 | set(gca,'linewidth',1.5) 95 | 96 | end 97 | 98 | % Simulation 2, varying prior consolidation 99 | 100 | for start_epoch = [8 20 40 60 2000] % amount of prior learning epochs 101 | SNR = 50; 102 | P = 300; 103 | Eg = []; 104 | Et = []; 105 | lr = learnrate; 106 | 107 | if SNR == inf 108 | variance_w = 1; 109 | variance_e = 0; 110 | else 111 | variance_w = SNR/(SNR + 1); 112 | variance_e = 1/(SNR + 1); 113 | end 114 | 115 | alpha = P/N_x_t; % number of examples divided by input dimension 116 | 117 | 118 | % Analytical curves for training and testing errors 119 | for t = 1:1:(nepoch + start_epoch) 120 | 121 | train = @(lam) ( ( ( ((alpha^0.5+1).^2 - lam) .* (lam - (alpha^0.5-1).^2) ).^0.5) ./ (lam*2*pi) ).* ( lam.*variance_w + variance_e ).*exp(-2*lam.*t./(1./lr)) ; 122 | Et = [Et (1/alpha)*(integral(train,(alpha^0.5-1)^2,(alpha^0.5+1)^2) + (alpha<1)*(1 - alpha)* variance_e ) + (1-1/alpha)*variance_e]; 123 | 124 | test = @(lam) ( ( ( ((alpha^0.5+1).^2 - lam) .* (lam - (alpha^0.5-1).^2) ).^0.5) ./ (lam*2*pi) ).*(exp(-2*lam*t/(1/lr)) + ((1-exp(-lam*t/(1/lr))).^2)./(lam*SNR)); 125 | Eg = [Eg variance_w*(integral(test,(alpha^0.5-1)^2,(alpha^0.5+1)^2) + (alpha<1)* (1 - alpha) + 1/SNR)]; 126 | 127 | end 128 | 129 | % Early stopping curves 130 | [min_Eg, pos] = min(Eg); 131 | Eg_ES = Eg; 132 | Eg_ES(pos+1:end) = Eg(pos); 133 | 134 | Et_ES = Et; 135 | Et_ES(pos+1:end) = Et(pos); 136 | 137 | % Pick better training error value and create memory and generalization 138 | % scores 139 | better_Et = min(Et,ones(1,nepoch + start_epoch)*Notebook_Train); 140 | control_memory_score = (Et(1) - better_Et)/Et(1); 141 | lesion_memory_score = (Et(1) - Et)/Et(1); 142 | 143 | 144 | better_Et_ES = min(Et_ES,ones(1,nepoch + start_epoch)*Notebook_Train); 145 | control_memory_score_ES = (Et_ES(1) - better_Et_ES)/Et_ES(1); 146 | lesion_memory_score_ES = (Et_ES(1) - Et_ES)/Et_ES(1); 147 | 148 | 149 | control_generalization_score = (Eg(1) - Eg)/Eg(1); 150 | control_generalization_score_ES = (Eg_ES(1) - Eg_ES)/Eg_ES(1); % this assumes only using student for generalization 151 | lesion_Eg_curve_early_stop = (Eg_ES(1) - Eg_ES)/Eg_ES(1); 152 | 153 | 154 | figure(1) 155 | 156 | hold on; 157 | 158 | %"Recent" memory performance with prior rule-consistent consolidation can 159 | %be captured by the generalization error. That is, given certain amount 160 | %of prior learning, how well can the network predict new examples? 161 | %"Remote" memory performance is the training error at convergence. 162 | 163 | plot([1 1800],[lesion_Eg_curve_early_stop(start_epoch) lesion_memory_score_ES(end)],'o--','color',[0 0 0]) 164 | 165 | set(gca, 'FontSize', 12) 166 | set(gca, 'FontSize', 12) 167 | xlabel('Epoch', 'FontSize',12) 168 | ylabel('Memory Score', 'FontSize',12) 169 | xlim([-300 nepoch + 100]) 170 | ylim([-0.1 1.15]) 171 | set(gca,'linewidth',1.5) 172 | 173 | end 174 | 175 | 176 | f = gca; 177 | f.XTick = [1 1800]; 178 | f.XTickLabel = [{'Recent'} {'Remote'}]; 179 | set(gcf,'position',[100,100,400,600]) 180 | 181 | 182 | 183 | -------------------------------------------------------------------------------- /Fig_4/Fig_4d_to_g.m: -------------------------------------------------------------------------------- 1 | %Generalization and memorization performance as a 2 | %function of alpha and SNR, for student-only, notebook-only, and combined 3 | %system 4 | 5 | close all; clear all; clc 6 | 7 | %% Batch 8 | 9 | nepoch = 10000; 10 | lr = 0.01; 11 | SNR_log_interval = -2:0.1:3; 12 | SNR_vec =10.^SNR_log_interval; 13 | alpha_vec= 0.1:0.1:5; 14 | 15 | Et_ES = zeros(length(SNR_vec),length(alpha_vec),nepoch); 16 | Eg_ES = zeros(length(SNR_vec),length(alpha_vec),nepoch); 17 | Eg_no_ES = zeros(length(SNR_vec),length(alpha_vec),nepoch); 18 | Et_no_ES = zeros(length(SNR_vec),length(alpha_vec),nepoch); 19 | 20 | for n = 1:length(SNR_vec) 21 | for m = 1:length(alpha_vec) 22 | SNR = SNR_vec(n); 23 | alpha = alpha_vec(m); 24 | 25 | if SNR == inf 26 | variance_w = 1; 27 | variance_e = 0; 28 | else 29 | variance_w = SNR/(SNR + 1); 30 | variance_e = 1/(SNR + 1); 31 | end 32 | 33 | 34 | parfor t = 1:nepoch 35 | 36 | 37 | train = @(lam) ( ( ( ((alpha^0.5+1).^2 - lam) .* (lam - (alpha^0.5-1).^2) ).^0.5) ./ (lam*2*pi) ).* ( lam.*variance_w + variance_e ).*exp(-2*lam.*t./(1./lr)) ; 38 | Et_no_ES(n,m,t) = (1/alpha)*(integral(train,(alpha^0.5-1)^2,(alpha^0.5+1)^2) + (alpha<1)*(1 - alpha)* variance_e ) + (1-1/alpha)*variance_e; 39 | 40 | test = @(lam) ( ( ( ((alpha^0.5+1).^2 - lam) .* (lam - (alpha^0.5-1).^2) ).^0.5) ./ (lam*2*pi) ).*(exp(-2*lam*t/(1/lr)) + ((1-exp(-lam*t/(1/lr))).^2)./(lam*SNR)); 41 | Eg_no_ES(n,m,t) = variance_w*(integral(test,(alpha^0.5-1)^2,(alpha^0.5+1)^2) + (alpha<1)* (1 - alpha) + 1/SNR); 42 | 43 | 44 | end 45 | 46 | [min_Eg, ES_time] = min(Eg_no_ES(n,m,:)); 47 | 48 | Eg_ES(n,m,:) = Eg_no_ES(n,m,:); 49 | Eg_ES(n,m,ES_time+1:end) = Eg_no_ES(n,m,ES_time); 50 | 51 | Et_ES(n,m,:) = Et_no_ES(n,m,:); 52 | Et_ES(n,m,ES_time+1:end) = Et_no_ES(n,m,ES_time); 53 | end 54 | 55 | end 56 | 57 | %% Optimal online student 58 | 59 | Eg_opt_online = zeros(length(SNR_vec),length(alpha_vec)); 60 | 61 | 62 | 63 | parfor n = 1:length(SNR_vec) 64 | SNR = SNR_vec(n); 65 | variance_e = 1/(SNR + 1); 66 | [alpha, Eg] = ode45(@(alpha,Eg)eg_prime(alpha,Eg,variance_e),(0.1:0.1:5),1); 67 | Eg_new = interp1(alpha,Eg,alpha_vec); 68 | Eg_opt_online(n,:) = Eg_new; 69 | 70 | end 71 | 72 | figure (1) 73 | hold on 74 | plot(alpha_vec,Eg_opt_online(end,:),'b-') 75 | plot(alpha_vec,Eg_ES(end,:,end),'r-') 76 | plot(alpha_vec,ones(1,length(alpha_vec))*2,'m') 77 | ylim([-0.05 2.1]) 78 | set(gcf,'position',[500 500 420 420]) 79 | ax = gca; 80 | ax.XTick = [0 1 2 3 4 5]; 81 | ax.YTick = [0 0.5 1 1.5 2]; 82 | %saveas(gcf,'Fig_4d.pdf'); 83 | 84 | figure (2) 85 | imagesc(flip(Eg_opt_online(:,:) - Eg_ES(:,:,end))) 86 | colormap(redblue) 87 | caxis([-0.3 0.3]) 88 | colorbar 89 | set(gcf,'position',[500 500 420 325]) 90 | %saveas(gcf,'Fig_4e.pdf'); 91 | 92 | figure (3) 93 | hold on 94 | plot(alpha_vec,Eg_no_ES(25,:,end),'b-') 95 | plot(alpha_vec,Eg_ES(25,:,end),'r-') 96 | ylim([0 2.5]) 97 | set(gcf,'position',[500 500 420 420]) 98 | %saveas(gcf,'Fig_4f.pdf'); 99 | 100 | figure (4) 101 | imagesc(flip(Eg_ES(:,:,end)-Eg_no_ES(:,:,end))) 102 | colormap(redblue) 103 | caxis([-1 1]) 104 | colorbar 105 | set(gcf,'position',[500 500 420 325]) 106 | %saveas(gcf,'Fig_4g.pdf'); 107 | 108 | function dEg = eg_prime(alpha,Eg,variance_e) 109 | dEg = 2*variance_e - Eg - variance_e^2/Eg; 110 | end 111 | 112 | 113 | 114 | 115 | 116 | -------------------------------------------------------------------------------- /Fig_5/Fig_5_e_f.m: -------------------------------------------------------------------------------- 1 | %Fig.5 e and f, noisy, complex, and partial observable teachers 2 | tic 3 | close all 4 | clear all 5 | 6 | r_n = 1000; % number of repeats 7 | nepoch = 500; 8 | learnrate = 0.1; 9 | N_x_t = 100; 10 | N_y_t = 1; 11 | P=100; 12 | p_test = 1000; 13 | 14 | %Set student dimensions 15 | N_x_s = N_x_t; 16 | N_y_s = N_y_t; 17 | 18 | % N_x_t_wide = 117 for a partial observability level matches the sin complex teacher (SNR = 5.74) 19 | % For Fig.5e, set N_x_t_wide = 100 for Fig. 5f. See supplementary material, 20 | % complex teacher section. 21 | N_x_t_wide = N_x_t + 17; 22 | 23 | 24 | Error_vector_big_noisy = zeros(r_n,nepoch); 25 | gError_vector_big_noisy = zeros(r_n,nepoch); 26 | 27 | Error_vector_big_com = zeros(r_n,nepoch); 28 | gError_vector_big_com = zeros(r_n,nepoch); 29 | 30 | Error_vector_big_partial = zeros(r_n,nepoch); 31 | gError_vector_big_partial = zeros(r_n,nepoch); 32 | 33 | 34 | 35 | 36 | parfor r = 1:r_n 37 | 38 | rng(r) 39 | 40 | Error_vector_noisy = zeros(nepoch,1); 41 | gError_vector_noisy = zeros(nepoch,1); 42 | 43 | 44 | Error_vector_com = zeros(nepoch,1); 45 | gError_vector_com = zeros(nepoch,1); 46 | 47 | Error_vector_partial = zeros(nepoch,1); 48 | gError_vector_partial = zeros(nepoch,1); 49 | 50 | 51 | 52 | %% Teacher Network 53 | 54 | w_t = normrnd(0,1^0.5,[N_x_t,N_y_t]); 55 | w_t_wide = normrnd(0,1^0.5,[N_x_t_wide,N_y_t]); 56 | 57 | %training data 58 | x_t_input = normrnd(0,(1/N_x_t)^0.5,[P,N_x_t]); %input patterns for noisy and complex teachers 59 | x_t_input_wide = normrnd(0,(1/N_x_t)^0.5,[P,N_x_t_wide]); %input patterns for partially observable teacher 60 | y_t_output_wide = x_t_input_wide*w_t_wide; %output of the partially observable teacher 61 | observable_mask = randperm(N_x_t_wide,100); % define the region of teacher input observable to the student 62 | x_t_input_observable = x_t_input_wide(:,observable_mask); 63 | y_t_output_complex = sin(x_t_input*w_t); % complex teacher output, using sine function for Fig. 5e, remove sine function for Fig. 5f. 64 | 65 | % Below are terms for calculating equivalent SNR of a complex teacher. 66 | % See supplementary material, complex teacher section for details. 67 | x_t_input_big = normrnd(0,(1/N_x_t)^0.5,[P*1000,N_x_t]); 68 | y_t_output_complex_big = sin(x_t_input_big*w_t); 69 | % Variance of the optimal linear weight for fitting the training data by a complex 70 | % teacher. 71 | variance_w_opt = var(inv(x_t_input_big'*x_t_input_big)*(x_t_input_big'*y_t_output_complex_big)); 72 | % variance of the residue after linear fitting 73 | variance_c = mean((y_t_output_complex - x_t_input*inv(x_t_input_big'*x_t_input_big)*(x_t_input_big'*y_t_output_complex_big)).^2); 74 | 75 | 76 | w_t_opt = normrnd(0,variance_w_opt^0.5,[N_x_t,N_y_t]); % set weight variance for the noisy teacher 77 | noise = normrnd(0,variance_c^0.5,[P,N_y_t]); % set training data noise variance for the noisy teacher 78 | noise1 = normrnd(0,variance_c^0.5,[p_test,N_y_t]);% set testing data noise variance for the noisy teacher 79 | y_t_output =x_t_input*w_t_opt + noise; %output of noisy teacher 80 | 81 | %test sets 82 | x_t_input_new = normrnd(0,(1/N_x_t)^0.5,[p_test,N_x_t]); 83 | y_t_output_new = x_t_input_new*w_t_opt + noise1; 84 | y_t_output_new_complex = sin(x_t_input_new*w_t); 85 | 86 | x_t_input_wide_new = normrnd(0,(1/N_x_t)^0.5,[p_test,N_x_t_wide]); 87 | x_t_input_wide_new_observable = x_t_input_wide_new(:,observable_mask); 88 | y_t_output_wide_new = x_t_input_wide_new*w_t_wide; 89 | 90 | 91 | %% Student Network 92 | 93 | w_s_noisy = normrnd(0,0^0.5,[N_x_s,N_y_s]); 94 | w_s_complex = normrnd(0,0^0.5,[N_x_s,N_y_s]); 95 | w_s_partial = normrnd(0,0^0.5,[N_x_s,N_y_s]); 96 | 97 | 98 | 99 | 100 | for m = 1:nepoch 101 | 102 | 103 | 104 | y_s_output_noisy = x_t_input*w_s_noisy; 105 | y_s_output_new_noisy = x_t_input_new*w_s_noisy; 106 | 107 | y_s_output_batch_com = x_t_input*w_s_complex; 108 | y_s_output_new_batch_com = x_t_input_new*w_s_complex; 109 | 110 | y_s_output_batch_partial = x_t_input_observable*w_s_partial; 111 | y_s_output_new_batch_partial = x_t_input_wide_new_observable*w_s_partial; 112 | 113 | 114 | %noisy teacher train error 115 | Error_noisy = y_t_output - y_s_output_noisy; 116 | MSE_noisy = sum(Error_noisy.^2)/P; 117 | Error_vector_noisy(m) = MSE_noisy; 118 | 119 | %noisy teacher generalization error 120 | gError_noisy = y_t_output_new - y_s_output_new_noisy; 121 | gMSE_noisy = sum(gError_noisy.^2)/p_test; 122 | gError_vector_noisy(m) = gMSE_noisy; 123 | 124 | 125 | %complex teacher train error 126 | Error_com = y_t_output_complex - y_s_output_batch_com; 127 | MSE_com = sum(Error_com.^2)/P; 128 | Error_vector_com(m) = MSE_com; 129 | 130 | %complex teacher generalization error 131 | gError_com = y_t_output_new_complex - y_s_output_new_batch_com; 132 | gMSE_com = sum(gError_com.^2)/p_test; 133 | gError_vector_com(m) = gMSE_com; 134 | 135 | %partial observable teacher train error 136 | Error_partial = y_t_output_wide - y_s_output_batch_partial; 137 | Cost_partial = sum(Error_partial.^2)/P; 138 | Error_vector_partial(m) = Cost_partial; 139 | 140 | %partial observable teacher generalization error 141 | gError_partial = y_t_output_wide_new - y_s_output_new_batch_partial; 142 | gCost_partial = sum(gError_partial.^2)/p_test; 143 | gError_vector_partial(m) = gCost_partial; 144 | 145 | % Weight updates through gradient descent for teach teacher 146 | w_delta_noisy = (x_t_input'*y_t_output - x_t_input'*x_t_input*w_s_noisy); 147 | w_s_noisy = w_s_noisy + learnrate*w_delta_noisy; 148 | 149 | w_delta_com = (x_t_input'*y_t_output_complex - x_t_input'*x_t_input*w_s_complex); 150 | w_s_complex = w_s_complex + learnrate*w_delta_com; 151 | 152 | w_delta_partial = (x_t_input_observable'*y_t_output_wide - x_t_input_observable'* x_t_input_observable*w_s_partial); 153 | w_s_partial = w_s_partial + learnrate*w_delta_partial; 154 | end 155 | 156 | Error_vector_big_noisy(r,:) = Error_vector_noisy; 157 | gError_vector_big_noisy(r,:) = gError_vector_noisy; 158 | 159 | Error_vector_big_com(r,:) = Error_vector_com; 160 | gError_vector_big_com(r,:) = gError_vector_com; 161 | 162 | Error_vector_big_partial(r,:) = Error_vector_partial; 163 | gError_vector_big_partial(r,:) = gError_vector_partial; 164 | 165 | end 166 | 167 | toc 168 | 169 | color_scheme = [137 152 193; 245 143 136]/255; 170 | line_w = 1; 171 | font_s = 12; 172 | 173 | figure(1) 174 | hold on 175 | 176 | 177 | plot(1:nepoch,mean(Error_vector_big_noisy)/mean(Error_vector_big_noisy(:,1)),'r-') 178 | plot(1:nepoch,mean(gError_vector_big_noisy)/mean(gError_vector_big_noisy(:,1)),'r-') 179 | 180 | 181 | plot(1:nepoch,mean(Error_vector_big_com)/mean(Error_vector_big_com(:,1)),'k-') 182 | plot(1:nepoch,mean(gError_vector_big_com)/mean(gError_vector_big_com(:,1)),'k-') 183 | 184 | plot(1:nepoch,mean(Error_vector_big_partial)/mean(Error_vector_big_partial(:,1)),'g-') 185 | plot(1:nepoch,mean(gError_vector_big_partial)/mean(gError_vector_big_partial(:,1)),'g-') 186 | 187 | 188 | 189 | xt = get(gca, 'XTick'); 190 | set(gca, 'FontSize', font_s) 191 | yt = get(gca, 'YTick'); 192 | set(gca, 'FontSize', font_s) 193 | xlabel('Epoch','Color','k') 194 | ylabel('Error','Color','k') 195 | set(gcf,'position',[100,100,360,225]) 196 | xlim([0 nepoch]) 197 | 198 | % print(gcf,'complex_teacher_low_SNR.png','-dpng','-r600'); 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Organizing memories for generalization in complementary learning systems 2 | --- 3 | Code for reproducing figures in: 4 |
5 |
Organizing memories for generalization in complementary learning systems 6 |
Weinan Sun, Madhu Advani, Nelson Spruston, Andrew Saxe, James E Fitzgerald 7 |
bioRxiv 2021.10.13.463791; doi: https://doi.org/10.1101/2021.10.13.463791 8 | 9 |
Matlab function required: 10 | Redblue colormap 11 |
https://www.mathworks.com/matlabcentral/fileexchange/25536-red-blue-colormap 12 | -------------------------------------------------------------------------------- /Supplements/Fig_S1_KNN.m: -------------------------------------------------------------------------------- 1 | %Generalization error of KNN regression 2 | 3 | close all; 4 | clear all; 5 | 6 | 7 | R = 10; % number of repeats 8 | range_N = [3 10 30 100 300 1000 3000 10000]; % # of dimensions 9 | 10 | mse_train = zeros(R,length(range_N)); 11 | mse_test = zeros(R,length(range_N)); 12 | 13 | N_y_t = 1; %output dimension 14 | P_test = 1000; 15 | 16 | SNR = inf; 17 | 18 | if SNR == inf 19 | variance_w = 1; 20 | variance_e = 0; 21 | else 22 | variance_w = SNR/(SNR + 1); 23 | variance_e = 1/(SNR + 1); 24 | end 25 | 26 | 27 | K=1; % number of nearest neighbors 28 | 29 | parfor repeat = 1:R 30 | 31 | mse_train_row = zeros(1,length(range_N)); 32 | mse_test_row = zeros(1,length(range_N)); 33 | 34 | for counter = 1:length(range_N) 35 | N_x_t = range_N(counter); 36 | P = N_x_t; % keeping alpha = 1 37 | train_error = zeros (1,P); 38 | test_error = zeros(1,P_test); 39 | 40 | w_t = normrnd(0,variance_w^0.5,[N_x_t,N_y_t]); 41 | 42 | %Generate patterns 43 | 44 | x_t_input = normrnd(0,(1/N_x_t)^0.5,[P,N_x_t]); 45 | noise = normrnd(0,variance_e^0.5,[P,N_y_t]); 46 | y_t_output = x_t_input*w_t + noise; 47 | 48 | 49 | x_t_input_new = normrnd(0,(1/N_x_t)^0.5,[P_test,N_x_t]); 50 | noise1 = normrnd(0,variance_e^0.5,[P_test,N_y_t]); 51 | y_t_output_new = x_t_input_new*w_t+ noise1; 52 | 53 | 54 | %Distance matrix 55 | 56 | D_train = pdist2(x_t_input,x_t_input); 57 | D_test = pdist2(x_t_input_new,x_t_input); 58 | 59 | %Train error 60 | for n = 1:P 61 | [B,I] = mink(D_train(n,:),K); 62 | retrieved_x = mean(x_t_input(I,:)); 63 | retrieved_y = mean(y_t_output(I)); 64 | train_error(n) = (y_t_output(n) - retrieved_y)^2; 65 | end 66 | 67 | %Test error 68 | for n = 1:P_test 69 | [B,I] = mink(D_test(n,:),K); 70 | retrieved_x = mean(x_t_input(I,:)); 71 | retrieved_y = mean(y_t_output(I)); 72 | test_error(n) = (y_t_output_new(n) - retrieved_y)^2; 73 | end 74 | 75 | mse_train_row(counter) = mean(train_error); 76 | mse_test_row(counter) = mean(test_error); 77 | 78 | end 79 | mse_train(repeat,:) = mse_train_row 80 | mse_test(repeat,:) = mse_test_row 81 | end 82 | 83 | 84 | figure(1) 85 | 86 | plot(log10(range_N),mean(mse_test,1),'ro-','LineWidth',3) 87 | 88 | xt = get(gca, 'XTick'); 89 | set(gca, 'FontSize', 20) 90 | yt = get(gca, 'YTick'); 91 | set(gca, 'FontSize', 20) 92 | ylim([0 2.5]) -------------------------------------------------------------------------------- /Supplements/Fig_S2_SNR_alpha.m: -------------------------------------------------------------------------------- 1 | %Generalization and memorization performance as a 2 | %function of alpha and SNR. 3 | 4 | close all; clear all; clc 5 | 6 | nepoch = 10000; 7 | lr = 0.01; 8 | SNR_log_interval = -2:0.1:3; 9 | SNR_vec =10.^SNR_log_interval; 10 | alpha_vec= 0.1:0.1:5; 11 | 12 | Et_ES = zeros(length(SNR_vec),length(alpha_vec),nepoch); 13 | Eg_ES = zeros(length(SNR_vec),length(alpha_vec),nepoch); 14 | Eg_no_ES = zeros(length(SNR_vec),length(alpha_vec),nepoch); 15 | Et_no_ES = zeros(length(SNR_vec),length(alpha_vec),nepoch); 16 | 17 | for n = 1:length(SNR_vec) 18 | for m = 1:length(alpha_vec) 19 | SNR = SNR_vec(n); 20 | alpha = alpha_vec(m); 21 | 22 | if SNR == inf 23 | variance_w = 1; 24 | variance_e = 0; 25 | else 26 | variance_w = SNR/(SNR + 1); 27 | variance_e = 1/(SNR + 1); 28 | end 29 | 30 | 31 | parfor t = 1:nepoch 32 | 33 | 34 | train = @(lam) ( ( ( ((alpha^0.5+1).^2 - lam) .* (lam - (alpha^0.5-1).^2) ).^0.5) ./ (lam*2*pi) ).* ( lam.*variance_w + variance_e ).*exp(-2*lam.*t./(1./lr)) ; 35 | Et_no_ES(n,m,t) = (1/alpha)*(integral(train,(alpha^0.5-1)^2,(alpha^0.5+1)^2) + (alpha<1)*(1 - alpha)* variance_e ) + (1-1/alpha)*variance_e; 36 | 37 | test = @(lam) ( ( ( ((alpha^0.5+1).^2 - lam) .* (lam - (alpha^0.5-1).^2) ).^0.5) ./ (lam*2*pi) ).*(exp(-2*lam*t/(1/lr)) + ((1-exp(-lam*t/(1/lr))).^2)./(lam*SNR)); 38 | Eg_no_ES(n,m,t) = variance_w*(integral(test,(alpha^0.5-1)^2,(alpha^0.5+1)^2) + (alpha<1)* (1 - alpha) + 1/SNR); 39 | 40 | 41 | end 42 | 43 | [min_Eg, ES_time] = min(Eg_no_ES(n,m,:)); 44 | 45 | Eg_ES(n,m,:) = Eg_no_ES(n,m,:); 46 | Eg_ES(n,m,ES_time+1:end) = Eg_no_ES(n,m,ES_time); 47 | 48 | Et_ES(n,m,:) = Et_no_ES(n,m,:); 49 | Et_ES(n,m,ES_time+1:end) = Et_no_ES(n,m,ES_time); 50 | end 51 | 52 | end 53 | 54 | figure(1) 55 | imagesc(flip((Et_no_ES(:,:,1)-Et_no_ES(:,:,end))./Et_no_ES(:,:,1))) 56 | colormap(redblue) 57 | caxis([-1 1]) 58 | colorbar 59 | set(gcf,'position',[500 500 440 325]) 60 | 61 | figure(2) 62 | imagesc(flip((Eg_no_ES(:,:,1)-Eg_no_ES(:,:,end))./Eg_no_ES(:,:,1))) 63 | colormap(redblue) 64 | caxis([-1 1]) 65 | colorbar 66 | set(gcf,'position',[500 500 440 325]) 67 | 68 | figure(3) 69 | imagesc(flip((Et_ES(:,:,1)-Et_ES(:,:,end))./Et_ES(:,:,1))) 70 | colormap(redblue) 71 | caxis([-1 1]) 72 | colorbar 73 | set(gcf,'position',[500 500 440 325]) 74 | 75 | figure(4) 76 | imagesc(flip(Eg_ES(:,:,1)-Eg_ES(:,:,end))./Eg_ES(:,:,1)) 77 | colormap(redblue) 78 | caxis([-1 1]) 79 | colorbar 80 | set(gcf,'position',[500 500 440 325]) 81 | 82 | 83 | -------------------------------------------------------------------------------- /Supplements/Fig_S4/Fig_MLE_Speed_VS_GroundT.m: -------------------------------------------------------------------------------- 1 | %% Clear workspace and initialize random number generator 2 | 3 | clear all; 4 | close all; 5 | RandStream.setGlobalStream(RandStream('mt19937ar','seed',sum(100*clock))); 6 | 7 | %% Construct an ensemble of teacher-generated student data. 8 | 9 | num_teachers = 10; 10 | N = 100; 11 | P = 100; 12 | std_X = sqrt(1/N); 13 | 14 | SNRs = [0.05 4 100]; 15 | 16 | num_SNRs = length(SNRs); 17 | 18 | wbars = NaN(N,num_SNRs,num_teachers); 19 | Xs = NaN(P,N,num_SNRs,num_teachers); 20 | etas = NaN(P,num_SNRs,num_teachers); 21 | ys = NaN(P,num_SNRs,num_teachers); 22 | for s = 1:num_SNRs; 23 | S = SNRs(s); 24 | std_noise = sqrt(1/(1+S)); 25 | std_weights = sqrt(S/(1+S)); 26 | for t = 1:num_teachers; 27 | wbars(:,s,t) = std_weights*randn(N,1); 28 | 29 | Xs(:,:,s,t) = std_X*randn(P,N); 30 | etas(:,s,t) = std_noise*randn(P,1); 31 | ys(:,s,t) = Xs(:,:,s,t)*wbars(:,s,t)+etas(:,s,t); 32 | end; 33 | end; 34 | 35 | var_ys = squeeze(var(ys,0,1)); 36 | var_ws = squeeze(var(wbars,0,1)); 37 | var_noises = squeeze(var(etas,0,1)); 38 | empirical_SNRs = (var_ys-var_noises)./var_noises; 39 | 40 | mean_empirical_SNRs = mean(empirical_SNRs,2); 41 | std_empirical_SNRs = std(empirical_SNRs,0,2); 42 | 43 | %% Do an SVD decomposition on the data matrix. 44 | 45 | 46 | Us = NaN(P,P,num_SNRs,num_teachers); 47 | Ss = NaN(P,N,num_SNRs,num_teachers); 48 | Vs = NaN(N,N,num_SNRs,num_teachers); 49 | Lambdas = NaN(P,P,num_SNRs,num_teachers); 50 | inv_Lambdas = NaN(P,P,num_SNRs,num_teachers); 51 | for s = 1:num_SNRs; 52 | strcat('On SNR:',num2str(s)) 53 | 54 | for t = 1:num_teachers; 55 | [Us(:,:,s,t) Ss(:,:,s,t) Vs(:,:,s,t)] = svd(Xs(:,:,s,t)); 56 | Lambdas(:,:,s,t) = Ss(:,:,s,t)*Ss(:,:,s,t)'; 57 | inv_Lambdas(:,:,s,t) = pinv(Lambdas(:,:,s,t)); 58 | end; 59 | end; 60 | 61 | %% Evaluate the log-likelihood function for each dataset using the SVD decomposition 62 | % More numerically efficient version. 63 | 64 | tic; 65 | num_components = min(N,P); 66 | 67 | 68 | SNR_vector = 2.^(-6:0.02:7); 69 | num_step_SNR = length(SNR_vector); 70 | 71 | log_likelihood_function = NaN(num_step_SNR,num_SNRs,num_teachers); 72 | max_log_likelihood = NaN(num_SNRs,num_teachers); 73 | estimated_SNRs = NaN(num_SNRs,num_teachers); 74 | for s = 1:num_SNRs; 75 | strcat('On SNR:',num2str(s)) 76 | toc 77 | for t = 1:num_teachers; 78 | y_modes = Us(:,:,s,t)'*ys(:,s,t); 79 | for i = 1:num_step_SNR; 80 | inv_C_matrix_modes = (1+SNR_vector(i))*eye(P); 81 | A1 = (P-num_components)/2*log(1+SNR_vector(i)); 82 | A2 = 0; 83 | for j = 1:num_components; 84 | inv_C_matrix_modes(j,j) = (1+SNR_vector(i))/(1+SNR_vector(i)*Lambdas(j,j,s,t)); 85 | A1 = A1 + 1/2*log(inv_C_matrix_modes(j,j)); 86 | A2 = A2 - 1/2*y_modes(j)^2*inv_C_matrix_modes(j,j); 87 | end; 88 | if gt(P,num_components) 89 | for j = 1:P; 90 | A2 = A2 - 1/2*y_modes(j)^2*inv_C_matrix_modes(j,j); 91 | end; 92 | end; 93 | %inv_C_matrix_modes = (1+SNR_vector(i))*inv(eye(P)+SNR_vector(i)*Lambdas(:,:,s,t)); 94 | %log_likelihood_function(i,s,t) = 1/2*log(det(inv_C_matrix_modes)) - 1/2*y_modes'*inv_C_matrix_modes*y_modes; 95 | log_likelihood_function(i,s,t) = A1 + A2; 96 | %log_likelihood_function(i,s,t) = -1/2*ys(:,s,t)'*inv_C_matrix*ys(:,s,t); 97 | %log_likelihood_function(i,s,t) = log(det(inv_C_matrix)); 98 | end; 99 | [max_log_likelihood(s,t) tmp_idx] = max(log_likelihood_function(:,s,t)); 100 | estimated_SNRs(s,t) = SNR_vector(tmp_idx); 101 | end; 102 | end; 103 | 104 | 105 | nepoch = 2000; 106 | learnrate = 0.005; 107 | N_x_t = 100; 108 | N_y_t = 1; 109 | P=100; 110 | 111 | for count = 1:size(SNRs,2) 112 | disp(count) 113 | SNR = SNRs(count); 114 | Eg = []; 115 | Et = []; 116 | lr = learnrate; 117 | 118 | %use normalized variances 119 | if SNR == inf 120 | variance_w = 1; 121 | variance_e = 0; 122 | else 123 | variance_w = SNR/(SNR + 1); 124 | variance_e = 1/(SNR + 1); 125 | end 126 | 127 | alpha = P/N_x_t; % number of examples divided by input dimension 128 | 129 | % Analytical curves for training and testing errors, see supplementary 130 | % material for derivations 131 | 132 | for t = 1:1:nepoch 133 | 134 | train = @(lam) ( ( ( ((alpha^0.5+1).^2 - lam) .* (lam - (alpha^0.5-1).^2) ).^0.5) ./ (lam*2*pi) ).* ( lam.*variance_w + variance_e ).*exp(-2*lam.*t./(1./lr)) ; 135 | Et = [Et (1/alpha)*(integral(train,(alpha^0.5-1)^2,(alpha^0.5+1)^2) + (alpha<1)*(1 - alpha)* variance_e ) + (1-1/alpha)*variance_e]; 136 | 137 | test = @(lam) ( ( ( ((alpha^0.5+1).^2 - lam) .* (lam - (alpha^0.5-1).^2) ).^0.5) ./ (lam*2*pi) ).*(exp(-2*lam*t/(1/lr)) + ((1-exp(-lam*t/(1/lr))).^2)./(lam*SNR)); 138 | Eg = [Eg variance_w*(integral(test,(alpha^0.5-1)^2,(alpha^0.5+1)^2) + (alpha<1)* (1 - alpha) + 1/SNR)]; 139 | 140 | end 141 | [mm, pp] = min(Eg); 142 | Eg_early_stop = Eg; 143 | Eg_early_stop(pp+1:end) = Eg(pp); 144 | 145 | Et_early_stop = Et; 146 | Et_early_stop(pp+1:end) = Et(pp); 147 | 148 | ES_time = []; 149 | for ind = 1:length(estimated_SNRs(count,:)) 150 | Eg_MLE = []; 151 | Et_MLE = []; 152 | 153 | SNR = estimated_SNRs(count,ind); 154 | 155 | if SNR == inf 156 | variance_w = 1; 157 | variance_e = 0; 158 | else 159 | variance_w = SNR/(SNR + 1); 160 | variance_e = 1/(SNR + 1); 161 | end 162 | 163 | 164 | for t = 1:1:nepoch 165 | 166 | train = @(lam) ( ( ( ((alpha^0.5+1).^2 - lam) .* (lam - (alpha^0.5-1).^2) ).^0.5) ./ (lam*2*pi) ).* ( lam.*variance_w + variance_e ).*exp(-2*lam.*t./(1./lr)) ; 167 | Et_MLE = [Et_MLE (1/alpha)*(integral(train,(alpha^0.5-1)^2,(alpha^0.5+1)^2) + (alpha<1)*(1 - alpha)* variance_e ) + (1-1/alpha)*variance_e]; 168 | 169 | test = @(lam) ( ( ( ((alpha^0.5+1).^2 - lam) .* (lam - (alpha^0.5-1).^2) ).^0.5) ./ (lam*2*pi) ).*(exp(-2*lam*t/(1/lr)) + ((1-exp(-lam*t/(1/lr))).^2)./(lam*SNR)); 170 | Eg_MLE = [Eg_MLE variance_w*(integral(test,(alpha^0.5-1)^2,(alpha^0.5+1)^2) + (alpha<1)* (1 - alpha) + 1/SNR)]; 171 | 172 | end 173 | [mm, pp] = min(Eg_MLE); 174 | ES_time = [ES_time pp]; 175 | 176 | 177 | end 178 | % Early stopping curves 179 | 180 | figure(count) 181 | 182 | hold on; 183 | 184 | for i =1:length(ES_time) 185 | min_p = ES_time(i); 186 | Eg_early_stop_MLE = Eg; 187 | Eg_early_stop_MLE(min_p+1:end) = Eg(min_p); 188 | Et_early_stop_MLE = Et; 189 | Et_early_stop_MLE(min_p+1:end) = Et(min_p); 190 | 191 | plot(1:1:nepoch,Et_early_stop_MLE,'-','color',[0 1 1 0.5],'LineWidth',1) 192 | plot(1:1:nepoch,Eg_early_stop_MLE,'-','color',[1 0 1 0.5],'LineWidth',1) 193 | end 194 | plot(1:1:nepoch,Et_early_stop,'-','color',[0 0 1],'LineWidth',3) 195 | plot(1:1:nepoch,Eg_early_stop,'-','color',[1 0 0],'LineWidth',3) 196 | 197 | 198 | set(gca, 'FontSize',12) 199 | set(gca, 'FontSize', 12) 200 | xlabel('Epoch') 201 | ylabel('Error') 202 | set(gca,'linewidth',1.5) 203 | ylim([0 2]) 204 | set(gcf,'position',[100,100,350,290]) 205 | saveas(gcf,strcat('MLE SNR=',num2str(SNRs(count)),'.pdf')); 206 | 207 | 208 | end 209 | 210 | %loading estimated SNR values from the learning rate approach (obtained by running the associated code for main Fig.2m,n) 211 | file = load('N_10000.mat'); 212 | field_names = fieldnames(file); 213 | first_field_name = field_names{1}; 214 | Est_SNRs = file.(first_field_name); 215 | 216 | for count = 1:size(SNRs,2) 217 | disp(count) 218 | SNR = SNRs(count); 219 | Eg = []; 220 | Et = []; 221 | lr = learnrate; 222 | 223 | %use normalized variances 224 | if SNR == inf 225 | variance_w = 1; 226 | variance_e = 0; 227 | else 228 | variance_w = SNR/(SNR + 1); 229 | variance_e = 1/(SNR + 1); 230 | end 231 | 232 | alpha = P/N_x_t; % number of examples divided by input dimension 233 | 234 | % Analytical curves for training and testing errors, see supplementary 235 | % material for derivations 236 | 237 | for t = 1:1:nepoch 238 | 239 | train = @(lam) ( ( ( ((alpha^0.5+1).^2 - lam) .* (lam - (alpha^0.5-1).^2) ).^0.5) ./ (lam*2*pi) ).* ( lam.*variance_w + variance_e ).*exp(-2*lam.*t./(1./lr)) ; 240 | Et = [Et (1/alpha)*(integral(train,(alpha^0.5-1)^2,(alpha^0.5+1)^2) + (alpha<1)*(1 - alpha)* variance_e ) + (1-1/alpha)*variance_e]; 241 | 242 | test = @(lam) ( ( ( ((alpha^0.5+1).^2 - lam) .* (lam - (alpha^0.5-1).^2) ).^0.5) ./ (lam*2*pi) ).*(exp(-2*lam*t/(1/lr)) + ((1-exp(-lam*t/(1/lr))).^2)./(lam*SNR)); 243 | Eg = [Eg variance_w*(integral(test,(alpha^0.5-1)^2,(alpha^0.5+1)^2) + (alpha<1)* (1 - alpha) + 1/SNR)]; 244 | 245 | end 246 | [mm, pp] = min(Eg); 247 | Eg_early_stop = Eg; 248 | Eg_early_stop(pp+1:end) = Eg(pp); 249 | 250 | Et_early_stop = Et; 251 | Et_early_stop(pp+1:end) = Et(pp); 252 | 253 | ES_time = []; 254 | for ind = 1:10 255 | Eg_LS = []; 256 | Et_LS = []; 257 | SNR_est = Est_SNRs(1+(count-1)*10:10+(count-1)*10); 258 | SNR = SNR_est(ind); 259 | 260 | if SNR == inf 261 | variance_w = 1; 262 | variance_e = 0; 263 | else 264 | variance_w = SNR/(SNR + 1); 265 | variance_e = 1/(SNR + 1); 266 | end 267 | 268 | 269 | for t = 1:1:nepoch 270 | 271 | train = @(lam) ( ( ( ((alpha^0.5+1).^2 - lam) .* (lam - (alpha^0.5-1).^2) ).^0.5) ./ (lam*2*pi) ).* ( lam.*variance_w + variance_e ).*exp(-2*lam.*t./(1./lr)) ; 272 | Et_LS = [Et_LS (1/alpha)*(integral(train,(alpha^0.5-1)^2,(alpha^0.5+1)^2) + (alpha<1)*(1 - alpha)* variance_e ) + (1-1/alpha)*variance_e]; 273 | 274 | test = @(lam) ( ( ( ((alpha^0.5+1).^2 - lam) .* (lam - (alpha^0.5-1).^2) ).^0.5) ./ (lam*2*pi) ).*(exp(-2*lam*t/(1/lr)) + ((1-exp(-lam*t/(1/lr))).^2)./(lam*SNR)); 275 | Eg_LS = [Eg_LS variance_w*(integral(test,(alpha^0.5-1)^2,(alpha^0.5+1)^2) + (alpha<1)* (1 - alpha) + 1/SNR)]; 276 | 277 | end 278 | [mm, pp] = min(Eg_LS); 279 | ES_time = [ES_time pp]; 280 | 281 | 282 | end 283 | % Early stopping curves 284 | 285 | figure(count+3) 286 | 287 | hold on; 288 | 289 | for i =1:length(ES_time) 290 | min_p = ES_time(i); 291 | Eg_early_stop_MLE = Eg; 292 | Eg_early_stop_MLE(min_p+1:end) = Eg(min_p); 293 | 294 | Et_early_stop_MLE = Et; 295 | Et_early_stop_MLE(min_p+1:end) = Et(min_p); 296 | 297 | plot(1:1:nepoch,Et_early_stop_MLE,'-','color',[0 1 1 0.5],'LineWidth',1) 298 | plot(1:1:nepoch,Eg_early_stop_MLE,'-','color',[1 0 1 0.5],'LineWidth',1) 299 | end 300 | plot(1:1:nepoch,Et_early_stop,'-','color',[0 0 1],'LineWidth',3) 301 | plot(1:1:nepoch,Eg_early_stop,'-','color',[1 0 0],'LineWidth',3) 302 | 303 | 304 | set(gca, 'FontSize',12) 305 | set(gca, 'FontSize', 12) 306 | xlabel('Epoch') 307 | ylabel('Error') 308 | set(gca,'linewidth',1.5) 309 | ylim([0 2]) 310 | set(gcf,'position',[100,100,350,290]) 311 | saveas(gcf,strcat('Learning_speed SNR=',num2str(SNRs(count)),'_',first_field_name,'.pdf')); 312 | 313 | 314 | end 315 | 316 | %% save figures 317 | % figure(1) 318 | % set(gcf,'position',[100,100,350,290]) 319 | % saveas(gcf,strcat('Fig_2g','.pdf')); 320 | % figure(2) 321 | % set(gcf,'position',[100,100,350,290]) 322 | % saveas(gcf,strcat('Fig_2h','.pdf')); 323 | 324 | 325 | 326 | -------------------------------------------------------------------------------- /Supplements/Fig_S4/N_100.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuroai/Go-CLS_v2/a5f90b73c61088f0a154414af05ea1e54122c06d/Supplements/Fig_S4/N_100.mat -------------------------------------------------------------------------------- /Supplements/Fig_S4/N_10000.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuroai/Go-CLS_v2/a5f90b73c61088f0a154414af05ea1e54122c06d/Supplements/Fig_S4/N_10000.mat -------------------------------------------------------------------------------- /Supplements/Fig_S5/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .vgg import * 2 | from .dpn import * 3 | from .lenet import * 4 | from .senet import * 5 | from .pnasnet import * 6 | from .densenet import * 7 | from .googlenet import * 8 | from .shufflenet import * 9 | from .shufflenetv2 import * 10 | from .resnet import * 11 | from .resnext import * 12 | from .preact_resnet import * 13 | from .mobilenet import * 14 | from .mobilenetv2 import * 15 | from .efficientnet import * 16 | from .regnet import * 17 | from .dla_simple import * 18 | from .dla import * 19 | -------------------------------------------------------------------------------- /Supplements/Fig_S5/models/densenet.py: -------------------------------------------------------------------------------- 1 | '''DenseNet in PyTorch.''' 2 | import math 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class Bottleneck(nn.Module): 10 | def __init__(self, in_planes, growth_rate): 11 | super(Bottleneck, self).__init__() 12 | self.bn1 = nn.BatchNorm2d(in_planes) 13 | self.conv1 = nn.Conv2d(in_planes, 4*growth_rate, kernel_size=1, bias=False) 14 | self.bn2 = nn.BatchNorm2d(4*growth_rate) 15 | self.conv2 = nn.Conv2d(4*growth_rate, growth_rate, kernel_size=3, padding=1, bias=False) 16 | 17 | def forward(self, x): 18 | out = self.conv1(F.relu(self.bn1(x))) 19 | out = self.conv2(F.relu(self.bn2(out))) 20 | out = torch.cat([out,x], 1) 21 | return out 22 | 23 | 24 | class Transition(nn.Module): 25 | def __init__(self, in_planes, out_planes): 26 | super(Transition, self).__init__() 27 | self.bn = nn.BatchNorm2d(in_planes) 28 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, bias=False) 29 | 30 | def forward(self, x): 31 | out = self.conv(F.relu(self.bn(x))) 32 | out = F.avg_pool2d(out, 2) 33 | return out 34 | 35 | 36 | class DenseNet(nn.Module): 37 | def __init__(self, block, nblocks, growth_rate=12, reduction=0.5, num_classes=10): 38 | super(DenseNet, self).__init__() 39 | self.growth_rate = growth_rate 40 | 41 | num_planes = 2*growth_rate 42 | self.conv1 = nn.Conv2d(3, num_planes, kernel_size=3, padding=1, bias=False) 43 | 44 | self.dense1 = self._make_dense_layers(block, num_planes, nblocks[0]) 45 | num_planes += nblocks[0]*growth_rate 46 | out_planes = int(math.floor(num_planes*reduction)) 47 | self.trans1 = Transition(num_planes, out_planes) 48 | num_planes = out_planes 49 | 50 | self.dense2 = self._make_dense_layers(block, num_planes, nblocks[1]) 51 | num_planes += nblocks[1]*growth_rate 52 | out_planes = int(math.floor(num_planes*reduction)) 53 | self.trans2 = Transition(num_planes, out_planes) 54 | num_planes = out_planes 55 | 56 | self.dense3 = self._make_dense_layers(block, num_planes, nblocks[2]) 57 | num_planes += nblocks[2]*growth_rate 58 | out_planes = int(math.floor(num_planes*reduction)) 59 | self.trans3 = Transition(num_planes, out_planes) 60 | num_planes = out_planes 61 | 62 | self.dense4 = self._make_dense_layers(block, num_planes, nblocks[3]) 63 | num_planes += nblocks[3]*growth_rate 64 | 65 | self.bn = nn.BatchNorm2d(num_planes) 66 | self.linear = nn.Linear(num_planes, num_classes) 67 | 68 | def _make_dense_layers(self, block, in_planes, nblock): 69 | layers = [] 70 | for i in range(nblock): 71 | layers.append(block(in_planes, self.growth_rate)) 72 | in_planes += self.growth_rate 73 | return nn.Sequential(*layers) 74 | 75 | def forward(self, x): 76 | out = self.conv1(x) 77 | out = self.trans1(self.dense1(out)) 78 | out = self.trans2(self.dense2(out)) 79 | out = self.trans3(self.dense3(out)) 80 | out = self.dense4(out) 81 | out = F.avg_pool2d(F.relu(self.bn(out)), 4) 82 | out = out.view(out.size(0), -1) 83 | out = self.linear(out) 84 | return out 85 | 86 | def DenseNet121(): 87 | return DenseNet(Bottleneck, [6,12,24,16], growth_rate=32) 88 | 89 | def DenseNet169(): 90 | return DenseNet(Bottleneck, [6,12,32,32], growth_rate=32) 91 | 92 | def DenseNet201(): 93 | return DenseNet(Bottleneck, [6,12,48,32], growth_rate=32) 94 | 95 | def DenseNet161(): 96 | return DenseNet(Bottleneck, [6,12,36,24], growth_rate=48) 97 | 98 | def densenet_cifar(): 99 | return DenseNet(Bottleneck, [6,12,24,16], growth_rate=12) 100 | 101 | def test(): 102 | net = densenet_cifar() 103 | x = torch.randn(1,3,32,32) 104 | y = net(x) 105 | print(y) 106 | 107 | # test() 108 | -------------------------------------------------------------------------------- /Supplements/Fig_S5/models/dla.py: -------------------------------------------------------------------------------- 1 | '''DLA in PyTorch. 2 | 3 | Reference: 4 | Deep Layer Aggregation. https://arxiv.org/abs/1707.06484 5 | ''' 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | class BasicBlock(nn.Module): 12 | expansion = 1 13 | 14 | def __init__(self, in_planes, planes, stride=1): 15 | super(BasicBlock, self).__init__() 16 | self.conv1 = nn.Conv2d( 17 | in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(planes) 19 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 20 | stride=1, padding=1, bias=False) 21 | self.bn2 = nn.BatchNorm2d(planes) 22 | 23 | self.shortcut = nn.Sequential() 24 | if stride != 1 or in_planes != self.expansion*planes: 25 | self.shortcut = nn.Sequential( 26 | nn.Conv2d(in_planes, self.expansion*planes, 27 | kernel_size=1, stride=stride, bias=False), 28 | nn.BatchNorm2d(self.expansion*planes) 29 | ) 30 | 31 | def forward(self, x): 32 | out = F.relu(self.bn1(self.conv1(x))) 33 | out = self.bn2(self.conv2(out)) 34 | out += self.shortcut(x) 35 | out = F.relu(out) 36 | return out 37 | 38 | 39 | class Root(nn.Module): 40 | def __init__(self, in_channels, out_channels, kernel_size=1): 41 | super(Root, self).__init__() 42 | self.conv = nn.Conv2d( 43 | in_channels, out_channels, kernel_size, 44 | stride=1, padding=(kernel_size - 1) // 2, bias=False) 45 | self.bn = nn.BatchNorm2d(out_channels) 46 | 47 | def forward(self, xs): 48 | x = torch.cat(xs, 1) 49 | out = F.relu(self.bn(self.conv(x))) 50 | return out 51 | 52 | 53 | class Tree(nn.Module): 54 | def __init__(self, block, in_channels, out_channels, level=1, stride=1): 55 | super(Tree, self).__init__() 56 | self.level = level 57 | if level == 1: 58 | self.root = Root(2*out_channels, out_channels) 59 | self.left_node = block(in_channels, out_channels, stride=stride) 60 | self.right_node = block(out_channels, out_channels, stride=1) 61 | else: 62 | self.root = Root((level+2)*out_channels, out_channels) 63 | for i in reversed(range(1, level)): 64 | subtree = Tree(block, in_channels, out_channels, 65 | level=i, stride=stride) 66 | self.__setattr__('level_%d' % i, subtree) 67 | self.prev_root = block(in_channels, out_channels, stride=stride) 68 | self.left_node = block(out_channels, out_channels, stride=1) 69 | self.right_node = block(out_channels, out_channels, stride=1) 70 | 71 | def forward(self, x): 72 | xs = [self.prev_root(x)] if self.level > 1 else [] 73 | for i in reversed(range(1, self.level)): 74 | level_i = self.__getattr__('level_%d' % i) 75 | x = level_i(x) 76 | xs.append(x) 77 | x = self.left_node(x) 78 | xs.append(x) 79 | x = self.right_node(x) 80 | xs.append(x) 81 | out = self.root(xs) 82 | return out 83 | 84 | 85 | class DLA(nn.Module): 86 | def __init__(self, block=BasicBlock, num_classes=10): 87 | super(DLA, self).__init__() 88 | self.base = nn.Sequential( 89 | nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False), 90 | nn.BatchNorm2d(16), 91 | nn.ReLU(True) 92 | ) 93 | 94 | self.layer1 = nn.Sequential( 95 | nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias=False), 96 | nn.BatchNorm2d(16), 97 | nn.ReLU(True) 98 | ) 99 | 100 | self.layer2 = nn.Sequential( 101 | nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1, bias=False), 102 | nn.BatchNorm2d(32), 103 | nn.ReLU(True) 104 | ) 105 | 106 | self.layer3 = Tree(block, 32, 64, level=1, stride=1) 107 | self.layer4 = Tree(block, 64, 128, level=2, stride=2) 108 | self.layer5 = Tree(block, 128, 256, level=2, stride=2) 109 | self.layer6 = Tree(block, 256, 512, level=1, stride=2) 110 | self.linear = nn.Linear(512, num_classes) 111 | 112 | def forward(self, x): 113 | out = self.base(x) 114 | out = self.layer1(out) 115 | out = self.layer2(out) 116 | out = self.layer3(out) 117 | out = self.layer4(out) 118 | out = self.layer5(out) 119 | out = self.layer6(out) 120 | out = F.avg_pool2d(out, 4) 121 | out = out.view(out.size(0), -1) 122 | out = self.linear(out) 123 | return out 124 | 125 | 126 | def test(): 127 | net = DLA() 128 | print(net) 129 | x = torch.randn(1, 3, 32, 32) 130 | y = net(x) 131 | print(y.size()) 132 | 133 | 134 | if __name__ == '__main__': 135 | test() 136 | -------------------------------------------------------------------------------- /Supplements/Fig_S5/models/dla_simple.py: -------------------------------------------------------------------------------- 1 | '''Simplified version of DLA in PyTorch. 2 | 3 | Note this implementation is not identical to the original paper version. 4 | But it seems works fine. 5 | 6 | See dla.py for the original paper version. 7 | 8 | Reference: 9 | Deep Layer Aggregation. https://arxiv.org/abs/1707.06484 10 | ''' 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | 16 | class BasicBlock(nn.Module): 17 | expansion = 1 18 | 19 | def __init__(self, in_planes, planes, stride=1): 20 | super(BasicBlock, self).__init__() 21 | self.conv1 = nn.Conv2d( 22 | in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 23 | self.bn1 = nn.BatchNorm2d(planes) 24 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 25 | stride=1, padding=1, bias=False) 26 | self.bn2 = nn.BatchNorm2d(planes) 27 | 28 | self.shortcut = nn.Sequential() 29 | if stride != 1 or in_planes != self.expansion*planes: 30 | self.shortcut = nn.Sequential( 31 | nn.Conv2d(in_planes, self.expansion*planes, 32 | kernel_size=1, stride=stride, bias=False), 33 | nn.BatchNorm2d(self.expansion*planes) 34 | ) 35 | 36 | def forward(self, x): 37 | out = F.relu(self.bn1(self.conv1(x))) 38 | out = self.bn2(self.conv2(out)) 39 | out += self.shortcut(x) 40 | out = F.relu(out) 41 | return out 42 | 43 | 44 | class Root(nn.Module): 45 | def __init__(self, in_channels, out_channels, kernel_size=1): 46 | super(Root, self).__init__() 47 | self.conv = nn.Conv2d( 48 | in_channels, out_channels, kernel_size, 49 | stride=1, padding=(kernel_size - 1) // 2, bias=False) 50 | self.bn = nn.BatchNorm2d(out_channels) 51 | 52 | def forward(self, xs): 53 | x = torch.cat(xs, 1) 54 | out = F.relu(self.bn(self.conv(x))) 55 | return out 56 | 57 | 58 | class Tree(nn.Module): 59 | def __init__(self, block, in_channels, out_channels, level=1, stride=1): 60 | super(Tree, self).__init__() 61 | self.root = Root(2*out_channels, out_channels) 62 | if level == 1: 63 | self.left_tree = block(in_channels, out_channels, stride=stride) 64 | self.right_tree = block(out_channels, out_channels, stride=1) 65 | else: 66 | self.left_tree = Tree(block, in_channels, 67 | out_channels, level=level-1, stride=stride) 68 | self.right_tree = Tree(block, out_channels, 69 | out_channels, level=level-1, stride=1) 70 | 71 | def forward(self, x): 72 | out1 = self.left_tree(x) 73 | out2 = self.right_tree(out1) 74 | out = self.root([out1, out2]) 75 | return out 76 | 77 | 78 | class SimpleDLA(nn.Module): 79 | def __init__(self, block=BasicBlock, num_classes=10): 80 | super(SimpleDLA, self).__init__() 81 | self.base = nn.Sequential( 82 | nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False), 83 | nn.BatchNorm2d(16), 84 | nn.ReLU(True) 85 | ) 86 | 87 | self.layer1 = nn.Sequential( 88 | nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias=False), 89 | nn.BatchNorm2d(16), 90 | nn.ReLU(True) 91 | ) 92 | 93 | self.layer2 = nn.Sequential( 94 | nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1, bias=False), 95 | nn.BatchNorm2d(32), 96 | nn.ReLU(True) 97 | ) 98 | 99 | self.layer3 = Tree(block, 32, 64, level=1, stride=1) 100 | self.layer4 = Tree(block, 64, 128, level=2, stride=2) 101 | self.layer5 = Tree(block, 128, 256, level=2, stride=2) 102 | self.layer6 = Tree(block, 256, 512, level=1, stride=2) 103 | self.linear = nn.Linear(512, num_classes) 104 | 105 | def forward(self, x): 106 | out = self.base(x) 107 | out = self.layer1(out) 108 | out = self.layer2(out) 109 | out = self.layer3(out) 110 | out = self.layer4(out) 111 | out = self.layer5(out) 112 | out = self.layer6(out) 113 | out = F.avg_pool2d(out, 4) 114 | out = out.view(out.size(0), -1) 115 | out = self.linear(out) 116 | return out 117 | 118 | 119 | def test(): 120 | net = SimpleDLA() 121 | print(net) 122 | x = torch.randn(1, 3, 32, 32) 123 | y = net(x) 124 | print(y.size()) 125 | 126 | 127 | if __name__ == '__main__': 128 | test() 129 | -------------------------------------------------------------------------------- /Supplements/Fig_S5/models/dpn.py: -------------------------------------------------------------------------------- 1 | '''Dual Path Networks in PyTorch.''' 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class Bottleneck(nn.Module): 8 | def __init__(self, last_planes, in_planes, out_planes, dense_depth, stride, first_layer): 9 | super(Bottleneck, self).__init__() 10 | self.out_planes = out_planes 11 | self.dense_depth = dense_depth 12 | 13 | self.conv1 = nn.Conv2d(last_planes, in_planes, kernel_size=1, bias=False) 14 | self.bn1 = nn.BatchNorm2d(in_planes) 15 | self.conv2 = nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=stride, padding=1, groups=32, bias=False) 16 | self.bn2 = nn.BatchNorm2d(in_planes) 17 | self.conv3 = nn.Conv2d(in_planes, out_planes+dense_depth, kernel_size=1, bias=False) 18 | self.bn3 = nn.BatchNorm2d(out_planes+dense_depth) 19 | 20 | self.shortcut = nn.Sequential() 21 | if first_layer: 22 | self.shortcut = nn.Sequential( 23 | nn.Conv2d(last_planes, out_planes+dense_depth, kernel_size=1, stride=stride, bias=False), 24 | nn.BatchNorm2d(out_planes+dense_depth) 25 | ) 26 | 27 | def forward(self, x): 28 | out = F.relu(self.bn1(self.conv1(x))) 29 | out = F.relu(self.bn2(self.conv2(out))) 30 | out = self.bn3(self.conv3(out)) 31 | x = self.shortcut(x) 32 | d = self.out_planes 33 | out = torch.cat([x[:,:d,:,:]+out[:,:d,:,:], x[:,d:,:,:], out[:,d:,:,:]], 1) 34 | out = F.relu(out) 35 | return out 36 | 37 | 38 | class DPN(nn.Module): 39 | def __init__(self, cfg): 40 | super(DPN, self).__init__() 41 | in_planes, out_planes = cfg['in_planes'], cfg['out_planes'] 42 | num_blocks, dense_depth = cfg['num_blocks'], cfg['dense_depth'] 43 | 44 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 45 | self.bn1 = nn.BatchNorm2d(64) 46 | self.last_planes = 64 47 | self.layer1 = self._make_layer(in_planes[0], out_planes[0], num_blocks[0], dense_depth[0], stride=1) 48 | self.layer2 = self._make_layer(in_planes[1], out_planes[1], num_blocks[1], dense_depth[1], stride=2) 49 | self.layer3 = self._make_layer(in_planes[2], out_planes[2], num_blocks[2], dense_depth[2], stride=2) 50 | self.layer4 = self._make_layer(in_planes[3], out_planes[3], num_blocks[3], dense_depth[3], stride=2) 51 | self.linear = nn.Linear(out_planes[3]+(num_blocks[3]+1)*dense_depth[3], 10) 52 | 53 | def _make_layer(self, in_planes, out_planes, num_blocks, dense_depth, stride): 54 | strides = [stride] + [1]*(num_blocks-1) 55 | layers = [] 56 | for i,stride in enumerate(strides): 57 | layers.append(Bottleneck(self.last_planes, in_planes, out_planes, dense_depth, stride, i==0)) 58 | self.last_planes = out_planes + (i+2) * dense_depth 59 | return nn.Sequential(*layers) 60 | 61 | def forward(self, x): 62 | out = F.relu(self.bn1(self.conv1(x))) 63 | out = self.layer1(out) 64 | out = self.layer2(out) 65 | out = self.layer3(out) 66 | out = self.layer4(out) 67 | out = F.avg_pool2d(out, 4) 68 | out = out.view(out.size(0), -1) 69 | out = self.linear(out) 70 | return out 71 | 72 | 73 | def DPN26(): 74 | cfg = { 75 | 'in_planes': (96,192,384,768), 76 | 'out_planes': (256,512,1024,2048), 77 | 'num_blocks': (2,2,2,2), 78 | 'dense_depth': (16,32,24,128) 79 | } 80 | return DPN(cfg) 81 | 82 | def DPN92(): 83 | cfg = { 84 | 'in_planes': (96,192,384,768), 85 | 'out_planes': (256,512,1024,2048), 86 | 'num_blocks': (3,4,20,3), 87 | 'dense_depth': (16,32,24,128) 88 | } 89 | return DPN(cfg) 90 | 91 | 92 | def test(): 93 | net = DPN92() 94 | x = torch.randn(1,3,32,32) 95 | y = net(x) 96 | print(y) 97 | 98 | # test() 99 | -------------------------------------------------------------------------------- /Supplements/Fig_S5/models/efficientnet.py: -------------------------------------------------------------------------------- 1 | '''EfficientNet in PyTorch. 2 | 3 | Paper: "EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks". 4 | 5 | Reference: https://github.com/keras-team/keras-applications/blob/master/keras_applications/efficientnet.py 6 | ''' 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | def swish(x): 13 | return x * x.sigmoid() 14 | 15 | 16 | def drop_connect(x, drop_ratio): 17 | keep_ratio = 1.0 - drop_ratio 18 | mask = torch.empty([x.shape[0], 1, 1, 1], dtype=x.dtype, device=x.device) 19 | mask.bernoulli_(keep_ratio) 20 | x.div_(keep_ratio) 21 | x.mul_(mask) 22 | return x 23 | 24 | 25 | class SE(nn.Module): 26 | '''Squeeze-and-Excitation block with Swish.''' 27 | 28 | def __init__(self, in_channels, se_channels): 29 | super(SE, self).__init__() 30 | self.se1 = nn.Conv2d(in_channels, se_channels, 31 | kernel_size=1, bias=True) 32 | self.se2 = nn.Conv2d(se_channels, in_channels, 33 | kernel_size=1, bias=True) 34 | 35 | def forward(self, x): 36 | out = F.adaptive_avg_pool2d(x, (1, 1)) 37 | out = swish(self.se1(out)) 38 | out = self.se2(out).sigmoid() 39 | out = x * out 40 | return out 41 | 42 | 43 | class Block(nn.Module): 44 | '''expansion + depthwise + pointwise + squeeze-excitation''' 45 | 46 | def __init__(self, 47 | in_channels, 48 | out_channels, 49 | kernel_size, 50 | stride, 51 | expand_ratio=1, 52 | se_ratio=0., 53 | drop_rate=0.): 54 | super(Block, self).__init__() 55 | self.stride = stride 56 | self.drop_rate = drop_rate 57 | self.expand_ratio = expand_ratio 58 | 59 | # Expansion 60 | channels = expand_ratio * in_channels 61 | self.conv1 = nn.Conv2d(in_channels, 62 | channels, 63 | kernel_size=1, 64 | stride=1, 65 | padding=0, 66 | bias=False) 67 | self.bn1 = nn.BatchNorm2d(channels) 68 | 69 | # Depthwise conv 70 | self.conv2 = nn.Conv2d(channels, 71 | channels, 72 | kernel_size=kernel_size, 73 | stride=stride, 74 | padding=(1 if kernel_size == 3 else 2), 75 | groups=channels, 76 | bias=False) 77 | self.bn2 = nn.BatchNorm2d(channels) 78 | 79 | # SE layers 80 | se_channels = int(in_channels * se_ratio) 81 | self.se = SE(channels, se_channels) 82 | 83 | # Output 84 | self.conv3 = nn.Conv2d(channels, 85 | out_channels, 86 | kernel_size=1, 87 | stride=1, 88 | padding=0, 89 | bias=False) 90 | self.bn3 = nn.BatchNorm2d(out_channels) 91 | 92 | # Skip connection if in and out shapes are the same (MV-V2 style) 93 | self.has_skip = (stride == 1) and (in_channels == out_channels) 94 | 95 | def forward(self, x): 96 | out = x if self.expand_ratio == 1 else swish(self.bn1(self.conv1(x))) 97 | out = swish(self.bn2(self.conv2(out))) 98 | out = self.se(out) 99 | out = self.bn3(self.conv3(out)) 100 | if self.has_skip: 101 | if self.training and self.drop_rate > 0: 102 | out = drop_connect(out, self.drop_rate) 103 | out = out + x 104 | return out 105 | 106 | 107 | class EfficientNet(nn.Module): 108 | def __init__(self, cfg, num_classes=10): 109 | super(EfficientNet, self).__init__() 110 | self.cfg = cfg 111 | self.conv1 = nn.Conv2d(3, 112 | 32, 113 | kernel_size=3, 114 | stride=1, 115 | padding=1, 116 | bias=False) 117 | self.bn1 = nn.BatchNorm2d(32) 118 | self.layers = self._make_layers(in_channels=32) 119 | self.linear = nn.Linear(cfg['out_channels'][-1], num_classes) 120 | 121 | def _make_layers(self, in_channels): 122 | layers = [] 123 | cfg = [self.cfg[k] for k in ['expansion', 'out_channels', 'num_blocks', 'kernel_size', 124 | 'stride']] 125 | b = 0 126 | blocks = sum(self.cfg['num_blocks']) 127 | for expansion, out_channels, num_blocks, kernel_size, stride in zip(*cfg): 128 | strides = [stride] + [1] * (num_blocks - 1) 129 | for stride in strides: 130 | drop_rate = self.cfg['drop_connect_rate'] * b / blocks 131 | layers.append( 132 | Block(in_channels, 133 | out_channels, 134 | kernel_size, 135 | stride, 136 | expansion, 137 | se_ratio=0.25, 138 | drop_rate=drop_rate)) 139 | in_channels = out_channels 140 | return nn.Sequential(*layers) 141 | 142 | def forward(self, x): 143 | out = swish(self.bn1(self.conv1(x))) 144 | out = self.layers(out) 145 | out = F.adaptive_avg_pool2d(out, 1) 146 | out = out.view(out.size(0), -1) 147 | dropout_rate = self.cfg['dropout_rate'] 148 | if self.training and dropout_rate > 0: 149 | out = F.dropout(out, p=dropout_rate) 150 | out = self.linear(out) 151 | return out 152 | 153 | 154 | def EfficientNetB0(): 155 | cfg = { 156 | 'num_blocks': [1, 2, 2, 3, 3, 4, 1], 157 | 'expansion': [1, 6, 6, 6, 6, 6, 6], 158 | 'out_channels': [16, 24, 40, 80, 112, 192, 320], 159 | 'kernel_size': [3, 3, 5, 3, 5, 5, 3], 160 | 'stride': [1, 2, 2, 2, 1, 2, 1], 161 | 'dropout_rate': 0.2, 162 | 'drop_connect_rate': 0.2, 163 | } 164 | return EfficientNet(cfg) 165 | 166 | 167 | def test(): 168 | net = EfficientNetB0() 169 | x = torch.randn(2, 3, 32, 32) 170 | y = net(x) 171 | print(y.shape) 172 | 173 | 174 | if __name__ == '__main__': 175 | test() 176 | -------------------------------------------------------------------------------- /Supplements/Fig_S5/models/googlenet.py: -------------------------------------------------------------------------------- 1 | '''GoogLeNet with PyTorch.''' 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class Inception(nn.Module): 8 | def __init__(self, in_planes, n1x1, n3x3red, n3x3, n5x5red, n5x5, pool_planes): 9 | super(Inception, self).__init__() 10 | # 1x1 conv branch 11 | self.b1 = nn.Sequential( 12 | nn.Conv2d(in_planes, n1x1, kernel_size=1), 13 | nn.BatchNorm2d(n1x1), 14 | nn.ReLU(True), 15 | ) 16 | 17 | # 1x1 conv -> 3x3 conv branch 18 | self.b2 = nn.Sequential( 19 | nn.Conv2d(in_planes, n3x3red, kernel_size=1), 20 | nn.BatchNorm2d(n3x3red), 21 | nn.ReLU(True), 22 | nn.Conv2d(n3x3red, n3x3, kernel_size=3, padding=1), 23 | nn.BatchNorm2d(n3x3), 24 | nn.ReLU(True), 25 | ) 26 | 27 | # 1x1 conv -> 5x5 conv branch 28 | self.b3 = nn.Sequential( 29 | nn.Conv2d(in_planes, n5x5red, kernel_size=1), 30 | nn.BatchNorm2d(n5x5red), 31 | nn.ReLU(True), 32 | nn.Conv2d(n5x5red, n5x5, kernel_size=3, padding=1), 33 | nn.BatchNorm2d(n5x5), 34 | nn.ReLU(True), 35 | nn.Conv2d(n5x5, n5x5, kernel_size=3, padding=1), 36 | nn.BatchNorm2d(n5x5), 37 | nn.ReLU(True), 38 | ) 39 | 40 | # 3x3 pool -> 1x1 conv branch 41 | self.b4 = nn.Sequential( 42 | nn.MaxPool2d(3, stride=1, padding=1), 43 | nn.Conv2d(in_planes, pool_planes, kernel_size=1), 44 | nn.BatchNorm2d(pool_planes), 45 | nn.ReLU(True), 46 | ) 47 | 48 | def forward(self, x): 49 | y1 = self.b1(x) 50 | y2 = self.b2(x) 51 | y3 = self.b3(x) 52 | y4 = self.b4(x) 53 | return torch.cat([y1,y2,y3,y4], 1) 54 | 55 | 56 | class GoogLeNet(nn.Module): 57 | def __init__(self): 58 | super(GoogLeNet, self).__init__() 59 | self.pre_layers = nn.Sequential( 60 | nn.Conv2d(3, 192, kernel_size=3, padding=1), 61 | nn.BatchNorm2d(192), 62 | nn.ReLU(True), 63 | ) 64 | 65 | self.a3 = Inception(192, 64, 96, 128, 16, 32, 32) 66 | self.b3 = Inception(256, 128, 128, 192, 32, 96, 64) 67 | 68 | self.maxpool = nn.MaxPool2d(3, stride=2, padding=1) 69 | 70 | self.a4 = Inception(480, 192, 96, 208, 16, 48, 64) 71 | self.b4 = Inception(512, 160, 112, 224, 24, 64, 64) 72 | self.c4 = Inception(512, 128, 128, 256, 24, 64, 64) 73 | self.d4 = Inception(512, 112, 144, 288, 32, 64, 64) 74 | self.e4 = Inception(528, 256, 160, 320, 32, 128, 128) 75 | 76 | self.a5 = Inception(832, 256, 160, 320, 32, 128, 128) 77 | self.b5 = Inception(832, 384, 192, 384, 48, 128, 128) 78 | 79 | self.avgpool = nn.AvgPool2d(8, stride=1) 80 | self.linear = nn.Linear(1024, 10) 81 | 82 | def forward(self, x): 83 | out = self.pre_layers(x) 84 | out = self.a3(out) 85 | out = self.b3(out) 86 | out = self.maxpool(out) 87 | out = self.a4(out) 88 | out = self.b4(out) 89 | out = self.c4(out) 90 | out = self.d4(out) 91 | out = self.e4(out) 92 | out = self.maxpool(out) 93 | out = self.a5(out) 94 | out = self.b5(out) 95 | out = self.avgpool(out) 96 | out = out.view(out.size(0), -1) 97 | out = self.linear(out) 98 | return out 99 | 100 | 101 | def test(): 102 | net = GoogLeNet() 103 | x = torch.randn(1,3,32,32) 104 | y = net(x) 105 | print(y.size()) 106 | 107 | # test() 108 | -------------------------------------------------------------------------------- /Supplements/Fig_S5/models/lenet.py: -------------------------------------------------------------------------------- 1 | '''LeNet in PyTorch.''' 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class LeNet(nn.Module): 6 | def __init__(self): 7 | super(LeNet, self).__init__() 8 | self.conv1 = nn.Conv2d(3, 6, 5) 9 | self.conv2 = nn.Conv2d(6, 16, 5) 10 | self.fc1 = nn.Linear(16*5*5, 120) 11 | self.fc2 = nn.Linear(120, 84) 12 | self.fc3 = nn.Linear(84, 10) 13 | 14 | def forward(self, x): 15 | out = F.relu(self.conv1(x)) 16 | out = F.max_pool2d(out, 2) 17 | out = F.relu(self.conv2(out)) 18 | out = F.max_pool2d(out, 2) 19 | out = out.view(out.size(0), -1) 20 | out = F.relu(self.fc1(out)) 21 | out = F.relu(self.fc2(out)) 22 | out = self.fc3(out) 23 | return out 24 | -------------------------------------------------------------------------------- /Supplements/Fig_S5/models/mobilenet.py: -------------------------------------------------------------------------------- 1 | '''MobileNet in PyTorch. 2 | 3 | See the paper "MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications" 4 | for more details. 5 | ''' 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | class Block(nn.Module): 12 | '''Depthwise conv + Pointwise conv''' 13 | def __init__(self, in_planes, out_planes, stride=1): 14 | super(Block, self).__init__() 15 | self.conv1 = nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=stride, padding=1, groups=in_planes, bias=False) 16 | self.bn1 = nn.BatchNorm2d(in_planes) 17 | self.conv2 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 18 | self.bn2 = nn.BatchNorm2d(out_planes) 19 | 20 | def forward(self, x): 21 | out = F.relu(self.bn1(self.conv1(x))) 22 | out = F.relu(self.bn2(self.conv2(out))) 23 | return out 24 | 25 | 26 | class MobileNet(nn.Module): 27 | # (128,2) means conv planes=128, conv stride=2, by default conv stride=1 28 | cfg = [64, (128,2), 128, (256,2), 256, (512,2), 512, 512, 512, 512, 512, (1024,2), 1024] 29 | 30 | def __init__(self, num_classes=10): 31 | super(MobileNet, self).__init__() 32 | self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False) 33 | self.bn1 = nn.BatchNorm2d(32) 34 | self.layers = self._make_layers(in_planes=32) 35 | self.linear = nn.Linear(1024, num_classes) 36 | 37 | def _make_layers(self, in_planes): 38 | layers = [] 39 | for x in self.cfg: 40 | out_planes = x if isinstance(x, int) else x[0] 41 | stride = 1 if isinstance(x, int) else x[1] 42 | layers.append(Block(in_planes, out_planes, stride)) 43 | in_planes = out_planes 44 | return nn.Sequential(*layers) 45 | 46 | def forward(self, x): 47 | out = F.relu(self.bn1(self.conv1(x))) 48 | out = self.layers(out) 49 | out = F.avg_pool2d(out, 2) 50 | out = out.view(out.size(0), -1) 51 | out = self.linear(out) 52 | return out 53 | 54 | 55 | def test(): 56 | net = MobileNet() 57 | x = torch.randn(1,3,32,32) 58 | y = net(x) 59 | print(y.size()) 60 | 61 | # test() 62 | -------------------------------------------------------------------------------- /Supplements/Fig_S5/models/mobilenetv2.py: -------------------------------------------------------------------------------- 1 | '''MobileNetV2 in PyTorch. 2 | 3 | See the paper "Inverted Residuals and Linear Bottlenecks: 4 | Mobile Networks for Classification, Detection and Segmentation" for more details. 5 | ''' 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | class Block(nn.Module): 12 | '''expand + depthwise + pointwise''' 13 | def __init__(self, in_planes, out_planes, expansion, stride): 14 | super(Block, self).__init__() 15 | self.stride = stride 16 | 17 | planes = expansion * in_planes 18 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, stride=1, padding=0, bias=False) 19 | self.bn1 = nn.BatchNorm2d(planes) 20 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, groups=planes, bias=False) 21 | self.bn2 = nn.BatchNorm2d(planes) 22 | self.conv3 = nn.Conv2d(planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 23 | self.bn3 = nn.BatchNorm2d(out_planes) 24 | 25 | self.shortcut = nn.Sequential() 26 | if stride == 1 and in_planes != out_planes: 27 | self.shortcut = nn.Sequential( 28 | nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False), 29 | nn.BatchNorm2d(out_planes), 30 | ) 31 | 32 | def forward(self, x): 33 | out = F.relu(self.bn1(self.conv1(x))) 34 | out = F.relu(self.bn2(self.conv2(out))) 35 | out = self.bn3(self.conv3(out)) 36 | out = out + self.shortcut(x) if self.stride==1 else out 37 | return out 38 | 39 | 40 | class MobileNetV2(nn.Module): 41 | # (expansion, out_planes, num_blocks, stride) 42 | cfg = [(1, 16, 1, 1), 43 | (6, 24, 2, 1), # NOTE: change stride 2 -> 1 for CIFAR10 44 | (6, 32, 3, 2), 45 | (6, 64, 4, 2), 46 | (6, 96, 3, 1), 47 | (6, 160, 3, 2), 48 | (6, 320, 1, 1)] 49 | 50 | def __init__(self, num_classes=10): 51 | super(MobileNetV2, self).__init__() 52 | # NOTE: change conv1 stride 2 -> 1 for CIFAR10 53 | self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False) 54 | self.bn1 = nn.BatchNorm2d(32) 55 | self.layers = self._make_layers(in_planes=32) 56 | self.conv2 = nn.Conv2d(320, 1280, kernel_size=1, stride=1, padding=0, bias=False) 57 | self.bn2 = nn.BatchNorm2d(1280) 58 | self.linear = nn.Linear(1280, num_classes) 59 | 60 | def _make_layers(self, in_planes): 61 | layers = [] 62 | for expansion, out_planes, num_blocks, stride in self.cfg: 63 | strides = [stride] + [1]*(num_blocks-1) 64 | for stride in strides: 65 | layers.append(Block(in_planes, out_planes, expansion, stride)) 66 | in_planes = out_planes 67 | return nn.Sequential(*layers) 68 | 69 | def forward(self, x): 70 | out = F.relu(self.bn1(self.conv1(x))) 71 | out = self.layers(out) 72 | out = F.relu(self.bn2(self.conv2(out))) 73 | # NOTE: change pooling kernel_size 7 -> 4 for CIFAR10 74 | out = F.avg_pool2d(out, 4) 75 | out = out.view(out.size(0), -1) 76 | out = self.linear(out) 77 | return out 78 | 79 | 80 | def test(): 81 | net = MobileNetV2() 82 | x = torch.randn(2,3,32,32) 83 | y = net(x) 84 | print(y.size()) 85 | 86 | # test() 87 | -------------------------------------------------------------------------------- /Supplements/Fig_S5/models/pnasnet.py: -------------------------------------------------------------------------------- 1 | '''PNASNet in PyTorch. 2 | 3 | Paper: Progressive Neural Architecture Search 4 | ''' 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class SepConv(nn.Module): 11 | '''Separable Convolution.''' 12 | def __init__(self, in_planes, out_planes, kernel_size, stride): 13 | super(SepConv, self).__init__() 14 | self.conv1 = nn.Conv2d(in_planes, out_planes, 15 | kernel_size, stride, 16 | padding=(kernel_size-1)//2, 17 | bias=False, groups=in_planes) 18 | self.bn1 = nn.BatchNorm2d(out_planes) 19 | 20 | def forward(self, x): 21 | return self.bn1(self.conv1(x)) 22 | 23 | 24 | class CellA(nn.Module): 25 | def __init__(self, in_planes, out_planes, stride=1): 26 | super(CellA, self).__init__() 27 | self.stride = stride 28 | self.sep_conv1 = SepConv(in_planes, out_planes, kernel_size=7, stride=stride) 29 | if stride==2: 30 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 31 | self.bn1 = nn.BatchNorm2d(out_planes) 32 | 33 | def forward(self, x): 34 | y1 = self.sep_conv1(x) 35 | y2 = F.max_pool2d(x, kernel_size=3, stride=self.stride, padding=1) 36 | if self.stride==2: 37 | y2 = self.bn1(self.conv1(y2)) 38 | return F.relu(y1+y2) 39 | 40 | class CellB(nn.Module): 41 | def __init__(self, in_planes, out_planes, stride=1): 42 | super(CellB, self).__init__() 43 | self.stride = stride 44 | # Left branch 45 | self.sep_conv1 = SepConv(in_planes, out_planes, kernel_size=7, stride=stride) 46 | self.sep_conv2 = SepConv(in_planes, out_planes, kernel_size=3, stride=stride) 47 | # Right branch 48 | self.sep_conv3 = SepConv(in_planes, out_planes, kernel_size=5, stride=stride) 49 | if stride==2: 50 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 51 | self.bn1 = nn.BatchNorm2d(out_planes) 52 | # Reduce channels 53 | self.conv2 = nn.Conv2d(2*out_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 54 | self.bn2 = nn.BatchNorm2d(out_planes) 55 | 56 | def forward(self, x): 57 | # Left branch 58 | y1 = self.sep_conv1(x) 59 | y2 = self.sep_conv2(x) 60 | # Right branch 61 | y3 = F.max_pool2d(x, kernel_size=3, stride=self.stride, padding=1) 62 | if self.stride==2: 63 | y3 = self.bn1(self.conv1(y3)) 64 | y4 = self.sep_conv3(x) 65 | # Concat & reduce channels 66 | b1 = F.relu(y1+y2) 67 | b2 = F.relu(y3+y4) 68 | y = torch.cat([b1,b2], 1) 69 | return F.relu(self.bn2(self.conv2(y))) 70 | 71 | class PNASNet(nn.Module): 72 | def __init__(self, cell_type, num_cells, num_planes): 73 | super(PNASNet, self).__init__() 74 | self.in_planes = num_planes 75 | self.cell_type = cell_type 76 | 77 | self.conv1 = nn.Conv2d(3, num_planes, kernel_size=3, stride=1, padding=1, bias=False) 78 | self.bn1 = nn.BatchNorm2d(num_planes) 79 | 80 | self.layer1 = self._make_layer(num_planes, num_cells=6) 81 | self.layer2 = self._downsample(num_planes*2) 82 | self.layer3 = self._make_layer(num_planes*2, num_cells=6) 83 | self.layer4 = self._downsample(num_planes*4) 84 | self.layer5 = self._make_layer(num_planes*4, num_cells=6) 85 | 86 | self.linear = nn.Linear(num_planes*4, 10) 87 | 88 | def _make_layer(self, planes, num_cells): 89 | layers = [] 90 | for _ in range(num_cells): 91 | layers.append(self.cell_type(self.in_planes, planes, stride=1)) 92 | self.in_planes = planes 93 | return nn.Sequential(*layers) 94 | 95 | def _downsample(self, planes): 96 | layer = self.cell_type(self.in_planes, planes, stride=2) 97 | self.in_planes = planes 98 | return layer 99 | 100 | def forward(self, x): 101 | out = F.relu(self.bn1(self.conv1(x))) 102 | out = self.layer1(out) 103 | out = self.layer2(out) 104 | out = self.layer3(out) 105 | out = self.layer4(out) 106 | out = self.layer5(out) 107 | out = F.avg_pool2d(out, 8) 108 | out = self.linear(out.view(out.size(0), -1)) 109 | return out 110 | 111 | 112 | def PNASNetA(): 113 | return PNASNet(CellA, num_cells=6, num_planes=44) 114 | 115 | def PNASNetB(): 116 | return PNASNet(CellB, num_cells=6, num_planes=32) 117 | 118 | 119 | def test(): 120 | net = PNASNetB() 121 | x = torch.randn(1,3,32,32) 122 | y = net(x) 123 | print(y) 124 | 125 | # test() 126 | -------------------------------------------------------------------------------- /Supplements/Fig_S5/models/preact_resnet.py: -------------------------------------------------------------------------------- 1 | '''Pre-activation ResNet in PyTorch. 2 | 3 | Reference: 4 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 5 | Identity Mappings in Deep Residual Networks. arXiv:1603.05027 6 | ''' 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | class PreActBlock(nn.Module): 13 | '''Pre-activation version of the BasicBlock.''' 14 | expansion = 1 15 | 16 | def __init__(self, in_planes, planes, stride=1): 17 | super(PreActBlock, self).__init__() 18 | self.bn1 = nn.BatchNorm2d(in_planes) 19 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 20 | self.bn2 = nn.BatchNorm2d(planes) 21 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 22 | 23 | if stride != 1 or in_planes != self.expansion*planes: 24 | self.shortcut = nn.Sequential( 25 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 26 | ) 27 | 28 | def forward(self, x): 29 | out = F.relu(self.bn1(x)) 30 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 31 | out = self.conv1(out) 32 | out = self.conv2(F.relu(self.bn2(out))) 33 | out += shortcut 34 | return out 35 | 36 | 37 | class PreActBottleneck(nn.Module): 38 | '''Pre-activation version of the original Bottleneck module.''' 39 | expansion = 4 40 | 41 | def __init__(self, in_planes, planes, stride=1): 42 | super(PreActBottleneck, self).__init__() 43 | self.bn1 = nn.BatchNorm2d(in_planes) 44 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 45 | self.bn2 = nn.BatchNorm2d(planes) 46 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 47 | self.bn3 = nn.BatchNorm2d(planes) 48 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 49 | 50 | if stride != 1 or in_planes != self.expansion*planes: 51 | self.shortcut = nn.Sequential( 52 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 53 | ) 54 | 55 | def forward(self, x): 56 | out = F.relu(self.bn1(x)) 57 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 58 | out = self.conv1(out) 59 | out = self.conv2(F.relu(self.bn2(out))) 60 | out = self.conv3(F.relu(self.bn3(out))) 61 | out += shortcut 62 | return out 63 | 64 | 65 | class PreActResNet(nn.Module): 66 | def __init__(self, block, num_blocks, num_classes=10): 67 | super(PreActResNet, self).__init__() 68 | self.in_planes = 64 69 | 70 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 71 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 72 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 73 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 74 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 75 | self.linear = nn.Linear(512*block.expansion, num_classes) 76 | 77 | def _make_layer(self, block, planes, num_blocks, stride): 78 | strides = [stride] + [1]*(num_blocks-1) 79 | layers = [] 80 | for stride in strides: 81 | layers.append(block(self.in_planes, planes, stride)) 82 | self.in_planes = planes * block.expansion 83 | return nn.Sequential(*layers) 84 | 85 | def forward(self, x): 86 | out = self.conv1(x) 87 | out = self.layer1(out) 88 | out = self.layer2(out) 89 | out = self.layer3(out) 90 | out = self.layer4(out) 91 | out = F.avg_pool2d(out, 4) 92 | out = out.view(out.size(0), -1) 93 | out = self.linear(out) 94 | return out 95 | 96 | 97 | def PreActResNet18(): 98 | return PreActResNet(PreActBlock, [2,2,2,2]) 99 | 100 | def PreActResNet34(): 101 | return PreActResNet(PreActBlock, [3,4,6,3]) 102 | 103 | def PreActResNet50(): 104 | return PreActResNet(PreActBottleneck, [3,4,6,3]) 105 | 106 | def PreActResNet101(): 107 | return PreActResNet(PreActBottleneck, [3,4,23,3]) 108 | 109 | def PreActResNet152(): 110 | return PreActResNet(PreActBottleneck, [3,8,36,3]) 111 | 112 | 113 | def test(): 114 | net = PreActResNet18() 115 | y = net((torch.randn(1,3,32,32))) 116 | print(y.size()) 117 | 118 | # test() 119 | -------------------------------------------------------------------------------- /Supplements/Fig_S5/models/regnet.py: -------------------------------------------------------------------------------- 1 | '''RegNet in PyTorch. 2 | 3 | Paper: "Designing Network Design Spaces". 4 | 5 | Reference: https://github.com/keras-team/keras-applications/blob/master/keras_applications/efficientnet.py 6 | ''' 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | class SE(nn.Module): 13 | '''Squeeze-and-Excitation block.''' 14 | 15 | def __init__(self, in_planes, se_planes): 16 | super(SE, self).__init__() 17 | self.se1 = nn.Conv2d(in_planes, se_planes, kernel_size=1, bias=True) 18 | self.se2 = nn.Conv2d(se_planes, in_planes, kernel_size=1, bias=True) 19 | 20 | def forward(self, x): 21 | out = F.adaptive_avg_pool2d(x, (1, 1)) 22 | out = F.relu(self.se1(out)) 23 | out = self.se2(out).sigmoid() 24 | out = x * out 25 | return out 26 | 27 | 28 | class Block(nn.Module): 29 | def __init__(self, w_in, w_out, stride, group_width, bottleneck_ratio, se_ratio): 30 | super(Block, self).__init__() 31 | # 1x1 32 | w_b = int(round(w_out * bottleneck_ratio)) 33 | self.conv1 = nn.Conv2d(w_in, w_b, kernel_size=1, bias=False) 34 | self.bn1 = nn.BatchNorm2d(w_b) 35 | # 3x3 36 | num_groups = w_b // group_width 37 | self.conv2 = nn.Conv2d(w_b, w_b, kernel_size=3, 38 | stride=stride, padding=1, groups=num_groups, bias=False) 39 | self.bn2 = nn.BatchNorm2d(w_b) 40 | # se 41 | self.with_se = se_ratio > 0 42 | if self.with_se: 43 | w_se = int(round(w_in * se_ratio)) 44 | self.se = SE(w_b, w_se) 45 | # 1x1 46 | self.conv3 = nn.Conv2d(w_b, w_out, kernel_size=1, bias=False) 47 | self.bn3 = nn.BatchNorm2d(w_out) 48 | 49 | self.shortcut = nn.Sequential() 50 | if stride != 1 or w_in != w_out: 51 | self.shortcut = nn.Sequential( 52 | nn.Conv2d(w_in, w_out, 53 | kernel_size=1, stride=stride, bias=False), 54 | nn.BatchNorm2d(w_out) 55 | ) 56 | 57 | def forward(self, x): 58 | out = F.relu(self.bn1(self.conv1(x))) 59 | out = F.relu(self.bn2(self.conv2(out))) 60 | if self.with_se: 61 | out = self.se(out) 62 | out = self.bn3(self.conv3(out)) 63 | out += self.shortcut(x) 64 | out = F.relu(out) 65 | return out 66 | 67 | 68 | class RegNet(nn.Module): 69 | def __init__(self, cfg, num_classes=10): 70 | super(RegNet, self).__init__() 71 | self.cfg = cfg 72 | self.in_planes = 64 73 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, 74 | stride=1, padding=1, bias=False) 75 | self.bn1 = nn.BatchNorm2d(64) 76 | self.layer1 = self._make_layer(0) 77 | self.layer2 = self._make_layer(1) 78 | self.layer3 = self._make_layer(2) 79 | self.layer4 = self._make_layer(3) 80 | self.linear = nn.Linear(self.cfg['widths'][-1], num_classes) 81 | 82 | def _make_layer(self, idx): 83 | depth = self.cfg['depths'][idx] 84 | width = self.cfg['widths'][idx] 85 | stride = self.cfg['strides'][idx] 86 | group_width = self.cfg['group_width'] 87 | bottleneck_ratio = self.cfg['bottleneck_ratio'] 88 | se_ratio = self.cfg['se_ratio'] 89 | 90 | layers = [] 91 | for i in range(depth): 92 | s = stride if i == 0 else 1 93 | layers.append(Block(self.in_planes, width, 94 | s, group_width, bottleneck_ratio, se_ratio)) 95 | self.in_planes = width 96 | return nn.Sequential(*layers) 97 | 98 | def forward(self, x): 99 | out = F.relu(self.bn1(self.conv1(x))) 100 | out = self.layer1(out) 101 | out = self.layer2(out) 102 | out = self.layer3(out) 103 | out = self.layer4(out) 104 | out = F.adaptive_avg_pool2d(out, (1, 1)) 105 | out = out.view(out.size(0), -1) 106 | out = self.linear(out) 107 | return out 108 | 109 | 110 | def RegNetX_200MF(): 111 | cfg = { 112 | 'depths': [1, 1, 4, 7], 113 | 'widths': [24, 56, 152, 368], 114 | 'strides': [1, 1, 2, 2], 115 | 'group_width': 8, 116 | 'bottleneck_ratio': 1, 117 | 'se_ratio': 0, 118 | } 119 | return RegNet(cfg) 120 | 121 | 122 | def RegNetX_400MF(): 123 | cfg = { 124 | 'depths': [1, 2, 7, 12], 125 | 'widths': [32, 64, 160, 384], 126 | 'strides': [1, 1, 2, 2], 127 | 'group_width': 16, 128 | 'bottleneck_ratio': 1, 129 | 'se_ratio': 0, 130 | } 131 | return RegNet(cfg) 132 | 133 | 134 | def RegNetY_400MF(): 135 | cfg = { 136 | 'depths': [1, 2, 7, 12], 137 | 'widths': [32, 64, 160, 384], 138 | 'strides': [1, 1, 2, 2], 139 | 'group_width': 16, 140 | 'bottleneck_ratio': 1, 141 | 'se_ratio': 0.25, 142 | } 143 | return RegNet(cfg) 144 | 145 | 146 | def test(): 147 | net = RegNetX_200MF() 148 | print(net) 149 | x = torch.randn(2, 3, 32, 32) 150 | y = net(x) 151 | print(y.shape) 152 | 153 | 154 | if __name__ == '__main__': 155 | test() 156 | -------------------------------------------------------------------------------- /Supplements/Fig_S5/models/resnet.py: -------------------------------------------------------------------------------- 1 | '''ResNet in PyTorch. 2 | 3 | For Pre-activation ResNet, see 'preact_resnet.py'. 4 | 5 | Reference: 6 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 7 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 8 | ''' 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | 14 | class BasicBlock(nn.Module): 15 | expansion = 1 16 | 17 | def __init__(self, in_planes, planes, stride=1): 18 | super(BasicBlock, self).__init__() 19 | self.conv1 = nn.Conv2d( 20 | in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 21 | self.bn1 = nn.BatchNorm2d(planes) 22 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 23 | stride=1, padding=1, bias=False) 24 | self.bn2 = nn.BatchNorm2d(planes) 25 | 26 | self.shortcut = nn.Sequential() 27 | if stride != 1 or in_planes != self.expansion*planes: 28 | self.shortcut = nn.Sequential( 29 | nn.Conv2d(in_planes, self.expansion*planes, 30 | kernel_size=1, stride=stride, bias=False), 31 | nn.BatchNorm2d(self.expansion*planes) 32 | ) 33 | 34 | def forward(self, x): 35 | out = F.relu(self.bn1(self.conv1(x))) 36 | out = self.bn2(self.conv2(out)) 37 | out += self.shortcut(x) 38 | out = F.relu(out) 39 | return out 40 | 41 | 42 | class Bottleneck(nn.Module): 43 | expansion = 4 44 | 45 | def __init__(self, in_planes, planes, stride=1): 46 | super(Bottleneck, self).__init__() 47 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 48 | self.bn1 = nn.BatchNorm2d(planes) 49 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 50 | stride=stride, padding=1, bias=False) 51 | self.bn2 = nn.BatchNorm2d(planes) 52 | self.conv3 = nn.Conv2d(planes, self.expansion * 53 | planes, kernel_size=1, bias=False) 54 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 55 | 56 | self.shortcut = nn.Sequential() 57 | if stride != 1 or in_planes != self.expansion*planes: 58 | self.shortcut = nn.Sequential( 59 | nn.Conv2d(in_planes, self.expansion*planes, 60 | kernel_size=1, stride=stride, bias=False), 61 | nn.BatchNorm2d(self.expansion*planes) 62 | ) 63 | 64 | def forward(self, x): 65 | out = F.relu(self.bn1(self.conv1(x))) 66 | out = F.relu(self.bn2(self.conv2(out))) 67 | out = self.bn3(self.conv3(out)) 68 | out += self.shortcut(x) 69 | out = F.relu(out) 70 | return out 71 | 72 | 73 | class ResNet(nn.Module): 74 | def __init__(self, block, num_blocks, num_classes=200): 75 | super(ResNet, self).__init__() 76 | self.in_planes = 64 77 | 78 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, 79 | stride=1, padding=1, bias=False) 80 | self.bn1 = nn.BatchNorm2d(64) 81 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 82 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 83 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 84 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 85 | self.linear = nn.Linear(512*block.expansion, num_classes) 86 | 87 | def _make_layer(self, block, planes, num_blocks, stride): 88 | strides = [stride] + [1]*(num_blocks-1) 89 | layers = [] 90 | for stride in strides: 91 | layers.append(block(self.in_planes, planes, stride)) 92 | self.in_planes = planes * block.expansion 93 | return nn.Sequential(*layers) 94 | 95 | def forward(self, x): 96 | out = F.relu(self.bn1(self.conv1(x))) 97 | out = self.layer1(out) 98 | out = self.layer2(out) 99 | out = self.layer3(out) 100 | out = self.layer4(out) 101 | out = nn.AdaptiveAvgPool2d(1)(out) 102 | out = out.view(out.size(0), -1) 103 | out = self.linear(out) 104 | return out 105 | 106 | 107 | def ResNet18(): 108 | return ResNet(BasicBlock, [2, 2, 2, 2]) 109 | 110 | 111 | def ResNet34(): 112 | return ResNet(BasicBlock, [3, 4, 6, 3]) 113 | 114 | 115 | def ResNet50(): 116 | return ResNet(Bottleneck, [3, 4, 6, 3]) 117 | 118 | 119 | def ResNet101(): 120 | return ResNet(Bottleneck, [3, 4, 23, 3]) 121 | 122 | 123 | def ResNet152(): 124 | return ResNet(Bottleneck, [3, 8, 36, 3]) 125 | 126 | 127 | def test(): 128 | net = ResNet18() 129 | y = net(torch.randn(1, 3, 32, 32)) 130 | print(y.size()) 131 | 132 | # test() 133 | -------------------------------------------------------------------------------- /Supplements/Fig_S5/models/resnext.py: -------------------------------------------------------------------------------- 1 | '''ResNeXt in PyTorch. 2 | 3 | See the paper "Aggregated Residual Transformations for Deep Neural Networks" for more details. 4 | ''' 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class Block(nn.Module): 11 | '''Grouped convolution block.''' 12 | expansion = 2 13 | 14 | def __init__(self, in_planes, cardinality=32, bottleneck_width=4, stride=1): 15 | super(Block, self).__init__() 16 | group_width = cardinality * bottleneck_width 17 | self.conv1 = nn.Conv2d(in_planes, group_width, kernel_size=1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(group_width) 19 | self.conv2 = nn.Conv2d(group_width, group_width, kernel_size=3, stride=stride, padding=1, groups=cardinality, bias=False) 20 | self.bn2 = nn.BatchNorm2d(group_width) 21 | self.conv3 = nn.Conv2d(group_width, self.expansion*group_width, kernel_size=1, bias=False) 22 | self.bn3 = nn.BatchNorm2d(self.expansion*group_width) 23 | 24 | self.shortcut = nn.Sequential() 25 | if stride != 1 or in_planes != self.expansion*group_width: 26 | self.shortcut = nn.Sequential( 27 | nn.Conv2d(in_planes, self.expansion*group_width, kernel_size=1, stride=stride, bias=False), 28 | nn.BatchNorm2d(self.expansion*group_width) 29 | ) 30 | 31 | def forward(self, x): 32 | out = F.relu(self.bn1(self.conv1(x))) 33 | out = F.relu(self.bn2(self.conv2(out))) 34 | out = self.bn3(self.conv3(out)) 35 | out += self.shortcut(x) 36 | out = F.relu(out) 37 | return out 38 | 39 | 40 | class ResNeXt(nn.Module): 41 | def __init__(self, num_blocks, cardinality, bottleneck_width, num_classes=10): 42 | super(ResNeXt, self).__init__() 43 | self.cardinality = cardinality 44 | self.bottleneck_width = bottleneck_width 45 | self.in_planes = 64 46 | 47 | self.conv1 = nn.Conv2d(3, 64, kernel_size=1, bias=False) 48 | self.bn1 = nn.BatchNorm2d(64) 49 | self.layer1 = self._make_layer(num_blocks[0], 1) 50 | self.layer2 = self._make_layer(num_blocks[1], 2) 51 | self.layer3 = self._make_layer(num_blocks[2], 2) 52 | # self.layer4 = self._make_layer(num_blocks[3], 2) 53 | self.linear = nn.Linear(cardinality*bottleneck_width*8, num_classes) 54 | 55 | def _make_layer(self, num_blocks, stride): 56 | strides = [stride] + [1]*(num_blocks-1) 57 | layers = [] 58 | for stride in strides: 59 | layers.append(Block(self.in_planes, self.cardinality, self.bottleneck_width, stride)) 60 | self.in_planes = Block.expansion * self.cardinality * self.bottleneck_width 61 | # Increase bottleneck_width by 2 after each stage. 62 | self.bottleneck_width *= 2 63 | return nn.Sequential(*layers) 64 | 65 | def forward(self, x): 66 | out = F.relu(self.bn1(self.conv1(x))) 67 | out = self.layer1(out) 68 | out = self.layer2(out) 69 | out = self.layer3(out) 70 | # out = self.layer4(out) 71 | out = F.avg_pool2d(out, 8) 72 | out = out.view(out.size(0), -1) 73 | out = self.linear(out) 74 | return out 75 | 76 | 77 | def ResNeXt29_2x64d(): 78 | return ResNeXt(num_blocks=[3,3,3], cardinality=2, bottleneck_width=64) 79 | 80 | def ResNeXt29_4x64d(): 81 | return ResNeXt(num_blocks=[3,3,3], cardinality=4, bottleneck_width=64) 82 | 83 | def ResNeXt29_8x64d(): 84 | return ResNeXt(num_blocks=[3,3,3], cardinality=8, bottleneck_width=64) 85 | 86 | def ResNeXt29_32x4d(): 87 | return ResNeXt(num_blocks=[3,3,3], cardinality=32, bottleneck_width=4) 88 | 89 | def test_resnext(): 90 | net = ResNeXt29_2x64d() 91 | x = torch.randn(1,3,32,32) 92 | y = net(x) 93 | print(y.size()) 94 | 95 | # test_resnext() 96 | -------------------------------------------------------------------------------- /Supplements/Fig_S5/models/senet.py: -------------------------------------------------------------------------------- 1 | '''SENet in PyTorch. 2 | 3 | SENet is the winner of ImageNet-2017. The paper is not released yet. 4 | ''' 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class BasicBlock(nn.Module): 11 | def __init__(self, in_planes, planes, stride=1): 12 | super(BasicBlock, self).__init__() 13 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 14 | self.bn1 = nn.BatchNorm2d(planes) 15 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 16 | self.bn2 = nn.BatchNorm2d(planes) 17 | 18 | self.shortcut = nn.Sequential() 19 | if stride != 1 or in_planes != planes: 20 | self.shortcut = nn.Sequential( 21 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False), 22 | nn.BatchNorm2d(planes) 23 | ) 24 | 25 | # SE layers 26 | self.fc1 = nn.Conv2d(planes, planes//16, kernel_size=1) # Use nn.Conv2d instead of nn.Linear 27 | self.fc2 = nn.Conv2d(planes//16, planes, kernel_size=1) 28 | 29 | def forward(self, x): 30 | out = F.relu(self.bn1(self.conv1(x))) 31 | out = self.bn2(self.conv2(out)) 32 | 33 | # Squeeze 34 | w = F.avg_pool2d(out, out.size(2)) 35 | w = F.relu(self.fc1(w)) 36 | w = F.sigmoid(self.fc2(w)) 37 | # Excitation 38 | out = out * w # New broadcasting feature from v0.2! 39 | 40 | out += self.shortcut(x) 41 | out = F.relu(out) 42 | return out 43 | 44 | 45 | class PreActBlock(nn.Module): 46 | def __init__(self, in_planes, planes, stride=1): 47 | super(PreActBlock, self).__init__() 48 | self.bn1 = nn.BatchNorm2d(in_planes) 49 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 50 | self.bn2 = nn.BatchNorm2d(planes) 51 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 52 | 53 | if stride != 1 or in_planes != planes: 54 | self.shortcut = nn.Sequential( 55 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False) 56 | ) 57 | 58 | # SE layers 59 | self.fc1 = nn.Conv2d(planes, planes//16, kernel_size=1) 60 | self.fc2 = nn.Conv2d(planes//16, planes, kernel_size=1) 61 | 62 | def forward(self, x): 63 | out = F.relu(self.bn1(x)) 64 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 65 | out = self.conv1(out) 66 | out = self.conv2(F.relu(self.bn2(out))) 67 | 68 | # Squeeze 69 | w = F.avg_pool2d(out, out.size(2)) 70 | w = F.relu(self.fc1(w)) 71 | w = F.sigmoid(self.fc2(w)) 72 | # Excitation 73 | out = out * w 74 | 75 | out += shortcut 76 | return out 77 | 78 | 79 | class SENet(nn.Module): 80 | def __init__(self, block, num_blocks, num_classes=10): 81 | super(SENet, self).__init__() 82 | self.in_planes = 64 83 | 84 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 85 | self.bn1 = nn.BatchNorm2d(64) 86 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 87 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 88 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 89 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 90 | self.linear = nn.Linear(512, num_classes) 91 | 92 | def _make_layer(self, block, planes, num_blocks, stride): 93 | strides = [stride] + [1]*(num_blocks-1) 94 | layers = [] 95 | for stride in strides: 96 | layers.append(block(self.in_planes, planes, stride)) 97 | self.in_planes = planes 98 | return nn.Sequential(*layers) 99 | 100 | def forward(self, x): 101 | out = F.relu(self.bn1(self.conv1(x))) 102 | out = self.layer1(out) 103 | out = self.layer2(out) 104 | out = self.layer3(out) 105 | out = self.layer4(out) 106 | out = F.avg_pool2d(out, 4) 107 | out = out.view(out.size(0), -1) 108 | out = self.linear(out) 109 | return out 110 | 111 | 112 | def SENet18(): 113 | return SENet(PreActBlock, [2,2,2,2]) 114 | 115 | 116 | def test(): 117 | net = SENet18() 118 | y = net(torch.randn(1,3,32,32)) 119 | print(y.size()) 120 | 121 | # test() 122 | -------------------------------------------------------------------------------- /Supplements/Fig_S5/models/shufflenet.py: -------------------------------------------------------------------------------- 1 | '''ShuffleNet in PyTorch. 2 | 3 | See the paper "ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices" for more details. 4 | ''' 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class ShuffleBlock(nn.Module): 11 | def __init__(self, groups): 12 | super(ShuffleBlock, self).__init__() 13 | self.groups = groups 14 | 15 | def forward(self, x): 16 | '''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]''' 17 | N,C,H,W = x.size() 18 | g = self.groups 19 | return x.view(N,g,C//g,H,W).permute(0,2,1,3,4).reshape(N,C,H,W) 20 | 21 | 22 | class Bottleneck(nn.Module): 23 | def __init__(self, in_planes, out_planes, stride, groups): 24 | super(Bottleneck, self).__init__() 25 | self.stride = stride 26 | 27 | mid_planes = out_planes/4 28 | g = 1 if in_planes==24 else groups 29 | self.conv1 = nn.Conv2d(in_planes, mid_planes, kernel_size=1, groups=g, bias=False) 30 | self.bn1 = nn.BatchNorm2d(mid_planes) 31 | self.shuffle1 = ShuffleBlock(groups=g) 32 | self.conv2 = nn.Conv2d(mid_planes, mid_planes, kernel_size=3, stride=stride, padding=1, groups=mid_planes, bias=False) 33 | self.bn2 = nn.BatchNorm2d(mid_planes) 34 | self.conv3 = nn.Conv2d(mid_planes, out_planes, kernel_size=1, groups=groups, bias=False) 35 | self.bn3 = nn.BatchNorm2d(out_planes) 36 | 37 | self.shortcut = nn.Sequential() 38 | if stride == 2: 39 | self.shortcut = nn.Sequential(nn.AvgPool2d(3, stride=2, padding=1)) 40 | 41 | def forward(self, x): 42 | out = F.relu(self.bn1(self.conv1(x))) 43 | out = self.shuffle1(out) 44 | out = F.relu(self.bn2(self.conv2(out))) 45 | out = self.bn3(self.conv3(out)) 46 | res = self.shortcut(x) 47 | out = F.relu(torch.cat([out,res], 1)) if self.stride==2 else F.relu(out+res) 48 | return out 49 | 50 | 51 | class ShuffleNet(nn.Module): 52 | def __init__(self, cfg): 53 | super(ShuffleNet, self).__init__() 54 | out_planes = cfg['out_planes'] 55 | num_blocks = cfg['num_blocks'] 56 | groups = cfg['groups'] 57 | 58 | self.conv1 = nn.Conv2d(3, 24, kernel_size=1, bias=False) 59 | self.bn1 = nn.BatchNorm2d(24) 60 | self.in_planes = 24 61 | self.layer1 = self._make_layer(out_planes[0], num_blocks[0], groups) 62 | self.layer2 = self._make_layer(out_planes[1], num_blocks[1], groups) 63 | self.layer3 = self._make_layer(out_planes[2], num_blocks[2], groups) 64 | self.linear = nn.Linear(out_planes[2], 10) 65 | 66 | def _make_layer(self, out_planes, num_blocks, groups): 67 | layers = [] 68 | for i in range(num_blocks): 69 | stride = 2 if i == 0 else 1 70 | cat_planes = self.in_planes if i == 0 else 0 71 | layers.append(Bottleneck(self.in_planes, out_planes-cat_planes, stride=stride, groups=groups)) 72 | self.in_planes = out_planes 73 | return nn.Sequential(*layers) 74 | 75 | def forward(self, x): 76 | out = F.relu(self.bn1(self.conv1(x))) 77 | out = self.layer1(out) 78 | out = self.layer2(out) 79 | out = self.layer3(out) 80 | out = F.avg_pool2d(out, 4) 81 | out = out.view(out.size(0), -1) 82 | out = self.linear(out) 83 | return out 84 | 85 | 86 | def ShuffleNetG2(): 87 | cfg = { 88 | 'out_planes': [200,400,800], 89 | 'num_blocks': [4,8,4], 90 | 'groups': 2 91 | } 92 | return ShuffleNet(cfg) 93 | 94 | def ShuffleNetG3(): 95 | cfg = { 96 | 'out_planes': [240,480,960], 97 | 'num_blocks': [4,8,4], 98 | 'groups': 3 99 | } 100 | return ShuffleNet(cfg) 101 | 102 | 103 | def test(): 104 | net = ShuffleNetG2() 105 | x = torch.randn(1,3,32,32) 106 | y = net(x) 107 | print(y) 108 | 109 | # test() 110 | -------------------------------------------------------------------------------- /Supplements/Fig_S5/models/shufflenetv2.py: -------------------------------------------------------------------------------- 1 | '''ShuffleNetV2 in PyTorch. 2 | 3 | See the paper "ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" for more details. 4 | ''' 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class ShuffleBlock(nn.Module): 11 | def __init__(self, groups=2): 12 | super(ShuffleBlock, self).__init__() 13 | self.groups = groups 14 | 15 | def forward(self, x): 16 | '''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]''' 17 | N, C, H, W = x.size() 18 | g = self.groups 19 | return x.view(N, g, C//g, H, W).permute(0, 2, 1, 3, 4).reshape(N, C, H, W) 20 | 21 | 22 | class SplitBlock(nn.Module): 23 | def __init__(self, ratio): 24 | super(SplitBlock, self).__init__() 25 | self.ratio = ratio 26 | 27 | def forward(self, x): 28 | c = int(x.size(1) * self.ratio) 29 | return x[:, :c, :, :], x[:, c:, :, :] 30 | 31 | 32 | class BasicBlock(nn.Module): 33 | def __init__(self, in_channels, split_ratio=0.5): 34 | super(BasicBlock, self).__init__() 35 | self.split = SplitBlock(split_ratio) 36 | in_channels = int(in_channels * split_ratio) 37 | self.conv1 = nn.Conv2d(in_channels, in_channels, 38 | kernel_size=1, bias=False) 39 | self.bn1 = nn.BatchNorm2d(in_channels) 40 | self.conv2 = nn.Conv2d(in_channels, in_channels, 41 | kernel_size=3, stride=1, padding=1, groups=in_channels, bias=False) 42 | self.bn2 = nn.BatchNorm2d(in_channels) 43 | self.conv3 = nn.Conv2d(in_channels, in_channels, 44 | kernel_size=1, bias=False) 45 | self.bn3 = nn.BatchNorm2d(in_channels) 46 | self.shuffle = ShuffleBlock() 47 | 48 | def forward(self, x): 49 | x1, x2 = self.split(x) 50 | out = F.relu(self.bn1(self.conv1(x2))) 51 | out = self.bn2(self.conv2(out)) 52 | out = F.relu(self.bn3(self.conv3(out))) 53 | out = torch.cat([x1, out], 1) 54 | out = self.shuffle(out) 55 | return out 56 | 57 | 58 | class DownBlock(nn.Module): 59 | def __init__(self, in_channels, out_channels): 60 | super(DownBlock, self).__init__() 61 | mid_channels = out_channels // 2 62 | # left 63 | self.conv1 = nn.Conv2d(in_channels, in_channels, 64 | kernel_size=3, stride=2, padding=1, groups=in_channels, bias=False) 65 | self.bn1 = nn.BatchNorm2d(in_channels) 66 | self.conv2 = nn.Conv2d(in_channels, mid_channels, 67 | kernel_size=1, bias=False) 68 | self.bn2 = nn.BatchNorm2d(mid_channels) 69 | # right 70 | self.conv3 = nn.Conv2d(in_channels, mid_channels, 71 | kernel_size=1, bias=False) 72 | self.bn3 = nn.BatchNorm2d(mid_channels) 73 | self.conv4 = nn.Conv2d(mid_channels, mid_channels, 74 | kernel_size=3, stride=2, padding=1, groups=mid_channels, bias=False) 75 | self.bn4 = nn.BatchNorm2d(mid_channels) 76 | self.conv5 = nn.Conv2d(mid_channels, mid_channels, 77 | kernel_size=1, bias=False) 78 | self.bn5 = nn.BatchNorm2d(mid_channels) 79 | 80 | self.shuffle = ShuffleBlock() 81 | 82 | def forward(self, x): 83 | # left 84 | out1 = self.bn1(self.conv1(x)) 85 | out1 = F.relu(self.bn2(self.conv2(out1))) 86 | # right 87 | out2 = F.relu(self.bn3(self.conv3(x))) 88 | out2 = self.bn4(self.conv4(out2)) 89 | out2 = F.relu(self.bn5(self.conv5(out2))) 90 | # concat 91 | out = torch.cat([out1, out2], 1) 92 | out = self.shuffle(out) 93 | return out 94 | 95 | 96 | class ShuffleNetV2(nn.Module): 97 | def __init__(self, net_size): 98 | super(ShuffleNetV2, self).__init__() 99 | out_channels = configs[net_size]['out_channels'] 100 | num_blocks = configs[net_size]['num_blocks'] 101 | 102 | self.conv1 = nn.Conv2d(3, 24, kernel_size=3, 103 | stride=1, padding=1, bias=False) 104 | self.bn1 = nn.BatchNorm2d(24) 105 | self.in_channels = 24 106 | self.layer1 = self._make_layer(out_channels[0], num_blocks[0]) 107 | self.layer2 = self._make_layer(out_channels[1], num_blocks[1]) 108 | self.layer3 = self._make_layer(out_channels[2], num_blocks[2]) 109 | self.conv2 = nn.Conv2d(out_channels[2], out_channels[3], 110 | kernel_size=1, stride=1, padding=0, bias=False) 111 | self.bn2 = nn.BatchNorm2d(out_channels[3]) 112 | self.linear = nn.Linear(out_channels[3], 10) 113 | 114 | def _make_layer(self, out_channels, num_blocks): 115 | layers = [DownBlock(self.in_channels, out_channels)] 116 | for i in range(num_blocks): 117 | layers.append(BasicBlock(out_channels)) 118 | self.in_channels = out_channels 119 | return nn.Sequential(*layers) 120 | 121 | def forward(self, x): 122 | out = F.relu(self.bn1(self.conv1(x))) 123 | # out = F.max_pool2d(out, 3, stride=2, padding=1) 124 | out = self.layer1(out) 125 | out = self.layer2(out) 126 | out = self.layer3(out) 127 | out = F.relu(self.bn2(self.conv2(out))) 128 | out = F.avg_pool2d(out, 4) 129 | out = out.view(out.size(0), -1) 130 | out = self.linear(out) 131 | return out 132 | 133 | 134 | configs = { 135 | 0.5: { 136 | 'out_channels': (48, 96, 192, 1024), 137 | 'num_blocks': (3, 7, 3) 138 | }, 139 | 140 | 1: { 141 | 'out_channels': (116, 232, 464, 1024), 142 | 'num_blocks': (3, 7, 3) 143 | }, 144 | 1.5: { 145 | 'out_channels': (176, 352, 704, 1024), 146 | 'num_blocks': (3, 7, 3) 147 | }, 148 | 2: { 149 | 'out_channels': (224, 488, 976, 2048), 150 | 'num_blocks': (3, 7, 3) 151 | } 152 | } 153 | 154 | 155 | def test(): 156 | net = ShuffleNetV2(net_size=0.5) 157 | x = torch.randn(3, 3, 32, 32) 158 | y = net(x) 159 | print(y.shape) 160 | 161 | 162 | # test() 163 | -------------------------------------------------------------------------------- /Supplements/Fig_S5/models/vgg.py: -------------------------------------------------------------------------------- 1 | '''VGG11/13/16/19 in Pytorch.''' 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | cfg = { 7 | 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 8 | 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 9 | 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 10 | 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 11 | } 12 | 13 | 14 | class VGG(nn.Module): 15 | def __init__(self, vgg_name): 16 | super(VGG, self).__init__() 17 | self.features = self._make_layers(cfg[vgg_name]) 18 | self.classifier = nn.Linear(512, 10) 19 | 20 | def forward(self, x): 21 | out = self.features(x) 22 | out = out.view(out.size(0), -1) 23 | out = self.classifier(out) 24 | return out 25 | 26 | def _make_layers(self, cfg): 27 | layers = [] 28 | in_channels = 3 29 | for x in cfg: 30 | if x == 'M': 31 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 32 | else: 33 | layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1), 34 | nn.BatchNorm2d(x), 35 | nn.ReLU(inplace=True)] 36 | in_channels = x 37 | layers += [nn.AvgPool2d(kernel_size=1, stride=1)] 38 | return nn.Sequential(*layers) 39 | 40 | 41 | def test(): 42 | net = VGG('VGG11') 43 | x = torch.randn(2,3,32,32) 44 | y = net(x) 45 | print(y.size()) 46 | 47 | # test() 48 | -------------------------------------------------------------------------------- /Supplements/Fig_S5/pytorchtools.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | class EarlyStopping: 5 | """Early stops the training if validation loss doesn't improve after a given patience.""" 6 | def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print): 7 | """ 8 | Args: 9 | patience (int): How long to wait after last time validation loss improved. 10 | Default: 7 11 | verbose (bool): If True, prints a message for each validation loss improvement. 12 | Default: False 13 | delta (float): Minimum change in the monitored quantity to qualify as an improvement. 14 | Default: 0 15 | path (str): Path for the checkpoint to be saved to. 16 | Default: 'checkpoint.pt' 17 | trace_func (function): trace print function. 18 | Default: print 19 | """ 20 | self.patience = patience 21 | self.verbose = verbose 22 | self.counter = 0 23 | self.best_score = None 24 | self.early_stop = False 25 | self.val_loss_min = np.Inf 26 | self.delta = delta 27 | self.path = path 28 | self.trace_func = trace_func 29 | def __call__(self, val_loss, model): 30 | 31 | score = -val_loss 32 | 33 | if self.best_score is None: 34 | self.best_score = score 35 | self.save_checkpoint(val_loss, model) 36 | elif score < self.best_score + self.delta: 37 | self.counter += 1 38 | self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}') 39 | if self.counter >= self.patience: 40 | self.early_stop = True 41 | else: 42 | self.best_score = score 43 | self.save_checkpoint(val_loss, model) 44 | self.counter = 0 45 | 46 | def save_checkpoint(self, val_loss, model): 47 | '''Saves model when validation loss decrease.''' 48 | if self.verbose: 49 | self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...') 50 | torch.save(model.state_dict(), self.path) 51 | self.val_loss_min = val_loss 52 | -------------------------------------------------------------------------------- /Supplements/Fig_S6i_j.m: -------------------------------------------------------------------------------- 1 | close all 2 | clear all 3 | 4 | nepoch = 2000; 5 | learnrate = 0.005; 6 | N_x_t = 100; 7 | N_y_t = 1; 8 | P=100; 9 | M = 5000; %num of units in notebook 10 | 11 | Et_big_early_stop = []; 12 | Eg_big_early_stop = []; 13 | Et_big = []; 14 | Eg_big = []; 15 | 16 | Fig3j_data = zeros(2,2); 17 | 18 | Fig3k_data = zeros(2,2); 19 | 20 | 21 | SNR_vec = [0.6 1000]; 22 | 23 | 24 | Notebook_Train = (P-1)/(M-1); % Analytial solution for notebook training error 25 | 26 | for count = 1:size(SNR_vec,2) 27 | SNR = SNR_vec(count); 28 | %Analytical solution 29 | Eg = []; 30 | Et = []; 31 | lr = learnrate; 32 | 33 | %use normalized variances 34 | if SNR == inf 35 | variance_w = 1; 36 | variance_e = 0; 37 | else 38 | variance_w = SNR/(SNR + 1); 39 | variance_e = 1/(SNR + 1); 40 | end 41 | 42 | alpha = P/N_x_t; % number of examples divided by input dimension 43 | 44 | % Analytical curves for training and testing errors 45 | for t = 1:1:nepoch 46 | 47 | %correct integral 48 | train = @(lam) ( ( ( ((alpha^0.5+1).^2 - lam) .* (lam - (alpha^0.5-1).^2) ).^0.5) ./ (lam*2*pi) ).* ( lam.*variance_w + variance_e ).*exp(-2*lam.*t./(1./lr)) ; 49 | Et = [Et (1/alpha)*(integral(train,(alpha^0.5-1)^2,(alpha^0.5+1)^2) + (alpha<1)*(1 - alpha)* variance_e ) + (1-1/alpha)*variance_e]; 50 | 51 | test = @(lam) ( ( ( ((alpha^0.5+1).^2 - lam) .* (lam - (alpha^0.5-1).^2) ).^0.5) ./ (lam*2*pi) ).*(exp(-2*lam*t/(1/lr)) + ((1-exp(-lam*t/(1/lr))).^2)./(lam*SNR)); 52 | Eg = [Eg variance_w*(integral(test,(alpha^0.5-1)^2,(alpha^0.5+1)^2) + (alpha<1)* (1 - alpha) + 1/SNR)]; 53 | 54 | end 55 | 56 | % Early stopping curves 57 | [mm, pp] = min(Eg); 58 | Eg_early_stop = Eg; 59 | Eg_early_stop(pp+1:end) = Eg(pp); 60 | 61 | Et_early_stop = Et; 62 | Et_early_stop(pp+1:end) = Et(pp); 63 | 64 | % Pick better training error value and create memory generalization 65 | % scores 66 | better_train_no_early_stop = min(Et,ones(1,nepoch)*Notebook_Train); 67 | control_curve = (Et(1) - better_train_no_early_stop)/Et(1); 68 | lesion_curve = (Et(1) - Et)/Et(1); 69 | 70 | 71 | better_train_yes_early_stop = min(Et_early_stop,ones(1,nepoch)*Notebook_Train); 72 | control_curve_early_stop = (Et_early_stop(1) - better_train_yes_early_stop)/Et_early_stop(1); 73 | lesion_curve_early_stop = (Et_early_stop(1) - Et_early_stop)/Et_early_stop(1); 74 | 75 | 76 | control_Eg_curve = (Eg(1) - Eg)/Eg(1); 77 | control_Eg_curve_early_stop = (Eg_early_stop(1) - Eg_early_stop)/Eg_early_stop(1); 78 | 79 | Fig3j_data(count,1) = control_curve_early_stop(end); 80 | Fig3j_data(count,2) = control_Eg_curve_early_stop(end); 81 | 82 | Fig3k_data(count,1) = control_curve_early_stop(end); 83 | Fig3k_data(count,2) = lesion_curve_early_stop(end); 84 | 85 | end 86 | 87 | 88 | figure(1) 89 | 90 | x = [[0.7,1.3];... 91 | [2.7, 3.3]]; 92 | 93 | data = Fig3j_data; 94 | 95 | 96 | f=bar(x,data*100); 97 | f(1).BarWidth = 3.2; 98 | xlim([0 4]) 99 | ylim([0 120]) 100 | hold on 101 | 102 | ax = gca; 103 | ax.XTick = [1 3]; 104 | ax.YTick = [0 20 40 60 80 100]; 105 | xax = ax.XAxis; 106 | set(xax,'TickDirection','out') 107 | set(gca,'box','off') 108 | set(gcf,'position',[600,400,340,210]) 109 | set(gca, 'FontSize', 16) 110 | 111 | 112 | ax.XTickLabel=[{'Discriminator'} {'Generalizer'}]; 113 | 114 | 115 | figure(2) 116 | 117 | x = [[0.7,1.3];... 118 | [2.7, 3.3]]; 119 | 120 | data = Fig3k_data; 121 | 122 | 123 | f=bar(x,data*100); 124 | f(1).BarWidth = 3.2; 125 | xlim([0 4]) 126 | ylim([0 120]) 127 | hold on 128 | 129 | ax = gca; 130 | ax.XTick = [1 3]; 131 | ax.YTick = [0 20 40 60 80 100]; 132 | 133 | xax = ax.XAxis; 134 | set(xax,'TickDirection','out') 135 | set(gca,'box','off') 136 | set(gcf,'position',[600,100,340,210]) 137 | set(gca, 'FontSize', 16) 138 | 139 | 140 | ax.XTickLabel=[{'Discriminator'} {'Generalizer'}]; 141 | -------------------------------------------------------------------------------- /Supplements/Fig_S6l.m: -------------------------------------------------------------------------------- 1 | % code for Student-Teacher-Notebook framework,simulating Sweegers 2 | % et al., 2014 3 | tic 4 | close all 5 | clear all 6 | 7 | r_n = 3; % number of repeats 8 | nepoch = 2000; 9 | learnrate = 0.015; 10 | N_x_t = 100; % teacher input dimension 11 | N_y_t = 1; % teacher output dimension 12 | P=100; % number of training examples 13 | P_test = 1000; % number of testing examples 14 | 15 | variance_s = 0.5; % student weight variance at initialization 16 | 17 | % According to SNR, set variances for teacher's weights (variance_w) and output noise (variance_e) that sum to 1 18 | SNR_vec = [0.05 2]; % set values to 0.05 and 2 to reproduce fig. 3n 19 | W_s_norm = zeros(size(SNR_vec,2),r_n,nepoch); 20 | 21 | for SNR_counter=1:size(SNR_vec,2) 22 | SNR = SNR_vec(SNR_counter); 23 | 24 | if SNR == inf 25 | variance_w = 1; 26 | variance_e = 0; 27 | else 28 | variance_w = SNR/(SNR + 1); 29 | variance_e = 1/(SNR + 1); 30 | end 31 | 32 | % Student and teacher share the same dimensions 33 | N_x_s = N_x_t; 34 | N_y_s = N_y_t; 35 | 36 | % Notebook parameters 37 | % see Buhmann, Divko, and Schulten, 1989 for details regarding gamma and U terms 38 | 39 | M = 2000; % num of units in notebook 40 | a = 0.05; % notebook sparseness 41 | gamma = 0.6; % inhibtion parameter 42 | U = -0.15; % threshold for unit activation 43 | ncycle = 9; % number of recurrent cycles 44 | 45 | 46 | % Matrices for storing train error, test error, reactivation error (driven by notebook) 47 | % Without early stopping 48 | train_error_all = zeros(r_n,nepoch); 49 | test_error_all = zeros(r_n,nepoch); 50 | N_train_error_all = zeros(r_n,nepoch); 51 | N_test_error_all = zeros(r_n,nepoch); 52 | 53 | % With early stopping 54 | train_error_early_stop_all = zeros(r_n,nepoch); 55 | test_error_early_stop_all = zeros(r_n,nepoch); 56 | 57 | 58 | %Run simulation for r_n times 59 | for r = 1:r_n 60 | disp(r) 61 | rng(r); %set random seed for reproducibility 62 | 63 | %Errors 64 | error_train_vector = zeros(nepoch,1); 65 | error_test_vector = zeros(nepoch,1); 66 | error_react_vector = zeros(nepoch,1); 67 | 68 | %% Teacher Network 69 | W_t = normrnd(0,variance_w^0.5,[N_x_t,N_y_t]);% set teacher's weights 70 | noise_train = normrnd(0,variance_e^0.5,[P,N_y_t]); 71 | % Training data 72 | x_t_input = normrnd(0,(1/N_x_t)^0.5,[P,N_x_t]); % inputs 73 | y_t_output = x_t_input*W_t + noise_train; % outputs 74 | 75 | % Testing data 76 | noise_test = normrnd(0,variance_e^0.5,[P_test,N_y_t]); 77 | x_t_input_test = normrnd(0,(1/N_x_t)^0.5,[P_test,N_x_t]); 78 | y_t_output_test = x_t_input_test*W_t + noise_test; 79 | 80 | %% Notebook Network 81 | % Generate P random binary indices with sparseness a 82 | N_patterns = zeros(P,M); 83 | for n=1:P 84 | N_patterns(n,randperm(M,M*a))=1; 85 | end 86 | 87 | %Hebbian learning for notebook recurrent weights 88 | W_N = (N_patterns - a)'*(N_patterns - a)/(M*a*(1-a)); 89 | W_N = W_N - gamma/(a*M);% add global inhibiton term, see Buhmann, Divko, and Schulten, 1989 90 | W_N = W_N.*~eye(size(W_N)); % set diagonal weights to zero 91 | 92 | % Hebbian learning for Notebook-Student weights (bidirectional) 93 | % Notebook to student weights 94 | W_N_S_Lin = (N_patterns-a)'*x_t_input/(M*a*(1-a)); 95 | W_N_S_Lout = (N_patterns-a)'*y_t_output/(M*a*(1-a)); 96 | % Student to notebook weights 97 | W_S_N_Lin = x_t_input'*(N_patterns-a); 98 | W_S_N_Lout = y_t_output'*(N_patterns-a); 99 | 100 | %% Student Network 101 | W_s = normrnd(0,variance_s^0.5,[N_x_s,N_y_s]); % set student's initial weights 102 | 103 | %% Generate offline training data from notebook reactivations 104 | N_patterns_reactivated = zeros(P,M,nepoch,'logical'); % array for storing retrieved notebook patterns, pre-calculating all epochs for speed considerations 105 | 106 | parfor m = 1:nepoch 107 | %for m = 1:nepoch % change to regular for-loop without multiple cores 108 | 109 | %% Notebook pattern completion through recurrent dynamis 110 | % Code below simulates hippocampal offline spontanenous 111 | % reactivations by seeding the initial notebook state with a random 112 | % binary index, then notebook goes through a two-step retrieval 113 | % process: (1) Retrieving a pattern using dynamic threshold to 114 | % ensure a pattern with sparseness a is retrieved. (2) Using the 115 | % retrieved pattern from (1) to seed a second round of pattern 116 | % completion using a fixed-threshold method (along with a global 117 | % inhibition term during encoding), so the retrieved patterns are 118 | % not forced to have a fixed sparseness, in addition, there is a 119 | % "silent state" attractor when the seeding pattern lies far away 120 | % from any of the encoded patterns. 121 | 122 | % Start recurrent cycles with dynamic threshold 123 | Activity_dyn_t = zeros(P, M); 124 | 125 | % First round of pattern completion through recurrent activtion cycles given 126 | % random initial input. 127 | for cycle = 1:ncycle 128 | if cycle <=1 129 | clamp = 1; 130 | else 131 | clamp = 0; 132 | end 133 | rand_patt = (rand(P,M)<=a); 134 | % Seeding notebook with random patterns 135 | M_input = Activity_dyn_t + (rand_patt*clamp); 136 | % Seeding notebook with original patterns 137 | %M_input = Activity_dyn_t + (N_patterns*clamp); 138 | M_current = M_input*W_N; 139 | % scale currents between 0 and 1 140 | scale = 1.0 ./ (max(M_current,[],2) - min(M_current,[],2)); 141 | M_current = (M_current - min(M_current,[],2)) .* scale; 142 | % find threshold based on desired sparseness 143 | sorted_M_current = sort(M_current,2,'descend'); 144 | t_ind = floor(size(Activity_dyn_t,2) * a); 145 | t_ind(t_ind<1) = 1; 146 | t = sorted_M_current(:,t_ind); % threshold for unit activations 147 | Activity_dyn_t = (M_current >=t); 148 | end 149 | 150 | % Second round of pattern completion, with fix threshold 151 | Activity_fix_t = zeros(P, M); 152 | for cycle = 1:ncycle 153 | if cycle <=1 154 | clamp = 1; 155 | else 156 | clamp = 0; 157 | end 158 | M_input = Activity_fix_t + Activity_dyn_t*clamp; 159 | M_current = M_input*W_N; 160 | Activity_fix_t = (M_current >= U); % U is the fixed threshold 161 | end 162 | N_patterns_reactivated(:,:,m)=Activity_fix_t; 163 | end 164 | 165 | %% Seeding notebook with original notebook patterns for calculating 166 | % training error mediated by notebook (seeding notebook with student 167 | % input via Student's input to Notebook weights, once pattern completion 168 | % finishes, use the retrieved pattern to activate Student's output unit 169 | % via Notebook to Student's output weights. 170 | 171 | Activity_notebook_train = zeros(P, M); 172 | for cycle = 1:ncycle 173 | if cycle <=1 174 | clamp = 1; 175 | else 176 | clamp = 0; 177 | end 178 | seed_patt = x_t_input*W_S_N_Lin; 179 | M_input = Activity_notebook_train + (seed_patt*clamp); 180 | M_current = M_input*W_N; 181 | scale = 1.0 ./ (max(M_current,[],2) - min(M_current,[],2)); 182 | M_current = (M_current - min(M_current,[],2)) .* scale; 183 | sorted_M_current = sort(M_current,2,'descend'); 184 | t_ind = floor(size(Activity_notebook_train,2) * a); 185 | t_ind(t_ind<1) = 1; 186 | t = sorted_M_current(:,t_ind); 187 | Activity_notebook_train = (M_current >=t); 188 | end 189 | N_S_output_train = Activity_notebook_train*W_N_S_Lout; 190 | % Notebook training error 191 | delta_N_train = y_t_output - N_S_output_train; 192 | error_N_train = sum(delta_N_train.^2)/P; 193 | % Since notebook errors stay constant throughout training, 194 | % populating each epoch with the same value 195 | error_N_train_vector = ones(nepoch,1)*error_N_train; 196 | N_train_error_all(r,:) = error_N_train_vector; 197 | 198 | % Notebook generalization error 199 | Activity_notebook_test = zeros(P_test, M); 200 | for cycle = 1:ncycle 201 | if cycle <=1 202 | clamp = 1; 203 | else 204 | clamp = 0; 205 | end 206 | seed_patt = x_t_input_test*W_S_N_Lin; 207 | M_input = Activity_notebook_test + (seed_patt*clamp); 208 | M_current = M_input*W_N; 209 | scale = 1.0 ./ (max(M_current,[],2) - min(M_current,[],2)); 210 | M_current = (M_current - min(M_current,[],2)) .* scale; 211 | sorted_M_current = sort(M_current,2,'descend'); 212 | t_ind = floor(size(Activity_notebook_test,2) * a); 213 | t_ind(t_ind<1) = 1; 214 | t = sorted_M_current(:,t_ind); 215 | Activity_notebook_test = (M_current >=t); 216 | end 217 | N_S_output_test = Activity_notebook_test*W_N_S_Lout; 218 | % Notebook test error 219 | delta_N_test = y_t_output_test - N_S_output_test; 220 | error_N_test = sum(delta_N_test.^2)/P_test; 221 | % Since notebook errors stay constant throughout training, 222 | % populating each epoch with the same value 223 | error_N_test_vector = ones(nepoch,1)*error_N_test; 224 | N_test_error_all(r,:) = error_N_test_vector; 225 | 226 | 227 | N_patterns_reactivated_test = zeros(P_test,M,'logical'); 228 | %% Student training 229 | for m = 1:nepoch %batch training starts 230 | 231 | W_s_norm(SNR_counter,r,m) = norm(W_s,2); 232 | 233 | N_S_input = N_patterns_reactivated(:,:,m)*W_N_S_Lin; % notebook reactivated student input activity 234 | N_S_output = N_patterns_reactivated(:,:,m)*W_N_S_Lout; % notebook reactivated student output activity 235 | N_S_prediction = N_S_input*W_s; % student output prediction calculated by notebook reactivated input and student weights 236 | S_prediction = x_t_input*W_s; % student output prediction calculated by true training inputs and student weights 237 | S_prediction_test = x_t_input_test*W_s; % student output prediction calculated by true testing inputs and student weights 238 | 239 | % Train error 240 | delta_train = y_t_output - S_prediction; 241 | error_train = sum(delta_train.^2)/P; 242 | error_train_vector(m) = error_train; 243 | 244 | % Generalization error 245 | delta_test = y_t_output_test - S_prediction_test; 246 | error_test = sum(delta_test.^2)/P_test; 247 | error_test_vector(m) = error_test; 248 | 249 | % Gradient descent 250 | w_delta = N_S_input'*N_S_output - N_S_input'*N_S_input*W_s; 251 | W_s = W_s + learnrate*w_delta; 252 | end 253 | 254 | train_error_all(r,:) = error_train_vector; 255 | test_error_all(r,:) = error_test_vector; 256 | 257 | % Early stopping 258 | [min_v, min_p] = min(error_test_vector); 259 | train_error_early_stop = error_train_vector; 260 | train_error_early_stop (min_p+1:end) = error_train_vector (min_p); 261 | test_error_early_stop = error_test_vector; 262 | test_error_early_stop (min_p+1:end) = error_test_vector (min_p); 263 | train_error_early_stop_all(r,:) = train_error_early_stop; 264 | test_error_early_stop_all(r,:) = test_error_early_stop; 265 | W_s_norm(SNR_counter,r,min_p+1:end) = W_s_norm(SNR_counter,r,min_p); 266 | end 267 | end 268 | toc 269 | 270 | 271 | 272 | 273 | %Weight norm with early stopping 274 | figure(3) 275 | x = [[0.95,2];... 276 | [1.05, 2]]; 277 | 278 | data = [squeeze(mean(W_s_norm(1,:,[1 2000]),2)/mean(W_s_norm(1,:,1),2))';... 279 | squeeze(mean(W_s_norm(2,:,[1 2000]),2)/mean(W_s_norm(1,:,1),2))']; 280 | 281 | 282 | % errlow = data - err; 283 | % errhigh = data + err; 284 | 285 | f=plot(x',data','o-'); 286 | 287 | xlim([0.5 2.4]) 288 | % ylim([-0.2 1.6]) 289 | hold on 290 | 291 | 292 | ax = gca; 293 | ax.XTick = [1 2]; 294 | %ax.YTick = [0 0.4 0.8 1.2 1.6]; 295 | 296 | xax = ax.XAxis; 297 | set(xax,'TickDirection','out') 298 | set(gca,'box','off') 299 | set(gcf,'position',[600,100, 270,300]) 300 | set(gca, 'FontSize', 16) 301 | ax.XTickLabel=[{'Recent'} {'Remote'}]; 302 | ylabel('Connectivity strength') 303 | 304 | 305 | 306 | 307 | 308 | --------------------------------------------------------------------------------