├── Fig_2
├── Fig_2a_to_f.m
├── Fig_2g_h.m
├── Fig_2i_j.ipynb
├── Fig_2k_l.m
├── Fig_2m_n.ipynb
└── task
│ ├── TSN.py
│ └── __init__.py
├── Fig_3
├── Fig_3a_to_d_f.m
└── Fig_3e.m
├── Fig_4
└── Fig_4d_to_g.m
├── Fig_5
└── Fig_5_e_f.m
├── README.md
└── Supplements
├── Fig_S1_KNN.m
├── Fig_S2_SNR_alpha.m
├── Fig_S3_correlated_biased.ipynb
├── Fig_S4
├── Fig_MLE_Speed_VS_GroundT.m
├── N_100.mat
└── N_10000.mat
├── Fig_S5
├── CIFAR-10.ipynb
├── MNIST.ipynb
├── Tiny_imagenet.ipynb
├── models
│ ├── __init__.py
│ ├── densenet.py
│ ├── dla.py
│ ├── dla_simple.py
│ ├── dpn.py
│ ├── efficientnet.py
│ ├── googlenet.py
│ ├── lenet.py
│ ├── mobilenet.py
│ ├── mobilenetv2.py
│ ├── pnasnet.py
│ ├── preact_resnet.py
│ ├── regnet.py
│ ├── resnet.py
│ ├── resnext.py
│ ├── senet.py
│ ├── shufflenet.py
│ ├── shufflenetv2.py
│ └── vgg.py
└── pytorchtools.py
├── Fig_S6i_j.m
└── Fig_S6l.m
/Fig_2/Fig_2a_to_f.m:
--------------------------------------------------------------------------------
1 | % Main code for Student-Teacher-Notebook framework
2 | % Weinan Sun 10-10-2021
3 |
4 | tic
5 | close all
6 | clear all
7 |
8 | r_n = 20; % number of repeats
9 | nepoch = 2000;
10 | learnrate = 0.015;
11 | N_x_t = 100; % teacher input dimension
12 | N_y_t = 1; % teacher output dimension
13 | P=100; % number of training examples
14 | P_test = 1000; % number of testing examples
15 |
16 | % According to SNR, set variances for teacher's weights (variance_w) and output noise (variance_e) that sum to 1
17 | SNR = inf;
18 |
19 | if SNR == inf
20 | variance_w = 1;
21 | variance_e = 0;
22 | else
23 | variance_w = SNR/(SNR + 1);
24 | variance_e = 1/(SNR + 1);
25 | end
26 |
27 | % Student and teacher share the same dimensions
28 | N_x_s = N_x_t;
29 | N_y_s = N_y_t;
30 |
31 | % Notebook parameters
32 | % see Buhmann, Divko, and Schulten, 1989 for details regarding gamma and U terms
33 |
34 | M = 2000; % num of units in notebook
35 | a = 0.05; % notebook sparseness
36 | gamma = 0.6; % inhibtion parameter
37 | U = -0.15; % threshold for unit activation
38 | ncycle = 9; % number of recurrent cycles
39 |
40 |
41 | % Matrices for storing train error, test error, reactivation error (driven by notebook)
42 | % Without early stopping
43 | train_error_all = zeros(r_n,nepoch); % student train error
44 | test_error_all = zeros(r_n,nepoch); % student test error
45 | N_train_error_all = zeros(r_n,nepoch); % notebook train error
46 | N_test_error_all = zeros(r_n,nepoch); % notebook test error
47 |
48 | % With early stopping
49 | train_error_early_stop_all = zeros(r_n,nepoch);
50 | test_error_early_stop_all = zeros(r_n,nepoch);
51 |
52 | %Run simulation for r_n times
53 | for r = 1:r_n
54 | disp(r)
55 | rng(r); %set random seed for reproducibility
56 |
57 | %Errors
58 | error_train_vector = zeros(nepoch,1);
59 | error_test_vector = zeros(nepoch,1);
60 | error_react_vector = zeros(nepoch,1);
61 |
62 | %% Teacher Network
63 | W_t = normrnd(0,variance_w^0.5,[N_x_t,N_y_t]); % set teacher's weights with variance_w
64 | noise_train = normrnd(0,variance_e^0.5,[P,N_y_t]); % set the variance for label noise
65 | % Training data
66 | x_t_input = normrnd(0,(1/N_x_t)^0.5,[P,N_x_t]); % inputs
67 | y_t_output = x_t_input*W_t + noise_train; % outputs
68 |
69 | % Testing data
70 | noise_test = normrnd(0,variance_e^0.5,[P_test,N_y_t]);
71 | x_t_input_test = normrnd(0,(1/N_x_t)^0.5,[P_test,N_x_t]);
72 | y_t_output_test = x_t_input_test*W_t + noise_test;
73 |
74 | %% Notebook Network
75 | % Generate P random binary indices (0 or 1) with sparseness a
76 | N_patterns = zeros(P,M);
77 | for n=1:P
78 | N_patterns(n,randperm(M,M*a))=1;
79 | end
80 |
81 | %Hebbian learning for notebook recurrent weights
82 | W_N = (N_patterns - a)'*(N_patterns - a)/(M*a*(1-a));
83 | W_N = W_N - gamma/(a*M);% add global inhibiton term, see Buhmann, Divko, and Schulten, 1989
84 | W_N = W_N.*~eye(size(W_N)); % no self connection
85 |
86 | % Hebbian learning for Notebook-Student weights (bidirectional)
87 |
88 | % Notebook to student weights, for reactivating student
89 | W_N_S_Lin = (N_patterns-a)'*x_t_input/(M*a*(1-a));
90 | W_N_S_Lout = (N_patterns-a)'*y_t_output/(M*a*(1-a));
91 | % Student to notebook weights, for providing partial cues
92 | W_S_N_Lin = x_t_input'*(N_patterns-a);
93 | W_S_N_Lout = y_t_output'*(N_patterns-a);
94 |
95 | %% Student Network
96 | W_s = normrnd(0,0^0.5,[N_x_s,N_y_s]); % set students weights, with zero weight initialization
97 |
98 | %% Generate offline training data from notebook reactivations
99 | N_patterns_reactivated = zeros(P,M,nepoch,'logical'); % array for storing retrieved notebook patterns, pre-calculating all epochs for speed considerations
100 |
101 | parfor m = 1:nepoch
102 | %for m = 1:nepoch % regular for-loop
103 |
104 | %% Notebook pattern completion through recurrent dynamis
105 | % Code below simulates hippocampal offline spontanenous
106 | % reactivations by seeding the initial notebook state with a random
107 | % binary pattern, then notebook goes through a two-step retrieval
108 | % process: (1) Retrieving a pattern using dynamic threshold to
109 | % ensure a pattern with sparseness a is retrieved (otherwise a silent
110 | % attractor will dominate retrieval). (2) Using the
111 | % retrieved pattern from (1) to seed a second round of pattern
112 | % completion using a fixed-threshold method (along with a global
113 | % inhibition term during encoding), so the retrieved patterns are
114 | % not forced to have a fixed sparseness, in addition, there is a
115 | % "silent state" attractor when the seeding pattern lies far away
116 | % from any of the encoded patterns.
117 |
118 | % Start recurrent cycles with dynamic threshold
119 | Activity_dyn_t = zeros(P, M);
120 |
121 | % First round of pattern completion through recurrent activtion cycles given
122 | % random initial input.
123 | for cycle = 1:ncycle
124 | if cycle <=1
125 | clamp = 1;
126 | else
127 | clamp = 0;
128 | end
129 | rand_patt = (rand(P,M)<=a); % random seeding activity
130 | % Seeding notebook with random patterns
131 | M_input = Activity_dyn_t + (rand_patt*clamp);
132 | % Seeding notebook with original patterns
133 | % M_input = Activity_dyn_t + (N_patterns*clamp);
134 | M_current = M_input*W_N;
135 | % scale currents between 0 and 1
136 | scale = 1.0 ./ (max(M_current,[],2) - min(M_current,[],2));
137 | M_current = (M_current - min(M_current,[],2)) .* scale;
138 | % find threshold based on desired sparseness
139 | sorted_M_current = sort(M_current,2,'descend');
140 | t_ind = floor(size(Activity_dyn_t,2) * a);
141 | t_ind(t_ind<1) = 1;
142 | t = sorted_M_current(:,t_ind); % threshold for unit activations
143 | Activity_dyn_t = (M_current >=t);
144 | end
145 |
146 | % Second round of pattern completion, with fix threshold
147 | Activity_fix_t = zeros(P, M);
148 | for cycle = 1:ncycle
149 | if cycle <=1
150 | clamp = 1;
151 | else
152 | clamp = 0;
153 | end
154 | M_input = Activity_fix_t + Activity_dyn_t*clamp;
155 | M_current = M_input*W_N;
156 | Activity_fix_t = (M_current >= U); % U is the fixed threshold
157 | end
158 | N_patterns_reactivated(:,:,m)=Activity_fix_t;
159 | end
160 |
161 | % Seeding notebook with original notebook patterns for calculating
162 | % training error mediated by notebook (seeding notebook with student
163 | % input via Student's input to Notebook weights, once pattern completion
164 | % finishes, use the retrieved pattern to activate Student's output unit
165 | % via Notebook to Student's output weights.
166 |
167 | Activity_notebook_train = zeros(P, M);
168 | for cycle = 1:ncycle
169 | if cycle <=1
170 | clamp = 1;
171 | else
172 | clamp = 0;
173 | end
174 | seed_patt = x_t_input*W_S_N_Lin;
175 | M_input = Activity_notebook_train + (seed_patt*clamp);
176 | M_current = M_input*W_N;
177 | scale = 1.0 ./ (max(M_current,[],2) - min(M_current,[],2));
178 | M_current = (M_current - min(M_current,[],2)) .* scale;
179 | sorted_M_current = sort(M_current,2,'descend');
180 | t_ind = floor(size(Activity_notebook_train,2) * a);
181 | t_ind(t_ind<1) = 1;
182 | t = sorted_M_current(:,t_ind);
183 | Activity_notebook_train = (M_current >=t);
184 | end
185 | N_S_output_train = Activity_notebook_train*W_N_S_Lout;
186 | % Notebook training error
187 | delta_N_train = y_t_output - N_S_output_train;
188 | error_N_train = sum(delta_N_train.^2)/P;
189 | % Since notebook errors stay constant throughout training,
190 | % populating each epoch with the same error value
191 | error_N_train_vector = ones(nepoch,1)*error_N_train;
192 | N_train_error_all(r,:) = error_N_train_vector;
193 |
194 | % Notebook generalization error
195 | Activity_notebook_test = zeros(P_test, M);
196 | for cycle = 1:ncycle
197 | if cycle <=1
198 | clamp = 1;
199 | else
200 | clamp = 0;
201 | end
202 | seed_patt = x_t_input_test*W_S_N_Lin;
203 | M_input = Activity_notebook_test + (seed_patt*clamp);
204 | M_current = M_input*W_N;
205 | scale = 1.0 ./ (max(M_current,[],2) - min(M_current,[],2));
206 | M_current = (M_current - min(M_current,[],2)) .* scale;
207 | sorted_M_current = sort(M_current,2,'descend');
208 | t_ind = floor(size(Activity_notebook_test,2) * a);
209 | t_ind(t_ind<1) = 1;
210 | t = sorted_M_current(:,t_ind);
211 | Activity_notebook_test = (M_current >=t);
212 | end
213 | N_S_output_test = Activity_notebook_test*W_N_S_Lout;
214 | % Notebook test error
215 | delta_N_test = y_t_output_test - N_S_output_test;
216 | error_N_test = sum(delta_N_test.^2)/P_test;
217 | % populating each epoch with the same error value
218 | error_N_test_vector = ones(nepoch,1)*error_N_test;
219 | N_test_error_all(r,:) = error_N_test_vector;
220 |
221 |
222 | % N_patterns_reactivated_test = zeros(P_test,M,'logical');
223 | %% Student training through offline notebook reactivations at each epoch
224 | for m = 1:nepoch
225 |
226 | N_S_input = N_patterns_reactivated(:,:,m)*W_N_S_Lin; % notebook reactivated student input activity
227 | N_S_output = N_patterns_reactivated(:,:,m)*W_N_S_Lout; % notebook reactivated student output activity
228 | N_S_prediction = N_S_input*W_s; % student output prediction calculated by notebook reactivated input and student weights
229 | S_prediction = x_t_input*W_s; % student output prediction calculated by true training inputs and student weights
230 | S_prediction_test = x_t_input_test*W_s; % student output prediction calculated by true testing inputs and student weights
231 |
232 | % Train error
233 | delta_train = y_t_output - S_prediction;
234 | error_train = sum(delta_train.^2)/P;
235 | error_train_vector(m) = error_train;
236 |
237 | % Generalization error
238 | delta_test = y_t_output_test - S_prediction_test;
239 | error_test = sum(delta_test.^2)/P_test;
240 | error_test_vector(m) = error_test;
241 |
242 | % Gradient descent
243 | w_delta = N_S_input'*N_S_output - N_S_input'*N_S_input*W_s;
244 | W_s = W_s + learnrate*w_delta;
245 | end
246 |
247 | train_error_all(r,:) = error_train_vector;
248 | test_error_all(r,:) = error_test_vector;
249 |
250 | % Early stopping
251 | [min_v, min_p] = min(error_test_vector);
252 | train__error_early_stop = error_train_vector;
253 | train__error_early_stop (min_p+1:end) = error_train_vector (min_p);
254 | test_error_early_stop = error_test_vector;
255 | test_error_early_stop (min_p+1:end) = error_test_vector (min_p);
256 | train_error_early_stop_all(r,:) = train__error_early_stop;
257 | test_error_early_stop_all(r,:) = test_error_early_stop;
258 | end
259 |
260 | toc
261 |
262 |
263 | %% Plotting
264 | color_scheme = [137 152 193; 245 143 136]/255;
265 | line_w = 2;
266 | font_s = 12;
267 |
268 | % Without early stopping
269 | figure(1)
270 | hold on
271 | plot(1:nepoch,mean(train_error_all),'color',color_scheme(1,:),'LineWidth',line_w)
272 | plot(1:nepoch,mean(test_error_all),'color',color_scheme(2,:),'LineWidth',line_w)
273 | plot(1:nepoch,mean(N_train_error_all),'b--','LineWidth',2)
274 | plot(1:nepoch,mean(N_test_error_all),'r--','LineWidth',2)
275 |
276 | set(gca, 'FontSize', font_s)
277 | set(gca, 'FontSize', font_s)
278 | xlabel('Epoch','Color','k')
279 | ylabel('Error','Color','k')
280 | xlim([0 nepoch])
281 | ylim([0 2])
282 | set(gca,'linewidth',1)
283 | set(gcf,'position',[100,100,350,300])
284 |
285 | % Save plot
286 | % saveas(gcf,strcat('Errors_No_ES_','SNR_',num2str(SNR),'_',date,'.pdf'));
287 |
288 | % With early stopping
289 | figure(2)
290 | hold on
291 | plot(1:nepoch,mean(train_error_early_stop_all),'color',color_scheme(1,:),'LineWidth',line_w)
292 | plot(1:nepoch,mean(test_error_early_stop_all),'color',color_scheme(2,:),'LineWidth',line_w)
293 | plot(1:nepoch,mean(N_train_error_all),'b--','LineWidth',line_w)
294 | plot(1:nepoch,mean(N_test_error_all),'r--','LineWidth',line_w)
295 |
296 | set(gca, 'FontSize', font_s)
297 | set(gca, 'FontSize', font_s)
298 | xlabel('Epoch','Color','k')
299 | ylabel('Error','Color','k')
300 | xlim([0 nepoch])
301 | ylim([0 2])
302 | set(gca,'linewidth',1)
303 | set(gcf,'position',[600,100,350,300])
304 |
305 | % Save plot
306 | % saveas(gcf,strcat('Errors_ES_','SNR_',num2str(SNR),'_',date,'.pdf'));
307 |
--------------------------------------------------------------------------------
/Fig_2/Fig_2g_h.m:
--------------------------------------------------------------------------------
1 | % code for Student-Teacher-Notebook framework,analytical curves
2 | % Weinan Sun 10-10-2021
3 | close all
4 | clear all
5 |
6 | nepoch = 2000;
7 | learnrate = 0.005;
8 | N_x_t = 100;
9 | N_y_t = 1;
10 | P=100;
11 | M = 5000; %num of units in notebook
12 |
13 | Et_big_early_stop = [];
14 | Eg_big_early_stop = [];
15 | Et_big = [];
16 | Eg_big = [];
17 |
18 | %SNR values sampled from log2 space
19 | SNR_log_interval = -4:0.5:4;
20 | SNR_vec =2.^SNR_log_interval;
21 |
22 |
23 | Notebook_Train = (P-1)/(M-1); % Analytial solution for notebook training error, see supplementary material for derivations.
24 |
25 | for count = 1:size(SNR_vec,2)
26 | disp(count)
27 | SNR = SNR_vec(count);
28 | Eg = [];
29 | Et = [];
30 | lr = learnrate;
31 |
32 | %use normalized variances
33 | if SNR == inf
34 | variance_w = 1;
35 | variance_e = 0;
36 | else
37 | variance_w = SNR/(SNR + 1);
38 | variance_e = 1/(SNR + 1);
39 | end
40 |
41 | alpha = P/N_x_t; % number of examples divided by input dimension
42 |
43 | % Analytical curves for training and testing errors, see supplementary
44 | % material for derivations
45 |
46 | for t = 1:1:nepoch
47 |
48 | train = @(lam) ( ( ( ((alpha^0.5+1).^2 - lam) .* (lam - (alpha^0.5-1).^2) ).^0.5) ./ (lam*2*pi) ).* ( lam.*variance_w + variance_e ).*exp(-2*lam.*t./(1./lr)) ;
49 | Et = [Et (1/alpha)*(integral(train,(alpha^0.5-1)^2,(alpha^0.5+1)^2) + (alpha<1)*(1 - alpha)* variance_e ) + (1-1/alpha)*variance_e];
50 |
51 | test = @(lam) ( ( ( ((alpha^0.5+1).^2 - lam) .* (lam - (alpha^0.5-1).^2) ).^0.5) ./ (lam*2*pi) ).*(exp(-2*lam*t/(1/lr)) + ((1-exp(-lam*t/(1/lr))).^2)./(lam*SNR));
52 | Eg = [Eg variance_w*(integral(test,(alpha^0.5-1)^2,(alpha^0.5+1)^2) + (alpha<1)* (1 - alpha) + 1/SNR)];
53 |
54 | end
55 |
56 | % Early stopping curves
57 | [mm, pp] = min(Eg);
58 | Eg_early_stop = Eg;
59 | Eg_early_stop(pp+1:end) = Eg(pp);
60 |
61 | Et_early_stop = Et;
62 | Et_early_stop(pp+1:end) = Et(pp);
63 |
64 | % Pick better training error value and convert to memory generalization
65 | % scores
66 | better_train_no_early_stop = min(Et,ones(1,nepoch)*Notebook_Train);
67 | control_curve = (Et(1) - better_train_no_early_stop)/Et(1);
68 | lesion_curve = (Et(1) - Et)/Et(1);
69 |
70 |
71 | better_train_yes_early_stop = min(Et_early_stop,ones(1,nepoch)*Notebook_Train);
72 | control_curve_early_stop = (Et_early_stop(1) - better_train_yes_early_stop)/Et_early_stop(1);
73 | lesion_curve_early_stop = (Et_early_stop(1) - Et_early_stop)/Et_early_stop(1);
74 |
75 |
76 | control_Eg_curve = (Eg(1) - Eg)/Eg(1);
77 | control_Eg_curve_early_stop = (Eg_early_stop(1) - Eg_early_stop)/Eg_early_stop(1);
78 |
79 |
80 |
81 | figure(1)
82 |
83 | hold on;
84 | plot(1:1:nepoch,Et,'-','color',[0 0 1 count/20],'LineWidth',2)
85 | plot(1:1:nepoch,Eg,'-','color',[1 0 0 count/20],'LineWidth',2)
86 |
87 | set(gca, 'FontSize', 12)
88 | set(gca, 'FontSize', 12)
89 | xlabel('Epoch')
90 | ylabel('Error')
91 | set(gca,'linewidth',1.5)
92 | ylim([0 2.5])
93 |
94 | figure(2)
95 |
96 | hold on;
97 | plot(1:1:nepoch,Et_early_stop,'-','color',[0 0 1 count/20],'LineWidth',2)
98 | plot(1:1:nepoch,Eg_early_stop,'-','color',[1 0 0 count/20],'LineWidth',2)
99 |
100 | set(gca, 'FontSize',12)
101 | set(gca, 'FontSize', 12)
102 | xlabel('Epoch')
103 | ylabel('Error')
104 | set(gca,'linewidth',1.5)
105 | ylim([0 2.5])
106 |
107 |
108 | end
109 |
110 | %% save figures
111 | % figure(1)
112 | % set(gcf,'position',[100,100,350,290])
113 | % saveas(gcf,strcat('Fig_2g','.pdf'));
114 | % figure(2)
115 | % set(gcf,'position',[100,100,350,290])
116 | % saveas(gcf,strcat('Fig_2h','.pdf'));
117 |
118 |
119 |
120 |
--------------------------------------------------------------------------------
/Fig_2/Fig_2k_l.m:
--------------------------------------------------------------------------------
1 | %% Clear workspace and initialize random number generator
2 |
3 | clear all;
4 | RandStream.setGlobalStream(RandStream('mt19937ar','seed',sum(100*clock)));
5 |
6 | %% Construct an ensemble of teacher-generated student data.
7 |
8 | num_teachers = 1;
9 | N = 1000;
10 | P = 1000;
11 | std_X = sqrt(1/N);
12 |
13 | SNR_vec = -4:0.02:4;
14 | SNRs = 2.^SNR_vec;
15 |
16 | num_SNRs = length(SNRs);
17 |
18 | wbars = NaN(N,num_SNRs,num_teachers);
19 | Xs = NaN(P,N,num_SNRs,num_teachers);
20 | etas = NaN(P,num_SNRs,num_teachers);
21 | ys = NaN(P,num_SNRs,num_teachers);
22 | for s = 1:num_SNRs;
23 | S = SNRs(s);
24 | std_noise = sqrt(1/(1+S));
25 | std_weights = sqrt(S/(1+S));
26 | for t = 1:num_teachers;
27 | wbars(:,s,t) = std_weights*randn(N,1);
28 |
29 | Xs(:,:,s,t) = std_X*randn(P,N);
30 | etas(:,s,t) = std_noise*randn(P,1);
31 | ys(:,s,t) = Xs(:,:,s,t)*wbars(:,s,t)+etas(:,s,t);
32 | end;
33 | end;
34 |
35 | var_ys = squeeze(var(ys,0,1));
36 | var_ws = squeeze(var(wbars,0,1));
37 | var_noises = squeeze(var(etas,0,1));
38 | empirical_SNRs = (var_ys-var_noises)./var_noises;
39 |
40 | mean_empirical_SNRs = mean(empirical_SNRs,2);
41 | std_empirical_SNRs = std(empirical_SNRs,0,2);
42 |
43 | %% Plot basic stuff
44 |
45 | %close all;
46 |
47 | % figure(1)
48 | % hold on;
49 | % errorbar(SNRs,mean_empirical_SNRs,std_empirical_SNRs,'LineWidth',3)
50 | % plot(SNRs, empirical_SNRs)
51 | % hold off;
52 | %set(gca,'XScale','log')
53 | %set(gca,'YScale','log')
54 |
55 | %% Evaluate the log-likelihood function for each dataset
56 | % % DON'T USE THIS VERSION FOR LARGE N OR P.
57 | %
58 | % tic;
59 | % min_SNR = 0.02;
60 | % max_SNR = 15;
61 | % step_SNR = 0.02;
62 | % SNR_vector = min_SNR:step_SNR:max_SNR;
63 | % num_step_SNR = length(SNR_vector);
64 | %
65 | % log_likelihood_function = NaN(num_step_SNR,num_SNRs,num_teachers);
66 | % max_log_likelihood = NaN(num_SNRs,num_teachers);
67 | % estimated_SNRs = NaN(num_SNRs,num_teachers);
68 | % for s = 1:num_SNRs;
69 | % strcat('On SNR:',num2str(s))
70 | % toc
71 | % for t = 1:num_teachers;
72 | % for i = 1:num_step_SNR;
73 | % %inv_C_matrix = (1+SNR_vector(i))*(eye(P)-Xs(:,:,s,t)*inv(eye(N)/SNR_vector(i)+Xs(:,:,s,t)'*Xs(:,:,s,t))*Xs(:,:,s,t)');
74 | % inv_C_matrix = (1+SNR_vector(i))*inv(eye(P)+SNR_vector(i)*Xs(:,:,s,t)*Xs(:,:,s,t)');
75 | % log_likelihood_function(i,s,t) = 1/2*log(det(inv_C_matrix)) - 1/2*ys(:,s,t)'*inv_C_matrix*ys(:,s,t);
76 | % %log_likelihood_function(i,s,t) = -1/2*ys(:,s,t)'*inv_C_matrix*ys(:,s,t);
77 | % %log_likelihood_function(i,s,t) = log(det(inv_C_matrix));
78 | % end;
79 | % [max_log_likelihood(s,t) tmp_idx] = max(log_likelihood_function(:,s,t));
80 | % estimated_SNRs(s,t) = SNR_vector(tmp_idx);
81 | % end;
82 | % end;
83 |
84 |
85 | %% Plot the log_likelihood functions
86 |
87 | %close all;
88 | %
89 | % figure(2)
90 | % for s = 1:num_SNRs;
91 | % for t = 1:2;
92 | % subplot(num_SNRs,2,(s-1)*2+t)
93 | % hold on;
94 | % plot(SNR_vector,log_likelihood_function(:,s,t));
95 | % plot(empirical_SNRs(s,t),max_log_likelihood(s,t),'ob')
96 | % plot(estimated_SNRs(s,t),max_log_likelihood(s,t),'or','LineWidth',2)
97 | % end;
98 | % end;
99 | %
100 | % figure(10)
101 | % plot(empirical_SNRs(:),estimated_SNRs(:),'ok')
102 |
103 | %% Do an SVD decomposition on the data matrix.
104 |
105 |
106 | Us = NaN(P,P,num_SNRs,num_teachers);
107 | Ss = NaN(P,N,num_SNRs,num_teachers);
108 | Vs = NaN(N,N,num_SNRs,num_teachers);
109 | Lambdas = NaN(P,P,num_SNRs,num_teachers);
110 | inv_Lambdas = NaN(P,P,num_SNRs,num_teachers);
111 | parfor s = 1:num_SNRs;
112 | strcat('On SNR:',num2str(s))
113 |
114 | for t = 1:num_teachers;
115 | [Us(:,:,s,t) Ss(:,:,s,t) Vs(:,:,s,t)] = svd(Xs(:,:,s,t));
116 | Lambdas(:,:,s,t) = Ss(:,:,s,t)*Ss(:,:,s,t)';
117 | inv_Lambdas(:,:,s,t) = pinv(Lambdas(:,:,s,t));
118 | end;
119 | end;
120 |
121 | %% Evaluate the log-likelihood function for each dataset using the SVD decomposition
122 | % More numerically efficient version.
123 |
124 | tic;
125 | num_components = min(N,P);
126 |
127 |
128 | SNR_vector = 2.^(-6:0.02:6);
129 | num_step_SNR = length(SNR_vector);
130 |
131 | log_likelihood_function = NaN(num_step_SNR,num_SNRs,num_teachers);
132 | max_log_likelihood = NaN(num_SNRs,num_teachers);
133 | estimated_SNRs = NaN(num_SNRs,num_teachers);
134 | for s = 1:num_SNRs;
135 | strcat('On SNR:',num2str(s))
136 | toc
137 | for t = 1:num_teachers;
138 | y_modes = Us(:,:,s,t)'*ys(:,s,t);
139 | for i = 1:num_step_SNR;
140 | inv_C_matrix_modes = (1+SNR_vector(i))*eye(P);
141 | A1 = (P-num_components)/2*log(1+SNR_vector(i));
142 | A2 = 0;
143 | for j = 1:num_components;
144 | inv_C_matrix_modes(j,j) = (1+SNR_vector(i))/(1+SNR_vector(i)*Lambdas(j,j,s,t));
145 | A1 = A1 + 1/2*log(inv_C_matrix_modes(j,j));
146 | A2 = A2 - 1/2*y_modes(j)^2*inv_C_matrix_modes(j,j);
147 | end;
148 | if gt(P,num_components)
149 | for j = 1:P;
150 | A2 = A2 - 1/2*y_modes(j)^2*inv_C_matrix_modes(j,j);
151 | end;
152 | end;
153 | %inv_C_matrix_modes = (1+SNR_vector(i))*inv(eye(P)+SNR_vector(i)*Lambdas(:,:,s,t));
154 | %log_likelihood_function(i,s,t) = 1/2*log(det(inv_C_matrix_modes)) - 1/2*y_modes'*inv_C_matrix_modes*y_modes;
155 | log_likelihood_function(i,s,t) = A1 + A2;
156 | %log_likelihood_function(i,s,t) = -1/2*ys(:,s,t)'*inv_C_matrix*ys(:,s,t);
157 | %log_likelihood_function(i,s,t) = log(det(inv_C_matrix));
158 | end;
159 | [max_log_likelihood(s,t) tmp_idx] = max(log_likelihood_function(:,s,t));
160 | estimated_SNRs(s,t) = SNR_vector(tmp_idx);
161 | end;
162 | end;
163 |
164 | %% Plot the log_likelihood functions
165 |
166 | %close all;
167 |
168 | % figure(3)
169 | % for s = 1:num_SNRs;
170 | % for t = 1:2;
171 | % subplot(num_SNRs,2,(s-1)*2+t)
172 | % hold on;
173 | %
174 | % plot(SNR_vector,log_likelihood_function(:,s,t));
175 | %
176 | % plot(SNR_vector(s),max_log_likelihood(s,t),'ob')
177 | % plot(estimated_SNRs(s,t),max_log_likelihood(s,t),'or','LineWidth',2)
178 | % end;
179 | % end;
180 |
181 | figure(400)
182 | plot(log2(SNRs),log2(estimated_SNRs(:,1)),'ok')
183 |
184 |
185 | set(gcf,'position',[500 500 300 300])
186 | xlim([-6,6])
187 | ylim([-6,6])
188 |
189 | figure(500)
190 | hold on
191 | plot(SNR_vector,log_likelihood_function(:,end))
192 | plot(SNRs(end),max_log_likelihood(end),'ob')
193 | plot(estimated_SNRs(end),max_log_likelihood(end),'or','LineWidth',2)
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
--------------------------------------------------------------------------------
/Fig_2/task/TSN.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import numpy as np
3 | import torch
4 | from torch import nn, optim
5 | import matplotlib.pyplot as plt
6 |
7 |
8 | class TSN():
9 |
10 | def __init__(self, input_dim, output_dim, P, P_test, SNR, lr = 0.05, rand_seed = 101):
11 | self.N = input_dim
12 | self.output_dim = output_dim
13 | self.P = P
14 | self.P_test = P_test
15 | self.variance_w = snr_to_var(SNR)[0]
16 | self.variance_e = snr_to_var(SNR)[1]
17 | self.lr = lr
18 | torch.manual_seed(rand_seed)
19 | noise_train = torch.normal(0, self.variance_e**0.5, size = [self.P,1])
20 | noise_test = torch.normal(0, self.variance_e**0.5, size = [self.P_test,1])
21 | W_t = torch.normal(0, self.variance_w**0.5, size=(self.N, 1))
22 | self.train_x = torch.normal(0, (1/self.N)**0.5, size=(self.P, self.N))
23 | self.train_y = torch.matmul(self.train_x, W_t) + noise_train
24 | self.test_x = torch.normal(0, (1/self.N)**0.5, size=(self.P_test, self.N))
25 | self.test_y = torch.matmul(self.test_x, W_t) + noise_test
26 |
27 | def training_loop(self, nepoch = 1000, reg_strength = 0):
28 |
29 | model = nn.Sequential(nn.Linear(self.N,self.output_dim,bias=False))
30 | optimizer = optim.SGD(model.parameters(), lr=self.lr*(self.N/2), weight_decay = reg_strength) # to match the learnrate in the matlab implementation.
31 | optimizer.zero_grad()
32 |
33 | with torch.no_grad():
34 | list(model.parameters())[0].zero_()
35 |
36 | Et = np.zeros((nepoch,1))
37 | Eg = np.zeros((nepoch,1))
38 | print_iter = 200
39 |
40 | for i in range(nepoch):
41 | optimizer.zero_grad()
42 | train_error = criterion(self.train_y, model(self.train_x))
43 | train_error.backward()
44 | optimizer.step()
45 | Et[i] = train_error.detach().numpy()
46 | with torch.no_grad():
47 | test_error = criterion(self.test_y, model(self.test_x))
48 | Eg[i] = test_error.numpy()
49 | # if i%print_iter == 0:
50 | # print('Et '+ str(Et[i].item()),'Eg '+ str(Eg[i].item()),str(i*100/nepoch) + '% Finished')
51 |
52 | return Et, Eg
53 |
54 | class TSN_validation():
55 |
56 | def __init__(self, input_dim, output_dim, P, P_test, SNR, lr = 0.05, rand_seed = 101):
57 | self.N = input_dim
58 | self.output_dim = output_dim
59 | self.P = P
60 | self.P_test = P_test
61 | self.variance_w = snr_to_var(SNR)[0]
62 | self.variance_e = snr_to_var(SNR)[1]
63 | self.lr = 0.05
64 | torch.manual_seed(rand_seed)
65 | noise_train = torch.normal(0, self.variance_e**0.5, size = [self.P,1])
66 | noise_test = torch.normal(0, self.variance_e**0.5, size = [self.P_test,1])
67 | W_t = torch.normal(0, self.variance_w**0.5, size=(self.N, 1))
68 | self.train_x = torch.normal(0, (1/self.N)**0.5, size=(self.P, self.N))
69 | self.train_y = torch.matmul(self.train_x, W_t) + noise_train
70 | self.test_x = torch.normal(0, (1/self.N)**0.5, size=(self.P_test, self.N))
71 | self.test_y = torch.matmul(self.test_x, W_t) + noise_test
72 |
73 | def training_loop(self, nepoch = 2000, reg_strength = 1e-2):
74 |
75 | model = nn.Sequential(nn.Linear(self.N,self.output_dim,bias=False))
76 | optimizer = optim.SGD(model.parameters(), lr=self.lr*(self.N/2), weight_decay= reg_strength) # to match the learnrate in the matlab implementation.
77 | optimizer.zero_grad()
78 |
79 | with torch.no_grad():
80 | list(model.parameters())[0].zero_()
81 |
82 | Et = np.zeros((nepoch,1))
83 | Eg = np.zeros((nepoch,1))
84 | print_iter = 200
85 |
86 | for i in range(nepoch):
87 | optimizer.zero_grad()
88 | train_error = criterion(self.train_y, model(self.train_x))
89 | train_error.backward()
90 | optimizer.step()
91 | Et[i] = train_error.detach().numpy()
92 | with torch.no_grad():
93 | test_error = criterion(self.test_y, model(self.test_x))
94 | Eg[i] = test_error.numpy()
95 | # if i%print_iter == 0:
96 | # print('Et '+ str(Et[i].item()),'Eg '+ str(Eg[i].item()),str(i*100/nepoch) + '% Finished')
97 |
98 | return Et, Eg
99 |
100 | def snr_to_var(SNR):
101 | if SNR == np.inf:
102 | variance_w = 1
103 | variance_e = 0
104 | else:
105 | variance_w = SNR/(SNR + 1)
106 | variance_e = 1/(SNR + 1)
107 | return variance_w, variance_e
108 |
109 | def criterion(y, y_hat):
110 | return (y - y_hat).pow(2).mean()
111 |
--------------------------------------------------------------------------------
/Fig_2/task/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuroai/Go-CLS_v2/a5f90b73c61088f0a154414af05ea1e54122c06d/Fig_2/task/__init__.py
--------------------------------------------------------------------------------
/Fig_3/Fig_3a_to_d_f.m:
--------------------------------------------------------------------------------
1 | % code for Student-Teacher-Notebook framework, analytical curves
2 | % Weinan Sun 10-10-2021
3 | close all
4 | clear all
5 |
6 | nepoch = 2000;
7 | learnrate = 0.005;
8 | N_x_t = 100;
9 | N_y_t = 1;
10 | P=100;
11 | M = 5000; %num of units in notebook
12 |
13 | Et_big_early_stop = [];
14 | Eg_big_early_stop = [];
15 | Et_big = [];
16 | Eg_big = [];
17 |
18 | Et_lesion_remote = [];
19 | Eg_lesion_remote = [];
20 |
21 | %SNR values sampled from log2 space
22 | SNR_log_interval = -4:0.5:4;
23 | SNR_vec =2.^SNR_log_interval;
24 |
25 |
26 | Notebook_Train = (P-1)/(M-1); % Analytial solution for notebook training error, see supplementary material for derivations.
27 |
28 | for count = 1:size(SNR_vec,2)
29 | disp(count)
30 | SNR = SNR_vec(count);
31 | Eg = [];
32 | Et = [];
33 | lr = learnrate;
34 |
35 | %use normalized variances
36 | if SNR == inf
37 | variance_w = 1;
38 | variance_e = 0;
39 | else
40 | variance_w = SNR/(SNR + 1);
41 | variance_e = 1/(SNR + 1);
42 | end
43 |
44 | alpha = P/N_x_t; % number of examples divided by input dimension
45 |
46 | % Analytical curves for training and testing errors, see supplementary
47 | % material for derivations
48 |
49 | for t = 1:1:nepoch
50 |
51 | train = @(lam) ( ( ( ((alpha^0.5+1).^2 - lam) .* (lam - (alpha^0.5-1).^2) ).^0.5) ./ (lam*2*pi) ).* ( lam.*variance_w + variance_e ).*exp(-2*lam.*t./(1./lr)) ;
52 | Et = [Et (1/alpha)*(integral(train,(alpha^0.5-1)^2,(alpha^0.5+1)^2) + (alpha<1)*(1 - alpha)* variance_e ) + (1-1/alpha)*variance_e];
53 |
54 | test = @(lam) ( ( ( ((alpha^0.5+1).^2 - lam) .* (lam - (alpha^0.5-1).^2) ).^0.5) ./ (lam*2*pi) ).*(exp(-2*lam*t/(1/lr)) + ((1-exp(-lam*t/(1/lr))).^2)./(lam*SNR));
55 | Eg = [Eg variance_w*(integral(test,(alpha^0.5-1)^2,(alpha^0.5+1)^2) + (alpha<1)* (1 - alpha) + 1/SNR)];
56 |
57 | end
58 |
59 | % Early stopping curves
60 | [mm, pp] = min(Eg);
61 | Eg_early_stop = Eg;
62 | Eg_early_stop(pp+1:end) = Eg(pp);
63 |
64 | Et_early_stop = Et;
65 | Et_early_stop(pp+1:end) = Et(pp);
66 |
67 | % Pick better training error value and convert to memory generalization
68 | % scores
69 | better_train_no_early_stop = min(Et,ones(1,nepoch)*Notebook_Train);
70 | control_curve = (Et(1) - better_train_no_early_stop)/Et(1);
71 | lesion_curve = (Et(1) - Et)/Et(1);
72 |
73 |
74 | better_train_yes_early_stop = min(Et_early_stop,ones(1,nepoch)*Notebook_Train);
75 | control_curve_early_stop = (Et_early_stop(1) - better_train_yes_early_stop)/Et_early_stop(1);
76 | lesion_curve_early_stop = (Et_early_stop(1) - Et_early_stop)/Et_early_stop(1);
77 |
78 |
79 | control_Eg_curve = (Eg(1) - Eg)/Eg(1);
80 | control_Eg_curve_early_stop = (Eg_early_stop(1) - Eg_early_stop)/Eg_early_stop(1);
81 |
82 |
83 | Et_lesion_remote = [Et_lesion_remote lesion_curve_early_stop(end)];
84 | Eg_lesion_remote = [Eg_lesion_remote control_Eg_curve_early_stop(end)];
85 |
86 | figure(3)
87 |
88 | hold on;
89 | plot(1:1:nepoch,control_curve,'k-','LineWidth',3)
90 | plot(1:1:nepoch,lesion_curve,'c-','LineWidth',3)
91 | % plot([5 1800],[control_curve(1) control_curve(1800)],'o-','color',[0 0 1])
92 | plot([5 1800],[lesion_curve(1) lesion_curve(1800)],'o-','color',[1 0 0]) % pick a early and a late point as recent and remote memory
93 |
94 | set(gca, 'FontSize', 12)
95 | set(gca, 'FontSize', 12)
96 | xlabel('Epoch', 'FontSize',12)
97 | ylabel('Memory Score', 'FontSize',12)
98 | xlim([0 nepoch])
99 | ylim([0 1])
100 | set(gca,'linewidth',1.5)
101 |
102 |
103 | figure(4)
104 |
105 | hold on;
106 | plot(1:1:nepoch,control_curve_early_stop,'k-','LineWidth',3)
107 | plot(1:1:nepoch,lesion_curve_early_stop,'c-','LineWidth',3)
108 | % plot([5 1800],[control_curve_early_stop(1) control_curve_early_stop(1800)],'o-','color',[0 0 1])
109 | plot([5 1800],[lesion_curve_early_stop(1) lesion_curve_early_stop(1800)],'o-','color',[1 0 0])
110 | ylim([0 1])
111 |
112 | set(gca, 'FontSize', 12)
113 | set(gca, 'FontSize', 12)
114 | xlabel('Epoch', 'FontSize',12)
115 | ylabel('Memory Score', 'FontSize',12)
116 | xlim([0 nepoch])
117 | ylim([0 1])
118 | set(gca,'linewidth',1.5)
119 |
120 | figure(5)
121 |
122 | hold on;
123 | plot(1:1:nepoch,control_Eg_curve,'g-','LineWidth',3)
124 | % plot([5 1800],[control_Eg_curve(1) control_Eg_curve(1800)],'o-','color',[0 0 1])
125 | ylim([0 1])
126 |
127 | set(gca, 'FontSize', 12)
128 | set(gca, 'FontSize', 12)
129 | xlabel('Epoch', 'FontSize',12)
130 | ylabel('Generalization Score', 'FontSize',12)
131 | xlim([0 nepoch])
132 | ylim([-1 1])
133 | set(gca,'linewidth',1.5)
134 |
135 | figure(6)
136 | hold on;
137 | plot(1:1:nepoch,control_Eg_curve_early_stop,'g-','LineWidth',3)
138 | % plot([5 1800],[control_Eg_curve_early_stop(1) control_Eg_curve_early_stop(1800)],'o-','color',[0 0 1])
139 | ylim([-1 1])
140 | set(gca, 'FontSize', 12)
141 | set(gca, 'FontSize', 12)
142 | xlabel('Epoch', 'FontSize',12)
143 | ylabel('Generalization Score', 'FontSize',12)
144 | xlim([0 nepoch])
145 | ylim([-1 1])
146 | set(gca,'linewidth',1.5)
147 |
148 | end
149 |
150 | figure(7)
151 | plot(Et_lesion_remote,Eg_lesion_remote,'ko')
152 | xlabel('Memory Score (Notebook lesioned)', 'FontSize',12)
153 | ylabel('Generalization Score', 'FontSize',12)
154 |
155 | %% save figures
156 |
157 | % figure(3)
158 | % set(gcf,'position',[100,100,350,290])
159 | % saveas(gcf,strcat('Fig_3a','.pdf'));
160 | % figure(4)
161 | % set(gcf,'position',[100,100,350,290])
162 | % saveas(gcf,strcat('Fig_3b','.pdf'));
163 | % figure(5)
164 | % set(gcf,'position',[100,100,350,290])
165 | % saveas(gcf,strcat('Fig_3c','.pdf'));
166 | % figure(6)
167 | % set(gcf,'position',[100,100,350,290])
168 | % saveas(gcf,strcat('Fig_3d','.pdf'));
169 | % figure(7)
170 | % set(gcf,'position',[100,100,350,290])
171 | % saveas(gcf,strcat('Fig_3f','.pdf'));
172 |
173 |
174 |
--------------------------------------------------------------------------------
/Fig_3/Fig_3e.m:
--------------------------------------------------------------------------------
1 | % code for Student-Teacher-Notebook framework
2 | % diversity of amnesia curves
3 |
4 | close all
5 | clear all
6 |
7 | saveplot = false;
8 | nepoch = 2000;
9 | learnrate = 0.005;
10 | N_x_t = 100;
11 | N_y_t = 1;
12 | P=100;
13 | M = 5000; %num of units in notebook
14 |
15 | Et_big_early_stop = [];
16 | Eg_big_early_stop = [];
17 | Et_big = [];
18 | Eg_big = [];
19 |
20 | SNR_vec = [0.01 0.1 0.3 1 8 inf];
21 |
22 | Notebook_Train = (P-1)/(M-1); % Analytial solution for notebook training error
23 |
24 | % Simulation 1, varying SNR
25 |
26 | for count = 1:size(SNR_vec,2)
27 |
28 | SNR = SNR_vec(count);
29 | Eg = [];
30 | Et = [];
31 | lr = learnrate;
32 | %use normalized variances
33 | if SNR == inf
34 | variance_w = 1;
35 | variance_e = 0;
36 | else
37 | variance_w = SNR/(SNR + 1);
38 | variance_e = 1/(SNR + 1);
39 | end
40 |
41 | alpha = P/N_x_t; % number of examples divided by input dimension
42 |
43 | % Analytical curves for training and testing errors, see supplementary
44 | % material for derivations
45 | for t = 1:1:nepoch
46 |
47 | train = @(lam) ( ( ( ((alpha^0.5+1).^2 - lam) .* (lam - (alpha^0.5-1).^2) ).^0.5) ./ (lam*2*pi) ).* ( lam.*variance_w + variance_e ).*exp(-2*lam.*t./(1./lr)) ;
48 | Et = [Et (1/alpha)*(integral(train,(alpha^0.5-1)^2,(alpha^0.5+1)^2) + (alpha<1)*(1 - alpha)* variance_e ) + (1-1/alpha)*variance_e];
49 |
50 | test = @(lam) ( ( ( ((alpha^0.5+1).^2 - lam) .* (lam - (alpha^0.5-1).^2) ).^0.5) ./ (lam*2*pi) ).*(exp(-2*lam*t/(1/lr)) + ((1-exp(-lam*t/(1/lr))).^2)./(lam*SNR));
51 | Eg = [Eg variance_w*(integral(test,(alpha^0.5-1)^2,(alpha^0.5+1)^2) + (alpha<1)* (1 - alpha) + 1/SNR)];
52 |
53 | end
54 |
55 | % Early stopping curves
56 | [min_Eg, pos] = min(Eg);
57 | Eg_ES = Eg;
58 | Eg_ES(pos+1:end) = Eg(pos);
59 |
60 | Et_ES = Et;
61 | Et_ES(pos+1:end) = Et(pos);
62 |
63 | % Pick better training error and create memory and generalization
64 | % scores
65 | better_Et = min(Et,ones(1,nepoch)*Notebook_Train);
66 | control_memory_score = (Et(1) - better_Et)/Et(1);
67 | lesion_memory_score = (Et(1) - Et)/Et(1);
68 |
69 | %with early stopping
70 | better_Et_ES = min(Et_ES,ones(1,nepoch)*Notebook_Train);
71 | control_memory_score_ES = (Et_ES(1) - better_Et_ES)/Et_ES(1);
72 | lesion_memory_score_ES = (Et_ES(1) - Et_ES)/Et_ES(1);
73 |
74 |
75 | control_generalization_score = (Eg(1) - Eg)/Eg(1);
76 | control_generalization_score_ES = (Eg_ES(1) - Eg_ES)/Eg_ES(1);
77 |
78 | figure(1)
79 |
80 | hold on;
81 | if SNR == inf
82 | % picking an early and late epoch as "recent" and 'remote'
83 | plot([1 1800],[control_memory_score_ES(1) control_memory_score_ES(1800)],'o-','color',[0 0 0])
84 | end
85 | if SNR ~= inf
86 | plot([1 1800],[lesion_memory_score_ES(1) lesion_memory_score_ES(1800)],'o-','color',[0 0 0])
87 | end
88 | set(gca, 'FontSize', 12)
89 | set(gca, 'FontSize', 12)
90 | xlabel('Epoch', 'FontSize',12)
91 | ylabel('Memory Score', 'FontSize',12)
92 | xlim([-300 nepoch + 100])
93 | ylim([-0.1 1.15])
94 | set(gca,'linewidth',1.5)
95 |
96 | end
97 |
98 | % Simulation 2, varying prior consolidation
99 |
100 | for start_epoch = [8 20 40 60 2000] % amount of prior learning epochs
101 | SNR = 50;
102 | P = 300;
103 | Eg = [];
104 | Et = [];
105 | lr = learnrate;
106 |
107 | if SNR == inf
108 | variance_w = 1;
109 | variance_e = 0;
110 | else
111 | variance_w = SNR/(SNR + 1);
112 | variance_e = 1/(SNR + 1);
113 | end
114 |
115 | alpha = P/N_x_t; % number of examples divided by input dimension
116 |
117 |
118 | % Analytical curves for training and testing errors
119 | for t = 1:1:(nepoch + start_epoch)
120 |
121 | train = @(lam) ( ( ( ((alpha^0.5+1).^2 - lam) .* (lam - (alpha^0.5-1).^2) ).^0.5) ./ (lam*2*pi) ).* ( lam.*variance_w + variance_e ).*exp(-2*lam.*t./(1./lr)) ;
122 | Et = [Et (1/alpha)*(integral(train,(alpha^0.5-1)^2,(alpha^0.5+1)^2) + (alpha<1)*(1 - alpha)* variance_e ) + (1-1/alpha)*variance_e];
123 |
124 | test = @(lam) ( ( ( ((alpha^0.5+1).^2 - lam) .* (lam - (alpha^0.5-1).^2) ).^0.5) ./ (lam*2*pi) ).*(exp(-2*lam*t/(1/lr)) + ((1-exp(-lam*t/(1/lr))).^2)./(lam*SNR));
125 | Eg = [Eg variance_w*(integral(test,(alpha^0.5-1)^2,(alpha^0.5+1)^2) + (alpha<1)* (1 - alpha) + 1/SNR)];
126 |
127 | end
128 |
129 | % Early stopping curves
130 | [min_Eg, pos] = min(Eg);
131 | Eg_ES = Eg;
132 | Eg_ES(pos+1:end) = Eg(pos);
133 |
134 | Et_ES = Et;
135 | Et_ES(pos+1:end) = Et(pos);
136 |
137 | % Pick better training error value and create memory and generalization
138 | % scores
139 | better_Et = min(Et,ones(1,nepoch + start_epoch)*Notebook_Train);
140 | control_memory_score = (Et(1) - better_Et)/Et(1);
141 | lesion_memory_score = (Et(1) - Et)/Et(1);
142 |
143 |
144 | better_Et_ES = min(Et_ES,ones(1,nepoch + start_epoch)*Notebook_Train);
145 | control_memory_score_ES = (Et_ES(1) - better_Et_ES)/Et_ES(1);
146 | lesion_memory_score_ES = (Et_ES(1) - Et_ES)/Et_ES(1);
147 |
148 |
149 | control_generalization_score = (Eg(1) - Eg)/Eg(1);
150 | control_generalization_score_ES = (Eg_ES(1) - Eg_ES)/Eg_ES(1); % this assumes only using student for generalization
151 | lesion_Eg_curve_early_stop = (Eg_ES(1) - Eg_ES)/Eg_ES(1);
152 |
153 |
154 | figure(1)
155 |
156 | hold on;
157 |
158 | %"Recent" memory performance with prior rule-consistent consolidation can
159 | %be captured by the generalization error. That is, given certain amount
160 | %of prior learning, how well can the network predict new examples?
161 | %"Remote" memory performance is the training error at convergence.
162 |
163 | plot([1 1800],[lesion_Eg_curve_early_stop(start_epoch) lesion_memory_score_ES(end)],'o--','color',[0 0 0])
164 |
165 | set(gca, 'FontSize', 12)
166 | set(gca, 'FontSize', 12)
167 | xlabel('Epoch', 'FontSize',12)
168 | ylabel('Memory Score', 'FontSize',12)
169 | xlim([-300 nepoch + 100])
170 | ylim([-0.1 1.15])
171 | set(gca,'linewidth',1.5)
172 |
173 | end
174 |
175 |
176 | f = gca;
177 | f.XTick = [1 1800];
178 | f.XTickLabel = [{'Recent'} {'Remote'}];
179 | set(gcf,'position',[100,100,400,600])
180 |
181 |
182 |
183 |
--------------------------------------------------------------------------------
/Fig_4/Fig_4d_to_g.m:
--------------------------------------------------------------------------------
1 | %Generalization and memorization performance as a
2 | %function of alpha and SNR, for student-only, notebook-only, and combined
3 | %system
4 |
5 | close all; clear all; clc
6 |
7 | %% Batch
8 |
9 | nepoch = 10000;
10 | lr = 0.01;
11 | SNR_log_interval = -2:0.1:3;
12 | SNR_vec =10.^SNR_log_interval;
13 | alpha_vec= 0.1:0.1:5;
14 |
15 | Et_ES = zeros(length(SNR_vec),length(alpha_vec),nepoch);
16 | Eg_ES = zeros(length(SNR_vec),length(alpha_vec),nepoch);
17 | Eg_no_ES = zeros(length(SNR_vec),length(alpha_vec),nepoch);
18 | Et_no_ES = zeros(length(SNR_vec),length(alpha_vec),nepoch);
19 |
20 | for n = 1:length(SNR_vec)
21 | for m = 1:length(alpha_vec)
22 | SNR = SNR_vec(n);
23 | alpha = alpha_vec(m);
24 |
25 | if SNR == inf
26 | variance_w = 1;
27 | variance_e = 0;
28 | else
29 | variance_w = SNR/(SNR + 1);
30 | variance_e = 1/(SNR + 1);
31 | end
32 |
33 |
34 | parfor t = 1:nepoch
35 |
36 |
37 | train = @(lam) ( ( ( ((alpha^0.5+1).^2 - lam) .* (lam - (alpha^0.5-1).^2) ).^0.5) ./ (lam*2*pi) ).* ( lam.*variance_w + variance_e ).*exp(-2*lam.*t./(1./lr)) ;
38 | Et_no_ES(n,m,t) = (1/alpha)*(integral(train,(alpha^0.5-1)^2,(alpha^0.5+1)^2) + (alpha<1)*(1 - alpha)* variance_e ) + (1-1/alpha)*variance_e;
39 |
40 | test = @(lam) ( ( ( ((alpha^0.5+1).^2 - lam) .* (lam - (alpha^0.5-1).^2) ).^0.5) ./ (lam*2*pi) ).*(exp(-2*lam*t/(1/lr)) + ((1-exp(-lam*t/(1/lr))).^2)./(lam*SNR));
41 | Eg_no_ES(n,m,t) = variance_w*(integral(test,(alpha^0.5-1)^2,(alpha^0.5+1)^2) + (alpha<1)* (1 - alpha) + 1/SNR);
42 |
43 |
44 | end
45 |
46 | [min_Eg, ES_time] = min(Eg_no_ES(n,m,:));
47 |
48 | Eg_ES(n,m,:) = Eg_no_ES(n,m,:);
49 | Eg_ES(n,m,ES_time+1:end) = Eg_no_ES(n,m,ES_time);
50 |
51 | Et_ES(n,m,:) = Et_no_ES(n,m,:);
52 | Et_ES(n,m,ES_time+1:end) = Et_no_ES(n,m,ES_time);
53 | end
54 |
55 | end
56 |
57 | %% Optimal online student
58 |
59 | Eg_opt_online = zeros(length(SNR_vec),length(alpha_vec));
60 |
61 |
62 |
63 | parfor n = 1:length(SNR_vec)
64 | SNR = SNR_vec(n);
65 | variance_e = 1/(SNR + 1);
66 | [alpha, Eg] = ode45(@(alpha,Eg)eg_prime(alpha,Eg,variance_e),(0.1:0.1:5),1);
67 | Eg_new = interp1(alpha,Eg,alpha_vec);
68 | Eg_opt_online(n,:) = Eg_new;
69 |
70 | end
71 |
72 | figure (1)
73 | hold on
74 | plot(alpha_vec,Eg_opt_online(end,:),'b-')
75 | plot(alpha_vec,Eg_ES(end,:,end),'r-')
76 | plot(alpha_vec,ones(1,length(alpha_vec))*2,'m')
77 | ylim([-0.05 2.1])
78 | set(gcf,'position',[500 500 420 420])
79 | ax = gca;
80 | ax.XTick = [0 1 2 3 4 5];
81 | ax.YTick = [0 0.5 1 1.5 2];
82 | %saveas(gcf,'Fig_4d.pdf');
83 |
84 | figure (2)
85 | imagesc(flip(Eg_opt_online(:,:) - Eg_ES(:,:,end)))
86 | colormap(redblue)
87 | caxis([-0.3 0.3])
88 | colorbar
89 | set(gcf,'position',[500 500 420 325])
90 | %saveas(gcf,'Fig_4e.pdf');
91 |
92 | figure (3)
93 | hold on
94 | plot(alpha_vec,Eg_no_ES(25,:,end),'b-')
95 | plot(alpha_vec,Eg_ES(25,:,end),'r-')
96 | ylim([0 2.5])
97 | set(gcf,'position',[500 500 420 420])
98 | %saveas(gcf,'Fig_4f.pdf');
99 |
100 | figure (4)
101 | imagesc(flip(Eg_ES(:,:,end)-Eg_no_ES(:,:,end)))
102 | colormap(redblue)
103 | caxis([-1 1])
104 | colorbar
105 | set(gcf,'position',[500 500 420 325])
106 | %saveas(gcf,'Fig_4g.pdf');
107 |
108 | function dEg = eg_prime(alpha,Eg,variance_e)
109 | dEg = 2*variance_e - Eg - variance_e^2/Eg;
110 | end
111 |
112 |
113 |
114 |
115 |
116 |
--------------------------------------------------------------------------------
/Fig_5/Fig_5_e_f.m:
--------------------------------------------------------------------------------
1 | %Fig.5 e and f, noisy, complex, and partial observable teachers
2 | tic
3 | close all
4 | clear all
5 |
6 | r_n = 1000; % number of repeats
7 | nepoch = 500;
8 | learnrate = 0.1;
9 | N_x_t = 100;
10 | N_y_t = 1;
11 | P=100;
12 | p_test = 1000;
13 |
14 | %Set student dimensions
15 | N_x_s = N_x_t;
16 | N_y_s = N_y_t;
17 |
18 | % N_x_t_wide = 117 for a partial observability level matches the sin complex teacher (SNR = 5.74)
19 | % For Fig.5e, set N_x_t_wide = 100 for Fig. 5f. See supplementary material,
20 | % complex teacher section.
21 | N_x_t_wide = N_x_t + 17;
22 |
23 |
24 | Error_vector_big_noisy = zeros(r_n,nepoch);
25 | gError_vector_big_noisy = zeros(r_n,nepoch);
26 |
27 | Error_vector_big_com = zeros(r_n,nepoch);
28 | gError_vector_big_com = zeros(r_n,nepoch);
29 |
30 | Error_vector_big_partial = zeros(r_n,nepoch);
31 | gError_vector_big_partial = zeros(r_n,nepoch);
32 |
33 |
34 |
35 |
36 | parfor r = 1:r_n
37 |
38 | rng(r)
39 |
40 | Error_vector_noisy = zeros(nepoch,1);
41 | gError_vector_noisy = zeros(nepoch,1);
42 |
43 |
44 | Error_vector_com = zeros(nepoch,1);
45 | gError_vector_com = zeros(nepoch,1);
46 |
47 | Error_vector_partial = zeros(nepoch,1);
48 | gError_vector_partial = zeros(nepoch,1);
49 |
50 |
51 |
52 | %% Teacher Network
53 |
54 | w_t = normrnd(0,1^0.5,[N_x_t,N_y_t]);
55 | w_t_wide = normrnd(0,1^0.5,[N_x_t_wide,N_y_t]);
56 |
57 | %training data
58 | x_t_input = normrnd(0,(1/N_x_t)^0.5,[P,N_x_t]); %input patterns for noisy and complex teachers
59 | x_t_input_wide = normrnd(0,(1/N_x_t)^0.5,[P,N_x_t_wide]); %input patterns for partially observable teacher
60 | y_t_output_wide = x_t_input_wide*w_t_wide; %output of the partially observable teacher
61 | observable_mask = randperm(N_x_t_wide,100); % define the region of teacher input observable to the student
62 | x_t_input_observable = x_t_input_wide(:,observable_mask);
63 | y_t_output_complex = sin(x_t_input*w_t); % complex teacher output, using sine function for Fig. 5e, remove sine function for Fig. 5f.
64 |
65 | % Below are terms for calculating equivalent SNR of a complex teacher.
66 | % See supplementary material, complex teacher section for details.
67 | x_t_input_big = normrnd(0,(1/N_x_t)^0.5,[P*1000,N_x_t]);
68 | y_t_output_complex_big = sin(x_t_input_big*w_t);
69 | % Variance of the optimal linear weight for fitting the training data by a complex
70 | % teacher.
71 | variance_w_opt = var(inv(x_t_input_big'*x_t_input_big)*(x_t_input_big'*y_t_output_complex_big));
72 | % variance of the residue after linear fitting
73 | variance_c = mean((y_t_output_complex - x_t_input*inv(x_t_input_big'*x_t_input_big)*(x_t_input_big'*y_t_output_complex_big)).^2);
74 |
75 |
76 | w_t_opt = normrnd(0,variance_w_opt^0.5,[N_x_t,N_y_t]); % set weight variance for the noisy teacher
77 | noise = normrnd(0,variance_c^0.5,[P,N_y_t]); % set training data noise variance for the noisy teacher
78 | noise1 = normrnd(0,variance_c^0.5,[p_test,N_y_t]);% set testing data noise variance for the noisy teacher
79 | y_t_output =x_t_input*w_t_opt + noise; %output of noisy teacher
80 |
81 | %test sets
82 | x_t_input_new = normrnd(0,(1/N_x_t)^0.5,[p_test,N_x_t]);
83 | y_t_output_new = x_t_input_new*w_t_opt + noise1;
84 | y_t_output_new_complex = sin(x_t_input_new*w_t);
85 |
86 | x_t_input_wide_new = normrnd(0,(1/N_x_t)^0.5,[p_test,N_x_t_wide]);
87 | x_t_input_wide_new_observable = x_t_input_wide_new(:,observable_mask);
88 | y_t_output_wide_new = x_t_input_wide_new*w_t_wide;
89 |
90 |
91 | %% Student Network
92 |
93 | w_s_noisy = normrnd(0,0^0.5,[N_x_s,N_y_s]);
94 | w_s_complex = normrnd(0,0^0.5,[N_x_s,N_y_s]);
95 | w_s_partial = normrnd(0,0^0.5,[N_x_s,N_y_s]);
96 |
97 |
98 |
99 |
100 | for m = 1:nepoch
101 |
102 |
103 |
104 | y_s_output_noisy = x_t_input*w_s_noisy;
105 | y_s_output_new_noisy = x_t_input_new*w_s_noisy;
106 |
107 | y_s_output_batch_com = x_t_input*w_s_complex;
108 | y_s_output_new_batch_com = x_t_input_new*w_s_complex;
109 |
110 | y_s_output_batch_partial = x_t_input_observable*w_s_partial;
111 | y_s_output_new_batch_partial = x_t_input_wide_new_observable*w_s_partial;
112 |
113 |
114 | %noisy teacher train error
115 | Error_noisy = y_t_output - y_s_output_noisy;
116 | MSE_noisy = sum(Error_noisy.^2)/P;
117 | Error_vector_noisy(m) = MSE_noisy;
118 |
119 | %noisy teacher generalization error
120 | gError_noisy = y_t_output_new - y_s_output_new_noisy;
121 | gMSE_noisy = sum(gError_noisy.^2)/p_test;
122 | gError_vector_noisy(m) = gMSE_noisy;
123 |
124 |
125 | %complex teacher train error
126 | Error_com = y_t_output_complex - y_s_output_batch_com;
127 | MSE_com = sum(Error_com.^2)/P;
128 | Error_vector_com(m) = MSE_com;
129 |
130 | %complex teacher generalization error
131 | gError_com = y_t_output_new_complex - y_s_output_new_batch_com;
132 | gMSE_com = sum(gError_com.^2)/p_test;
133 | gError_vector_com(m) = gMSE_com;
134 |
135 | %partial observable teacher train error
136 | Error_partial = y_t_output_wide - y_s_output_batch_partial;
137 | Cost_partial = sum(Error_partial.^2)/P;
138 | Error_vector_partial(m) = Cost_partial;
139 |
140 | %partial observable teacher generalization error
141 | gError_partial = y_t_output_wide_new - y_s_output_new_batch_partial;
142 | gCost_partial = sum(gError_partial.^2)/p_test;
143 | gError_vector_partial(m) = gCost_partial;
144 |
145 | % Weight updates through gradient descent for teach teacher
146 | w_delta_noisy = (x_t_input'*y_t_output - x_t_input'*x_t_input*w_s_noisy);
147 | w_s_noisy = w_s_noisy + learnrate*w_delta_noisy;
148 |
149 | w_delta_com = (x_t_input'*y_t_output_complex - x_t_input'*x_t_input*w_s_complex);
150 | w_s_complex = w_s_complex + learnrate*w_delta_com;
151 |
152 | w_delta_partial = (x_t_input_observable'*y_t_output_wide - x_t_input_observable'* x_t_input_observable*w_s_partial);
153 | w_s_partial = w_s_partial + learnrate*w_delta_partial;
154 | end
155 |
156 | Error_vector_big_noisy(r,:) = Error_vector_noisy;
157 | gError_vector_big_noisy(r,:) = gError_vector_noisy;
158 |
159 | Error_vector_big_com(r,:) = Error_vector_com;
160 | gError_vector_big_com(r,:) = gError_vector_com;
161 |
162 | Error_vector_big_partial(r,:) = Error_vector_partial;
163 | gError_vector_big_partial(r,:) = gError_vector_partial;
164 |
165 | end
166 |
167 | toc
168 |
169 | color_scheme = [137 152 193; 245 143 136]/255;
170 | line_w = 1;
171 | font_s = 12;
172 |
173 | figure(1)
174 | hold on
175 |
176 |
177 | plot(1:nepoch,mean(Error_vector_big_noisy)/mean(Error_vector_big_noisy(:,1)),'r-')
178 | plot(1:nepoch,mean(gError_vector_big_noisy)/mean(gError_vector_big_noisy(:,1)),'r-')
179 |
180 |
181 | plot(1:nepoch,mean(Error_vector_big_com)/mean(Error_vector_big_com(:,1)),'k-')
182 | plot(1:nepoch,mean(gError_vector_big_com)/mean(gError_vector_big_com(:,1)),'k-')
183 |
184 | plot(1:nepoch,mean(Error_vector_big_partial)/mean(Error_vector_big_partial(:,1)),'g-')
185 | plot(1:nepoch,mean(gError_vector_big_partial)/mean(gError_vector_big_partial(:,1)),'g-')
186 |
187 |
188 |
189 | xt = get(gca, 'XTick');
190 | set(gca, 'FontSize', font_s)
191 | yt = get(gca, 'YTick');
192 | set(gca, 'FontSize', font_s)
193 | xlabel('Epoch','Color','k')
194 | ylabel('Error','Color','k')
195 | set(gcf,'position',[100,100,360,225])
196 | xlim([0 nepoch])
197 |
198 | % print(gcf,'complex_teacher_low_SNR.png','-dpng','-r600');
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Organizing memories for generalization in complementary learning systems
2 | ---
3 | Code for reproducing figures in:
4 |
5 |
Organizing memories for generalization in complementary learning systems
6 |
Weinan Sun, Madhu Advani, Nelson Spruston, Andrew Saxe, James E Fitzgerald
7 |
bioRxiv 2021.10.13.463791; doi: https://doi.org/10.1101/2021.10.13.463791
8 |
9 |
Matlab function required:
10 | Redblue colormap
11 |
https://www.mathworks.com/matlabcentral/fileexchange/25536-red-blue-colormap
12 |
--------------------------------------------------------------------------------
/Supplements/Fig_S1_KNN.m:
--------------------------------------------------------------------------------
1 | %Generalization error of KNN regression
2 |
3 | close all;
4 | clear all;
5 |
6 |
7 | R = 10; % number of repeats
8 | range_N = [3 10 30 100 300 1000 3000 10000]; % # of dimensions
9 |
10 | mse_train = zeros(R,length(range_N));
11 | mse_test = zeros(R,length(range_N));
12 |
13 | N_y_t = 1; %output dimension
14 | P_test = 1000;
15 |
16 | SNR = inf;
17 |
18 | if SNR == inf
19 | variance_w = 1;
20 | variance_e = 0;
21 | else
22 | variance_w = SNR/(SNR + 1);
23 | variance_e = 1/(SNR + 1);
24 | end
25 |
26 |
27 | K=1; % number of nearest neighbors
28 |
29 | parfor repeat = 1:R
30 |
31 | mse_train_row = zeros(1,length(range_N));
32 | mse_test_row = zeros(1,length(range_N));
33 |
34 | for counter = 1:length(range_N)
35 | N_x_t = range_N(counter);
36 | P = N_x_t; % keeping alpha = 1
37 | train_error = zeros (1,P);
38 | test_error = zeros(1,P_test);
39 |
40 | w_t = normrnd(0,variance_w^0.5,[N_x_t,N_y_t]);
41 |
42 | %Generate patterns
43 |
44 | x_t_input = normrnd(0,(1/N_x_t)^0.5,[P,N_x_t]);
45 | noise = normrnd(0,variance_e^0.5,[P,N_y_t]);
46 | y_t_output = x_t_input*w_t + noise;
47 |
48 |
49 | x_t_input_new = normrnd(0,(1/N_x_t)^0.5,[P_test,N_x_t]);
50 | noise1 = normrnd(0,variance_e^0.5,[P_test,N_y_t]);
51 | y_t_output_new = x_t_input_new*w_t+ noise1;
52 |
53 |
54 | %Distance matrix
55 |
56 | D_train = pdist2(x_t_input,x_t_input);
57 | D_test = pdist2(x_t_input_new,x_t_input);
58 |
59 | %Train error
60 | for n = 1:P
61 | [B,I] = mink(D_train(n,:),K);
62 | retrieved_x = mean(x_t_input(I,:));
63 | retrieved_y = mean(y_t_output(I));
64 | train_error(n) = (y_t_output(n) - retrieved_y)^2;
65 | end
66 |
67 | %Test error
68 | for n = 1:P_test
69 | [B,I] = mink(D_test(n,:),K);
70 | retrieved_x = mean(x_t_input(I,:));
71 | retrieved_y = mean(y_t_output(I));
72 | test_error(n) = (y_t_output_new(n) - retrieved_y)^2;
73 | end
74 |
75 | mse_train_row(counter) = mean(train_error);
76 | mse_test_row(counter) = mean(test_error);
77 |
78 | end
79 | mse_train(repeat,:) = mse_train_row
80 | mse_test(repeat,:) = mse_test_row
81 | end
82 |
83 |
84 | figure(1)
85 |
86 | plot(log10(range_N),mean(mse_test,1),'ro-','LineWidth',3)
87 |
88 | xt = get(gca, 'XTick');
89 | set(gca, 'FontSize', 20)
90 | yt = get(gca, 'YTick');
91 | set(gca, 'FontSize', 20)
92 | ylim([0 2.5])
--------------------------------------------------------------------------------
/Supplements/Fig_S2_SNR_alpha.m:
--------------------------------------------------------------------------------
1 | %Generalization and memorization performance as a
2 | %function of alpha and SNR.
3 |
4 | close all; clear all; clc
5 |
6 | nepoch = 10000;
7 | lr = 0.01;
8 | SNR_log_interval = -2:0.1:3;
9 | SNR_vec =10.^SNR_log_interval;
10 | alpha_vec= 0.1:0.1:5;
11 |
12 | Et_ES = zeros(length(SNR_vec),length(alpha_vec),nepoch);
13 | Eg_ES = zeros(length(SNR_vec),length(alpha_vec),nepoch);
14 | Eg_no_ES = zeros(length(SNR_vec),length(alpha_vec),nepoch);
15 | Et_no_ES = zeros(length(SNR_vec),length(alpha_vec),nepoch);
16 |
17 | for n = 1:length(SNR_vec)
18 | for m = 1:length(alpha_vec)
19 | SNR = SNR_vec(n);
20 | alpha = alpha_vec(m);
21 |
22 | if SNR == inf
23 | variance_w = 1;
24 | variance_e = 0;
25 | else
26 | variance_w = SNR/(SNR + 1);
27 | variance_e = 1/(SNR + 1);
28 | end
29 |
30 |
31 | parfor t = 1:nepoch
32 |
33 |
34 | train = @(lam) ( ( ( ((alpha^0.5+1).^2 - lam) .* (lam - (alpha^0.5-1).^2) ).^0.5) ./ (lam*2*pi) ).* ( lam.*variance_w + variance_e ).*exp(-2*lam.*t./(1./lr)) ;
35 | Et_no_ES(n,m,t) = (1/alpha)*(integral(train,(alpha^0.5-1)^2,(alpha^0.5+1)^2) + (alpha<1)*(1 - alpha)* variance_e ) + (1-1/alpha)*variance_e;
36 |
37 | test = @(lam) ( ( ( ((alpha^0.5+1).^2 - lam) .* (lam - (alpha^0.5-1).^2) ).^0.5) ./ (lam*2*pi) ).*(exp(-2*lam*t/(1/lr)) + ((1-exp(-lam*t/(1/lr))).^2)./(lam*SNR));
38 | Eg_no_ES(n,m,t) = variance_w*(integral(test,(alpha^0.5-1)^2,(alpha^0.5+1)^2) + (alpha<1)* (1 - alpha) + 1/SNR);
39 |
40 |
41 | end
42 |
43 | [min_Eg, ES_time] = min(Eg_no_ES(n,m,:));
44 |
45 | Eg_ES(n,m,:) = Eg_no_ES(n,m,:);
46 | Eg_ES(n,m,ES_time+1:end) = Eg_no_ES(n,m,ES_time);
47 |
48 | Et_ES(n,m,:) = Et_no_ES(n,m,:);
49 | Et_ES(n,m,ES_time+1:end) = Et_no_ES(n,m,ES_time);
50 | end
51 |
52 | end
53 |
54 | figure(1)
55 | imagesc(flip((Et_no_ES(:,:,1)-Et_no_ES(:,:,end))./Et_no_ES(:,:,1)))
56 | colormap(redblue)
57 | caxis([-1 1])
58 | colorbar
59 | set(gcf,'position',[500 500 440 325])
60 |
61 | figure(2)
62 | imagesc(flip((Eg_no_ES(:,:,1)-Eg_no_ES(:,:,end))./Eg_no_ES(:,:,1)))
63 | colormap(redblue)
64 | caxis([-1 1])
65 | colorbar
66 | set(gcf,'position',[500 500 440 325])
67 |
68 | figure(3)
69 | imagesc(flip((Et_ES(:,:,1)-Et_ES(:,:,end))./Et_ES(:,:,1)))
70 | colormap(redblue)
71 | caxis([-1 1])
72 | colorbar
73 | set(gcf,'position',[500 500 440 325])
74 |
75 | figure(4)
76 | imagesc(flip(Eg_ES(:,:,1)-Eg_ES(:,:,end))./Eg_ES(:,:,1))
77 | colormap(redblue)
78 | caxis([-1 1])
79 | colorbar
80 | set(gcf,'position',[500 500 440 325])
81 |
82 |
83 |
--------------------------------------------------------------------------------
/Supplements/Fig_S4/Fig_MLE_Speed_VS_GroundT.m:
--------------------------------------------------------------------------------
1 | %% Clear workspace and initialize random number generator
2 |
3 | clear all;
4 | close all;
5 | RandStream.setGlobalStream(RandStream('mt19937ar','seed',sum(100*clock)));
6 |
7 | %% Construct an ensemble of teacher-generated student data.
8 |
9 | num_teachers = 10;
10 | N = 100;
11 | P = 100;
12 | std_X = sqrt(1/N);
13 |
14 | SNRs = [0.05 4 100];
15 |
16 | num_SNRs = length(SNRs);
17 |
18 | wbars = NaN(N,num_SNRs,num_teachers);
19 | Xs = NaN(P,N,num_SNRs,num_teachers);
20 | etas = NaN(P,num_SNRs,num_teachers);
21 | ys = NaN(P,num_SNRs,num_teachers);
22 | for s = 1:num_SNRs;
23 | S = SNRs(s);
24 | std_noise = sqrt(1/(1+S));
25 | std_weights = sqrt(S/(1+S));
26 | for t = 1:num_teachers;
27 | wbars(:,s,t) = std_weights*randn(N,1);
28 |
29 | Xs(:,:,s,t) = std_X*randn(P,N);
30 | etas(:,s,t) = std_noise*randn(P,1);
31 | ys(:,s,t) = Xs(:,:,s,t)*wbars(:,s,t)+etas(:,s,t);
32 | end;
33 | end;
34 |
35 | var_ys = squeeze(var(ys,0,1));
36 | var_ws = squeeze(var(wbars,0,1));
37 | var_noises = squeeze(var(etas,0,1));
38 | empirical_SNRs = (var_ys-var_noises)./var_noises;
39 |
40 | mean_empirical_SNRs = mean(empirical_SNRs,2);
41 | std_empirical_SNRs = std(empirical_SNRs,0,2);
42 |
43 | %% Do an SVD decomposition on the data matrix.
44 |
45 |
46 | Us = NaN(P,P,num_SNRs,num_teachers);
47 | Ss = NaN(P,N,num_SNRs,num_teachers);
48 | Vs = NaN(N,N,num_SNRs,num_teachers);
49 | Lambdas = NaN(P,P,num_SNRs,num_teachers);
50 | inv_Lambdas = NaN(P,P,num_SNRs,num_teachers);
51 | for s = 1:num_SNRs;
52 | strcat('On SNR:',num2str(s))
53 |
54 | for t = 1:num_teachers;
55 | [Us(:,:,s,t) Ss(:,:,s,t) Vs(:,:,s,t)] = svd(Xs(:,:,s,t));
56 | Lambdas(:,:,s,t) = Ss(:,:,s,t)*Ss(:,:,s,t)';
57 | inv_Lambdas(:,:,s,t) = pinv(Lambdas(:,:,s,t));
58 | end;
59 | end;
60 |
61 | %% Evaluate the log-likelihood function for each dataset using the SVD decomposition
62 | % More numerically efficient version.
63 |
64 | tic;
65 | num_components = min(N,P);
66 |
67 |
68 | SNR_vector = 2.^(-6:0.02:7);
69 | num_step_SNR = length(SNR_vector);
70 |
71 | log_likelihood_function = NaN(num_step_SNR,num_SNRs,num_teachers);
72 | max_log_likelihood = NaN(num_SNRs,num_teachers);
73 | estimated_SNRs = NaN(num_SNRs,num_teachers);
74 | for s = 1:num_SNRs;
75 | strcat('On SNR:',num2str(s))
76 | toc
77 | for t = 1:num_teachers;
78 | y_modes = Us(:,:,s,t)'*ys(:,s,t);
79 | for i = 1:num_step_SNR;
80 | inv_C_matrix_modes = (1+SNR_vector(i))*eye(P);
81 | A1 = (P-num_components)/2*log(1+SNR_vector(i));
82 | A2 = 0;
83 | for j = 1:num_components;
84 | inv_C_matrix_modes(j,j) = (1+SNR_vector(i))/(1+SNR_vector(i)*Lambdas(j,j,s,t));
85 | A1 = A1 + 1/2*log(inv_C_matrix_modes(j,j));
86 | A2 = A2 - 1/2*y_modes(j)^2*inv_C_matrix_modes(j,j);
87 | end;
88 | if gt(P,num_components)
89 | for j = 1:P;
90 | A2 = A2 - 1/2*y_modes(j)^2*inv_C_matrix_modes(j,j);
91 | end;
92 | end;
93 | %inv_C_matrix_modes = (1+SNR_vector(i))*inv(eye(P)+SNR_vector(i)*Lambdas(:,:,s,t));
94 | %log_likelihood_function(i,s,t) = 1/2*log(det(inv_C_matrix_modes)) - 1/2*y_modes'*inv_C_matrix_modes*y_modes;
95 | log_likelihood_function(i,s,t) = A1 + A2;
96 | %log_likelihood_function(i,s,t) = -1/2*ys(:,s,t)'*inv_C_matrix*ys(:,s,t);
97 | %log_likelihood_function(i,s,t) = log(det(inv_C_matrix));
98 | end;
99 | [max_log_likelihood(s,t) tmp_idx] = max(log_likelihood_function(:,s,t));
100 | estimated_SNRs(s,t) = SNR_vector(tmp_idx);
101 | end;
102 | end;
103 |
104 |
105 | nepoch = 2000;
106 | learnrate = 0.005;
107 | N_x_t = 100;
108 | N_y_t = 1;
109 | P=100;
110 |
111 | for count = 1:size(SNRs,2)
112 | disp(count)
113 | SNR = SNRs(count);
114 | Eg = [];
115 | Et = [];
116 | lr = learnrate;
117 |
118 | %use normalized variances
119 | if SNR == inf
120 | variance_w = 1;
121 | variance_e = 0;
122 | else
123 | variance_w = SNR/(SNR + 1);
124 | variance_e = 1/(SNR + 1);
125 | end
126 |
127 | alpha = P/N_x_t; % number of examples divided by input dimension
128 |
129 | % Analytical curves for training and testing errors, see supplementary
130 | % material for derivations
131 |
132 | for t = 1:1:nepoch
133 |
134 | train = @(lam) ( ( ( ((alpha^0.5+1).^2 - lam) .* (lam - (alpha^0.5-1).^2) ).^0.5) ./ (lam*2*pi) ).* ( lam.*variance_w + variance_e ).*exp(-2*lam.*t./(1./lr)) ;
135 | Et = [Et (1/alpha)*(integral(train,(alpha^0.5-1)^2,(alpha^0.5+1)^2) + (alpha<1)*(1 - alpha)* variance_e ) + (1-1/alpha)*variance_e];
136 |
137 | test = @(lam) ( ( ( ((alpha^0.5+1).^2 - lam) .* (lam - (alpha^0.5-1).^2) ).^0.5) ./ (lam*2*pi) ).*(exp(-2*lam*t/(1/lr)) + ((1-exp(-lam*t/(1/lr))).^2)./(lam*SNR));
138 | Eg = [Eg variance_w*(integral(test,(alpha^0.5-1)^2,(alpha^0.5+1)^2) + (alpha<1)* (1 - alpha) + 1/SNR)];
139 |
140 | end
141 | [mm, pp] = min(Eg);
142 | Eg_early_stop = Eg;
143 | Eg_early_stop(pp+1:end) = Eg(pp);
144 |
145 | Et_early_stop = Et;
146 | Et_early_stop(pp+1:end) = Et(pp);
147 |
148 | ES_time = [];
149 | for ind = 1:length(estimated_SNRs(count,:))
150 | Eg_MLE = [];
151 | Et_MLE = [];
152 |
153 | SNR = estimated_SNRs(count,ind);
154 |
155 | if SNR == inf
156 | variance_w = 1;
157 | variance_e = 0;
158 | else
159 | variance_w = SNR/(SNR + 1);
160 | variance_e = 1/(SNR + 1);
161 | end
162 |
163 |
164 | for t = 1:1:nepoch
165 |
166 | train = @(lam) ( ( ( ((alpha^0.5+1).^2 - lam) .* (lam - (alpha^0.5-1).^2) ).^0.5) ./ (lam*2*pi) ).* ( lam.*variance_w + variance_e ).*exp(-2*lam.*t./(1./lr)) ;
167 | Et_MLE = [Et_MLE (1/alpha)*(integral(train,(alpha^0.5-1)^2,(alpha^0.5+1)^2) + (alpha<1)*(1 - alpha)* variance_e ) + (1-1/alpha)*variance_e];
168 |
169 | test = @(lam) ( ( ( ((alpha^0.5+1).^2 - lam) .* (lam - (alpha^0.5-1).^2) ).^0.5) ./ (lam*2*pi) ).*(exp(-2*lam*t/(1/lr)) + ((1-exp(-lam*t/(1/lr))).^2)./(lam*SNR));
170 | Eg_MLE = [Eg_MLE variance_w*(integral(test,(alpha^0.5-1)^2,(alpha^0.5+1)^2) + (alpha<1)* (1 - alpha) + 1/SNR)];
171 |
172 | end
173 | [mm, pp] = min(Eg_MLE);
174 | ES_time = [ES_time pp];
175 |
176 |
177 | end
178 | % Early stopping curves
179 |
180 | figure(count)
181 |
182 | hold on;
183 |
184 | for i =1:length(ES_time)
185 | min_p = ES_time(i);
186 | Eg_early_stop_MLE = Eg;
187 | Eg_early_stop_MLE(min_p+1:end) = Eg(min_p);
188 | Et_early_stop_MLE = Et;
189 | Et_early_stop_MLE(min_p+1:end) = Et(min_p);
190 |
191 | plot(1:1:nepoch,Et_early_stop_MLE,'-','color',[0 1 1 0.5],'LineWidth',1)
192 | plot(1:1:nepoch,Eg_early_stop_MLE,'-','color',[1 0 1 0.5],'LineWidth',1)
193 | end
194 | plot(1:1:nepoch,Et_early_stop,'-','color',[0 0 1],'LineWidth',3)
195 | plot(1:1:nepoch,Eg_early_stop,'-','color',[1 0 0],'LineWidth',3)
196 |
197 |
198 | set(gca, 'FontSize',12)
199 | set(gca, 'FontSize', 12)
200 | xlabel('Epoch')
201 | ylabel('Error')
202 | set(gca,'linewidth',1.5)
203 | ylim([0 2])
204 | set(gcf,'position',[100,100,350,290])
205 | saveas(gcf,strcat('MLE SNR=',num2str(SNRs(count)),'.pdf'));
206 |
207 |
208 | end
209 |
210 | %loading estimated SNR values from the learning rate approach (obtained by running the associated code for main Fig.2m,n)
211 | file = load('N_10000.mat');
212 | field_names = fieldnames(file);
213 | first_field_name = field_names{1};
214 | Est_SNRs = file.(first_field_name);
215 |
216 | for count = 1:size(SNRs,2)
217 | disp(count)
218 | SNR = SNRs(count);
219 | Eg = [];
220 | Et = [];
221 | lr = learnrate;
222 |
223 | %use normalized variances
224 | if SNR == inf
225 | variance_w = 1;
226 | variance_e = 0;
227 | else
228 | variance_w = SNR/(SNR + 1);
229 | variance_e = 1/(SNR + 1);
230 | end
231 |
232 | alpha = P/N_x_t; % number of examples divided by input dimension
233 |
234 | % Analytical curves for training and testing errors, see supplementary
235 | % material for derivations
236 |
237 | for t = 1:1:nepoch
238 |
239 | train = @(lam) ( ( ( ((alpha^0.5+1).^2 - lam) .* (lam - (alpha^0.5-1).^2) ).^0.5) ./ (lam*2*pi) ).* ( lam.*variance_w + variance_e ).*exp(-2*lam.*t./(1./lr)) ;
240 | Et = [Et (1/alpha)*(integral(train,(alpha^0.5-1)^2,(alpha^0.5+1)^2) + (alpha<1)*(1 - alpha)* variance_e ) + (1-1/alpha)*variance_e];
241 |
242 | test = @(lam) ( ( ( ((alpha^0.5+1).^2 - lam) .* (lam - (alpha^0.5-1).^2) ).^0.5) ./ (lam*2*pi) ).*(exp(-2*lam*t/(1/lr)) + ((1-exp(-lam*t/(1/lr))).^2)./(lam*SNR));
243 | Eg = [Eg variance_w*(integral(test,(alpha^0.5-1)^2,(alpha^0.5+1)^2) + (alpha<1)* (1 - alpha) + 1/SNR)];
244 |
245 | end
246 | [mm, pp] = min(Eg);
247 | Eg_early_stop = Eg;
248 | Eg_early_stop(pp+1:end) = Eg(pp);
249 |
250 | Et_early_stop = Et;
251 | Et_early_stop(pp+1:end) = Et(pp);
252 |
253 | ES_time = [];
254 | for ind = 1:10
255 | Eg_LS = [];
256 | Et_LS = [];
257 | SNR_est = Est_SNRs(1+(count-1)*10:10+(count-1)*10);
258 | SNR = SNR_est(ind);
259 |
260 | if SNR == inf
261 | variance_w = 1;
262 | variance_e = 0;
263 | else
264 | variance_w = SNR/(SNR + 1);
265 | variance_e = 1/(SNR + 1);
266 | end
267 |
268 |
269 | for t = 1:1:nepoch
270 |
271 | train = @(lam) ( ( ( ((alpha^0.5+1).^2 - lam) .* (lam - (alpha^0.5-1).^2) ).^0.5) ./ (lam*2*pi) ).* ( lam.*variance_w + variance_e ).*exp(-2*lam.*t./(1./lr)) ;
272 | Et_LS = [Et_LS (1/alpha)*(integral(train,(alpha^0.5-1)^2,(alpha^0.5+1)^2) + (alpha<1)*(1 - alpha)* variance_e ) + (1-1/alpha)*variance_e];
273 |
274 | test = @(lam) ( ( ( ((alpha^0.5+1).^2 - lam) .* (lam - (alpha^0.5-1).^2) ).^0.5) ./ (lam*2*pi) ).*(exp(-2*lam*t/(1/lr)) + ((1-exp(-lam*t/(1/lr))).^2)./(lam*SNR));
275 | Eg_LS = [Eg_LS variance_w*(integral(test,(alpha^0.5-1)^2,(alpha^0.5+1)^2) + (alpha<1)* (1 - alpha) + 1/SNR)];
276 |
277 | end
278 | [mm, pp] = min(Eg_LS);
279 | ES_time = [ES_time pp];
280 |
281 |
282 | end
283 | % Early stopping curves
284 |
285 | figure(count+3)
286 |
287 | hold on;
288 |
289 | for i =1:length(ES_time)
290 | min_p = ES_time(i);
291 | Eg_early_stop_MLE = Eg;
292 | Eg_early_stop_MLE(min_p+1:end) = Eg(min_p);
293 |
294 | Et_early_stop_MLE = Et;
295 | Et_early_stop_MLE(min_p+1:end) = Et(min_p);
296 |
297 | plot(1:1:nepoch,Et_early_stop_MLE,'-','color',[0 1 1 0.5],'LineWidth',1)
298 | plot(1:1:nepoch,Eg_early_stop_MLE,'-','color',[1 0 1 0.5],'LineWidth',1)
299 | end
300 | plot(1:1:nepoch,Et_early_stop,'-','color',[0 0 1],'LineWidth',3)
301 | plot(1:1:nepoch,Eg_early_stop,'-','color',[1 0 0],'LineWidth',3)
302 |
303 |
304 | set(gca, 'FontSize',12)
305 | set(gca, 'FontSize', 12)
306 | xlabel('Epoch')
307 | ylabel('Error')
308 | set(gca,'linewidth',1.5)
309 | ylim([0 2])
310 | set(gcf,'position',[100,100,350,290])
311 | saveas(gcf,strcat('Learning_speed SNR=',num2str(SNRs(count)),'_',first_field_name,'.pdf'));
312 |
313 |
314 | end
315 |
316 | %% save figures
317 | % figure(1)
318 | % set(gcf,'position',[100,100,350,290])
319 | % saveas(gcf,strcat('Fig_2g','.pdf'));
320 | % figure(2)
321 | % set(gcf,'position',[100,100,350,290])
322 | % saveas(gcf,strcat('Fig_2h','.pdf'));
323 |
324 |
325 |
326 |
--------------------------------------------------------------------------------
/Supplements/Fig_S4/N_100.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuroai/Go-CLS_v2/a5f90b73c61088f0a154414af05ea1e54122c06d/Supplements/Fig_S4/N_100.mat
--------------------------------------------------------------------------------
/Supplements/Fig_S4/N_10000.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuroai/Go-CLS_v2/a5f90b73c61088f0a154414af05ea1e54122c06d/Supplements/Fig_S4/N_10000.mat
--------------------------------------------------------------------------------
/Supplements/Fig_S5/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .vgg import *
2 | from .dpn import *
3 | from .lenet import *
4 | from .senet import *
5 | from .pnasnet import *
6 | from .densenet import *
7 | from .googlenet import *
8 | from .shufflenet import *
9 | from .shufflenetv2 import *
10 | from .resnet import *
11 | from .resnext import *
12 | from .preact_resnet import *
13 | from .mobilenet import *
14 | from .mobilenetv2 import *
15 | from .efficientnet import *
16 | from .regnet import *
17 | from .dla_simple import *
18 | from .dla import *
19 |
--------------------------------------------------------------------------------
/Supplements/Fig_S5/models/densenet.py:
--------------------------------------------------------------------------------
1 | '''DenseNet in PyTorch.'''
2 | import math
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 |
9 | class Bottleneck(nn.Module):
10 | def __init__(self, in_planes, growth_rate):
11 | super(Bottleneck, self).__init__()
12 | self.bn1 = nn.BatchNorm2d(in_planes)
13 | self.conv1 = nn.Conv2d(in_planes, 4*growth_rate, kernel_size=1, bias=False)
14 | self.bn2 = nn.BatchNorm2d(4*growth_rate)
15 | self.conv2 = nn.Conv2d(4*growth_rate, growth_rate, kernel_size=3, padding=1, bias=False)
16 |
17 | def forward(self, x):
18 | out = self.conv1(F.relu(self.bn1(x)))
19 | out = self.conv2(F.relu(self.bn2(out)))
20 | out = torch.cat([out,x], 1)
21 | return out
22 |
23 |
24 | class Transition(nn.Module):
25 | def __init__(self, in_planes, out_planes):
26 | super(Transition, self).__init__()
27 | self.bn = nn.BatchNorm2d(in_planes)
28 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, bias=False)
29 |
30 | def forward(self, x):
31 | out = self.conv(F.relu(self.bn(x)))
32 | out = F.avg_pool2d(out, 2)
33 | return out
34 |
35 |
36 | class DenseNet(nn.Module):
37 | def __init__(self, block, nblocks, growth_rate=12, reduction=0.5, num_classes=10):
38 | super(DenseNet, self).__init__()
39 | self.growth_rate = growth_rate
40 |
41 | num_planes = 2*growth_rate
42 | self.conv1 = nn.Conv2d(3, num_planes, kernel_size=3, padding=1, bias=False)
43 |
44 | self.dense1 = self._make_dense_layers(block, num_planes, nblocks[0])
45 | num_planes += nblocks[0]*growth_rate
46 | out_planes = int(math.floor(num_planes*reduction))
47 | self.trans1 = Transition(num_planes, out_planes)
48 | num_planes = out_planes
49 |
50 | self.dense2 = self._make_dense_layers(block, num_planes, nblocks[1])
51 | num_planes += nblocks[1]*growth_rate
52 | out_planes = int(math.floor(num_planes*reduction))
53 | self.trans2 = Transition(num_planes, out_planes)
54 | num_planes = out_planes
55 |
56 | self.dense3 = self._make_dense_layers(block, num_planes, nblocks[2])
57 | num_planes += nblocks[2]*growth_rate
58 | out_planes = int(math.floor(num_planes*reduction))
59 | self.trans3 = Transition(num_planes, out_planes)
60 | num_planes = out_planes
61 |
62 | self.dense4 = self._make_dense_layers(block, num_planes, nblocks[3])
63 | num_planes += nblocks[3]*growth_rate
64 |
65 | self.bn = nn.BatchNorm2d(num_planes)
66 | self.linear = nn.Linear(num_planes, num_classes)
67 |
68 | def _make_dense_layers(self, block, in_planes, nblock):
69 | layers = []
70 | for i in range(nblock):
71 | layers.append(block(in_planes, self.growth_rate))
72 | in_planes += self.growth_rate
73 | return nn.Sequential(*layers)
74 |
75 | def forward(self, x):
76 | out = self.conv1(x)
77 | out = self.trans1(self.dense1(out))
78 | out = self.trans2(self.dense2(out))
79 | out = self.trans3(self.dense3(out))
80 | out = self.dense4(out)
81 | out = F.avg_pool2d(F.relu(self.bn(out)), 4)
82 | out = out.view(out.size(0), -1)
83 | out = self.linear(out)
84 | return out
85 |
86 | def DenseNet121():
87 | return DenseNet(Bottleneck, [6,12,24,16], growth_rate=32)
88 |
89 | def DenseNet169():
90 | return DenseNet(Bottleneck, [6,12,32,32], growth_rate=32)
91 |
92 | def DenseNet201():
93 | return DenseNet(Bottleneck, [6,12,48,32], growth_rate=32)
94 |
95 | def DenseNet161():
96 | return DenseNet(Bottleneck, [6,12,36,24], growth_rate=48)
97 |
98 | def densenet_cifar():
99 | return DenseNet(Bottleneck, [6,12,24,16], growth_rate=12)
100 |
101 | def test():
102 | net = densenet_cifar()
103 | x = torch.randn(1,3,32,32)
104 | y = net(x)
105 | print(y)
106 |
107 | # test()
108 |
--------------------------------------------------------------------------------
/Supplements/Fig_S5/models/dla.py:
--------------------------------------------------------------------------------
1 | '''DLA in PyTorch.
2 |
3 | Reference:
4 | Deep Layer Aggregation. https://arxiv.org/abs/1707.06484
5 | '''
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 |
11 | class BasicBlock(nn.Module):
12 | expansion = 1
13 |
14 | def __init__(self, in_planes, planes, stride=1):
15 | super(BasicBlock, self).__init__()
16 | self.conv1 = nn.Conv2d(
17 | in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
18 | self.bn1 = nn.BatchNorm2d(planes)
19 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
20 | stride=1, padding=1, bias=False)
21 | self.bn2 = nn.BatchNorm2d(planes)
22 |
23 | self.shortcut = nn.Sequential()
24 | if stride != 1 or in_planes != self.expansion*planes:
25 | self.shortcut = nn.Sequential(
26 | nn.Conv2d(in_planes, self.expansion*planes,
27 | kernel_size=1, stride=stride, bias=False),
28 | nn.BatchNorm2d(self.expansion*planes)
29 | )
30 |
31 | def forward(self, x):
32 | out = F.relu(self.bn1(self.conv1(x)))
33 | out = self.bn2(self.conv2(out))
34 | out += self.shortcut(x)
35 | out = F.relu(out)
36 | return out
37 |
38 |
39 | class Root(nn.Module):
40 | def __init__(self, in_channels, out_channels, kernel_size=1):
41 | super(Root, self).__init__()
42 | self.conv = nn.Conv2d(
43 | in_channels, out_channels, kernel_size,
44 | stride=1, padding=(kernel_size - 1) // 2, bias=False)
45 | self.bn = nn.BatchNorm2d(out_channels)
46 |
47 | def forward(self, xs):
48 | x = torch.cat(xs, 1)
49 | out = F.relu(self.bn(self.conv(x)))
50 | return out
51 |
52 |
53 | class Tree(nn.Module):
54 | def __init__(self, block, in_channels, out_channels, level=1, stride=1):
55 | super(Tree, self).__init__()
56 | self.level = level
57 | if level == 1:
58 | self.root = Root(2*out_channels, out_channels)
59 | self.left_node = block(in_channels, out_channels, stride=stride)
60 | self.right_node = block(out_channels, out_channels, stride=1)
61 | else:
62 | self.root = Root((level+2)*out_channels, out_channels)
63 | for i in reversed(range(1, level)):
64 | subtree = Tree(block, in_channels, out_channels,
65 | level=i, stride=stride)
66 | self.__setattr__('level_%d' % i, subtree)
67 | self.prev_root = block(in_channels, out_channels, stride=stride)
68 | self.left_node = block(out_channels, out_channels, stride=1)
69 | self.right_node = block(out_channels, out_channels, stride=1)
70 |
71 | def forward(self, x):
72 | xs = [self.prev_root(x)] if self.level > 1 else []
73 | for i in reversed(range(1, self.level)):
74 | level_i = self.__getattr__('level_%d' % i)
75 | x = level_i(x)
76 | xs.append(x)
77 | x = self.left_node(x)
78 | xs.append(x)
79 | x = self.right_node(x)
80 | xs.append(x)
81 | out = self.root(xs)
82 | return out
83 |
84 |
85 | class DLA(nn.Module):
86 | def __init__(self, block=BasicBlock, num_classes=10):
87 | super(DLA, self).__init__()
88 | self.base = nn.Sequential(
89 | nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False),
90 | nn.BatchNorm2d(16),
91 | nn.ReLU(True)
92 | )
93 |
94 | self.layer1 = nn.Sequential(
95 | nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias=False),
96 | nn.BatchNorm2d(16),
97 | nn.ReLU(True)
98 | )
99 |
100 | self.layer2 = nn.Sequential(
101 | nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1, bias=False),
102 | nn.BatchNorm2d(32),
103 | nn.ReLU(True)
104 | )
105 |
106 | self.layer3 = Tree(block, 32, 64, level=1, stride=1)
107 | self.layer4 = Tree(block, 64, 128, level=2, stride=2)
108 | self.layer5 = Tree(block, 128, 256, level=2, stride=2)
109 | self.layer6 = Tree(block, 256, 512, level=1, stride=2)
110 | self.linear = nn.Linear(512, num_classes)
111 |
112 | def forward(self, x):
113 | out = self.base(x)
114 | out = self.layer1(out)
115 | out = self.layer2(out)
116 | out = self.layer3(out)
117 | out = self.layer4(out)
118 | out = self.layer5(out)
119 | out = self.layer6(out)
120 | out = F.avg_pool2d(out, 4)
121 | out = out.view(out.size(0), -1)
122 | out = self.linear(out)
123 | return out
124 |
125 |
126 | def test():
127 | net = DLA()
128 | print(net)
129 | x = torch.randn(1, 3, 32, 32)
130 | y = net(x)
131 | print(y.size())
132 |
133 |
134 | if __name__ == '__main__':
135 | test()
136 |
--------------------------------------------------------------------------------
/Supplements/Fig_S5/models/dla_simple.py:
--------------------------------------------------------------------------------
1 | '''Simplified version of DLA in PyTorch.
2 |
3 | Note this implementation is not identical to the original paper version.
4 | But it seems works fine.
5 |
6 | See dla.py for the original paper version.
7 |
8 | Reference:
9 | Deep Layer Aggregation. https://arxiv.org/abs/1707.06484
10 | '''
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.functional as F
14 |
15 |
16 | class BasicBlock(nn.Module):
17 | expansion = 1
18 |
19 | def __init__(self, in_planes, planes, stride=1):
20 | super(BasicBlock, self).__init__()
21 | self.conv1 = nn.Conv2d(
22 | in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
23 | self.bn1 = nn.BatchNorm2d(planes)
24 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
25 | stride=1, padding=1, bias=False)
26 | self.bn2 = nn.BatchNorm2d(planes)
27 |
28 | self.shortcut = nn.Sequential()
29 | if stride != 1 or in_planes != self.expansion*planes:
30 | self.shortcut = nn.Sequential(
31 | nn.Conv2d(in_planes, self.expansion*planes,
32 | kernel_size=1, stride=stride, bias=False),
33 | nn.BatchNorm2d(self.expansion*planes)
34 | )
35 |
36 | def forward(self, x):
37 | out = F.relu(self.bn1(self.conv1(x)))
38 | out = self.bn2(self.conv2(out))
39 | out += self.shortcut(x)
40 | out = F.relu(out)
41 | return out
42 |
43 |
44 | class Root(nn.Module):
45 | def __init__(self, in_channels, out_channels, kernel_size=1):
46 | super(Root, self).__init__()
47 | self.conv = nn.Conv2d(
48 | in_channels, out_channels, kernel_size,
49 | stride=1, padding=(kernel_size - 1) // 2, bias=False)
50 | self.bn = nn.BatchNorm2d(out_channels)
51 |
52 | def forward(self, xs):
53 | x = torch.cat(xs, 1)
54 | out = F.relu(self.bn(self.conv(x)))
55 | return out
56 |
57 |
58 | class Tree(nn.Module):
59 | def __init__(self, block, in_channels, out_channels, level=1, stride=1):
60 | super(Tree, self).__init__()
61 | self.root = Root(2*out_channels, out_channels)
62 | if level == 1:
63 | self.left_tree = block(in_channels, out_channels, stride=stride)
64 | self.right_tree = block(out_channels, out_channels, stride=1)
65 | else:
66 | self.left_tree = Tree(block, in_channels,
67 | out_channels, level=level-1, stride=stride)
68 | self.right_tree = Tree(block, out_channels,
69 | out_channels, level=level-1, stride=1)
70 |
71 | def forward(self, x):
72 | out1 = self.left_tree(x)
73 | out2 = self.right_tree(out1)
74 | out = self.root([out1, out2])
75 | return out
76 |
77 |
78 | class SimpleDLA(nn.Module):
79 | def __init__(self, block=BasicBlock, num_classes=10):
80 | super(SimpleDLA, self).__init__()
81 | self.base = nn.Sequential(
82 | nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False),
83 | nn.BatchNorm2d(16),
84 | nn.ReLU(True)
85 | )
86 |
87 | self.layer1 = nn.Sequential(
88 | nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias=False),
89 | nn.BatchNorm2d(16),
90 | nn.ReLU(True)
91 | )
92 |
93 | self.layer2 = nn.Sequential(
94 | nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1, bias=False),
95 | nn.BatchNorm2d(32),
96 | nn.ReLU(True)
97 | )
98 |
99 | self.layer3 = Tree(block, 32, 64, level=1, stride=1)
100 | self.layer4 = Tree(block, 64, 128, level=2, stride=2)
101 | self.layer5 = Tree(block, 128, 256, level=2, stride=2)
102 | self.layer6 = Tree(block, 256, 512, level=1, stride=2)
103 | self.linear = nn.Linear(512, num_classes)
104 |
105 | def forward(self, x):
106 | out = self.base(x)
107 | out = self.layer1(out)
108 | out = self.layer2(out)
109 | out = self.layer3(out)
110 | out = self.layer4(out)
111 | out = self.layer5(out)
112 | out = self.layer6(out)
113 | out = F.avg_pool2d(out, 4)
114 | out = out.view(out.size(0), -1)
115 | out = self.linear(out)
116 | return out
117 |
118 |
119 | def test():
120 | net = SimpleDLA()
121 | print(net)
122 | x = torch.randn(1, 3, 32, 32)
123 | y = net(x)
124 | print(y.size())
125 |
126 |
127 | if __name__ == '__main__':
128 | test()
129 |
--------------------------------------------------------------------------------
/Supplements/Fig_S5/models/dpn.py:
--------------------------------------------------------------------------------
1 | '''Dual Path Networks in PyTorch.'''
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | class Bottleneck(nn.Module):
8 | def __init__(self, last_planes, in_planes, out_planes, dense_depth, stride, first_layer):
9 | super(Bottleneck, self).__init__()
10 | self.out_planes = out_planes
11 | self.dense_depth = dense_depth
12 |
13 | self.conv1 = nn.Conv2d(last_planes, in_planes, kernel_size=1, bias=False)
14 | self.bn1 = nn.BatchNorm2d(in_planes)
15 | self.conv2 = nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=stride, padding=1, groups=32, bias=False)
16 | self.bn2 = nn.BatchNorm2d(in_planes)
17 | self.conv3 = nn.Conv2d(in_planes, out_planes+dense_depth, kernel_size=1, bias=False)
18 | self.bn3 = nn.BatchNorm2d(out_planes+dense_depth)
19 |
20 | self.shortcut = nn.Sequential()
21 | if first_layer:
22 | self.shortcut = nn.Sequential(
23 | nn.Conv2d(last_planes, out_planes+dense_depth, kernel_size=1, stride=stride, bias=False),
24 | nn.BatchNorm2d(out_planes+dense_depth)
25 | )
26 |
27 | def forward(self, x):
28 | out = F.relu(self.bn1(self.conv1(x)))
29 | out = F.relu(self.bn2(self.conv2(out)))
30 | out = self.bn3(self.conv3(out))
31 | x = self.shortcut(x)
32 | d = self.out_planes
33 | out = torch.cat([x[:,:d,:,:]+out[:,:d,:,:], x[:,d:,:,:], out[:,d:,:,:]], 1)
34 | out = F.relu(out)
35 | return out
36 |
37 |
38 | class DPN(nn.Module):
39 | def __init__(self, cfg):
40 | super(DPN, self).__init__()
41 | in_planes, out_planes = cfg['in_planes'], cfg['out_planes']
42 | num_blocks, dense_depth = cfg['num_blocks'], cfg['dense_depth']
43 |
44 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
45 | self.bn1 = nn.BatchNorm2d(64)
46 | self.last_planes = 64
47 | self.layer1 = self._make_layer(in_planes[0], out_planes[0], num_blocks[0], dense_depth[0], stride=1)
48 | self.layer2 = self._make_layer(in_planes[1], out_planes[1], num_blocks[1], dense_depth[1], stride=2)
49 | self.layer3 = self._make_layer(in_planes[2], out_planes[2], num_blocks[2], dense_depth[2], stride=2)
50 | self.layer4 = self._make_layer(in_planes[3], out_planes[3], num_blocks[3], dense_depth[3], stride=2)
51 | self.linear = nn.Linear(out_planes[3]+(num_blocks[3]+1)*dense_depth[3], 10)
52 |
53 | def _make_layer(self, in_planes, out_planes, num_blocks, dense_depth, stride):
54 | strides = [stride] + [1]*(num_blocks-1)
55 | layers = []
56 | for i,stride in enumerate(strides):
57 | layers.append(Bottleneck(self.last_planes, in_planes, out_planes, dense_depth, stride, i==0))
58 | self.last_planes = out_planes + (i+2) * dense_depth
59 | return nn.Sequential(*layers)
60 |
61 | def forward(self, x):
62 | out = F.relu(self.bn1(self.conv1(x)))
63 | out = self.layer1(out)
64 | out = self.layer2(out)
65 | out = self.layer3(out)
66 | out = self.layer4(out)
67 | out = F.avg_pool2d(out, 4)
68 | out = out.view(out.size(0), -1)
69 | out = self.linear(out)
70 | return out
71 |
72 |
73 | def DPN26():
74 | cfg = {
75 | 'in_planes': (96,192,384,768),
76 | 'out_planes': (256,512,1024,2048),
77 | 'num_blocks': (2,2,2,2),
78 | 'dense_depth': (16,32,24,128)
79 | }
80 | return DPN(cfg)
81 |
82 | def DPN92():
83 | cfg = {
84 | 'in_planes': (96,192,384,768),
85 | 'out_planes': (256,512,1024,2048),
86 | 'num_blocks': (3,4,20,3),
87 | 'dense_depth': (16,32,24,128)
88 | }
89 | return DPN(cfg)
90 |
91 |
92 | def test():
93 | net = DPN92()
94 | x = torch.randn(1,3,32,32)
95 | y = net(x)
96 | print(y)
97 |
98 | # test()
99 |
--------------------------------------------------------------------------------
/Supplements/Fig_S5/models/efficientnet.py:
--------------------------------------------------------------------------------
1 | '''EfficientNet in PyTorch.
2 |
3 | Paper: "EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks".
4 |
5 | Reference: https://github.com/keras-team/keras-applications/blob/master/keras_applications/efficientnet.py
6 | '''
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 |
12 | def swish(x):
13 | return x * x.sigmoid()
14 |
15 |
16 | def drop_connect(x, drop_ratio):
17 | keep_ratio = 1.0 - drop_ratio
18 | mask = torch.empty([x.shape[0], 1, 1, 1], dtype=x.dtype, device=x.device)
19 | mask.bernoulli_(keep_ratio)
20 | x.div_(keep_ratio)
21 | x.mul_(mask)
22 | return x
23 |
24 |
25 | class SE(nn.Module):
26 | '''Squeeze-and-Excitation block with Swish.'''
27 |
28 | def __init__(self, in_channels, se_channels):
29 | super(SE, self).__init__()
30 | self.se1 = nn.Conv2d(in_channels, se_channels,
31 | kernel_size=1, bias=True)
32 | self.se2 = nn.Conv2d(se_channels, in_channels,
33 | kernel_size=1, bias=True)
34 |
35 | def forward(self, x):
36 | out = F.adaptive_avg_pool2d(x, (1, 1))
37 | out = swish(self.se1(out))
38 | out = self.se2(out).sigmoid()
39 | out = x * out
40 | return out
41 |
42 |
43 | class Block(nn.Module):
44 | '''expansion + depthwise + pointwise + squeeze-excitation'''
45 |
46 | def __init__(self,
47 | in_channels,
48 | out_channels,
49 | kernel_size,
50 | stride,
51 | expand_ratio=1,
52 | se_ratio=0.,
53 | drop_rate=0.):
54 | super(Block, self).__init__()
55 | self.stride = stride
56 | self.drop_rate = drop_rate
57 | self.expand_ratio = expand_ratio
58 |
59 | # Expansion
60 | channels = expand_ratio * in_channels
61 | self.conv1 = nn.Conv2d(in_channels,
62 | channels,
63 | kernel_size=1,
64 | stride=1,
65 | padding=0,
66 | bias=False)
67 | self.bn1 = nn.BatchNorm2d(channels)
68 |
69 | # Depthwise conv
70 | self.conv2 = nn.Conv2d(channels,
71 | channels,
72 | kernel_size=kernel_size,
73 | stride=stride,
74 | padding=(1 if kernel_size == 3 else 2),
75 | groups=channels,
76 | bias=False)
77 | self.bn2 = nn.BatchNorm2d(channels)
78 |
79 | # SE layers
80 | se_channels = int(in_channels * se_ratio)
81 | self.se = SE(channels, se_channels)
82 |
83 | # Output
84 | self.conv3 = nn.Conv2d(channels,
85 | out_channels,
86 | kernel_size=1,
87 | stride=1,
88 | padding=0,
89 | bias=False)
90 | self.bn3 = nn.BatchNorm2d(out_channels)
91 |
92 | # Skip connection if in and out shapes are the same (MV-V2 style)
93 | self.has_skip = (stride == 1) and (in_channels == out_channels)
94 |
95 | def forward(self, x):
96 | out = x if self.expand_ratio == 1 else swish(self.bn1(self.conv1(x)))
97 | out = swish(self.bn2(self.conv2(out)))
98 | out = self.se(out)
99 | out = self.bn3(self.conv3(out))
100 | if self.has_skip:
101 | if self.training and self.drop_rate > 0:
102 | out = drop_connect(out, self.drop_rate)
103 | out = out + x
104 | return out
105 |
106 |
107 | class EfficientNet(nn.Module):
108 | def __init__(self, cfg, num_classes=10):
109 | super(EfficientNet, self).__init__()
110 | self.cfg = cfg
111 | self.conv1 = nn.Conv2d(3,
112 | 32,
113 | kernel_size=3,
114 | stride=1,
115 | padding=1,
116 | bias=False)
117 | self.bn1 = nn.BatchNorm2d(32)
118 | self.layers = self._make_layers(in_channels=32)
119 | self.linear = nn.Linear(cfg['out_channels'][-1], num_classes)
120 |
121 | def _make_layers(self, in_channels):
122 | layers = []
123 | cfg = [self.cfg[k] for k in ['expansion', 'out_channels', 'num_blocks', 'kernel_size',
124 | 'stride']]
125 | b = 0
126 | blocks = sum(self.cfg['num_blocks'])
127 | for expansion, out_channels, num_blocks, kernel_size, stride in zip(*cfg):
128 | strides = [stride] + [1] * (num_blocks - 1)
129 | for stride in strides:
130 | drop_rate = self.cfg['drop_connect_rate'] * b / blocks
131 | layers.append(
132 | Block(in_channels,
133 | out_channels,
134 | kernel_size,
135 | stride,
136 | expansion,
137 | se_ratio=0.25,
138 | drop_rate=drop_rate))
139 | in_channels = out_channels
140 | return nn.Sequential(*layers)
141 |
142 | def forward(self, x):
143 | out = swish(self.bn1(self.conv1(x)))
144 | out = self.layers(out)
145 | out = F.adaptive_avg_pool2d(out, 1)
146 | out = out.view(out.size(0), -1)
147 | dropout_rate = self.cfg['dropout_rate']
148 | if self.training and dropout_rate > 0:
149 | out = F.dropout(out, p=dropout_rate)
150 | out = self.linear(out)
151 | return out
152 |
153 |
154 | def EfficientNetB0():
155 | cfg = {
156 | 'num_blocks': [1, 2, 2, 3, 3, 4, 1],
157 | 'expansion': [1, 6, 6, 6, 6, 6, 6],
158 | 'out_channels': [16, 24, 40, 80, 112, 192, 320],
159 | 'kernel_size': [3, 3, 5, 3, 5, 5, 3],
160 | 'stride': [1, 2, 2, 2, 1, 2, 1],
161 | 'dropout_rate': 0.2,
162 | 'drop_connect_rate': 0.2,
163 | }
164 | return EfficientNet(cfg)
165 |
166 |
167 | def test():
168 | net = EfficientNetB0()
169 | x = torch.randn(2, 3, 32, 32)
170 | y = net(x)
171 | print(y.shape)
172 |
173 |
174 | if __name__ == '__main__':
175 | test()
176 |
--------------------------------------------------------------------------------
/Supplements/Fig_S5/models/googlenet.py:
--------------------------------------------------------------------------------
1 | '''GoogLeNet with PyTorch.'''
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | class Inception(nn.Module):
8 | def __init__(self, in_planes, n1x1, n3x3red, n3x3, n5x5red, n5x5, pool_planes):
9 | super(Inception, self).__init__()
10 | # 1x1 conv branch
11 | self.b1 = nn.Sequential(
12 | nn.Conv2d(in_planes, n1x1, kernel_size=1),
13 | nn.BatchNorm2d(n1x1),
14 | nn.ReLU(True),
15 | )
16 |
17 | # 1x1 conv -> 3x3 conv branch
18 | self.b2 = nn.Sequential(
19 | nn.Conv2d(in_planes, n3x3red, kernel_size=1),
20 | nn.BatchNorm2d(n3x3red),
21 | nn.ReLU(True),
22 | nn.Conv2d(n3x3red, n3x3, kernel_size=3, padding=1),
23 | nn.BatchNorm2d(n3x3),
24 | nn.ReLU(True),
25 | )
26 |
27 | # 1x1 conv -> 5x5 conv branch
28 | self.b3 = nn.Sequential(
29 | nn.Conv2d(in_planes, n5x5red, kernel_size=1),
30 | nn.BatchNorm2d(n5x5red),
31 | nn.ReLU(True),
32 | nn.Conv2d(n5x5red, n5x5, kernel_size=3, padding=1),
33 | nn.BatchNorm2d(n5x5),
34 | nn.ReLU(True),
35 | nn.Conv2d(n5x5, n5x5, kernel_size=3, padding=1),
36 | nn.BatchNorm2d(n5x5),
37 | nn.ReLU(True),
38 | )
39 |
40 | # 3x3 pool -> 1x1 conv branch
41 | self.b4 = nn.Sequential(
42 | nn.MaxPool2d(3, stride=1, padding=1),
43 | nn.Conv2d(in_planes, pool_planes, kernel_size=1),
44 | nn.BatchNorm2d(pool_planes),
45 | nn.ReLU(True),
46 | )
47 |
48 | def forward(self, x):
49 | y1 = self.b1(x)
50 | y2 = self.b2(x)
51 | y3 = self.b3(x)
52 | y4 = self.b4(x)
53 | return torch.cat([y1,y2,y3,y4], 1)
54 |
55 |
56 | class GoogLeNet(nn.Module):
57 | def __init__(self):
58 | super(GoogLeNet, self).__init__()
59 | self.pre_layers = nn.Sequential(
60 | nn.Conv2d(3, 192, kernel_size=3, padding=1),
61 | nn.BatchNorm2d(192),
62 | nn.ReLU(True),
63 | )
64 |
65 | self.a3 = Inception(192, 64, 96, 128, 16, 32, 32)
66 | self.b3 = Inception(256, 128, 128, 192, 32, 96, 64)
67 |
68 | self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
69 |
70 | self.a4 = Inception(480, 192, 96, 208, 16, 48, 64)
71 | self.b4 = Inception(512, 160, 112, 224, 24, 64, 64)
72 | self.c4 = Inception(512, 128, 128, 256, 24, 64, 64)
73 | self.d4 = Inception(512, 112, 144, 288, 32, 64, 64)
74 | self.e4 = Inception(528, 256, 160, 320, 32, 128, 128)
75 |
76 | self.a5 = Inception(832, 256, 160, 320, 32, 128, 128)
77 | self.b5 = Inception(832, 384, 192, 384, 48, 128, 128)
78 |
79 | self.avgpool = nn.AvgPool2d(8, stride=1)
80 | self.linear = nn.Linear(1024, 10)
81 |
82 | def forward(self, x):
83 | out = self.pre_layers(x)
84 | out = self.a3(out)
85 | out = self.b3(out)
86 | out = self.maxpool(out)
87 | out = self.a4(out)
88 | out = self.b4(out)
89 | out = self.c4(out)
90 | out = self.d4(out)
91 | out = self.e4(out)
92 | out = self.maxpool(out)
93 | out = self.a5(out)
94 | out = self.b5(out)
95 | out = self.avgpool(out)
96 | out = out.view(out.size(0), -1)
97 | out = self.linear(out)
98 | return out
99 |
100 |
101 | def test():
102 | net = GoogLeNet()
103 | x = torch.randn(1,3,32,32)
104 | y = net(x)
105 | print(y.size())
106 |
107 | # test()
108 |
--------------------------------------------------------------------------------
/Supplements/Fig_S5/models/lenet.py:
--------------------------------------------------------------------------------
1 | '''LeNet in PyTorch.'''
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class LeNet(nn.Module):
6 | def __init__(self):
7 | super(LeNet, self).__init__()
8 | self.conv1 = nn.Conv2d(3, 6, 5)
9 | self.conv2 = nn.Conv2d(6, 16, 5)
10 | self.fc1 = nn.Linear(16*5*5, 120)
11 | self.fc2 = nn.Linear(120, 84)
12 | self.fc3 = nn.Linear(84, 10)
13 |
14 | def forward(self, x):
15 | out = F.relu(self.conv1(x))
16 | out = F.max_pool2d(out, 2)
17 | out = F.relu(self.conv2(out))
18 | out = F.max_pool2d(out, 2)
19 | out = out.view(out.size(0), -1)
20 | out = F.relu(self.fc1(out))
21 | out = F.relu(self.fc2(out))
22 | out = self.fc3(out)
23 | return out
24 |
--------------------------------------------------------------------------------
/Supplements/Fig_S5/models/mobilenet.py:
--------------------------------------------------------------------------------
1 | '''MobileNet in PyTorch.
2 |
3 | See the paper "MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications"
4 | for more details.
5 | '''
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 |
11 | class Block(nn.Module):
12 | '''Depthwise conv + Pointwise conv'''
13 | def __init__(self, in_planes, out_planes, stride=1):
14 | super(Block, self).__init__()
15 | self.conv1 = nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=stride, padding=1, groups=in_planes, bias=False)
16 | self.bn1 = nn.BatchNorm2d(in_planes)
17 | self.conv2 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False)
18 | self.bn2 = nn.BatchNorm2d(out_planes)
19 |
20 | def forward(self, x):
21 | out = F.relu(self.bn1(self.conv1(x)))
22 | out = F.relu(self.bn2(self.conv2(out)))
23 | return out
24 |
25 |
26 | class MobileNet(nn.Module):
27 | # (128,2) means conv planes=128, conv stride=2, by default conv stride=1
28 | cfg = [64, (128,2), 128, (256,2), 256, (512,2), 512, 512, 512, 512, 512, (1024,2), 1024]
29 |
30 | def __init__(self, num_classes=10):
31 | super(MobileNet, self).__init__()
32 | self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False)
33 | self.bn1 = nn.BatchNorm2d(32)
34 | self.layers = self._make_layers(in_planes=32)
35 | self.linear = nn.Linear(1024, num_classes)
36 |
37 | def _make_layers(self, in_planes):
38 | layers = []
39 | for x in self.cfg:
40 | out_planes = x if isinstance(x, int) else x[0]
41 | stride = 1 if isinstance(x, int) else x[1]
42 | layers.append(Block(in_planes, out_planes, stride))
43 | in_planes = out_planes
44 | return nn.Sequential(*layers)
45 |
46 | def forward(self, x):
47 | out = F.relu(self.bn1(self.conv1(x)))
48 | out = self.layers(out)
49 | out = F.avg_pool2d(out, 2)
50 | out = out.view(out.size(0), -1)
51 | out = self.linear(out)
52 | return out
53 |
54 |
55 | def test():
56 | net = MobileNet()
57 | x = torch.randn(1,3,32,32)
58 | y = net(x)
59 | print(y.size())
60 |
61 | # test()
62 |
--------------------------------------------------------------------------------
/Supplements/Fig_S5/models/mobilenetv2.py:
--------------------------------------------------------------------------------
1 | '''MobileNetV2 in PyTorch.
2 |
3 | See the paper "Inverted Residuals and Linear Bottlenecks:
4 | Mobile Networks for Classification, Detection and Segmentation" for more details.
5 | '''
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 |
11 | class Block(nn.Module):
12 | '''expand + depthwise + pointwise'''
13 | def __init__(self, in_planes, out_planes, expansion, stride):
14 | super(Block, self).__init__()
15 | self.stride = stride
16 |
17 | planes = expansion * in_planes
18 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, stride=1, padding=0, bias=False)
19 | self.bn1 = nn.BatchNorm2d(planes)
20 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, groups=planes, bias=False)
21 | self.bn2 = nn.BatchNorm2d(planes)
22 | self.conv3 = nn.Conv2d(planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False)
23 | self.bn3 = nn.BatchNorm2d(out_planes)
24 |
25 | self.shortcut = nn.Sequential()
26 | if stride == 1 and in_planes != out_planes:
27 | self.shortcut = nn.Sequential(
28 | nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False),
29 | nn.BatchNorm2d(out_planes),
30 | )
31 |
32 | def forward(self, x):
33 | out = F.relu(self.bn1(self.conv1(x)))
34 | out = F.relu(self.bn2(self.conv2(out)))
35 | out = self.bn3(self.conv3(out))
36 | out = out + self.shortcut(x) if self.stride==1 else out
37 | return out
38 |
39 |
40 | class MobileNetV2(nn.Module):
41 | # (expansion, out_planes, num_blocks, stride)
42 | cfg = [(1, 16, 1, 1),
43 | (6, 24, 2, 1), # NOTE: change stride 2 -> 1 for CIFAR10
44 | (6, 32, 3, 2),
45 | (6, 64, 4, 2),
46 | (6, 96, 3, 1),
47 | (6, 160, 3, 2),
48 | (6, 320, 1, 1)]
49 |
50 | def __init__(self, num_classes=10):
51 | super(MobileNetV2, self).__init__()
52 | # NOTE: change conv1 stride 2 -> 1 for CIFAR10
53 | self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False)
54 | self.bn1 = nn.BatchNorm2d(32)
55 | self.layers = self._make_layers(in_planes=32)
56 | self.conv2 = nn.Conv2d(320, 1280, kernel_size=1, stride=1, padding=0, bias=False)
57 | self.bn2 = nn.BatchNorm2d(1280)
58 | self.linear = nn.Linear(1280, num_classes)
59 |
60 | def _make_layers(self, in_planes):
61 | layers = []
62 | for expansion, out_planes, num_blocks, stride in self.cfg:
63 | strides = [stride] + [1]*(num_blocks-1)
64 | for stride in strides:
65 | layers.append(Block(in_planes, out_planes, expansion, stride))
66 | in_planes = out_planes
67 | return nn.Sequential(*layers)
68 |
69 | def forward(self, x):
70 | out = F.relu(self.bn1(self.conv1(x)))
71 | out = self.layers(out)
72 | out = F.relu(self.bn2(self.conv2(out)))
73 | # NOTE: change pooling kernel_size 7 -> 4 for CIFAR10
74 | out = F.avg_pool2d(out, 4)
75 | out = out.view(out.size(0), -1)
76 | out = self.linear(out)
77 | return out
78 |
79 |
80 | def test():
81 | net = MobileNetV2()
82 | x = torch.randn(2,3,32,32)
83 | y = net(x)
84 | print(y.size())
85 |
86 | # test()
87 |
--------------------------------------------------------------------------------
/Supplements/Fig_S5/models/pnasnet.py:
--------------------------------------------------------------------------------
1 | '''PNASNet in PyTorch.
2 |
3 | Paper: Progressive Neural Architecture Search
4 | '''
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 |
10 | class SepConv(nn.Module):
11 | '''Separable Convolution.'''
12 | def __init__(self, in_planes, out_planes, kernel_size, stride):
13 | super(SepConv, self).__init__()
14 | self.conv1 = nn.Conv2d(in_planes, out_planes,
15 | kernel_size, stride,
16 | padding=(kernel_size-1)//2,
17 | bias=False, groups=in_planes)
18 | self.bn1 = nn.BatchNorm2d(out_planes)
19 |
20 | def forward(self, x):
21 | return self.bn1(self.conv1(x))
22 |
23 |
24 | class CellA(nn.Module):
25 | def __init__(self, in_planes, out_planes, stride=1):
26 | super(CellA, self).__init__()
27 | self.stride = stride
28 | self.sep_conv1 = SepConv(in_planes, out_planes, kernel_size=7, stride=stride)
29 | if stride==2:
30 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False)
31 | self.bn1 = nn.BatchNorm2d(out_planes)
32 |
33 | def forward(self, x):
34 | y1 = self.sep_conv1(x)
35 | y2 = F.max_pool2d(x, kernel_size=3, stride=self.stride, padding=1)
36 | if self.stride==2:
37 | y2 = self.bn1(self.conv1(y2))
38 | return F.relu(y1+y2)
39 |
40 | class CellB(nn.Module):
41 | def __init__(self, in_planes, out_planes, stride=1):
42 | super(CellB, self).__init__()
43 | self.stride = stride
44 | # Left branch
45 | self.sep_conv1 = SepConv(in_planes, out_planes, kernel_size=7, stride=stride)
46 | self.sep_conv2 = SepConv(in_planes, out_planes, kernel_size=3, stride=stride)
47 | # Right branch
48 | self.sep_conv3 = SepConv(in_planes, out_planes, kernel_size=5, stride=stride)
49 | if stride==2:
50 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False)
51 | self.bn1 = nn.BatchNorm2d(out_planes)
52 | # Reduce channels
53 | self.conv2 = nn.Conv2d(2*out_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False)
54 | self.bn2 = nn.BatchNorm2d(out_planes)
55 |
56 | def forward(self, x):
57 | # Left branch
58 | y1 = self.sep_conv1(x)
59 | y2 = self.sep_conv2(x)
60 | # Right branch
61 | y3 = F.max_pool2d(x, kernel_size=3, stride=self.stride, padding=1)
62 | if self.stride==2:
63 | y3 = self.bn1(self.conv1(y3))
64 | y4 = self.sep_conv3(x)
65 | # Concat & reduce channels
66 | b1 = F.relu(y1+y2)
67 | b2 = F.relu(y3+y4)
68 | y = torch.cat([b1,b2], 1)
69 | return F.relu(self.bn2(self.conv2(y)))
70 |
71 | class PNASNet(nn.Module):
72 | def __init__(self, cell_type, num_cells, num_planes):
73 | super(PNASNet, self).__init__()
74 | self.in_planes = num_planes
75 | self.cell_type = cell_type
76 |
77 | self.conv1 = nn.Conv2d(3, num_planes, kernel_size=3, stride=1, padding=1, bias=False)
78 | self.bn1 = nn.BatchNorm2d(num_planes)
79 |
80 | self.layer1 = self._make_layer(num_planes, num_cells=6)
81 | self.layer2 = self._downsample(num_planes*2)
82 | self.layer3 = self._make_layer(num_planes*2, num_cells=6)
83 | self.layer4 = self._downsample(num_planes*4)
84 | self.layer5 = self._make_layer(num_planes*4, num_cells=6)
85 |
86 | self.linear = nn.Linear(num_planes*4, 10)
87 |
88 | def _make_layer(self, planes, num_cells):
89 | layers = []
90 | for _ in range(num_cells):
91 | layers.append(self.cell_type(self.in_planes, planes, stride=1))
92 | self.in_planes = planes
93 | return nn.Sequential(*layers)
94 |
95 | def _downsample(self, planes):
96 | layer = self.cell_type(self.in_planes, planes, stride=2)
97 | self.in_planes = planes
98 | return layer
99 |
100 | def forward(self, x):
101 | out = F.relu(self.bn1(self.conv1(x)))
102 | out = self.layer1(out)
103 | out = self.layer2(out)
104 | out = self.layer3(out)
105 | out = self.layer4(out)
106 | out = self.layer5(out)
107 | out = F.avg_pool2d(out, 8)
108 | out = self.linear(out.view(out.size(0), -1))
109 | return out
110 |
111 |
112 | def PNASNetA():
113 | return PNASNet(CellA, num_cells=6, num_planes=44)
114 |
115 | def PNASNetB():
116 | return PNASNet(CellB, num_cells=6, num_planes=32)
117 |
118 |
119 | def test():
120 | net = PNASNetB()
121 | x = torch.randn(1,3,32,32)
122 | y = net(x)
123 | print(y)
124 |
125 | # test()
126 |
--------------------------------------------------------------------------------
/Supplements/Fig_S5/models/preact_resnet.py:
--------------------------------------------------------------------------------
1 | '''Pre-activation ResNet in PyTorch.
2 |
3 | Reference:
4 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
5 | Identity Mappings in Deep Residual Networks. arXiv:1603.05027
6 | '''
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 |
12 | class PreActBlock(nn.Module):
13 | '''Pre-activation version of the BasicBlock.'''
14 | expansion = 1
15 |
16 | def __init__(self, in_planes, planes, stride=1):
17 | super(PreActBlock, self).__init__()
18 | self.bn1 = nn.BatchNorm2d(in_planes)
19 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
20 | self.bn2 = nn.BatchNorm2d(planes)
21 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
22 |
23 | if stride != 1 or in_planes != self.expansion*planes:
24 | self.shortcut = nn.Sequential(
25 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
26 | )
27 |
28 | def forward(self, x):
29 | out = F.relu(self.bn1(x))
30 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
31 | out = self.conv1(out)
32 | out = self.conv2(F.relu(self.bn2(out)))
33 | out += shortcut
34 | return out
35 |
36 |
37 | class PreActBottleneck(nn.Module):
38 | '''Pre-activation version of the original Bottleneck module.'''
39 | expansion = 4
40 |
41 | def __init__(self, in_planes, planes, stride=1):
42 | super(PreActBottleneck, self).__init__()
43 | self.bn1 = nn.BatchNorm2d(in_planes)
44 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
45 | self.bn2 = nn.BatchNorm2d(planes)
46 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
47 | self.bn3 = nn.BatchNorm2d(planes)
48 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
49 |
50 | if stride != 1 or in_planes != self.expansion*planes:
51 | self.shortcut = nn.Sequential(
52 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
53 | )
54 |
55 | def forward(self, x):
56 | out = F.relu(self.bn1(x))
57 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
58 | out = self.conv1(out)
59 | out = self.conv2(F.relu(self.bn2(out)))
60 | out = self.conv3(F.relu(self.bn3(out)))
61 | out += shortcut
62 | return out
63 |
64 |
65 | class PreActResNet(nn.Module):
66 | def __init__(self, block, num_blocks, num_classes=10):
67 | super(PreActResNet, self).__init__()
68 | self.in_planes = 64
69 |
70 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
71 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
72 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
73 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
74 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
75 | self.linear = nn.Linear(512*block.expansion, num_classes)
76 |
77 | def _make_layer(self, block, planes, num_blocks, stride):
78 | strides = [stride] + [1]*(num_blocks-1)
79 | layers = []
80 | for stride in strides:
81 | layers.append(block(self.in_planes, planes, stride))
82 | self.in_planes = planes * block.expansion
83 | return nn.Sequential(*layers)
84 |
85 | def forward(self, x):
86 | out = self.conv1(x)
87 | out = self.layer1(out)
88 | out = self.layer2(out)
89 | out = self.layer3(out)
90 | out = self.layer4(out)
91 | out = F.avg_pool2d(out, 4)
92 | out = out.view(out.size(0), -1)
93 | out = self.linear(out)
94 | return out
95 |
96 |
97 | def PreActResNet18():
98 | return PreActResNet(PreActBlock, [2,2,2,2])
99 |
100 | def PreActResNet34():
101 | return PreActResNet(PreActBlock, [3,4,6,3])
102 |
103 | def PreActResNet50():
104 | return PreActResNet(PreActBottleneck, [3,4,6,3])
105 |
106 | def PreActResNet101():
107 | return PreActResNet(PreActBottleneck, [3,4,23,3])
108 |
109 | def PreActResNet152():
110 | return PreActResNet(PreActBottleneck, [3,8,36,3])
111 |
112 |
113 | def test():
114 | net = PreActResNet18()
115 | y = net((torch.randn(1,3,32,32)))
116 | print(y.size())
117 |
118 | # test()
119 |
--------------------------------------------------------------------------------
/Supplements/Fig_S5/models/regnet.py:
--------------------------------------------------------------------------------
1 | '''RegNet in PyTorch.
2 |
3 | Paper: "Designing Network Design Spaces".
4 |
5 | Reference: https://github.com/keras-team/keras-applications/blob/master/keras_applications/efficientnet.py
6 | '''
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 |
12 | class SE(nn.Module):
13 | '''Squeeze-and-Excitation block.'''
14 |
15 | def __init__(self, in_planes, se_planes):
16 | super(SE, self).__init__()
17 | self.se1 = nn.Conv2d(in_planes, se_planes, kernel_size=1, bias=True)
18 | self.se2 = nn.Conv2d(se_planes, in_planes, kernel_size=1, bias=True)
19 |
20 | def forward(self, x):
21 | out = F.adaptive_avg_pool2d(x, (1, 1))
22 | out = F.relu(self.se1(out))
23 | out = self.se2(out).sigmoid()
24 | out = x * out
25 | return out
26 |
27 |
28 | class Block(nn.Module):
29 | def __init__(self, w_in, w_out, stride, group_width, bottleneck_ratio, se_ratio):
30 | super(Block, self).__init__()
31 | # 1x1
32 | w_b = int(round(w_out * bottleneck_ratio))
33 | self.conv1 = nn.Conv2d(w_in, w_b, kernel_size=1, bias=False)
34 | self.bn1 = nn.BatchNorm2d(w_b)
35 | # 3x3
36 | num_groups = w_b // group_width
37 | self.conv2 = nn.Conv2d(w_b, w_b, kernel_size=3,
38 | stride=stride, padding=1, groups=num_groups, bias=False)
39 | self.bn2 = nn.BatchNorm2d(w_b)
40 | # se
41 | self.with_se = se_ratio > 0
42 | if self.with_se:
43 | w_se = int(round(w_in * se_ratio))
44 | self.se = SE(w_b, w_se)
45 | # 1x1
46 | self.conv3 = nn.Conv2d(w_b, w_out, kernel_size=1, bias=False)
47 | self.bn3 = nn.BatchNorm2d(w_out)
48 |
49 | self.shortcut = nn.Sequential()
50 | if stride != 1 or w_in != w_out:
51 | self.shortcut = nn.Sequential(
52 | nn.Conv2d(w_in, w_out,
53 | kernel_size=1, stride=stride, bias=False),
54 | nn.BatchNorm2d(w_out)
55 | )
56 |
57 | def forward(self, x):
58 | out = F.relu(self.bn1(self.conv1(x)))
59 | out = F.relu(self.bn2(self.conv2(out)))
60 | if self.with_se:
61 | out = self.se(out)
62 | out = self.bn3(self.conv3(out))
63 | out += self.shortcut(x)
64 | out = F.relu(out)
65 | return out
66 |
67 |
68 | class RegNet(nn.Module):
69 | def __init__(self, cfg, num_classes=10):
70 | super(RegNet, self).__init__()
71 | self.cfg = cfg
72 | self.in_planes = 64
73 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
74 | stride=1, padding=1, bias=False)
75 | self.bn1 = nn.BatchNorm2d(64)
76 | self.layer1 = self._make_layer(0)
77 | self.layer2 = self._make_layer(1)
78 | self.layer3 = self._make_layer(2)
79 | self.layer4 = self._make_layer(3)
80 | self.linear = nn.Linear(self.cfg['widths'][-1], num_classes)
81 |
82 | def _make_layer(self, idx):
83 | depth = self.cfg['depths'][idx]
84 | width = self.cfg['widths'][idx]
85 | stride = self.cfg['strides'][idx]
86 | group_width = self.cfg['group_width']
87 | bottleneck_ratio = self.cfg['bottleneck_ratio']
88 | se_ratio = self.cfg['se_ratio']
89 |
90 | layers = []
91 | for i in range(depth):
92 | s = stride if i == 0 else 1
93 | layers.append(Block(self.in_planes, width,
94 | s, group_width, bottleneck_ratio, se_ratio))
95 | self.in_planes = width
96 | return nn.Sequential(*layers)
97 |
98 | def forward(self, x):
99 | out = F.relu(self.bn1(self.conv1(x)))
100 | out = self.layer1(out)
101 | out = self.layer2(out)
102 | out = self.layer3(out)
103 | out = self.layer4(out)
104 | out = F.adaptive_avg_pool2d(out, (1, 1))
105 | out = out.view(out.size(0), -1)
106 | out = self.linear(out)
107 | return out
108 |
109 |
110 | def RegNetX_200MF():
111 | cfg = {
112 | 'depths': [1, 1, 4, 7],
113 | 'widths': [24, 56, 152, 368],
114 | 'strides': [1, 1, 2, 2],
115 | 'group_width': 8,
116 | 'bottleneck_ratio': 1,
117 | 'se_ratio': 0,
118 | }
119 | return RegNet(cfg)
120 |
121 |
122 | def RegNetX_400MF():
123 | cfg = {
124 | 'depths': [1, 2, 7, 12],
125 | 'widths': [32, 64, 160, 384],
126 | 'strides': [1, 1, 2, 2],
127 | 'group_width': 16,
128 | 'bottleneck_ratio': 1,
129 | 'se_ratio': 0,
130 | }
131 | return RegNet(cfg)
132 |
133 |
134 | def RegNetY_400MF():
135 | cfg = {
136 | 'depths': [1, 2, 7, 12],
137 | 'widths': [32, 64, 160, 384],
138 | 'strides': [1, 1, 2, 2],
139 | 'group_width': 16,
140 | 'bottleneck_ratio': 1,
141 | 'se_ratio': 0.25,
142 | }
143 | return RegNet(cfg)
144 |
145 |
146 | def test():
147 | net = RegNetX_200MF()
148 | print(net)
149 | x = torch.randn(2, 3, 32, 32)
150 | y = net(x)
151 | print(y.shape)
152 |
153 |
154 | if __name__ == '__main__':
155 | test()
156 |
--------------------------------------------------------------------------------
/Supplements/Fig_S5/models/resnet.py:
--------------------------------------------------------------------------------
1 | '''ResNet in PyTorch.
2 |
3 | For Pre-activation ResNet, see 'preact_resnet.py'.
4 |
5 | Reference:
6 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
7 | Deep Residual Learning for Image Recognition. arXiv:1512.03385
8 | '''
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 |
13 |
14 | class BasicBlock(nn.Module):
15 | expansion = 1
16 |
17 | def __init__(self, in_planes, planes, stride=1):
18 | super(BasicBlock, self).__init__()
19 | self.conv1 = nn.Conv2d(
20 | in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
21 | self.bn1 = nn.BatchNorm2d(planes)
22 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
23 | stride=1, padding=1, bias=False)
24 | self.bn2 = nn.BatchNorm2d(planes)
25 |
26 | self.shortcut = nn.Sequential()
27 | if stride != 1 or in_planes != self.expansion*planes:
28 | self.shortcut = nn.Sequential(
29 | nn.Conv2d(in_planes, self.expansion*planes,
30 | kernel_size=1, stride=stride, bias=False),
31 | nn.BatchNorm2d(self.expansion*planes)
32 | )
33 |
34 | def forward(self, x):
35 | out = F.relu(self.bn1(self.conv1(x)))
36 | out = self.bn2(self.conv2(out))
37 | out += self.shortcut(x)
38 | out = F.relu(out)
39 | return out
40 |
41 |
42 | class Bottleneck(nn.Module):
43 | expansion = 4
44 |
45 | def __init__(self, in_planes, planes, stride=1):
46 | super(Bottleneck, self).__init__()
47 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
48 | self.bn1 = nn.BatchNorm2d(planes)
49 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
50 | stride=stride, padding=1, bias=False)
51 | self.bn2 = nn.BatchNorm2d(planes)
52 | self.conv3 = nn.Conv2d(planes, self.expansion *
53 | planes, kernel_size=1, bias=False)
54 | self.bn3 = nn.BatchNorm2d(self.expansion*planes)
55 |
56 | self.shortcut = nn.Sequential()
57 | if stride != 1 or in_planes != self.expansion*planes:
58 | self.shortcut = nn.Sequential(
59 | nn.Conv2d(in_planes, self.expansion*planes,
60 | kernel_size=1, stride=stride, bias=False),
61 | nn.BatchNorm2d(self.expansion*planes)
62 | )
63 |
64 | def forward(self, x):
65 | out = F.relu(self.bn1(self.conv1(x)))
66 | out = F.relu(self.bn2(self.conv2(out)))
67 | out = self.bn3(self.conv3(out))
68 | out += self.shortcut(x)
69 | out = F.relu(out)
70 | return out
71 |
72 |
73 | class ResNet(nn.Module):
74 | def __init__(self, block, num_blocks, num_classes=200):
75 | super(ResNet, self).__init__()
76 | self.in_planes = 64
77 |
78 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
79 | stride=1, padding=1, bias=False)
80 | self.bn1 = nn.BatchNorm2d(64)
81 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
82 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
83 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
84 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
85 | self.linear = nn.Linear(512*block.expansion, num_classes)
86 |
87 | def _make_layer(self, block, planes, num_blocks, stride):
88 | strides = [stride] + [1]*(num_blocks-1)
89 | layers = []
90 | for stride in strides:
91 | layers.append(block(self.in_planes, planes, stride))
92 | self.in_planes = planes * block.expansion
93 | return nn.Sequential(*layers)
94 |
95 | def forward(self, x):
96 | out = F.relu(self.bn1(self.conv1(x)))
97 | out = self.layer1(out)
98 | out = self.layer2(out)
99 | out = self.layer3(out)
100 | out = self.layer4(out)
101 | out = nn.AdaptiveAvgPool2d(1)(out)
102 | out = out.view(out.size(0), -1)
103 | out = self.linear(out)
104 | return out
105 |
106 |
107 | def ResNet18():
108 | return ResNet(BasicBlock, [2, 2, 2, 2])
109 |
110 |
111 | def ResNet34():
112 | return ResNet(BasicBlock, [3, 4, 6, 3])
113 |
114 |
115 | def ResNet50():
116 | return ResNet(Bottleneck, [3, 4, 6, 3])
117 |
118 |
119 | def ResNet101():
120 | return ResNet(Bottleneck, [3, 4, 23, 3])
121 |
122 |
123 | def ResNet152():
124 | return ResNet(Bottleneck, [3, 8, 36, 3])
125 |
126 |
127 | def test():
128 | net = ResNet18()
129 | y = net(torch.randn(1, 3, 32, 32))
130 | print(y.size())
131 |
132 | # test()
133 |
--------------------------------------------------------------------------------
/Supplements/Fig_S5/models/resnext.py:
--------------------------------------------------------------------------------
1 | '''ResNeXt in PyTorch.
2 |
3 | See the paper "Aggregated Residual Transformations for Deep Neural Networks" for more details.
4 | '''
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 |
10 | class Block(nn.Module):
11 | '''Grouped convolution block.'''
12 | expansion = 2
13 |
14 | def __init__(self, in_planes, cardinality=32, bottleneck_width=4, stride=1):
15 | super(Block, self).__init__()
16 | group_width = cardinality * bottleneck_width
17 | self.conv1 = nn.Conv2d(in_planes, group_width, kernel_size=1, bias=False)
18 | self.bn1 = nn.BatchNorm2d(group_width)
19 | self.conv2 = nn.Conv2d(group_width, group_width, kernel_size=3, stride=stride, padding=1, groups=cardinality, bias=False)
20 | self.bn2 = nn.BatchNorm2d(group_width)
21 | self.conv3 = nn.Conv2d(group_width, self.expansion*group_width, kernel_size=1, bias=False)
22 | self.bn3 = nn.BatchNorm2d(self.expansion*group_width)
23 |
24 | self.shortcut = nn.Sequential()
25 | if stride != 1 or in_planes != self.expansion*group_width:
26 | self.shortcut = nn.Sequential(
27 | nn.Conv2d(in_planes, self.expansion*group_width, kernel_size=1, stride=stride, bias=False),
28 | nn.BatchNorm2d(self.expansion*group_width)
29 | )
30 |
31 | def forward(self, x):
32 | out = F.relu(self.bn1(self.conv1(x)))
33 | out = F.relu(self.bn2(self.conv2(out)))
34 | out = self.bn3(self.conv3(out))
35 | out += self.shortcut(x)
36 | out = F.relu(out)
37 | return out
38 |
39 |
40 | class ResNeXt(nn.Module):
41 | def __init__(self, num_blocks, cardinality, bottleneck_width, num_classes=10):
42 | super(ResNeXt, self).__init__()
43 | self.cardinality = cardinality
44 | self.bottleneck_width = bottleneck_width
45 | self.in_planes = 64
46 |
47 | self.conv1 = nn.Conv2d(3, 64, kernel_size=1, bias=False)
48 | self.bn1 = nn.BatchNorm2d(64)
49 | self.layer1 = self._make_layer(num_blocks[0], 1)
50 | self.layer2 = self._make_layer(num_blocks[1], 2)
51 | self.layer3 = self._make_layer(num_blocks[2], 2)
52 | # self.layer4 = self._make_layer(num_blocks[3], 2)
53 | self.linear = nn.Linear(cardinality*bottleneck_width*8, num_classes)
54 |
55 | def _make_layer(self, num_blocks, stride):
56 | strides = [stride] + [1]*(num_blocks-1)
57 | layers = []
58 | for stride in strides:
59 | layers.append(Block(self.in_planes, self.cardinality, self.bottleneck_width, stride))
60 | self.in_planes = Block.expansion * self.cardinality * self.bottleneck_width
61 | # Increase bottleneck_width by 2 after each stage.
62 | self.bottleneck_width *= 2
63 | return nn.Sequential(*layers)
64 |
65 | def forward(self, x):
66 | out = F.relu(self.bn1(self.conv1(x)))
67 | out = self.layer1(out)
68 | out = self.layer2(out)
69 | out = self.layer3(out)
70 | # out = self.layer4(out)
71 | out = F.avg_pool2d(out, 8)
72 | out = out.view(out.size(0), -1)
73 | out = self.linear(out)
74 | return out
75 |
76 |
77 | def ResNeXt29_2x64d():
78 | return ResNeXt(num_blocks=[3,3,3], cardinality=2, bottleneck_width=64)
79 |
80 | def ResNeXt29_4x64d():
81 | return ResNeXt(num_blocks=[3,3,3], cardinality=4, bottleneck_width=64)
82 |
83 | def ResNeXt29_8x64d():
84 | return ResNeXt(num_blocks=[3,3,3], cardinality=8, bottleneck_width=64)
85 |
86 | def ResNeXt29_32x4d():
87 | return ResNeXt(num_blocks=[3,3,3], cardinality=32, bottleneck_width=4)
88 |
89 | def test_resnext():
90 | net = ResNeXt29_2x64d()
91 | x = torch.randn(1,3,32,32)
92 | y = net(x)
93 | print(y.size())
94 |
95 | # test_resnext()
96 |
--------------------------------------------------------------------------------
/Supplements/Fig_S5/models/senet.py:
--------------------------------------------------------------------------------
1 | '''SENet in PyTorch.
2 |
3 | SENet is the winner of ImageNet-2017. The paper is not released yet.
4 | '''
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 |
10 | class BasicBlock(nn.Module):
11 | def __init__(self, in_planes, planes, stride=1):
12 | super(BasicBlock, self).__init__()
13 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
14 | self.bn1 = nn.BatchNorm2d(planes)
15 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
16 | self.bn2 = nn.BatchNorm2d(planes)
17 |
18 | self.shortcut = nn.Sequential()
19 | if stride != 1 or in_planes != planes:
20 | self.shortcut = nn.Sequential(
21 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False),
22 | nn.BatchNorm2d(planes)
23 | )
24 |
25 | # SE layers
26 | self.fc1 = nn.Conv2d(planes, planes//16, kernel_size=1) # Use nn.Conv2d instead of nn.Linear
27 | self.fc2 = nn.Conv2d(planes//16, planes, kernel_size=1)
28 |
29 | def forward(self, x):
30 | out = F.relu(self.bn1(self.conv1(x)))
31 | out = self.bn2(self.conv2(out))
32 |
33 | # Squeeze
34 | w = F.avg_pool2d(out, out.size(2))
35 | w = F.relu(self.fc1(w))
36 | w = F.sigmoid(self.fc2(w))
37 | # Excitation
38 | out = out * w # New broadcasting feature from v0.2!
39 |
40 | out += self.shortcut(x)
41 | out = F.relu(out)
42 | return out
43 |
44 |
45 | class PreActBlock(nn.Module):
46 | def __init__(self, in_planes, planes, stride=1):
47 | super(PreActBlock, self).__init__()
48 | self.bn1 = nn.BatchNorm2d(in_planes)
49 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
50 | self.bn2 = nn.BatchNorm2d(planes)
51 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
52 |
53 | if stride != 1 or in_planes != planes:
54 | self.shortcut = nn.Sequential(
55 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False)
56 | )
57 |
58 | # SE layers
59 | self.fc1 = nn.Conv2d(planes, planes//16, kernel_size=1)
60 | self.fc2 = nn.Conv2d(planes//16, planes, kernel_size=1)
61 |
62 | def forward(self, x):
63 | out = F.relu(self.bn1(x))
64 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
65 | out = self.conv1(out)
66 | out = self.conv2(F.relu(self.bn2(out)))
67 |
68 | # Squeeze
69 | w = F.avg_pool2d(out, out.size(2))
70 | w = F.relu(self.fc1(w))
71 | w = F.sigmoid(self.fc2(w))
72 | # Excitation
73 | out = out * w
74 |
75 | out += shortcut
76 | return out
77 |
78 |
79 | class SENet(nn.Module):
80 | def __init__(self, block, num_blocks, num_classes=10):
81 | super(SENet, self).__init__()
82 | self.in_planes = 64
83 |
84 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
85 | self.bn1 = nn.BatchNorm2d(64)
86 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
87 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
88 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
89 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
90 | self.linear = nn.Linear(512, num_classes)
91 |
92 | def _make_layer(self, block, planes, num_blocks, stride):
93 | strides = [stride] + [1]*(num_blocks-1)
94 | layers = []
95 | for stride in strides:
96 | layers.append(block(self.in_planes, planes, stride))
97 | self.in_planes = planes
98 | return nn.Sequential(*layers)
99 |
100 | def forward(self, x):
101 | out = F.relu(self.bn1(self.conv1(x)))
102 | out = self.layer1(out)
103 | out = self.layer2(out)
104 | out = self.layer3(out)
105 | out = self.layer4(out)
106 | out = F.avg_pool2d(out, 4)
107 | out = out.view(out.size(0), -1)
108 | out = self.linear(out)
109 | return out
110 |
111 |
112 | def SENet18():
113 | return SENet(PreActBlock, [2,2,2,2])
114 |
115 |
116 | def test():
117 | net = SENet18()
118 | y = net(torch.randn(1,3,32,32))
119 | print(y.size())
120 |
121 | # test()
122 |
--------------------------------------------------------------------------------
/Supplements/Fig_S5/models/shufflenet.py:
--------------------------------------------------------------------------------
1 | '''ShuffleNet in PyTorch.
2 |
3 | See the paper "ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices" for more details.
4 | '''
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 |
10 | class ShuffleBlock(nn.Module):
11 | def __init__(self, groups):
12 | super(ShuffleBlock, self).__init__()
13 | self.groups = groups
14 |
15 | def forward(self, x):
16 | '''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]'''
17 | N,C,H,W = x.size()
18 | g = self.groups
19 | return x.view(N,g,C//g,H,W).permute(0,2,1,3,4).reshape(N,C,H,W)
20 |
21 |
22 | class Bottleneck(nn.Module):
23 | def __init__(self, in_planes, out_planes, stride, groups):
24 | super(Bottleneck, self).__init__()
25 | self.stride = stride
26 |
27 | mid_planes = out_planes/4
28 | g = 1 if in_planes==24 else groups
29 | self.conv1 = nn.Conv2d(in_planes, mid_planes, kernel_size=1, groups=g, bias=False)
30 | self.bn1 = nn.BatchNorm2d(mid_planes)
31 | self.shuffle1 = ShuffleBlock(groups=g)
32 | self.conv2 = nn.Conv2d(mid_planes, mid_planes, kernel_size=3, stride=stride, padding=1, groups=mid_planes, bias=False)
33 | self.bn2 = nn.BatchNorm2d(mid_planes)
34 | self.conv3 = nn.Conv2d(mid_planes, out_planes, kernel_size=1, groups=groups, bias=False)
35 | self.bn3 = nn.BatchNorm2d(out_planes)
36 |
37 | self.shortcut = nn.Sequential()
38 | if stride == 2:
39 | self.shortcut = nn.Sequential(nn.AvgPool2d(3, stride=2, padding=1))
40 |
41 | def forward(self, x):
42 | out = F.relu(self.bn1(self.conv1(x)))
43 | out = self.shuffle1(out)
44 | out = F.relu(self.bn2(self.conv2(out)))
45 | out = self.bn3(self.conv3(out))
46 | res = self.shortcut(x)
47 | out = F.relu(torch.cat([out,res], 1)) if self.stride==2 else F.relu(out+res)
48 | return out
49 |
50 |
51 | class ShuffleNet(nn.Module):
52 | def __init__(self, cfg):
53 | super(ShuffleNet, self).__init__()
54 | out_planes = cfg['out_planes']
55 | num_blocks = cfg['num_blocks']
56 | groups = cfg['groups']
57 |
58 | self.conv1 = nn.Conv2d(3, 24, kernel_size=1, bias=False)
59 | self.bn1 = nn.BatchNorm2d(24)
60 | self.in_planes = 24
61 | self.layer1 = self._make_layer(out_planes[0], num_blocks[0], groups)
62 | self.layer2 = self._make_layer(out_planes[1], num_blocks[1], groups)
63 | self.layer3 = self._make_layer(out_planes[2], num_blocks[2], groups)
64 | self.linear = nn.Linear(out_planes[2], 10)
65 |
66 | def _make_layer(self, out_planes, num_blocks, groups):
67 | layers = []
68 | for i in range(num_blocks):
69 | stride = 2 if i == 0 else 1
70 | cat_planes = self.in_planes if i == 0 else 0
71 | layers.append(Bottleneck(self.in_planes, out_planes-cat_planes, stride=stride, groups=groups))
72 | self.in_planes = out_planes
73 | return nn.Sequential(*layers)
74 |
75 | def forward(self, x):
76 | out = F.relu(self.bn1(self.conv1(x)))
77 | out = self.layer1(out)
78 | out = self.layer2(out)
79 | out = self.layer3(out)
80 | out = F.avg_pool2d(out, 4)
81 | out = out.view(out.size(0), -1)
82 | out = self.linear(out)
83 | return out
84 |
85 |
86 | def ShuffleNetG2():
87 | cfg = {
88 | 'out_planes': [200,400,800],
89 | 'num_blocks': [4,8,4],
90 | 'groups': 2
91 | }
92 | return ShuffleNet(cfg)
93 |
94 | def ShuffleNetG3():
95 | cfg = {
96 | 'out_planes': [240,480,960],
97 | 'num_blocks': [4,8,4],
98 | 'groups': 3
99 | }
100 | return ShuffleNet(cfg)
101 |
102 |
103 | def test():
104 | net = ShuffleNetG2()
105 | x = torch.randn(1,3,32,32)
106 | y = net(x)
107 | print(y)
108 |
109 | # test()
110 |
--------------------------------------------------------------------------------
/Supplements/Fig_S5/models/shufflenetv2.py:
--------------------------------------------------------------------------------
1 | '''ShuffleNetV2 in PyTorch.
2 |
3 | See the paper "ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" for more details.
4 | '''
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 |
10 | class ShuffleBlock(nn.Module):
11 | def __init__(self, groups=2):
12 | super(ShuffleBlock, self).__init__()
13 | self.groups = groups
14 |
15 | def forward(self, x):
16 | '''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]'''
17 | N, C, H, W = x.size()
18 | g = self.groups
19 | return x.view(N, g, C//g, H, W).permute(0, 2, 1, 3, 4).reshape(N, C, H, W)
20 |
21 |
22 | class SplitBlock(nn.Module):
23 | def __init__(self, ratio):
24 | super(SplitBlock, self).__init__()
25 | self.ratio = ratio
26 |
27 | def forward(self, x):
28 | c = int(x.size(1) * self.ratio)
29 | return x[:, :c, :, :], x[:, c:, :, :]
30 |
31 |
32 | class BasicBlock(nn.Module):
33 | def __init__(self, in_channels, split_ratio=0.5):
34 | super(BasicBlock, self).__init__()
35 | self.split = SplitBlock(split_ratio)
36 | in_channels = int(in_channels * split_ratio)
37 | self.conv1 = nn.Conv2d(in_channels, in_channels,
38 | kernel_size=1, bias=False)
39 | self.bn1 = nn.BatchNorm2d(in_channels)
40 | self.conv2 = nn.Conv2d(in_channels, in_channels,
41 | kernel_size=3, stride=1, padding=1, groups=in_channels, bias=False)
42 | self.bn2 = nn.BatchNorm2d(in_channels)
43 | self.conv3 = nn.Conv2d(in_channels, in_channels,
44 | kernel_size=1, bias=False)
45 | self.bn3 = nn.BatchNorm2d(in_channels)
46 | self.shuffle = ShuffleBlock()
47 |
48 | def forward(self, x):
49 | x1, x2 = self.split(x)
50 | out = F.relu(self.bn1(self.conv1(x2)))
51 | out = self.bn2(self.conv2(out))
52 | out = F.relu(self.bn3(self.conv3(out)))
53 | out = torch.cat([x1, out], 1)
54 | out = self.shuffle(out)
55 | return out
56 |
57 |
58 | class DownBlock(nn.Module):
59 | def __init__(self, in_channels, out_channels):
60 | super(DownBlock, self).__init__()
61 | mid_channels = out_channels // 2
62 | # left
63 | self.conv1 = nn.Conv2d(in_channels, in_channels,
64 | kernel_size=3, stride=2, padding=1, groups=in_channels, bias=False)
65 | self.bn1 = nn.BatchNorm2d(in_channels)
66 | self.conv2 = nn.Conv2d(in_channels, mid_channels,
67 | kernel_size=1, bias=False)
68 | self.bn2 = nn.BatchNorm2d(mid_channels)
69 | # right
70 | self.conv3 = nn.Conv2d(in_channels, mid_channels,
71 | kernel_size=1, bias=False)
72 | self.bn3 = nn.BatchNorm2d(mid_channels)
73 | self.conv4 = nn.Conv2d(mid_channels, mid_channels,
74 | kernel_size=3, stride=2, padding=1, groups=mid_channels, bias=False)
75 | self.bn4 = nn.BatchNorm2d(mid_channels)
76 | self.conv5 = nn.Conv2d(mid_channels, mid_channels,
77 | kernel_size=1, bias=False)
78 | self.bn5 = nn.BatchNorm2d(mid_channels)
79 |
80 | self.shuffle = ShuffleBlock()
81 |
82 | def forward(self, x):
83 | # left
84 | out1 = self.bn1(self.conv1(x))
85 | out1 = F.relu(self.bn2(self.conv2(out1)))
86 | # right
87 | out2 = F.relu(self.bn3(self.conv3(x)))
88 | out2 = self.bn4(self.conv4(out2))
89 | out2 = F.relu(self.bn5(self.conv5(out2)))
90 | # concat
91 | out = torch.cat([out1, out2], 1)
92 | out = self.shuffle(out)
93 | return out
94 |
95 |
96 | class ShuffleNetV2(nn.Module):
97 | def __init__(self, net_size):
98 | super(ShuffleNetV2, self).__init__()
99 | out_channels = configs[net_size]['out_channels']
100 | num_blocks = configs[net_size]['num_blocks']
101 |
102 | self.conv1 = nn.Conv2d(3, 24, kernel_size=3,
103 | stride=1, padding=1, bias=False)
104 | self.bn1 = nn.BatchNorm2d(24)
105 | self.in_channels = 24
106 | self.layer1 = self._make_layer(out_channels[0], num_blocks[0])
107 | self.layer2 = self._make_layer(out_channels[1], num_blocks[1])
108 | self.layer3 = self._make_layer(out_channels[2], num_blocks[2])
109 | self.conv2 = nn.Conv2d(out_channels[2], out_channels[3],
110 | kernel_size=1, stride=1, padding=0, bias=False)
111 | self.bn2 = nn.BatchNorm2d(out_channels[3])
112 | self.linear = nn.Linear(out_channels[3], 10)
113 |
114 | def _make_layer(self, out_channels, num_blocks):
115 | layers = [DownBlock(self.in_channels, out_channels)]
116 | for i in range(num_blocks):
117 | layers.append(BasicBlock(out_channels))
118 | self.in_channels = out_channels
119 | return nn.Sequential(*layers)
120 |
121 | def forward(self, x):
122 | out = F.relu(self.bn1(self.conv1(x)))
123 | # out = F.max_pool2d(out, 3, stride=2, padding=1)
124 | out = self.layer1(out)
125 | out = self.layer2(out)
126 | out = self.layer3(out)
127 | out = F.relu(self.bn2(self.conv2(out)))
128 | out = F.avg_pool2d(out, 4)
129 | out = out.view(out.size(0), -1)
130 | out = self.linear(out)
131 | return out
132 |
133 |
134 | configs = {
135 | 0.5: {
136 | 'out_channels': (48, 96, 192, 1024),
137 | 'num_blocks': (3, 7, 3)
138 | },
139 |
140 | 1: {
141 | 'out_channels': (116, 232, 464, 1024),
142 | 'num_blocks': (3, 7, 3)
143 | },
144 | 1.5: {
145 | 'out_channels': (176, 352, 704, 1024),
146 | 'num_blocks': (3, 7, 3)
147 | },
148 | 2: {
149 | 'out_channels': (224, 488, 976, 2048),
150 | 'num_blocks': (3, 7, 3)
151 | }
152 | }
153 |
154 |
155 | def test():
156 | net = ShuffleNetV2(net_size=0.5)
157 | x = torch.randn(3, 3, 32, 32)
158 | y = net(x)
159 | print(y.shape)
160 |
161 |
162 | # test()
163 |
--------------------------------------------------------------------------------
/Supplements/Fig_S5/models/vgg.py:
--------------------------------------------------------------------------------
1 | '''VGG11/13/16/19 in Pytorch.'''
2 | import torch
3 | import torch.nn as nn
4 |
5 |
6 | cfg = {
7 | 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
8 | 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
9 | 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
10 | 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
11 | }
12 |
13 |
14 | class VGG(nn.Module):
15 | def __init__(self, vgg_name):
16 | super(VGG, self).__init__()
17 | self.features = self._make_layers(cfg[vgg_name])
18 | self.classifier = nn.Linear(512, 10)
19 |
20 | def forward(self, x):
21 | out = self.features(x)
22 | out = out.view(out.size(0), -1)
23 | out = self.classifier(out)
24 | return out
25 |
26 | def _make_layers(self, cfg):
27 | layers = []
28 | in_channels = 3
29 | for x in cfg:
30 | if x == 'M':
31 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
32 | else:
33 | layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
34 | nn.BatchNorm2d(x),
35 | nn.ReLU(inplace=True)]
36 | in_channels = x
37 | layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
38 | return nn.Sequential(*layers)
39 |
40 |
41 | def test():
42 | net = VGG('VGG11')
43 | x = torch.randn(2,3,32,32)
44 | y = net(x)
45 | print(y.size())
46 |
47 | # test()
48 |
--------------------------------------------------------------------------------
/Supplements/Fig_S5/pytorchtools.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | class EarlyStopping:
5 | """Early stops the training if validation loss doesn't improve after a given patience."""
6 | def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print):
7 | """
8 | Args:
9 | patience (int): How long to wait after last time validation loss improved.
10 | Default: 7
11 | verbose (bool): If True, prints a message for each validation loss improvement.
12 | Default: False
13 | delta (float): Minimum change in the monitored quantity to qualify as an improvement.
14 | Default: 0
15 | path (str): Path for the checkpoint to be saved to.
16 | Default: 'checkpoint.pt'
17 | trace_func (function): trace print function.
18 | Default: print
19 | """
20 | self.patience = patience
21 | self.verbose = verbose
22 | self.counter = 0
23 | self.best_score = None
24 | self.early_stop = False
25 | self.val_loss_min = np.Inf
26 | self.delta = delta
27 | self.path = path
28 | self.trace_func = trace_func
29 | def __call__(self, val_loss, model):
30 |
31 | score = -val_loss
32 |
33 | if self.best_score is None:
34 | self.best_score = score
35 | self.save_checkpoint(val_loss, model)
36 | elif score < self.best_score + self.delta:
37 | self.counter += 1
38 | self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
39 | if self.counter >= self.patience:
40 | self.early_stop = True
41 | else:
42 | self.best_score = score
43 | self.save_checkpoint(val_loss, model)
44 | self.counter = 0
45 |
46 | def save_checkpoint(self, val_loss, model):
47 | '''Saves model when validation loss decrease.'''
48 | if self.verbose:
49 | self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')
50 | torch.save(model.state_dict(), self.path)
51 | self.val_loss_min = val_loss
52 |
--------------------------------------------------------------------------------
/Supplements/Fig_S6i_j.m:
--------------------------------------------------------------------------------
1 | close all
2 | clear all
3 |
4 | nepoch = 2000;
5 | learnrate = 0.005;
6 | N_x_t = 100;
7 | N_y_t = 1;
8 | P=100;
9 | M = 5000; %num of units in notebook
10 |
11 | Et_big_early_stop = [];
12 | Eg_big_early_stop = [];
13 | Et_big = [];
14 | Eg_big = [];
15 |
16 | Fig3j_data = zeros(2,2);
17 |
18 | Fig3k_data = zeros(2,2);
19 |
20 |
21 | SNR_vec = [0.6 1000];
22 |
23 |
24 | Notebook_Train = (P-1)/(M-1); % Analytial solution for notebook training error
25 |
26 | for count = 1:size(SNR_vec,2)
27 | SNR = SNR_vec(count);
28 | %Analytical solution
29 | Eg = [];
30 | Et = [];
31 | lr = learnrate;
32 |
33 | %use normalized variances
34 | if SNR == inf
35 | variance_w = 1;
36 | variance_e = 0;
37 | else
38 | variance_w = SNR/(SNR + 1);
39 | variance_e = 1/(SNR + 1);
40 | end
41 |
42 | alpha = P/N_x_t; % number of examples divided by input dimension
43 |
44 | % Analytical curves for training and testing errors
45 | for t = 1:1:nepoch
46 |
47 | %correct integral
48 | train = @(lam) ( ( ( ((alpha^0.5+1).^2 - lam) .* (lam - (alpha^0.5-1).^2) ).^0.5) ./ (lam*2*pi) ).* ( lam.*variance_w + variance_e ).*exp(-2*lam.*t./(1./lr)) ;
49 | Et = [Et (1/alpha)*(integral(train,(alpha^0.5-1)^2,(alpha^0.5+1)^2) + (alpha<1)*(1 - alpha)* variance_e ) + (1-1/alpha)*variance_e];
50 |
51 | test = @(lam) ( ( ( ((alpha^0.5+1).^2 - lam) .* (lam - (alpha^0.5-1).^2) ).^0.5) ./ (lam*2*pi) ).*(exp(-2*lam*t/(1/lr)) + ((1-exp(-lam*t/(1/lr))).^2)./(lam*SNR));
52 | Eg = [Eg variance_w*(integral(test,(alpha^0.5-1)^2,(alpha^0.5+1)^2) + (alpha<1)* (1 - alpha) + 1/SNR)];
53 |
54 | end
55 |
56 | % Early stopping curves
57 | [mm, pp] = min(Eg);
58 | Eg_early_stop = Eg;
59 | Eg_early_stop(pp+1:end) = Eg(pp);
60 |
61 | Et_early_stop = Et;
62 | Et_early_stop(pp+1:end) = Et(pp);
63 |
64 | % Pick better training error value and create memory generalization
65 | % scores
66 | better_train_no_early_stop = min(Et,ones(1,nepoch)*Notebook_Train);
67 | control_curve = (Et(1) - better_train_no_early_stop)/Et(1);
68 | lesion_curve = (Et(1) - Et)/Et(1);
69 |
70 |
71 | better_train_yes_early_stop = min(Et_early_stop,ones(1,nepoch)*Notebook_Train);
72 | control_curve_early_stop = (Et_early_stop(1) - better_train_yes_early_stop)/Et_early_stop(1);
73 | lesion_curve_early_stop = (Et_early_stop(1) - Et_early_stop)/Et_early_stop(1);
74 |
75 |
76 | control_Eg_curve = (Eg(1) - Eg)/Eg(1);
77 | control_Eg_curve_early_stop = (Eg_early_stop(1) - Eg_early_stop)/Eg_early_stop(1);
78 |
79 | Fig3j_data(count,1) = control_curve_early_stop(end);
80 | Fig3j_data(count,2) = control_Eg_curve_early_stop(end);
81 |
82 | Fig3k_data(count,1) = control_curve_early_stop(end);
83 | Fig3k_data(count,2) = lesion_curve_early_stop(end);
84 |
85 | end
86 |
87 |
88 | figure(1)
89 |
90 | x = [[0.7,1.3];...
91 | [2.7, 3.3]];
92 |
93 | data = Fig3j_data;
94 |
95 |
96 | f=bar(x,data*100);
97 | f(1).BarWidth = 3.2;
98 | xlim([0 4])
99 | ylim([0 120])
100 | hold on
101 |
102 | ax = gca;
103 | ax.XTick = [1 3];
104 | ax.YTick = [0 20 40 60 80 100];
105 | xax = ax.XAxis;
106 | set(xax,'TickDirection','out')
107 | set(gca,'box','off')
108 | set(gcf,'position',[600,400,340,210])
109 | set(gca, 'FontSize', 16)
110 |
111 |
112 | ax.XTickLabel=[{'Discriminator'} {'Generalizer'}];
113 |
114 |
115 | figure(2)
116 |
117 | x = [[0.7,1.3];...
118 | [2.7, 3.3]];
119 |
120 | data = Fig3k_data;
121 |
122 |
123 | f=bar(x,data*100);
124 | f(1).BarWidth = 3.2;
125 | xlim([0 4])
126 | ylim([0 120])
127 | hold on
128 |
129 | ax = gca;
130 | ax.XTick = [1 3];
131 | ax.YTick = [0 20 40 60 80 100];
132 |
133 | xax = ax.XAxis;
134 | set(xax,'TickDirection','out')
135 | set(gca,'box','off')
136 | set(gcf,'position',[600,100,340,210])
137 | set(gca, 'FontSize', 16)
138 |
139 |
140 | ax.XTickLabel=[{'Discriminator'} {'Generalizer'}];
141 |
--------------------------------------------------------------------------------
/Supplements/Fig_S6l.m:
--------------------------------------------------------------------------------
1 | % code for Student-Teacher-Notebook framework,simulating Sweegers
2 | % et al., 2014
3 | tic
4 | close all
5 | clear all
6 |
7 | r_n = 3; % number of repeats
8 | nepoch = 2000;
9 | learnrate = 0.015;
10 | N_x_t = 100; % teacher input dimension
11 | N_y_t = 1; % teacher output dimension
12 | P=100; % number of training examples
13 | P_test = 1000; % number of testing examples
14 |
15 | variance_s = 0.5; % student weight variance at initialization
16 |
17 | % According to SNR, set variances for teacher's weights (variance_w) and output noise (variance_e) that sum to 1
18 | SNR_vec = [0.05 2]; % set values to 0.05 and 2 to reproduce fig. 3n
19 | W_s_norm = zeros(size(SNR_vec,2),r_n,nepoch);
20 |
21 | for SNR_counter=1:size(SNR_vec,2)
22 | SNR = SNR_vec(SNR_counter);
23 |
24 | if SNR == inf
25 | variance_w = 1;
26 | variance_e = 0;
27 | else
28 | variance_w = SNR/(SNR + 1);
29 | variance_e = 1/(SNR + 1);
30 | end
31 |
32 | % Student and teacher share the same dimensions
33 | N_x_s = N_x_t;
34 | N_y_s = N_y_t;
35 |
36 | % Notebook parameters
37 | % see Buhmann, Divko, and Schulten, 1989 for details regarding gamma and U terms
38 |
39 | M = 2000; % num of units in notebook
40 | a = 0.05; % notebook sparseness
41 | gamma = 0.6; % inhibtion parameter
42 | U = -0.15; % threshold for unit activation
43 | ncycle = 9; % number of recurrent cycles
44 |
45 |
46 | % Matrices for storing train error, test error, reactivation error (driven by notebook)
47 | % Without early stopping
48 | train_error_all = zeros(r_n,nepoch);
49 | test_error_all = zeros(r_n,nepoch);
50 | N_train_error_all = zeros(r_n,nepoch);
51 | N_test_error_all = zeros(r_n,nepoch);
52 |
53 | % With early stopping
54 | train_error_early_stop_all = zeros(r_n,nepoch);
55 | test_error_early_stop_all = zeros(r_n,nepoch);
56 |
57 |
58 | %Run simulation for r_n times
59 | for r = 1:r_n
60 | disp(r)
61 | rng(r); %set random seed for reproducibility
62 |
63 | %Errors
64 | error_train_vector = zeros(nepoch,1);
65 | error_test_vector = zeros(nepoch,1);
66 | error_react_vector = zeros(nepoch,1);
67 |
68 | %% Teacher Network
69 | W_t = normrnd(0,variance_w^0.5,[N_x_t,N_y_t]);% set teacher's weights
70 | noise_train = normrnd(0,variance_e^0.5,[P,N_y_t]);
71 | % Training data
72 | x_t_input = normrnd(0,(1/N_x_t)^0.5,[P,N_x_t]); % inputs
73 | y_t_output = x_t_input*W_t + noise_train; % outputs
74 |
75 | % Testing data
76 | noise_test = normrnd(0,variance_e^0.5,[P_test,N_y_t]);
77 | x_t_input_test = normrnd(0,(1/N_x_t)^0.5,[P_test,N_x_t]);
78 | y_t_output_test = x_t_input_test*W_t + noise_test;
79 |
80 | %% Notebook Network
81 | % Generate P random binary indices with sparseness a
82 | N_patterns = zeros(P,M);
83 | for n=1:P
84 | N_patterns(n,randperm(M,M*a))=1;
85 | end
86 |
87 | %Hebbian learning for notebook recurrent weights
88 | W_N = (N_patterns - a)'*(N_patterns - a)/(M*a*(1-a));
89 | W_N = W_N - gamma/(a*M);% add global inhibiton term, see Buhmann, Divko, and Schulten, 1989
90 | W_N = W_N.*~eye(size(W_N)); % set diagonal weights to zero
91 |
92 | % Hebbian learning for Notebook-Student weights (bidirectional)
93 | % Notebook to student weights
94 | W_N_S_Lin = (N_patterns-a)'*x_t_input/(M*a*(1-a));
95 | W_N_S_Lout = (N_patterns-a)'*y_t_output/(M*a*(1-a));
96 | % Student to notebook weights
97 | W_S_N_Lin = x_t_input'*(N_patterns-a);
98 | W_S_N_Lout = y_t_output'*(N_patterns-a);
99 |
100 | %% Student Network
101 | W_s = normrnd(0,variance_s^0.5,[N_x_s,N_y_s]); % set student's initial weights
102 |
103 | %% Generate offline training data from notebook reactivations
104 | N_patterns_reactivated = zeros(P,M,nepoch,'logical'); % array for storing retrieved notebook patterns, pre-calculating all epochs for speed considerations
105 |
106 | parfor m = 1:nepoch
107 | %for m = 1:nepoch % change to regular for-loop without multiple cores
108 |
109 | %% Notebook pattern completion through recurrent dynamis
110 | % Code below simulates hippocampal offline spontanenous
111 | % reactivations by seeding the initial notebook state with a random
112 | % binary index, then notebook goes through a two-step retrieval
113 | % process: (1) Retrieving a pattern using dynamic threshold to
114 | % ensure a pattern with sparseness a is retrieved. (2) Using the
115 | % retrieved pattern from (1) to seed a second round of pattern
116 | % completion using a fixed-threshold method (along with a global
117 | % inhibition term during encoding), so the retrieved patterns are
118 | % not forced to have a fixed sparseness, in addition, there is a
119 | % "silent state" attractor when the seeding pattern lies far away
120 | % from any of the encoded patterns.
121 |
122 | % Start recurrent cycles with dynamic threshold
123 | Activity_dyn_t = zeros(P, M);
124 |
125 | % First round of pattern completion through recurrent activtion cycles given
126 | % random initial input.
127 | for cycle = 1:ncycle
128 | if cycle <=1
129 | clamp = 1;
130 | else
131 | clamp = 0;
132 | end
133 | rand_patt = (rand(P,M)<=a);
134 | % Seeding notebook with random patterns
135 | M_input = Activity_dyn_t + (rand_patt*clamp);
136 | % Seeding notebook with original patterns
137 | %M_input = Activity_dyn_t + (N_patterns*clamp);
138 | M_current = M_input*W_N;
139 | % scale currents between 0 and 1
140 | scale = 1.0 ./ (max(M_current,[],2) - min(M_current,[],2));
141 | M_current = (M_current - min(M_current,[],2)) .* scale;
142 | % find threshold based on desired sparseness
143 | sorted_M_current = sort(M_current,2,'descend');
144 | t_ind = floor(size(Activity_dyn_t,2) * a);
145 | t_ind(t_ind<1) = 1;
146 | t = sorted_M_current(:,t_ind); % threshold for unit activations
147 | Activity_dyn_t = (M_current >=t);
148 | end
149 |
150 | % Second round of pattern completion, with fix threshold
151 | Activity_fix_t = zeros(P, M);
152 | for cycle = 1:ncycle
153 | if cycle <=1
154 | clamp = 1;
155 | else
156 | clamp = 0;
157 | end
158 | M_input = Activity_fix_t + Activity_dyn_t*clamp;
159 | M_current = M_input*W_N;
160 | Activity_fix_t = (M_current >= U); % U is the fixed threshold
161 | end
162 | N_patterns_reactivated(:,:,m)=Activity_fix_t;
163 | end
164 |
165 | %% Seeding notebook with original notebook patterns for calculating
166 | % training error mediated by notebook (seeding notebook with student
167 | % input via Student's input to Notebook weights, once pattern completion
168 | % finishes, use the retrieved pattern to activate Student's output unit
169 | % via Notebook to Student's output weights.
170 |
171 | Activity_notebook_train = zeros(P, M);
172 | for cycle = 1:ncycle
173 | if cycle <=1
174 | clamp = 1;
175 | else
176 | clamp = 0;
177 | end
178 | seed_patt = x_t_input*W_S_N_Lin;
179 | M_input = Activity_notebook_train + (seed_patt*clamp);
180 | M_current = M_input*W_N;
181 | scale = 1.0 ./ (max(M_current,[],2) - min(M_current,[],2));
182 | M_current = (M_current - min(M_current,[],2)) .* scale;
183 | sorted_M_current = sort(M_current,2,'descend');
184 | t_ind = floor(size(Activity_notebook_train,2) * a);
185 | t_ind(t_ind<1) = 1;
186 | t = sorted_M_current(:,t_ind);
187 | Activity_notebook_train = (M_current >=t);
188 | end
189 | N_S_output_train = Activity_notebook_train*W_N_S_Lout;
190 | % Notebook training error
191 | delta_N_train = y_t_output - N_S_output_train;
192 | error_N_train = sum(delta_N_train.^2)/P;
193 | % Since notebook errors stay constant throughout training,
194 | % populating each epoch with the same value
195 | error_N_train_vector = ones(nepoch,1)*error_N_train;
196 | N_train_error_all(r,:) = error_N_train_vector;
197 |
198 | % Notebook generalization error
199 | Activity_notebook_test = zeros(P_test, M);
200 | for cycle = 1:ncycle
201 | if cycle <=1
202 | clamp = 1;
203 | else
204 | clamp = 0;
205 | end
206 | seed_patt = x_t_input_test*W_S_N_Lin;
207 | M_input = Activity_notebook_test + (seed_patt*clamp);
208 | M_current = M_input*W_N;
209 | scale = 1.0 ./ (max(M_current,[],2) - min(M_current,[],2));
210 | M_current = (M_current - min(M_current,[],2)) .* scale;
211 | sorted_M_current = sort(M_current,2,'descend');
212 | t_ind = floor(size(Activity_notebook_test,2) * a);
213 | t_ind(t_ind<1) = 1;
214 | t = sorted_M_current(:,t_ind);
215 | Activity_notebook_test = (M_current >=t);
216 | end
217 | N_S_output_test = Activity_notebook_test*W_N_S_Lout;
218 | % Notebook test error
219 | delta_N_test = y_t_output_test - N_S_output_test;
220 | error_N_test = sum(delta_N_test.^2)/P_test;
221 | % Since notebook errors stay constant throughout training,
222 | % populating each epoch with the same value
223 | error_N_test_vector = ones(nepoch,1)*error_N_test;
224 | N_test_error_all(r,:) = error_N_test_vector;
225 |
226 |
227 | N_patterns_reactivated_test = zeros(P_test,M,'logical');
228 | %% Student training
229 | for m = 1:nepoch %batch training starts
230 |
231 | W_s_norm(SNR_counter,r,m) = norm(W_s,2);
232 |
233 | N_S_input = N_patterns_reactivated(:,:,m)*W_N_S_Lin; % notebook reactivated student input activity
234 | N_S_output = N_patterns_reactivated(:,:,m)*W_N_S_Lout; % notebook reactivated student output activity
235 | N_S_prediction = N_S_input*W_s; % student output prediction calculated by notebook reactivated input and student weights
236 | S_prediction = x_t_input*W_s; % student output prediction calculated by true training inputs and student weights
237 | S_prediction_test = x_t_input_test*W_s; % student output prediction calculated by true testing inputs and student weights
238 |
239 | % Train error
240 | delta_train = y_t_output - S_prediction;
241 | error_train = sum(delta_train.^2)/P;
242 | error_train_vector(m) = error_train;
243 |
244 | % Generalization error
245 | delta_test = y_t_output_test - S_prediction_test;
246 | error_test = sum(delta_test.^2)/P_test;
247 | error_test_vector(m) = error_test;
248 |
249 | % Gradient descent
250 | w_delta = N_S_input'*N_S_output - N_S_input'*N_S_input*W_s;
251 | W_s = W_s + learnrate*w_delta;
252 | end
253 |
254 | train_error_all(r,:) = error_train_vector;
255 | test_error_all(r,:) = error_test_vector;
256 |
257 | % Early stopping
258 | [min_v, min_p] = min(error_test_vector);
259 | train_error_early_stop = error_train_vector;
260 | train_error_early_stop (min_p+1:end) = error_train_vector (min_p);
261 | test_error_early_stop = error_test_vector;
262 | test_error_early_stop (min_p+1:end) = error_test_vector (min_p);
263 | train_error_early_stop_all(r,:) = train_error_early_stop;
264 | test_error_early_stop_all(r,:) = test_error_early_stop;
265 | W_s_norm(SNR_counter,r,min_p+1:end) = W_s_norm(SNR_counter,r,min_p);
266 | end
267 | end
268 | toc
269 |
270 |
271 |
272 |
273 | %Weight norm with early stopping
274 | figure(3)
275 | x = [[0.95,2];...
276 | [1.05, 2]];
277 |
278 | data = [squeeze(mean(W_s_norm(1,:,[1 2000]),2)/mean(W_s_norm(1,:,1),2))';...
279 | squeeze(mean(W_s_norm(2,:,[1 2000]),2)/mean(W_s_norm(1,:,1),2))'];
280 |
281 |
282 | % errlow = data - err;
283 | % errhigh = data + err;
284 |
285 | f=plot(x',data','o-');
286 |
287 | xlim([0.5 2.4])
288 | % ylim([-0.2 1.6])
289 | hold on
290 |
291 |
292 | ax = gca;
293 | ax.XTick = [1 2];
294 | %ax.YTick = [0 0.4 0.8 1.2 1.6];
295 |
296 | xax = ax.XAxis;
297 | set(xax,'TickDirection','out')
298 | set(gca,'box','off')
299 | set(gcf,'position',[600,100, 270,300])
300 | set(gca, 'FontSize', 16)
301 | ax.XTickLabel=[{'Recent'} {'Remote'}];
302 | ylabel('Connectivity strength')
303 |
304 |
305 |
306 |
307 |
308 |
--------------------------------------------------------------------------------