├── .gitignore
├── .idea
├── PyTorch-BayesianCNN.iml
├── misc.xml
├── modules.xml
├── vcs.xml
└── workspace.xml
├── LICENSE
├── Mixtures
├── config_mixtures.py
├── gmm.py
├── main.py
├── mixture_experiment.py
├── temp_gmm.py
├── train_splitted.py
└── utils_mixture.py
├── README.md
├── __init__.py
├── config_bayesian.py
├── config_frequentist.py
├── data
├── __init__.py
└── data.py
├── experiments
└── figures
│ ├── BayesCNNwithdist.png
│ ├── CNNwithdist_git.png
│ ├── fc3-node_0-both-distplot.gif
│ ├── fc3-node_0-mean-distplot.gif
│ ├── fc3-node_0-mean-lineplot.jpg
│ ├── fc3-node_0-std-distplot.gif
│ └── fc3-node_0-std-lineplot.jpg
├── layers
├── BBB
│ ├── BBBConv.py
│ ├── BBBLinear.py
│ └── __init__.py
├── BBB_LRT
│ ├── BBBConv.py
│ ├── BBBLinear.py
│ └── __init__.py
├── __init__.py
├── __pycache__
│ ├── BBBConv.cpython-37.pyc
│ ├── BBBLinear.cpython-37.pyc
│ ├── __init__.cpython-37.pyc
│ └── misc.cpython-37.pyc
└── misc.py
├── main_bayesian.py
├── main_frequentist.py
├── metrics.py
├── models
├── BayesianModels
│ ├── Bayesian3Conv3FC.py
│ ├── BayesianAlexNet.py
│ ├── BayesianLeNet.py
│ └── __pycache__
│ │ ├── Bayesian3Conv3FC.cpython-37.pyc
│ │ ├── BayesianAlexNet.cpython-37.pyc
│ │ └── BayesianLeNet.cpython-37.pyc
└── NonBayesianModels
│ ├── AlexNet.py
│ ├── LeNet.py
│ ├── ThreeConvThreeFC.py
│ └── __pycache__
│ ├── AlexNet.cpython-37.pyc
│ ├── LeNet.cpython-37.pyc
│ └── ThreeConvThreeFC.cpython-37.pyc
├── tests
└── test_models.py
├── uncertainty_estimation.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | checkpoint/
2 | data/
3 | checkpoint_/
4 | *.pyc
5 |
6 |
--------------------------------------------------------------------------------
/.idea/PyTorch-BayesianCNN.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
112 |
113 |
114 |
115 | nn
116 | np
117 | conv_init
118 |
119 |
120 | self
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 | true
159 | DEFINITION_ORDER
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
231 |
232 |
233 |
234 |
235 |
236 |
237 |
238 |
239 |
240 |
241 |
242 |
243 |
244 |
245 |
246 |
247 |
248 |
249 |
250 |
251 |
252 |
253 |
254 |
255 |
256 |
257 |
258 |
259 |
260 |
261 |
262 |
263 |
264 |
265 |
266 |
267 |
268 |
269 |
270 |
271 |
272 |
273 |
274 |
275 |
276 |
277 |
278 |
279 |
280 |
281 |
282 |
283 |
284 |
285 |
286 |
287 |
288 |
289 |
290 |
291 |
292 |
293 |
294 |
295 |
296 |
297 |
298 |
299 |
300 |
301 |
302 |
303 |
304 |
305 |
306 |
307 |
308 |
309 |
310 |
311 |
312 |
313 |
314 |
315 |
316 |
317 |
318 |
319 |
320 |
321 |
322 |
323 |
324 |
325 |
326 |
327 |
328 |
329 |
330 |
331 |
332 |
333 |
334 |
335 |
336 |
337 |
338 |
339 |
340 |
341 |
342 |
343 | 1533674567661
344 |
345 |
346 | 1533674567661
347 |
348 |
349 | 1533675052999
350 |
351 |
352 |
353 | 1533675053000
354 |
355 |
356 | 1533675631291
357 |
358 |
359 |
360 | 1533675631291
361 |
362 |
363 | 1533676355810
364 |
365 |
366 |
367 | 1533676355810
368 |
369 |
370 | 1533676412753
371 |
372 |
373 |
374 | 1533676412753
375 |
376 |
377 | 1533677149685
378 |
379 |
380 |
381 | 1533677149685
382 |
383 |
384 | 1534351051666
385 |
386 |
387 |
388 | 1534351051667
389 |
390 |
391 | 1534367049147
392 |
393 |
394 |
395 | 1534367049147
396 |
397 |
398 | 1534544213411
399 |
400 |
401 |
402 | 1534544213411
403 |
404 |
405 | 1535393182188
406 |
407 |
408 |
409 | 1535393182188
410 |
411 |
412 | 1535393367445
413 |
414 |
415 |
416 | 1535393367445
417 |
418 |
419 | 1536706160571
420 |
421 |
422 |
423 | 1536706160571
424 |
425 |
426 | 1538580008442
427 |
428 |
429 |
430 | 1538580008442
431 |
432 |
433 | 1539368312992
434 |
435 |
436 |
437 | 1539368312992
438 |
439 |
440 | 1539368348298
441 |
442 |
443 |
444 | 1539368348298
445 |
446 |
447 | 1539368395541
448 |
449 |
450 |
451 | 1539368395541
452 |
453 |
454 | 1539434333858
455 |
456 |
457 |
458 | 1539434333858
459 |
460 |
461 | 1539464376150
462 |
463 |
464 |
465 | 1539464376150
466 |
467 |
468 |
469 |
470 |
471 |
472 |
473 |
474 |
475 |
476 |
477 |
478 |
479 |
480 |
481 |
482 |
483 |
484 |
485 |
486 |
487 |
488 |
489 |
490 |
491 |
492 |
493 |
494 |
495 |
496 |
497 |
498 |
499 |
500 |
501 |
502 |
503 |
504 |
505 |
506 |
507 |
508 |
509 |
510 |
511 |
512 |
513 |
514 |
515 |
516 |
517 |
518 |
519 |
520 |
521 |
522 |
523 |
524 |
525 |
526 |
527 |
528 |
529 |
530 |
531 |
532 |
533 |
534 |
535 |
536 |
537 |
538 |
539 |
540 |
541 |
542 |
543 |
544 |
545 |
546 |
547 |
548 |
549 |
550 |
551 |
552 |
553 |
554 |
555 |
556 |
557 |
558 |
559 |
560 |
561 |
562 |
563 |
564 |
565 |
566 |
567 |
568 |
569 |
570 |
571 |
572 |
573 |
574 |
575 |
576 |
577 |
578 |
579 |
580 |
581 |
582 |
583 |
584 |
585 |
586 |
587 |
588 |
589 |
590 |
591 |
592 |
593 |
594 |
595 |
596 |
597 |
598 |
599 |
600 |
601 |
602 |
603 |
604 |
605 |
606 |
607 |
608 |
609 |
610 |
611 |
612 |
613 |
614 |
615 |
616 |
617 |
618 |
619 |
620 |
621 |
622 |
623 |
624 |
625 |
626 |
627 |
628 |
629 |
630 |
631 |
632 |
633 |
634 |
635 |
636 |
637 |
638 |
639 |
640 |
641 |
642 |
643 |
644 |
645 |
646 |
647 |
648 |
649 |
650 |
651 |
652 |
653 |
654 |
655 |
656 |
657 |
658 |
659 |
660 |
661 |
662 |
663 |
664 |
665 |
666 |
667 |
668 |
669 |
670 |
671 |
672 |
673 |
674 |
675 |
676 |
677 |
678 |
679 |
680 |
681 |
682 |
683 |
684 |
685 |
686 |
687 |
688 |
689 |
690 |
691 |
692 |
693 |
694 |
695 |
696 |
697 |
698 |
699 |
700 |
701 |
702 |
703 |
704 |
705 |
706 |
707 |
708 |
709 |
710 |
711 |
712 |
713 |
714 |
715 |
716 |
717 |
718 |
719 |
720 |
721 |
722 |
723 |
724 |
725 |
726 |
727 |
728 |
729 |
730 |
731 |
732 |
733 |
734 |
735 |
736 |
737 |
738 |
739 |
740 |
741 |
742 |
743 |
744 |
745 |
746 |
747 |
748 |
749 |
750 |
751 |
752 |
753 |
754 |
755 |
756 |
757 |
758 |
759 |
760 |
761 |
762 |
763 |
764 |
765 |
766 |
767 |
768 |
769 |
770 |
771 |
772 |
773 |
774 |
775 |
776 |
777 |
778 |
779 |
780 |
781 |
782 |
783 |
784 |
785 |
786 |
787 |
788 |
789 |
790 |
791 |
792 |
793 |
794 |
795 |
796 |
797 |
798 |
799 |
800 |
801 |
802 |
803 |
804 |
805 |
806 |
807 |
808 |
809 |
810 |
811 |
812 |
813 |
814 |
815 |
816 |
817 |
818 |
819 |
820 |
821 |
822 |
823 |
824 |
825 |
826 |
827 |
828 |
829 |
830 |
831 |
832 |
833 |
834 |
835 |
836 |
837 |
838 |
839 |
840 |
841 |
842 |
843 |
844 |
845 |
846 |
847 |
848 |
849 |
850 |
851 |
852 |
853 |
854 |
855 |
856 |
857 |
858 |
859 |
860 |
861 |
862 |
863 |
864 |
865 |
866 |
867 |
868 |
869 |
870 |
871 |
872 |
873 |
874 |
875 |
876 |
877 |
878 |
879 |
880 |
881 |
882 |
883 |
884 |
885 |
886 |
887 |
888 |
889 |
890 |
891 |
892 |
893 |
894 |
895 |
896 |
897 |
898 |
899 |
900 |
901 |
902 |
903 |
904 |
905 |
906 |
907 |
908 |
909 |
910 |
911 |
912 |
913 |
914 |
915 |
916 |
917 |
918 |
919 |
920 |
921 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Kumar Shridhar
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/Mixtures/config_mixtures.py:
--------------------------------------------------------------------------------
1 | ############### Configuration file for Training of SplitMNIST and Mixtures ###############
2 | n_epochs = 5
3 | lr_start = 0.001
4 | num_workers = 4
5 | valid_size = 0.2
6 | batch_size = 256
7 | train_ens = 15
8 | valid_ens = 10
9 |
--------------------------------------------------------------------------------
/Mixtures/gmm.py:
--------------------------------------------------------------------------------
1 | # Courtesy of https://github.com/ldeecke/gmm-torch
2 |
3 | import torch
4 | import numpy as np
5 |
6 | from math import pi
7 |
8 |
9 | class GaussianMixture(torch.nn.Module):
10 | """
11 | Fits a mixture of k=1,..,K Gaussians to the input data. Input tensors are expected to be flat with dimensions (n: number of samples, d: number of features).
12 | The model then extends them to (n, k: number of components, d).
13 | The model parametrization (mu, sigma) is stored as (1, k, d), and probabilities are shaped (n, k, 1) if they relate to an individual sample, or (1, k, 1) if they assign membership probabilities to one of the mixture components.
14 | """
15 | def __init__(self, n_components, n_features, mu_init=None, var_init=None, eps=1.e-6):
16 | """
17 | Initializes the model and brings all tensors into their required shape. The class expects data to be fed as a flat tensor in (n, d). The class owns:
18 | x: torch.Tensor (n, k, d)
19 | mu: torch.Tensor (1, k, d)
20 | var: torch.Tensor (1, k, d)
21 | pi: torch.Tensor (1, k, 1)
22 | eps: float
23 | n_components: int
24 | n_features: int
25 | score: float
26 | args:
27 | n_components: int
28 | n_features: int
29 | mu_init: torch.Tensor (1, k, d)
30 | var_init: torch.Tensor (1, k, d)
31 | eps: float
32 | """
33 |
34 | super(GaussianMixture, self).__init__()
35 |
36 | self.eps = eps
37 | self.n_components = n_components
38 | self.n_features = n_features
39 | self.log_likelihood = -np.inf
40 |
41 | self.mu_init = mu_init
42 | self.var_init = var_init
43 |
44 | self._init_params()
45 |
46 |
47 | def _init_params(self):
48 | if self.mu_init is not None:
49 | assert self.mu_init.size() == (1, self.n_components, self.n_features), "Input mu_init does not have required tensor dimensions (1, %i, %i)" % (self.n_components, self.n_features)
50 | # (1, k, d)
51 | self.mu = torch.nn.Parameter(self.mu_init, requires_grad=False)
52 | else:
53 | self.mu = torch.nn.Parameter(torch.randn(1, self.n_components, self.n_features), requires_grad=False)
54 |
55 | if self.var_init is not None:
56 | assert self.var_init.size() == (1, self.n_components, self.n_features), "Input var_init does not have required tensor dimensions (1, %i, %i)" % (self.n_components, self.n_features)
57 | # (1, k, d)
58 | self.var = torch.nn.Parameter(self.var_init, requires_grad=False)
59 | else:
60 | self.var = torch.nn.Parameter(torch.ones(1, self.n_components, self.n_features), requires_grad=False)
61 |
62 | # (1, k, 1)
63 | self.pi = torch.nn.Parameter(torch.Tensor(1, self.n_components, 1), requires_grad=False).fill_(1./self.n_components)
64 |
65 | self.params_fitted = False
66 |
67 |
68 | def bic(self, x):
69 | """
70 | Bayesian information criterion for samples x.
71 | args:
72 | x: torch.Tensor (n, d) or (n, k, d)
73 | returns:
74 | bic: float
75 | """
76 | n = x.shape[0]
77 |
78 | if len(x.size()) == 2:
79 | # (n, d) --> (n, k, d)
80 | x = x.unsqueeze(1).expand(n, self.n_components, x.size(1))
81 |
82 | bic = -2. * self.__score(self.pi, self.__p_k(x, self.mu, self.var), sum_data=True) * n + self.n_components * np.log(n)
83 |
84 | return bic
85 |
86 |
87 | def fit(self, x, warm_start=False, delta=1e-8, n_iter=1000):
88 | """
89 | Public method that fits data to the model.
90 | args:
91 | n_iter: int
92 | delta: float
93 | """
94 |
95 | if not warm_start and self.params_fitted:
96 | self._init_params()
97 |
98 | if len(x.size()) == 2:
99 | # (n, d) --> (n, k, d)
100 | x = x.unsqueeze(1).expand(x.size(0), self.n_components, x.size(1))
101 |
102 | i = 0
103 | j = np.inf
104 |
105 | while (i <= n_iter) and (j >= delta):
106 |
107 | log_likelihood_old = self.log_likelihood
108 | mu_old = self.mu
109 | var_old = self.var
110 |
111 | self.__em(x)
112 | self.log_likelihood = self.__score(self.pi, self.__p_k(x, self.mu, self.var))
113 |
114 | if (self.log_likelihood.abs() == float("Inf")) or (self.log_likelihood == float("nan")):
115 | # when the log-likelihood assumes inane values, reinitialize model
116 | self.__init__(self.n_components, self.n_features)
117 |
118 | i += 1
119 | j = self.log_likelihood - log_likelihood_old
120 | if j <= delta:
121 | # when the score decreases, revert to old parameters
122 | self.__update_mu(mu_old)
123 | self.__update_var(var_old)
124 |
125 | self.params_fitted = True
126 |
127 |
128 | def predict(self, x, probs=False):
129 | """
130 | Assigns input data to one of the mixture components by evaluating the likelihood under each. If probs=True returns normalized probabilities of class membership instead.
131 | args:
132 | x: torch.Tensor (n, d) or (n, k, d)
133 | probs: bool
134 | returns:
135 | y: torch.LongTensor (n)
136 | """
137 |
138 | if len(x.size()) == 2:
139 | # (n, d) --> (n, k, d)
140 | x = x.unsqueeze(1).expand(x.size(0), self.n_components, x.size(1))
141 |
142 | p_k = self.__p_k(x, self.mu, self.var)
143 | if probs:
144 | return p_k / (p_k.sum(1, keepdim=True) + self.eps)
145 | else:
146 | _, predictions = torch.max(p_k, 1)
147 | return torch.squeeze(predictions).type(torch.LongTensor)
148 |
149 | def predict_proba(self, x):
150 | """
151 | Returns normalized probabilities of class membership.
152 | args:
153 | x: torch.Tensor (n, d) or (n, k, d)
154 | returns:
155 | y: torch.LongTensor (n)
156 | """
157 | return self.predict(x, probs=True)
158 |
159 |
160 | def score_samples(self, x):
161 | """
162 | Computes log-likelihood of data (x) under the current model.
163 | args:
164 | x: torch.Tensor (n, d) or (n, k, d)
165 | returns:
166 | score: torch.LongTensor (n)
167 | """
168 | if len(x.size()) == 2:
169 | # (n, d) --> (n, k, d)
170 | x = x.unsqueeze(1).expand(x.size(0), self.n_components, x.size(1))
171 |
172 | score = self.__score(self.pi, self.__p_k(x, self.mu, self.var), sum_data=False)
173 | return score
174 |
175 |
176 | def __p_k(self, x, mu, var):
177 | """
178 | Returns a tensor with dimensions (n, k, 1) indicating the likelihood of data belonging to the k-th Gaussian.
179 | args:
180 | x: torch.Tensor (n, k, d)
181 | mu: torch.Tensor (1, k, d)
182 | var: torch.Tensor (1, k, d)
183 | returns:
184 | p_k: torch.Tensor (n, k, 1)
185 | """
186 |
187 | # (1, k, d) --> (n, k, d)
188 | mu = mu.expand(x.size(0), self.n_components, self.n_features)
189 | var = var.expand(x.size(0), self.n_components, self.n_features)
190 |
191 | # (n, k, d) --> (n, k, 1)
192 | exponent = torch.exp(-.5 * torch.sum((x - mu) * (x - mu) / var, 2, keepdim=True))
193 | # (n, k, d) --> (n, k, 1)
194 | prefactor = torch.rsqrt(((2. * pi) ** self.n_features) * torch.prod(var, dim=2, keepdim=True) + self.eps)
195 |
196 | return prefactor * exponent
197 |
198 |
199 | def __e_step(self, pi, p_k):
200 | """
201 | Computes weights that indicate the probabilistic belief that a data point was generated by one of the k mixture components. This is the so-called expectation step of the EM-algorithm.
202 | args:
203 | pi: torch.Tensor (1, k, 1)
204 | p_k: torch.Tensor (n, k, 1)
205 | returns:
206 | weights: torch.Tensor (n, k, 1)
207 | """
208 |
209 | weights = pi * p_k
210 | return torch.div(weights, torch.sum(weights, 1, keepdim=True) + self.eps)
211 |
212 |
213 | def __m_step(self, x, weights):
214 | """
215 | Updates the model's parameters. This is the maximization step of the EM-algorithm.
216 | args:
217 | x: torch.Tensor (n, k, d)
218 | weights: torch.Tensor (n, k, 1)
219 | returns:
220 | pi_new: torch.Tensor (1, k, 1)
221 | mu_new: torch.Tensor (1, k, d)
222 | var_new: torch.Tensor (1, k, d)
223 | """
224 |
225 | # (n, k, 1) --> (1, k, 1)
226 | n_k = torch.sum(weights, 0, keepdim=True)
227 | pi_new = torch.div(n_k, torch.sum(n_k, 1, keepdim=True) + self.eps)
228 | # (n, k, d) --> (1, k, d)
229 | mu_new = torch.div(torch.sum(weights * x, 0, keepdim=True), n_k + self.eps)
230 | # (n, k, d) --> (1, k, d)
231 | var_new = torch.div(torch.sum(weights * (x - mu_new) * (x - mu_new), 0, keepdim=True), n_k + self.eps)
232 |
233 | return pi_new, mu_new, var_new
234 |
235 |
236 | def __em(self, x):
237 | """
238 | Performs one iteration of the expectation-maximization algorithm by calling the respective subroutines.
239 | args:
240 | x: torch.Tensor (n, k, d)
241 | """
242 |
243 | weights = self.__e_step(self.pi, self.__p_k(x, self.mu, self.var))
244 | pi_new, mu_new, var_new = self.__m_step(x, weights)
245 |
246 | self.__update_pi(pi_new)
247 | self.__update_mu(mu_new)
248 | self.__update_var(var_new)
249 |
250 |
251 | def __score(self, pi, p_k, sum_data=True):
252 | """
253 | Computes the log-likelihood of the data under the model.
254 | args:
255 | pi: torch.Tensor (1, k, 1)
256 | p_k: torch.Tensor (n, k, 1)
257 | """
258 |
259 | weights = pi * p_k
260 | if sum_data:
261 | return torch.sum(torch.log(torch.sum(weights, 1) + self.eps))
262 | else:
263 | return torch.log(torch.sum(weights, 1) + self.eps)
264 |
265 |
266 | def __update_mu(self, mu):
267 | """
268 | Updates mean to the provided value.
269 | args:
270 | mu: torch.FloatTensor
271 | """
272 |
273 | assert mu.size() in [(self.n_components, self.n_features), (1, self.n_components, self.n_features)], "Input mu does not have required tensor dimensions (%i, %i) or (1, %i, %i)" % (self.n_components, self.n_features, self.n_components, self.n_features)
274 |
275 | if mu.size() == (self.n_components, self.n_features):
276 | self.mu = mu.unsqueeze(0)
277 | elif mu.size() == (1, self.n_components, self.n_features):
278 | self.mu.data = mu
279 |
280 |
281 | def __update_var(self, var):
282 | """
283 | Updates variance to the provided value.
284 | args:
285 | var: torch.FloatTensor
286 | """
287 |
288 | assert var.size() in [(self.n_components, self.n_features), (1, self.n_components, self.n_features)], "Input var does not have required tensor dimensions (%i, %i) or (1, %i, %i)" % (self.n_components, self.n_features, self.n_components, self.n_features)
289 |
290 | if var.size() == (self.n_components, self.n_features):
291 | self.var = var.unsqueeze(0)
292 | elif var.size() == (1, self.n_components, self.n_features):
293 | self.var.data = var
294 |
295 |
296 | def __update_pi(self, pi):
297 | """
298 | Updates pi to the provided value.
299 | args:
300 | pi: torch.FloatTensor
301 | """
302 |
303 | assert pi.size() in [(1, self.n_components, 1)], "Input pi does not have required tensor dimensions (%i, %i, %i)" % (1, self.n_components, 1)
304 |
305 | self.pi.data = pi
--------------------------------------------------------------------------------
/Mixtures/main.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('..')
3 |
4 | import os
5 | import torch
6 | import numpy as np
7 | from collections import OrderedDict
8 |
9 | import gmm
10 | import utils
11 | import utils_mixture
12 | import config_bayesian as cfg
13 |
14 |
15 | def feedforward_and_save_mean_var(net, dataloader, task_no, num_ens=1):
16 | cfg.mean_var_dir = "Mixtures/mean_vars/task-{}/".format(task_no)
17 | if not os.path.exists(cfg.mean_var_dir):
18 | os.makedirs(cfg.mean_var_dir, exist_ok=True)
19 | cfg.record_mean_var = True
20 | cfg.record_layers = None # All layers
21 | cfg.record_now = True
22 | cfg.curr_batch_no = 0 # Not required
23 | cfg.curr_epoch_no = 0 # Not required
24 |
25 | net.train() # To get distribution of mean and var
26 | accs = []
27 |
28 | for i, (inputs, labels) in enumerate(dataloader):
29 | inputs, labels = inputs.to(device), labels.to(device)
30 | outputs = torch.zeros(inputs.shape[0], net.num_classes, num_ens).to(device)
31 | kl = 0.0
32 | for j in range(num_ens):
33 | net_out, _kl = net(inputs)
34 | kl += _kl
35 | outputs[:, :, j] = F.log_softmax(net_out, dim=1).data
36 |
37 | log_outputs = utils.logmeanexp(outputs, dim=2)
38 | accs.append(metrics.acc(log_outputs, labels))
39 | return np.mean(accs)
40 |
41 |
42 | def _get_ordered_layer_name(mean_var_path):
43 | # Order files according to creation time
44 | files = os.listdir(mean_var_path)
45 | files = [os.path.join(mean_var_path, f) for f in files]
46 | files.sort(key=os.path.getctime)
47 | layer_names = [f.split('/')[-1].split('.')[0] for f in files]
48 | return layer_names
49 |
50 |
51 | def _get_layer_wise_mean_var_per_task(mean_var_path):
52 | # Order files according to creation time
53 | # To get the correct model architecture
54 | files = os.listdir(mean_var_path)
55 | files = [os.path.join(mean_var_path, f) for f in files]
56 | files.sort(key=os.path.getctime)
57 | layer_names = [f.split('/')[-1].split('.')[0] for f in files]
58 |
59 | mean_var = OrderedDict()
60 | for i in range(len(files)):
61 | data = {}
62 | mean, var = utils.load_mean_std_from_file(files[i])
63 | mean, var = np.vstack(mean), np.vstack(var) # shape is (len(trainset), output shape)
64 | data['mean'] = mean
65 | data['var'] = var
66 | data['mean.mu'] = np.mean(mean, axis=0)
67 | data['var.mu'] = np.mean(var, axis=0)
68 | data['mean.var'] = np.var(mean, axis=0)
69 | data['var.var'] = np.var(var, axis=0)
70 | mean_var[layer_names[i]] = data
71 | return mean_var
72 |
73 |
74 | def get_mean_vars_for_all_tasks(mean_var_dir):
75 | all_tasks = {}
76 | for task in os.listdir(mean_var_dir):
77 | path_to_task = os.path.join(mean_var_dir, task)
78 | mean_var_per_task = _get_layer_wise_mean_var_per_task(path_to_task)
79 | all_tasks[task] = mean_var_per_task
80 | return all_tasks
81 |
82 |
83 | def fit_to_gmm(num_tasks, layer_name, data_type, all_tasks):
84 | data = np.vstack([all_tasks[f'task-{i}'][layer_name][data_type] for i in range(1, num_tasks+1)])
85 | data = torch.tensor(data).float()
86 | #data_mu = torch.cat([torch.tensor(all_tasks[f'task-{i}'][layer_name][data_type+'.mu']).unsqueeze(0) for i in range(1, num_tasks+1)], dim=0).float().unsqueeze(0)
87 | #data_var = torch.cat([torch.tensor(all_tasks[f'task-{i}'][layer_name][data_type+'.var']).unsqueeze(0) for i in range(1, num_tasks+1)], dim=0).float().unsqueeze(0)
88 | model = gmm.GaussianMixture(n_components=num_tasks, n_features=np.prod(data.shape[1:]))#, mu_init=data_mu, var_init=data_var)
89 | data = data[torch.randperm(data.size()[0])] # Shuffling of data
90 | model.fit(data)
91 | return model.predict(data)
92 |
93 |
94 | def main():
95 | num_tasks = 2
96 | weights_dir = "checkpoints/MNIST/bayesian/splitted/2-tasks/"
97 |
98 | loader_task1, loader_task2 = utils_mixture.get_splitmnist_dataloaders(num_tasks)
99 | train_loader_task1 = loader_task1[0]
100 | train_loader_task2 = loader_task2[0]
101 |
102 | net_task1, net_task2 = utils_mixture.get_splitmnist_models(num_tasks, True, weights_dir)
103 | net_task1.cuda()
104 | net_task2.cuda()
105 |
106 | print("Task-1 Accuracy:", feedforward_and_save_mean_var(net_task1, train_loader_task1, task_no=1))
107 | print("Task-2 Accuracy:", feedforward_and_save_mean_var(net_task2, train_loader_task2, task_no=2))
108 |
109 | mean_vars_all_tasks = get_mean_vars_for_all_tasks("Mixtures/mean_vars/")
110 |
111 |
112 |
113 | if __name__ == '__main__':
114 | all_tasks = get_mean_vars_for_all_tasks("Mixtures/mean_vars/")
115 | y = fit_to_gmm(2, 'fc3', 'mean', all_tasks)
116 | print("Cluster0", (1-y).sum())
117 | print("Cluster1", y.sum())
118 |
--------------------------------------------------------------------------------
/Mixtures/mixture_experiment.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('..')
3 |
4 | import os
5 | import datetime
6 | import torch
7 | import contextlib
8 |
9 | from utils_mixture import *
10 | from layers.BBBLinear import BBBLinear
11 |
12 |
13 | @contextlib.contextmanager
14 | def print_to_logfile(file):
15 | # capture all outputs to a log file while still printing it
16 | class Logger:
17 | def __init__(self, file):
18 | self.terminal = sys.stdout
19 | self.log = file
20 |
21 | def write(self, message):
22 | self.terminal.write(message)
23 | self.log.write(message)
24 |
25 | def __getattr__(self, attr):
26 | return getattr(self.terminal, attr)
27 |
28 | logger = Logger(file)
29 |
30 | _stdout = sys.stdout
31 | sys.stdout = logger
32 | try:
33 | yield logger.log
34 | finally:
35 | sys.stdout = _stdout
36 |
37 |
38 | def initiate_experiment(experiment):
39 |
40 | def decorator(*args, **kwargs):
41 | log_file_dir = "experiments/mixtures/"
42 | log_file = log_file_dir + experiment.__name__ + ".txt"
43 | if not os.path.exists(log_file):
44 | os.makedirs(log_file_dir, exist_ok=True)
45 | with print_to_logfile(open(log_file, 'a')):
46 | print("Performing experiment:", experiment.__name__)
47 | print("Date-Time:", datetime.datetime.now())
48 | print("\n", end="")
49 | print("Args:", args)
50 | print("Kwargs:", kwargs)
51 | print("\n", end="")
52 | experiment(*args, **kwargs)
53 | print("\n\n", end="")
54 | return decorator
55 |
56 |
57 | @initiate_experiment
58 | def experiment_regular_prediction_bayesian(weights_dir=None, num_ens=10):
59 | num_tasks = 2
60 | weights_dir = "checkpoints/MNIST/bayesian/splitted/2-tasks/" if weights_dir is None else weights_dir
61 |
62 | loaders1, loaders2 = get_splitmnist_dataloaders(num_tasks)
63 | net1, net2 = get_splitmnist_models(num_tasks, bayesian=True, pretrained=True, weights_dir=weights_dir)
64 | net1.cuda()
65 | net2.cuda()
66 |
67 | print("Model-1, Task-1-Dataset=> Accuracy:", predict_regular(net1, loaders1[1], bayesian=True, num_ens=num_ens))
68 | print("Model-2, Task-2-Dataset=> Accuracy:", predict_regular(net2, loaders2[1], bayesian=True, num_ens=num_ens))
69 |
70 |
71 | @initiate_experiment
72 | def experiment_regular_prediction_frequentist(weights_dir=None):
73 | num_tasks = 2
74 | weights_dir = "checkpoints/MNIST/frequentist/splitted/2-tasks/" if weights_dir is None else weights_dir
75 |
76 | loaders1, loaders2 = get_splitmnist_dataloaders(num_tasks)
77 | net1, net2 = get_splitmnist_models(num_tasks, bayesian=False, pretrained=True, weights_dir=weights_dir)
78 | net1.cuda()
79 | net2.cuda()
80 |
81 | print("Model-1, Task-1-Dataset=> Accuracy:", predict_regular(net1, loaders1[1], bayesian=False))
82 | print("Model-2, Task-2-Dataset=> Accuracy:", predict_regular(net2, loaders2[1], bayesian=False))
83 |
84 |
85 | @initiate_experiment
86 | def experiment_simultaneous_without_mixture_model_with_uncertainty(uncertainty_type="epistemic_softmax", T=25, weights_dir=None):
87 | num_tasks = 2
88 | weights_dir = "checkpoints/MNIST/bayesian/splitted/2-tasks/" if weights_dir is None else weights_dir
89 |
90 | loaders1, loaders2 = get_splitmnist_dataloaders(num_tasks)
91 | net1, net2 = get_splitmnist_models(num_tasks, True, True, weights_dir)
92 | net1.cuda()
93 | net2.cuda()
94 |
95 | print("Both Models, Task-1-Dataset=> Accuracy: {:.3}\tModel-1-Preferred: {:.3}\tModel-2-Preferred: {:.3}\t" \
96 | "Task-1-Dataset-Uncertainty: {:.3}\tTask-2-Dataset-Uncertainty: {:.3}".format(
97 | *predict_using_uncertainty_separate_models(net1, net2, loaders1[1], uncertainty_type=uncertainty_type, T=T)))
98 | print("Both Models, Task-2-Dataset=> Accuracy: {:.3}\tModel-1-Preferred: {:.3}\tModel-2-Preferred: {:.3}\t" \
99 | "Task-1-Dataset-Uncertainty: {:.3}\tTask-2-Dataset-Uncertainty: {:.3}".format(
100 | *predict_using_uncertainty_separate_models(net1, net2, loaders2[1], uncertainty_type=uncertainty_type, T=T)))
101 |
102 |
103 | @initiate_experiment
104 | def experiment_simultaneous_without_mixture_model_with_confidence(weights_dir=None):
105 | num_tasks = 2
106 | weights_dir = "checkpoints/MNIST/frequentist/splitted/2-tasks/" if weights_dir is None else weights_dir
107 |
108 | loaders1, loaders2 = get_splitmnist_dataloaders(num_tasks)
109 | net1, net2 = get_splitmnist_models(num_tasks, False, True, weights_dir)
110 | net1.cuda()
111 | net2.cuda()
112 |
113 | print("Both Models, Task-1-Dataset=> Accuracy: {:.3}\tModel-1-Preferred: {:.3}\tModel-2-Preferred: {:.3}".format(
114 | *predict_using_confidence_separate_models(net1, net2, loaders1[1])))
115 | print("Both Models, Task-2-Dataset=> Accuracy: {:.3}\tModel-1-Preferred: {:.3}\tModel-2-Preferred: {:.3}".format(
116 | *predict_using_confidence_separate_models(net1, net2, loaders2[1])))
117 |
118 |
119 | @initiate_experiment
120 | def wip_experiment_average_weights_mixture_model():
121 | num_tasks = 2
122 | weights_dir = "checkpoints/MNIST/bayesian/splitted/2-tasks/"
123 |
124 | loaders1, loaders2 = get_splitmnist_dataloaders(num_tasks)
125 | net1, net2 = get_splitmnist_models(num_tasks, True, weights_dir)
126 | net1.cuda()
127 | net2.cuda()
128 | net_mix = get_mixture_model(num_tasks, weights_dir, include_last_layer=True)
129 | net_mix.cuda()
130 |
131 | print("Model-1, Loader-1:", calculate_accuracy(net1, loaders1[1]))
132 | print("Model-2, Loader-2:", calculate_accuracy(net2, loaders2[1]))
133 | print("Model-1, Loader-2:", calculate_accuracy(net1, loaders2[1]))
134 | print("Model-2, Loader-1:", calculate_accuracy(net2, loaders1[1]))
135 | print("Model-Mix, Loader-1:", calculate_accuracy(net_mix, loaders1[1]))
136 | print("Model-Mix, Loader-2:", calculate_accuracy(net_mix, loaders2[1]))
137 |
138 |
139 | @initiate_experiment
140 | def wip_experiment_simultaneous_average_weights_mixture_model_with_uncertainty():
141 | num_tasks = 2
142 | weights_dir = "checkpoints/MNIST/bayesian/splitted/2-tasks/"
143 |
144 | loaders1, loaders2 = get_splitmnist_dataloaders(num_tasks)
145 | net1, net2 = get_splitmnist_models(num_tasks, True, weights_dir)
146 | net1.cuda()
147 | net2.cuda()
148 | net_mix = get_mixture_model(num_tasks, weights_dir, include_last_layer=False)
149 | net_mix.cuda()
150 |
151 | # Creating 2 sets of last layer
152 | fc3_1 = BBBLinear(84, 5, name='fc3_1') # hardcoded for lenet
153 | weights_1 = torch.load(weights_dir + "model_lenet_2.1.pt")
154 | fc3_1.W = torch.nn.Parameter(weights_1['fc3.W'])
155 | fc3_1.log_alpha = torch.nn.Parameter(weights_1['fc3.log_alpha'])
156 |
157 | fc3_2 = BBBLinear(84, 5, name='fc3_2') # hardcoded for lenet
158 | weights_2 = torch.load(weights_dir + "model_lenet_2.2.pt")
159 | fc3_2.W = torch.nn.Parameter(weights_2['fc3.W'])
160 | fc3_2.log_alpha = torch.nn.Parameter(weights_2['fc3.log_alpha'])
161 |
162 | fc3_1, fc3_2 = fc3_1.cuda(), fc3_2.cuda()
163 |
164 | print("Model-1, Loader-1:", calculate_accuracy(net1, loaders1[1]))
165 | print("Model-2, Loader-2:", calculate_accuracy(net2, loaders2[1]))
166 | print("Model-Mix, Loader-1:", predict_using_epistemic_uncertainty_with_mixture_model(net_mix, fc3_1, fc3_2, loaders1[1]))
167 | print("Model-Mix, Loader-2:", predict_using_epistemic_uncertainty_with_mixture_model(net_mix, fc3_1, fc3_2, loaders2[1]))
168 |
169 |
170 | if __name__ == '__main__':
171 | experiment_simultaneous_without_mixture_model_with_confidence()
172 |
--------------------------------------------------------------------------------
/Mixtures/temp_gmm.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import seaborn as sns
3 | import matplotlib.pyplot as plt
4 | import torch
5 |
6 | import gmm
7 |
8 | def create_synthetic_data(num_gaussians, num_features, num_samples, means, vars):
9 | assert len(means[0]) == len(vars[0]) == num_features
10 | samples = []
11 | for g in range(num_gaussians):
12 | loc = torch.tensor(means[g]).float()
13 | covariance_matrix = torch.eye(num_features).float() * torch.tensor(vars[g]).float()
14 | dist = torch.distributions.multivariate_normal.MultivariateNormal(
15 | loc = loc, covariance_matrix=covariance_matrix)
16 |
17 | for i in range(num_samples//num_gaussians):
18 | sample = dist.sample()
19 | samples.append(sample.unsqueeze(0))
20 |
21 | samples = torch.cat(samples, axis=0)
22 | return samples
23 |
24 |
25 | def plot_data(data, y=None):
26 | if y is not None:
27 | for sample, target in zip(data, y):
28 | if target==0:
29 | plt.scatter(*sample, color='blue')
30 | elif target==1:
31 | plt.scatter(*sample, color='red')
32 | elif target==2:
33 | plt.scatter(*sample, color='green')
34 | else:
35 | for sample in data:
36 | plt.scatter(*sample, color='black')
37 | plt.show(block=False)
38 | plt.pause(2)
39 | plt.close()
40 |
41 |
42 | means = [[1, 4], [5, 5], [2, -1]] # list of task's means(which is mean of each feature)
43 | vars = [[0.1, 0.1], [0.05, 0.4], [0.5, 0.2]] # list of task's vars(which is var of each feature)
44 | data = create_synthetic_data(3, 2, 600, means, vars) # shape: (total_samples, num_features)
45 | plot_data(data)
46 | model = gmm.GaussianMixture(3, 2) # (num_gaussians, num_features)
47 | model.fit(data)
48 | y = model.predict(data)
49 | plot_data(data, y)
50 |
51 |
--------------------------------------------------------------------------------
/Mixtures/train_splitted.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('..')
3 |
4 | import os
5 | import argparse
6 | import torch
7 | from torch import nn
8 | import numpy as np
9 | from torch.optim import Adam
10 |
11 | import utils
12 | import metrics
13 | import config_mixtures as cfg
14 | import utils_mixture as mix_utils
15 | from main_bayesian import train_model as train_bayesian, validate_model as validate_bayesian
16 | from main_frequentist import train_model as train_frequentist, validate_model as validate_frequentist
17 |
18 | # CUDA settings
19 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
20 |
21 |
22 | def train_splitted(num_tasks, bayesian=True, net_type='lenet'):
23 | assert 10 % num_tasks == 0
24 |
25 | # Hyper Parameter settings
26 | train_ens = cfg.train_ens
27 | valid_ens = cfg.valid_ens
28 | n_epochs = cfg.n_epochs
29 | lr_start = cfg.lr_start
30 |
31 | if bayesian:
32 | ckpt_dir = f"checkpoints/MNIST/bayesian/splitted/{num_tasks}-tasks/"
33 | else:
34 | ckpt_dir = f"checkpoints/MNIST/frequentist/splitted/{num_tasks}-tasks/"
35 | if not os.path.exists(ckpt_dir):
36 | os.makedirs(ckpt_dir, exist_ok=True)
37 |
38 | loaders, datasets = mix_utils.get_splitmnist_dataloaders(num_tasks, return_datasets=True)
39 | models = mix_utils.get_splitmnist_models(num_tasks, bayesian=bayesian, pretrained=False, net_type=net_type)
40 |
41 | for task in range(1, num_tasks + 1):
42 | print(f"Training task-{task}..")
43 | trainset, testset, _, _ = datasets[task-1]
44 | train_loader, valid_loader, _ = loaders[task-1]
45 | net = models[task-1]
46 | net = net.to(device)
47 | ckpt_name = ckpt_dir + f"model_{net_type}_{num_tasks}.{task}.pt"
48 |
49 | criterion = (metrics.ELBO(len(trainset)) if bayesian else nn.CrossEntropyLoss()).to(device)
50 | optimizer = Adam(net.parameters(), lr=lr_start)
51 | valid_loss_max = np.Inf
52 | for epoch in range(n_epochs): # loop over the dataset multiple times
53 | utils.adjust_learning_rate(optimizer, metrics.lr_linear(epoch, 0, n_epochs, lr_start))
54 |
55 | if bayesian:
56 | train_loss, train_acc, train_kl = train_bayesian(net, optimizer, criterion, train_loader, num_ens=train_ens)
57 | valid_loss, valid_acc = validate_bayesian(net, criterion, valid_loader, num_ens=valid_ens)
58 | print('Epoch: {} \tTraining Loss: {:.4f} \tTraining Accuracy: {:.4f} \tValidation Loss: {:.4f} \tValidation Accuracy: {:.4f} \ttrain_kl_div: {:.4f}'.format(
59 | epoch, train_loss, train_acc, valid_loss, valid_acc, train_kl))
60 | else:
61 | train_loss, train_acc = train_frequentist(net, optimizer, criterion, train_loader)
62 | valid_loss, valid_acc = validate_frequentist(net, criterion, valid_loader)
63 | print('Epoch: {} \tTraining Loss: {:.4f} \tTraining Accuracy: {:.4f} \tValidation Loss: {:.4f} \tValidation Accuracy: {:.4f}'.format(
64 | epoch, train_loss, train_acc, valid_loss, valid_acc))
65 |
66 | # save model if validation accuracy has increased
67 | if valid_loss <= valid_loss_max:
68 | print('Validation loss decreased ({:.6f} --> {:.6f}). Saving model ...'.format(
69 | valid_loss_max, valid_loss))
70 | torch.save(net.state_dict(), ckpt_name)
71 | valid_loss_max = valid_loss
72 |
73 | print(f"Done training task-{task}")
74 |
75 |
76 | if __name__ == '__main__':
77 | parser = argparse.ArgumentParser(description = "PyTorch Bayesian Split Model Training")
78 | parser.add_argument('--num_tasks', default=2, type=int, help='number of tasks')
79 | parser.add_argument('--net_type', default='lenet', type=str, help='model')
80 | parser.add_argument('--bayesian', default=1, type=int, help='is_bayesian_model(0/1)')
81 | args = parser.parse_args()
82 |
83 | train_splitted(args.num_tasks, bool(args.bayesian), args.net_type)
84 |
--------------------------------------------------------------------------------
/Mixtures/utils_mixture.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('..')
3 |
4 | import os
5 | import torch
6 | import numpy as np
7 | import torch.nn as nn
8 | from torch.nn import functional as F
9 |
10 | import data
11 | import utils
12 | import metrics
13 | from main_bayesian import getModel as getBayesianModel
14 | from main_frequentist import getModel as getFrequentistModel
15 | import config_mixtures as cfg
16 | import uncertainty_estimation as ue
17 |
18 |
19 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
20 |
21 |
22 | class Pass(nn.Module):
23 | def __init__(self):
24 | super(Pass, self).__init__()
25 |
26 | def forward(self, x):
27 | return x
28 |
29 |
30 | def _get_splitmnist_datasets(num_tasks):
31 | datasets = []
32 | for i in range(1, num_tasks + 1):
33 | name = 'SplitMNIST-{}.{}'.format(num_tasks, i)
34 | datasets.append(data.getDataset(name))
35 | return datasets
36 |
37 |
38 | def get_splitmnist_dataloaders(num_tasks, return_datasets=False):
39 | loaders = []
40 | datasets = _get_splitmnist_datasets(num_tasks)
41 | for i in range(1, num_tasks + 1):
42 | trainset, testset, _, _ = datasets[i-1]
43 | curr_loaders = data.getDataloader(
44 | trainset, testset, cfg.valid_size, cfg.batch_size, cfg.num_workers)
45 | loaders.append(curr_loaders) # (train_loader, valid_loader, test_loader)
46 | if return_datasets:
47 | return loaders, datasets
48 | return loaders
49 |
50 |
51 | def get_splitmnist_models(num_tasks, bayesian=True, pretrained=False, weights_dir=None, net_type='lenet'):
52 | inputs = 1
53 | outputs = 10 // num_tasks
54 | models = []
55 | if pretrained:
56 | assert weights_dir
57 | for i in range(1, num_tasks + 1):
58 | if bayesian:
59 | model = getBayesianModel(net_type, inputs, outputs)
60 | else:
61 | model = getFrequentistModel(net_type, inputs, outputs)
62 | models.append(model)
63 | if pretrained:
64 | weight_path = weights_dir + f"model_{net_type}_{num_tasks}.{i}.pt"
65 | models[-1].load_state_dict(torch.load(weight_path))
66 | return models
67 |
68 |
69 | def get_mixture_model(num_tasks, weights_dir, net_type='lenet', include_last_layer=True):
70 | """
71 | Current implementation is based on average value of weights
72 | """
73 | net = getBayesianModel(net_type, 1, 5)
74 | if not include_last_layer:
75 | net.fc3 = Pass()
76 |
77 | task_weights = []
78 | for i in range(1, num_tasks + 1):
79 | weight_path = weights_dir + f"model_{net_type}_{num_tasks}.{i}.pt"
80 | task_weights.append(torch.load(weight_path))
81 |
82 | mixture_weights = net.state_dict().copy()
83 | layer_list = list(mixture_weights.keys())
84 |
85 | for key in mixture_weights:
86 | if key in layer_list:
87 | concat_weights = torch.cat([w[key].unsqueeze(0) for w in task_weights] , dim=0)
88 | average_weight = torch.mean(concat_weights, dim=0)
89 | mixture_weights[key] = average_weight
90 |
91 | net.load_state_dict(mixture_weights)
92 | return net
93 |
94 |
95 | def predict_regular(net, validloader, bayesian=True, num_ens=10):
96 | """
97 | For both Bayesian and Frequentist models
98 | """
99 | net.eval()
100 | accs = []
101 |
102 | for i, (inputs, labels) in enumerate(validloader):
103 | inputs, labels = inputs.to(device), labels.to(device)
104 | if bayesian:
105 | outputs = torch.zeros(inputs.shape[0], net.num_classes, num_ens).to(device)
106 | for j in range(num_ens):
107 | net_out, _ = net(inputs)
108 | outputs[:, :, j] = F.log_softmax(net_out, dim=1).data
109 |
110 | log_outputs = utils.logmeanexp(outputs, dim=2)
111 | accs.append(metrics.acc(log_outputs, labels))
112 | else:
113 | output = net(inputs)
114 | accs.append(metrics.acc(output.detach(), labels))
115 |
116 | return np.mean(accs)
117 |
118 |
119 | def predict_using_uncertainty_separate_models(net1, net2, valid_loader, uncertainty_type='epistemic_softmax', T=25):
120 | """
121 | For Bayesian models
122 | """
123 | accs = []
124 | total_u1 = 0.0
125 | total_u2 = 0.0
126 | set1_selected = 0
127 | set2_selected = 0
128 |
129 | epi_or_ale, soft_or_norm = uncertainty_type.split('_')
130 | soft_or_norm = True if soft_or_norm=='normalized' else False
131 |
132 | for i, (inputs, labels) in enumerate(valid_loader):
133 | inputs, labels = inputs.to(device), labels.to(device)
134 | pred1, epi1, ale1 = ue.get_uncertainty_per_batch(net1, inputs, T=T, normalized=soft_or_norm)
135 | pred2, epi2, ale2 = ue.get_uncertainty_per_batch(net2, inputs, T=T, normalized=soft_or_norm)
136 |
137 | if epi_or_ale=='epistemic':
138 | u1 = np.sum(epi1, axis=1)
139 | u2 = np.sum(epi2, axis=1)
140 | elif epi_or_ale=='aleatoric':
141 | u1 = np.sum(ale1, axis=1)
142 | u2 = np.sum(ale2, axis=1)
143 | elif epi_or_ale=='both':
144 | u1 = np.sum(epi1, axis=1) + np.sum(ale1, axis=1)
145 | u2 = np.sum(epi2, axis=1) + np.sum(ale2, axis=1)
146 | else:
147 | raise ValueError("Not correct uncertainty type")
148 |
149 | total_u1 += np.sum(u1).item()
150 | total_u2 += np.sum(u2).item()
151 |
152 | set1_preferred = u2 > u1 # idx where set1 has less uncertainty
153 | set1_preferred = np.expand_dims(set1_preferred, 1)
154 | preds = np.where(set1_preferred, pred1, pred2)
155 |
156 | set1_selected += np.sum(set1_preferred)
157 | set2_selected += np.sum(~set1_preferred)
158 |
159 | accs.append(metrics.acc(torch.tensor(preds), labels))
160 |
161 | return np.mean(accs), set1_selected/(set1_selected + set2_selected), \
162 | set2_selected/(set1_selected + set2_selected), total_u1, total_u2
163 |
164 |
165 | def predict_using_confidence_separate_models(net1, net2, valid_loader):
166 | """
167 | For Frequentist models
168 | """
169 | accs = []
170 | set1_selected = 0
171 | set2_selected = 0
172 |
173 | for i, (inputs, labels) in enumerate(valid_loader):
174 | inputs, labels = inputs.to(device), labels.to(device)
175 | pred1 = F.softmax(net1(inputs), dim=1)
176 | pred2 = F.softmax(net2(inputs), dim=1)
177 |
178 | set1_preferred = pred1.max(dim=1)[0] > pred2.max(dim=1)[0] # idx where set1 has more confidence
179 | preds = torch.where(set1_preferred.unsqueeze(1), pred1, pred2)
180 |
181 | set1_selected += torch.sum(set1_preferred).float().item()
182 | set2_selected += torch.sum(~set1_preferred).float().item()
183 |
184 | accs.append(metrics.acc(preds.detach(), labels))
185 |
186 | return np.mean(accs), set1_selected/(set1_selected + set2_selected), \
187 | set2_selected/(set1_selected + set2_selected)
188 |
189 |
190 | def wip_predict_using_epistemic_uncertainty_with_mixture_model(model, fc3_1, fc3_2, valid_loader, T=10):
191 | accs = []
192 | total_epistemic_1 = 0.0
193 | total_epistemic_2 = 0.0
194 | set_1_selected = 0
195 | set_2_selected = 0
196 |
197 | for i, (inputs, labels) in enumerate(valid_loader):
198 | inputs, labels = inputs.to(device), labels.to(device)
199 | outputs = []
200 | for i in range(inputs.shape[0]): # loop over batch
201 | input_image = inputs[i].unsqueeze(0)
202 |
203 | p_hat_1 = []
204 | p_hat_2 = []
205 | preds_1 = []
206 | preds_2 = []
207 | for t in range(T):
208 | net_out_mix, _ = model(input_image)
209 |
210 | # set_1
211 | net_out_1 = fc3_1(net_out_mix)
212 | preds_1.append(net_out_1)
213 | prediction = F.softplus(net_out_1)
214 | prediction = prediction / torch.sum(prediction, dim=1)
215 | p_hat_1.append(prediction.cpu().detach())
216 |
217 | # set_2
218 | net_out_2 = fc3_2(net_out_mix)
219 | preds_2.append(net_out_2)
220 | prediction = F.softplus(net_out_2)
221 | prediction = prediction / torch.sum(prediction, dim=1)
222 | p_hat_2.append(prediction.cpu().detach())
223 |
224 | # set_1
225 | p_hat = torch.cat(p_hat_1, dim=0).numpy()
226 | p_bar = np.mean(p_hat, axis=0)
227 |
228 | preds = torch.cat(preds_1, dim=0)
229 | pred_set_1 = torch.sum(preds, dim=0).unsqueeze(0)
230 |
231 | temp = p_hat - np.expand_dims(p_bar, 0)
232 | epistemic = np.dot(temp.T, temp) / T
233 | epistemic_set_1 = np.sum(np.diag(epistemic)).item()
234 | total_epistemic_1 += epistemic_set_1
235 |
236 | # set_2
237 | p_hat = torch.cat(p_hat_2, dim=0).numpy()
238 | p_bar = np.mean(p_hat, axis=0)
239 |
240 | preds = torch.cat(preds_2, dim=0)
241 | pred_set_2 = torch.sum(preds, dim=0).unsqueeze(0)
242 |
243 | temp = p_hat - np.expand_dims(p_bar, 0)
244 | epistemic = np.dot(temp.T, temp) / T
245 | epistemic_set_2 = np.sum(np.diag(epistemic)).item()
246 | total_epistemic_2 += epistemic_set_2
247 |
248 | if epistemic_set_1 > epistemic_set_2:
249 | set_2_selected += 1
250 | outputs.append(pred_set_2)
251 | else:
252 | set_1_selected += 1
253 | outputs.append(pred_set_1)
254 |
255 | outputs = torch.cat(outputs, dim=0)
256 | accs.append(metrics.acc(outputs.detach(), labels))
257 |
258 | return np.mean(accs), set_1_selected/(set_1_selected + set_2_selected), \
259 | set_2_selected/(set_1_selected + set_2_selected), total_epistemic_1, total_epistemic_2
260 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | [](https://www.python.org/downloads/release/python-376/)
3 | [](https://pytorch.org/)
4 | [](https://github.com/kumar-shridhar/PyTorch-BayesianCNN/blob/master/LICENSE)
5 | [](https://arxiv.org/abs/1901.02731)
6 |
7 | We introduce **Bayesian convolutional neural networks with variational inference**, a variant of convolutional neural networks (CNNs), in which the intractable posterior probability distributions over weights are inferred by **Bayes by Backprop**. We demonstrate how our proposed variational inference method achieves performances equivalent to frequentist inference in identical architectures on several datasets (MNIST, CIFAR10, CIFAR100) as described in the [paper](https://arxiv.org/abs/1901.02731).
8 |
9 | ---------------------------------------------------------------------------------------------------------
10 |
11 |
12 | ### Filter weight distributions in a Bayesian Vs Frequentist approach
13 |
14 | 
15 |
16 | ---------------------------------------------------------------------------------------------------------
17 |
18 | ### Fully Bayesian perspective of an entire CNN
19 |
20 | 
21 |
22 | ---------------------------------------------------------------------------------------------------------
23 |
24 |
25 |
26 | ### Layer types
27 |
28 | This repository contains two types of bayesian lauer implementation:
29 | * **BBB (Bayes by Backprop):**
30 | Based on [this paper](https://arxiv.org/abs/1505.05424). This layer samples all the weights individually and then combines them with the inputs to compute a sample from the activations.
31 |
32 | * **BBB_LRT (Bayes by Backprop w/ Local Reparametrization Trick):**
33 | This layer combines Bayes by Backprop with local reparametrization trick from [this paper](https://arxiv.org/abs/1506.02557). This trick makes it possible to directly sample from the distribution over activations.
34 | ---------------------------------------------------------------------------------------------------------
35 |
36 |
37 |
38 | ### Make your custom Bayesian Network?
39 | To make a custom Bayesian Network, inherit `layers.misc.ModuleWrapper` instead of `torch.nn.Module` and use `BBBLinear` and `BBBConv2d` from any of the given layers (`BBB` or `BBB_LRT`) instead of `torch.nn.Linear` and `torch.nn.Conv2d`. Moreover, no need to define `forward` method. It'll automatically be taken care of by `ModuleWrapper`.
40 |
41 | For example:
42 | ```python
43 | class Net(nn.Module):
44 |
45 | def __init__(self):
46 | super().__init__()
47 | self.conv = nn.Conv2d(3, 16, 5, strides=2)
48 | self.bn = nn.BatchNorm2d(16)
49 | self.relu = nn.ReLU()
50 | self.fc = nn.Linear(800, 10)
51 |
52 | def forward(self, x):
53 | x = self.conv(x)
54 | x = self.bn(x)
55 | x = self.relu(x)
56 | x = x.view(-1, 800)
57 | x = self.fc(x)
58 | return x
59 | ```
60 | Above Network can be converted to Bayesian as follows:
61 | ```python
62 | class Net(ModuleWrapper):
63 |
64 | def __init__(self):
65 | super().__init__()
66 | self.conv = BBBConv2d(3, 16, 5, strides=2)
67 | self.bn = nn.BatchNorm2d(16)
68 | self.relu = nn.ReLU()
69 | self.flatten = FlattenLayer(800)
70 | self.fc = BBBLinear(800, 10)
71 | ```
72 |
73 | #### Notes:
74 | 1. Add `FlattenLayer` before first `BBBLinear` block.
75 | 2. `forward` method of the model will return a tuple as `(logits, kl)`.
76 | 3. `priors` can be passed as an argument to the layers. Default value is:
77 | ```python
78 | priors={
79 | 'prior_mu': 0,
80 | 'prior_sigma': 0.1,
81 | 'posterior_mu_initial': (0, 0.1), # (mean, std) normal_
82 | 'posterior_rho_initial': (-3, 0.1), # (mean, std) normal_
83 | }
84 | ```
85 |
86 | ---------------------------------------------------------------------------------------------------------
87 |
88 | ### How to perform standard experiments?
89 | Currently, following datasets and models are supported.
90 | * Datasets: MNIST, CIFAR10, CIFAR100
91 | * Models: AlexNet, LeNet, 3Conv3FC
92 |
93 | #### Bayesian
94 |
95 | `python main_bayesian.py`
96 | * set hyperparameters in `config_bayesian.py`
97 |
98 |
99 | #### Frequentist
100 |
101 | `python main_frequentist.py`
102 | * set hyperparameters in `config_frequentist.py`
103 |
104 | ---------------------------------------------------------------------------------------------------------
105 |
106 |
107 |
108 | ### Directory Structure:
109 | `layers/`: Contains `ModuleWrapper`, `FlattenLayer`, `BBBLinear` and `BBBConv2d`.
110 | `models/BayesianModels/`: Contains standard Bayesian models (BBBLeNet, BBBAlexNet, BBB3Conv3FC).
111 | `models/NonBayesianModels/`: Contains standard Non-Bayesian models (LeNet, AlexNet).
112 | `checkpoints/`: Checkpoint directory: Models will be saved here.
113 | `tests/`: Basic unittest cases for layers and models.
114 | `main_bayesian.py`: Train and Evaluate Bayesian models.
115 | `config_bayesian.py`: Hyperparameters for `main_bayesian` file.
116 | `main_frequentist.py`: Train and Evaluate non-Bayesian (Frequentist) models.
117 | `config_frequentist.py`: Hyperparameters for `main_frequentist` file.
118 |
119 | ---------------------------------------------------------------------------------------------------------
120 |
121 |
122 |
123 | ### Uncertainty Estimation:
124 | There are two types of uncertainties: **Aleatoric** and **Epistemic**.
125 | Aleatoric uncertainty is a measure for the variation of data and Epistemic uncertainty is caused by the model.
126 | Here, two methods are provided in `uncertainty_estimation.py`, those are `'softmax'` & `'normalized'` and are respectively based on equation 4 from [this paper](https://openreview.net/pdf?id=Sk_P2Q9sG) and equation 15 from [this paper](https://arxiv.org/pdf/1806.05978.pdf).
127 | Also, `uncertainty_estimation.py` can be used to compare uncertainties by a Bayesian Neural Network on `MNIST` and `notMNIST` dataset. You can provide arguments like:
128 | 1. `net_type`: `lenet`, `alexnet` or `3conv3fc`. Default is `lenet`.
129 | 2. `weights_path`: Weights for the given `net_type`. Default is `'checkpoints/MNIST/bayesian/model_lenet.pt'`.
130 | 3. `not_mnist_dir`: Directory of `notMNIST` dataset. Default is `'data\'`.
131 | 4. `num_batches`: Number of batches for which uncertainties need to be calculated.
132 |
133 | **Notes**:
134 | 1. You need to download the [notMNIST](http://yaroslavvb.blogspot.com/2011/09/notmnist-dataset.html) dataset from [here](http://yaroslavvb.com/upload/notMNIST/notMNIST_small.tar.gz).
135 | 2. Parameters `layer_type` and `activation_type` used in `uncertainty_etimation.py` needs to be set from `config_bayesian.py` in order to match with provided weights.
136 |
137 | ---------------------------------------------------------------------------------------------------------
138 |
139 |
140 |
141 | If you are using this work, please cite:
142 |
143 | ```
144 | @article{shridhar2019comprehensive,
145 | title={A comprehensive guide to bayesian convolutional neural network with variational inference},
146 | author={Shridhar, Kumar and Laumann, Felix and Liwicki, Marcus},
147 | journal={arXiv preprint arXiv:1901.02731},
148 | year={2019}
149 | }
150 | ```
151 |
152 | ```
153 | @article{shridhar2018uncertainty,
154 | title={Uncertainty estimations by softplus normalization in bayesian convolutional neural networks with variational inference},
155 | author={Shridhar, Kumar and Laumann, Felix and Liwicki, Marcus},
156 | journal={arXiv preprint arXiv:1806.05978},
157 | year={2018}
158 | }
159 | }
160 | ```
161 |
162 | --------------------------------------------------------------------------------------------------------
163 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kumar-shridhar/PyTorch-BayesianCNN/d93bad543c3226cd0fe05c0cb0ba033c41b3caa6/__init__.py
--------------------------------------------------------------------------------
/config_bayesian.py:
--------------------------------------------------------------------------------
1 | ############### Configuration file for Bayesian ###############
2 | layer_type = 'lrt' # 'bbb' or 'lrt'
3 | activation_type = 'softplus' # 'softplus' or 'relu'
4 | priors={
5 | 'prior_mu': 0,
6 | 'prior_sigma': 0.1,
7 | 'posterior_mu_initial': (0, 0.1), # (mean, std) normal_
8 | 'posterior_rho_initial': (-5, 0.1), # (mean, std) normal_
9 | }
10 |
11 | n_epochs = 200
12 | lr_start = 0.001
13 | num_workers = 4
14 | valid_size = 0.2
15 | batch_size = 256
16 | train_ens = 1
17 | valid_ens = 1
18 | beta_type = 0.1 # 'Blundell', 'Standard', etc. Use float for const value
19 |
--------------------------------------------------------------------------------
/config_frequentist.py:
--------------------------------------------------------------------------------
1 | ############### Configuration file for Frequentist ###############
2 | n_epochs = 200
3 | lr = 0.001
4 | num_workers = 4
5 | valid_size = 0.2
6 | batch_size = 256
7 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .data import getDataset
2 | from .data import getDataloader
--------------------------------------------------------------------------------
/data/data.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torchvision
4 | from torch.utils.data import Dataset
5 | import torchvision.transforms as transforms
6 | from torch.utils.data.sampler import SubsetRandomSampler
7 |
8 |
9 | class CustomDataset(Dataset):
10 | def __init__(self, data, labels, transform=None):
11 | self.data = data
12 | self.labels = labels
13 | self.transform = transform
14 |
15 | def __len__(self):
16 | return len(self.labels)
17 |
18 | def __getitem__(self, idx):
19 | sample = self.data[idx]
20 | label = self.labels[idx]
21 | if self.transform:
22 | sample = self.transform(sample)
23 |
24 | return sample, label
25 |
26 |
27 | def extract_classes(dataset, classes):
28 | idx = torch.zeros_like(dataset.targets, dtype=torch.bool)
29 | for target in classes:
30 | idx = idx | (dataset.targets==target)
31 |
32 | data, targets = dataset.data[idx], dataset.targets[idx]
33 | return data, targets
34 |
35 |
36 | def getDataset(dataset):
37 | transform_split_mnist = transforms.Compose([
38 | transforms.ToPILImage(),
39 | transforms.Resize((32, 32)),
40 | transforms.ToTensor(),
41 | ])
42 |
43 | transform_mnist = transforms.Compose([
44 | transforms.Resize((32, 32)),
45 | transforms.ToTensor(),
46 | ])
47 |
48 | transform_cifar = transforms.Compose([
49 | transforms.Resize((32, 32)),
50 | transforms.RandomHorizontalFlip(),
51 | transforms.ToTensor(),
52 | ])
53 |
54 | if(dataset == 'CIFAR10'):
55 | trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_cifar)
56 | testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_cifar)
57 | num_classes = 10
58 | inputs=3
59 |
60 | elif(dataset == 'CIFAR100'):
61 | trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_cifar)
62 | testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_cifar)
63 | num_classes = 100
64 | inputs = 3
65 |
66 | elif(dataset == 'MNIST'):
67 | trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform_mnist)
68 | testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform_mnist)
69 | num_classes = 10
70 | inputs = 1
71 |
72 | elif(dataset == 'SplitMNIST-2.1'):
73 | trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform_mnist)
74 | testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform_mnist)
75 |
76 | train_data, train_targets = extract_classes(trainset, [0, 1, 2, 3, 4])
77 | test_data, test_targets = extract_classes(testset, [0, 1, 2, 3, 4])
78 |
79 | trainset = CustomDataset(train_data, train_targets, transform=transform_split_mnist)
80 | testset = CustomDataset(test_data, test_targets, transform=transform_split_mnist)
81 | num_classes = 5
82 | inputs = 1
83 |
84 | elif(dataset == 'SplitMNIST-2.2'):
85 | trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform_mnist)
86 | testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform_mnist)
87 |
88 | train_data, train_targets = extract_classes(trainset, [5, 6, 7, 8, 9])
89 | test_data, test_targets = extract_classes(testset, [5, 6, 7, 8, 9])
90 | train_targets -= 5 # Mapping target 5-9 to 0-4
91 | test_targets -= 5 # Hence, add 5 after prediction
92 |
93 | trainset = CustomDataset(train_data, train_targets, transform=transform_split_mnist)
94 | testset = CustomDataset(test_data, test_targets, transform=transform_split_mnist)
95 | num_classes = 5
96 | inputs = 1
97 |
98 | elif(dataset == 'SplitMNIST-5.1'):
99 | trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform_mnist)
100 | testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform_mnist)
101 |
102 | train_data, train_targets = extract_classes(trainset, [0, 1])
103 | test_data, test_targets = extract_classes(testset, [0, 1])
104 |
105 | trainset = CustomDataset(train_data, train_targets, transform=transform_split_mnist)
106 | testset = CustomDataset(test_data, test_targets, transform=transform_split_mnist)
107 | num_classes = 2
108 | inputs = 1
109 |
110 | elif(dataset == 'SplitMNIST-5.2'):
111 | trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform_mnist)
112 | testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform_mnist)
113 |
114 | train_data, train_targets = extract_classes(trainset, [2, 3])
115 | test_data, test_targets = extract_classes(testset, [2, 3])
116 | train_targets -= 2 # Mapping target 2-3 to 0-1
117 | test_targets -= 2 # Hence, add 2 after prediction
118 |
119 | trainset = CustomDataset(train_data, train_targets, transform=transform_split_mnist)
120 | testset = CustomDataset(test_data, test_targets, transform=transform_split_mnist)
121 | num_classes = 2
122 | inputs = 1
123 |
124 | elif(dataset == 'SplitMNIST-5.3'):
125 | trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform_mnist)
126 | testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform_mnist)
127 |
128 | train_data, train_targets = extract_classes(trainset, [4, 5])
129 | test_data, test_targets = extract_classes(testset, [4, 5])
130 | train_targets -= 4 # Mapping target 4-5 to 0-1
131 | test_targets -= 4 # Hence, add 4 after prediction
132 |
133 | trainset = CustomDataset(train_data, train_targets, transform=transform_split_mnist)
134 | testset = CustomDataset(test_data, test_targets, transform=transform_split_mnist)
135 | num_classes = 2
136 | inputs = 1
137 |
138 | elif(dataset == 'SplitMNIST-5.4'):
139 | trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform_mnist)
140 | testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform_mnist)
141 |
142 | train_data, train_targets = extract_classes(trainset, [6, 7])
143 | test_data, test_targets = extract_classes(testset, [6, 7])
144 | train_targets -= 6 # Mapping target 6-7 to 0-1
145 | test_targets -= 6 # Hence, add 6 after prediction
146 |
147 | trainset = CustomDataset(train_data, train_targets, transform=transform_split_mnist)
148 | testset = CustomDataset(test_data, test_targets, transform=transform_split_mnist)
149 | num_classes = 2
150 | inputs = 1
151 |
152 | elif(dataset == 'SplitMNIST-5.5'):
153 | trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform_mnist)
154 | testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform_mnist)
155 |
156 | train_data, train_targets = extract_classes(trainset, [8, 9])
157 | test_data, test_targets = extract_classes(testset, [8, 9])
158 | train_targets -= 8 # Mapping target 8-9 to 0-1
159 | test_targets -= 8 # Hence, add 8 after prediction
160 |
161 | trainset = CustomDataset(train_data, train_targets, transform=transform_split_mnist)
162 | testset = CustomDataset(test_data, test_targets, transform=transform_split_mnist)
163 | num_classes = 2
164 | inputs = 1
165 |
166 | return trainset, testset, inputs, num_classes
167 |
168 |
169 | def getDataloader(trainset, testset, valid_size, batch_size, num_workers):
170 | num_train = len(trainset)
171 | indices = list(range(num_train))
172 | np.random.shuffle(indices)
173 | split = int(np.floor(valid_size * num_train))
174 | train_idx, valid_idx = indices[split:], indices[:split]
175 |
176 | train_sampler = SubsetRandomSampler(train_idx)
177 | valid_sampler = SubsetRandomSampler(valid_idx)
178 |
179 | train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
180 | sampler=train_sampler, num_workers=num_workers)
181 | valid_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
182 | sampler=valid_sampler, num_workers=num_workers)
183 | test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
184 | num_workers=num_workers)
185 |
186 | return train_loader, valid_loader, test_loader
187 |
--------------------------------------------------------------------------------
/experiments/figures/BayesCNNwithdist.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kumar-shridhar/PyTorch-BayesianCNN/d93bad543c3226cd0fe05c0cb0ba033c41b3caa6/experiments/figures/BayesCNNwithdist.png
--------------------------------------------------------------------------------
/experiments/figures/CNNwithdist_git.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kumar-shridhar/PyTorch-BayesianCNN/d93bad543c3226cd0fe05c0cb0ba033c41b3caa6/experiments/figures/CNNwithdist_git.png
--------------------------------------------------------------------------------
/experiments/figures/fc3-node_0-both-distplot.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kumar-shridhar/PyTorch-BayesianCNN/d93bad543c3226cd0fe05c0cb0ba033c41b3caa6/experiments/figures/fc3-node_0-both-distplot.gif
--------------------------------------------------------------------------------
/experiments/figures/fc3-node_0-mean-distplot.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kumar-shridhar/PyTorch-BayesianCNN/d93bad543c3226cd0fe05c0cb0ba033c41b3caa6/experiments/figures/fc3-node_0-mean-distplot.gif
--------------------------------------------------------------------------------
/experiments/figures/fc3-node_0-mean-lineplot.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kumar-shridhar/PyTorch-BayesianCNN/d93bad543c3226cd0fe05c0cb0ba033c41b3caa6/experiments/figures/fc3-node_0-mean-lineplot.jpg
--------------------------------------------------------------------------------
/experiments/figures/fc3-node_0-std-distplot.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kumar-shridhar/PyTorch-BayesianCNN/d93bad543c3226cd0fe05c0cb0ba033c41b3caa6/experiments/figures/fc3-node_0-std-distplot.gif
--------------------------------------------------------------------------------
/experiments/figures/fc3-node_0-std-lineplot.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kumar-shridhar/PyTorch-BayesianCNN/d93bad543c3226cd0fe05c0cb0ba033c41b3caa6/experiments/figures/fc3-node_0-std-lineplot.jpg
--------------------------------------------------------------------------------
/layers/BBB/BBBConv.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append("..")
3 |
4 | import math
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from torch.nn import Parameter
9 |
10 | from metrics import calculate_kl as KL_DIV
11 | from ..misc import ModuleWrapper
12 |
13 |
14 | class BBBConv2d(ModuleWrapper):
15 | def __init__(self, in_channels, out_channels, kernel_size,
16 | stride=1, padding=0, dilation=1, bias=True, priors=None):
17 |
18 | super(BBBConv2d, self).__init__()
19 | self.in_channels = in_channels
20 | self.out_channels = out_channels
21 | self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size)
22 | self.stride = stride
23 | self.padding = padding
24 | self.dilation = dilation
25 | self.groups = 1
26 | self.use_bias = bias
27 | self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
28 |
29 | if priors is None:
30 | priors = {
31 | 'prior_mu': 0,
32 | 'prior_sigma': 0.1,
33 | 'posterior_mu_initial': (0, 0.1),
34 | 'posterior_rho_initial': (-3, 0.1),
35 | }
36 | self.prior_mu = priors['prior_mu']
37 | self.prior_sigma = priors['prior_sigma']
38 | self.posterior_mu_initial = priors['posterior_mu_initial']
39 | self.posterior_rho_initial = priors['posterior_rho_initial']
40 |
41 | self.W_mu = Parameter(torch.empty((out_channels, in_channels, *self.kernel_size), device=self.device))
42 | self.W_rho = Parameter(torch.empty((out_channels, in_channels, *self.kernel_size), device=self.device))
43 |
44 | if self.use_bias:
45 | self.bias_mu = Parameter(torch.empty((out_channels), device=self.device))
46 | self.bias_rho = Parameter(torch.empty((out_channels), device=self.device))
47 | else:
48 | self.register_parameter('bias_mu', None)
49 | self.register_parameter('bias_rho', None)
50 |
51 | self.reset_parameters()
52 |
53 | def reset_parameters(self):
54 | self.W_mu.data.normal_(*self.posterior_mu_initial)
55 | self.W_rho.data.normal_(*self.posterior_rho_initial)
56 |
57 | if self.use_bias:
58 | self.bias_mu.data.normal_(*self.posterior_mu_initial)
59 | self.bias_rho.data.normal_(*self.posterior_rho_initial)
60 |
61 | def forward(self, input, sample=True):
62 | if self.training or sample:
63 | W_eps = torch.empty(self.W_mu.size()).normal_(0, 1).to(self.device)
64 | self.W_sigma = torch.log1p(torch.exp(self.W_rho))
65 | weight = self.W_mu + W_eps * self.W_sigma
66 |
67 | if self.use_bias:
68 | bias_eps = torch.empty(self.bias_mu.size()).normal_(0, 1).to(self.device)
69 | self.bias_sigma = torch.log1p(torch.exp(self.bias_rho))
70 | bias = self.bias_mu + bias_eps * self.bias_sigma
71 | else:
72 | bias = None
73 | else:
74 | weight = self.W_mu
75 | bias = self.bias_mu if self.use_bias else None
76 |
77 | return F.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
78 |
79 | def kl_loss(self):
80 | kl = KL_DIV(self.prior_mu, self.prior_sigma, self.W_mu, self.W_sigma)
81 | if self.use_bias:
82 | kl += KL_DIV(self.prior_mu, self.prior_sigma, self.bias_mu, self.bias_sigma)
83 | return kl
84 |
--------------------------------------------------------------------------------
/layers/BBB/BBBLinear.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append("..")
3 |
4 | import math
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from torch.nn import Parameter
9 |
10 | from metrics import calculate_kl as KL_DIV
11 | from ..misc import ModuleWrapper
12 |
13 |
14 | class BBBLinear(ModuleWrapper):
15 | def __init__(self, in_features, out_features, bias=True, priors=None):
16 | super(BBBLinear, self).__init__()
17 | self.in_features = in_features
18 | self.out_features = out_features
19 | self.use_bias = bias
20 | self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
21 |
22 | if priors is None:
23 | priors = {
24 | 'prior_mu': 0,
25 | 'prior_sigma': 0.1,
26 | 'posterior_mu_initial': (0, 0.1),
27 | 'posterior_rho_initial': (-3, 0.1),
28 | }
29 | self.prior_mu = priors['prior_mu']
30 | self.prior_sigma = priors['prior_sigma']
31 | self.posterior_mu_initial = priors['posterior_mu_initial']
32 | self.posterior_rho_initial = priors['posterior_rho_initial']
33 |
34 | self.W_mu = Parameter(torch.empty((out_features, in_features), device=self.device))
35 | self.W_rho = Parameter(torch.empty((out_features, in_features), device=self.device))
36 |
37 | if self.use_bias:
38 | self.bias_mu = Parameter(torch.empty((out_features), device=self.device))
39 | self.bias_rho = Parameter(torch.empty((out_features), device=self.device))
40 | else:
41 | self.register_parameter('bias_mu', None)
42 | self.register_parameter('bias_rho', None)
43 |
44 | self.reset_parameters()
45 |
46 | def reset_parameters(self):
47 | self.W_mu.data.normal_(*self.posterior_mu_initial)
48 | self.W_rho.data.normal_(*self.posterior_rho_initial)
49 |
50 | if self.use_bias:
51 | self.bias_mu.data.normal_(*self.posterior_mu_initial)
52 | self.bias_rho.data.normal_(*self.posterior_rho_initial)
53 |
54 | def forward(self, input, sample=True):
55 | if self.training or sample:
56 | W_eps = torch.empty(self.W_mu.size()).normal_(0, 1).to(self.device)
57 | self.W_sigma = torch.log1p(torch.exp(self.W_rho))
58 | weight = self.W_mu + W_eps * self.W_sigma
59 |
60 | if self.use_bias:
61 | bias_eps = torch.empty(self.bias_mu.size()).normal_(0, 1).to(self.device)
62 | self.bias_sigma = torch.log1p(torch.exp(self.bias_rho))
63 | bias = self.bias_mu + bias_eps * self.bias_sigma
64 | else:
65 | bias = None
66 | else:
67 | weight = self.W_mu
68 | bias = self.bias_mu if self.use_bias else None
69 |
70 | return F.linear(input, weight, bias)
71 |
72 | def kl_loss(self):
73 | kl = KL_DIV(self.prior_mu, self.prior_sigma, self.W_mu, self.W_sigma)
74 | if self.use_bias:
75 | kl += KL_DIV(self.prior_mu, self.prior_sigma, self.bias_mu, self.bias_sigma)
76 | return kl
77 |
--------------------------------------------------------------------------------
/layers/BBB/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kumar-shridhar/PyTorch-BayesianCNN/d93bad543c3226cd0fe05c0cb0ba033c41b3caa6/layers/BBB/__init__.py
--------------------------------------------------------------------------------
/layers/BBB_LRT/BBBConv.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append("..")
3 |
4 | import math
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from torch.nn import Parameter
9 |
10 | import utils
11 | from metrics import calculate_kl as KL_DIV
12 | import config_bayesian as cfg
13 | from ..misc import ModuleWrapper
14 |
15 |
16 | class BBBConv2d(ModuleWrapper):
17 |
18 | def __init__(self, in_channels, out_channels, kernel_size, stride=1,
19 | padding=0, dilation=1, bias=True, priors=None):
20 | super(BBBConv2d, self).__init__()
21 | self.in_channels = in_channels
22 | self.out_channels = out_channels
23 | self.kernel_size = (kernel_size, kernel_size)
24 | self.stride = stride
25 | self.padding = padding
26 | self.dilation = dilation
27 | self.groups = 1
28 | self.use_bias = bias
29 | self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
30 |
31 | if priors is None:
32 | priors = {
33 | 'prior_mu': 0,
34 | 'prior_sigma': 0.1,
35 | 'posterior_mu_initial': (0, 0.1),
36 | 'posterior_rho_initial': (-3, 0.1),
37 | }
38 | self.prior_mu = priors['prior_mu']
39 | self.prior_sigma = priors['prior_sigma']
40 | self.posterior_mu_initial = priors['posterior_mu_initial']
41 | self.posterior_rho_initial = priors['posterior_rho_initial']
42 |
43 | self.W_mu = Parameter(torch.Tensor(out_channels, in_channels, *self.kernel_size))
44 | self.W_rho = Parameter(torch.Tensor(out_channels, in_channels, *self.kernel_size))
45 | if self.use_bias:
46 | self.bias_mu = Parameter(torch.Tensor(out_channels))
47 | self.bias_rho = Parameter(torch.Tensor(out_channels))
48 | else:
49 | self.register_parameter('bias_mu', None)
50 | self.register_parameter('bias_rho', None)
51 |
52 | self.reset_parameters()
53 |
54 | def reset_parameters(self):
55 | self.W_mu.data.normal_(*self.posterior_mu_initial)
56 | self.W_rho.data.normal_(*self.posterior_rho_initial)
57 |
58 | if self.use_bias:
59 | self.bias_mu.data.normal_(*self.posterior_mu_initial)
60 | self.bias_rho.data.normal_(*self.posterior_rho_initial)
61 |
62 | def forward(self, x, sample=True):
63 |
64 | self.W_sigma = torch.log1p(torch.exp(self.W_rho))
65 | if self.use_bias:
66 | self.bias_sigma = torch.log1p(torch.exp(self.bias_rho))
67 | bias_var = self.bias_sigma ** 2
68 | else:
69 | self.bias_sigma = bias_var = None
70 |
71 | act_mu = F.conv2d(
72 | x, self.W_mu, self.bias_mu, self.stride, self.padding, self.dilation, self.groups)
73 | act_var = 1e-16 + F.conv2d(
74 | x ** 2, self.W_sigma ** 2, bias_var, self.stride, self.padding, self.dilation, self.groups)
75 | act_std = torch.sqrt(act_var)
76 |
77 | if self.training or sample:
78 | eps = torch.empty(act_mu.size()).normal_(0, 1).to(self.device)
79 | return act_mu + act_std * eps
80 | else:
81 | return act_mu
82 |
83 | def kl_loss(self):
84 | kl = KL_DIV(self.prior_mu, self.prior_sigma, self.W_mu, self.W_sigma)
85 | if self.use_bias:
86 | kl += KL_DIV(self.prior_mu, self.prior_sigma, self.bias_mu, self.bias_sigma)
87 | return kl
88 |
--------------------------------------------------------------------------------
/layers/BBB_LRT/BBBLinear.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append("..")
3 |
4 | import math
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from torch.nn import Parameter
9 |
10 | import utils
11 | from metrics import calculate_kl as KL_DIV
12 | import config_bayesian as cfg
13 | from ..misc import ModuleWrapper
14 |
15 |
16 | class BBBLinear(ModuleWrapper):
17 |
18 | def __init__(self, in_features, out_features, bias=True, priors=None):
19 | super(BBBLinear, self).__init__()
20 | self.in_features = in_features
21 | self.out_features = out_features
22 | self.use_bias = bias
23 | self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
24 |
25 | if priors is None:
26 | priors = {
27 | 'prior_mu': 0,
28 | 'prior_sigma': 0.1,
29 | 'posterior_mu_initial': (0, 0.1),
30 | 'posterior_rho_initial': (-3, 0.1),
31 | }
32 | self.prior_mu = priors['prior_mu']
33 | self.prior_sigma = priors['prior_sigma']
34 | self.posterior_mu_initial = priors['posterior_mu_initial']
35 | self.posterior_rho_initial = priors['posterior_rho_initial']
36 |
37 | self.W_mu = Parameter(torch.Tensor(out_features, in_features))
38 | self.W_rho = Parameter(torch.Tensor(out_features, in_features))
39 | if self.use_bias:
40 | self.bias_mu = Parameter(torch.Tensor(out_features))
41 | self.bias_rho = Parameter(torch.Tensor(out_features))
42 | else:
43 | self.register_parameter('bias_mu', None)
44 | self.register_parameter('bias_rho', None)
45 |
46 | self.reset_parameters()
47 |
48 | def reset_parameters(self):
49 | self.W_mu.data.normal_(*self.posterior_mu_initial)
50 | self.W_rho.data.normal_(*self.posterior_rho_initial)
51 |
52 | if self.use_bias:
53 | self.bias_mu.data.normal_(*self.posterior_mu_initial)
54 | self.bias_rho.data.normal_(*self.posterior_rho_initial)
55 |
56 | def forward(self, x, sample=True):
57 |
58 | self.W_sigma = torch.log1p(torch.exp(self.W_rho))
59 | if self.use_bias:
60 | self.bias_sigma = torch.log1p(torch.exp(self.bias_rho))
61 | bias_var = self.bias_sigma ** 2
62 | else:
63 | self.bias_sigma = bias_var = None
64 |
65 | act_mu = F.linear(x, self.W_mu, self.bias_mu)
66 | act_var = 1e-16 + F.linear(x ** 2, self.W_sigma ** 2, bias_var)
67 | act_std = torch.sqrt(act_var)
68 |
69 | if self.training or sample:
70 | eps = torch.empty(act_mu.size()).normal_(0, 1).to(self.device)
71 | return act_mu + act_std * eps
72 | else:
73 | return act_mu
74 |
75 | def kl_loss(self):
76 | kl = KL_DIV(self.prior_mu, self.prior_sigma, self.W_mu, self.W_sigma)
77 | if self.use_bias:
78 | kl += KL_DIV(self.prior_mu, self.prior_sigma, self.bias_mu, self.bias_sigma)
79 | return kl
80 |
--------------------------------------------------------------------------------
/layers/BBB_LRT/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kumar-shridhar/PyTorch-BayesianCNN/d93bad543c3226cd0fe05c0cb0ba033c41b3caa6/layers/BBB_LRT/__init__.py
--------------------------------------------------------------------------------
/layers/__init__.py:
--------------------------------------------------------------------------------
1 | from .BBB.BBBLinear import BBBLinear as BBB_Linear
2 | from .BBB.BBBConv import BBBConv2d as BBB_Conv2d
3 |
4 | from .BBB_LRT.BBBLinear import BBBLinear as BBB_LRT_Linear
5 | from .BBB_LRT.BBBConv import BBBConv2d as BBB_LRT_Conv2d
6 |
7 | from .misc import FlattenLayer, ModuleWrapper
8 |
--------------------------------------------------------------------------------
/layers/__pycache__/BBBConv.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kumar-shridhar/PyTorch-BayesianCNN/d93bad543c3226cd0fe05c0cb0ba033c41b3caa6/layers/__pycache__/BBBConv.cpython-37.pyc
--------------------------------------------------------------------------------
/layers/__pycache__/BBBLinear.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kumar-shridhar/PyTorch-BayesianCNN/d93bad543c3226cd0fe05c0cb0ba033c41b3caa6/layers/__pycache__/BBBLinear.cpython-37.pyc
--------------------------------------------------------------------------------
/layers/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kumar-shridhar/PyTorch-BayesianCNN/d93bad543c3226cd0fe05c0cb0ba033c41b3caa6/layers/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/layers/__pycache__/misc.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kumar-shridhar/PyTorch-BayesianCNN/d93bad543c3226cd0fe05c0cb0ba033c41b3caa6/layers/__pycache__/misc.cpython-37.pyc
--------------------------------------------------------------------------------
/layers/misc.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 |
3 |
4 | class ModuleWrapper(nn.Module):
5 | """Wrapper for nn.Module with support for arbitrary flags and a universal forward pass"""
6 |
7 | def __init__(self):
8 | super(ModuleWrapper, self).__init__()
9 |
10 | def set_flag(self, flag_name, value):
11 | setattr(self, flag_name, value)
12 | for m in self.children():
13 | if hasattr(m, 'set_flag'):
14 | m.set_flag(flag_name, value)
15 |
16 | def forward(self, x):
17 | for module in self.children():
18 | x = module(x)
19 |
20 | kl = 0.0
21 | for module in self.modules():
22 | if hasattr(module, 'kl_loss'):
23 | kl = kl + module.kl_loss()
24 |
25 | return x, kl
26 |
27 |
28 | class FlattenLayer(ModuleWrapper):
29 |
30 | def __init__(self, num_features):
31 | super(FlattenLayer, self).__init__()
32 | self.num_features = num_features
33 |
34 | def forward(self, x):
35 | return x.view(-1, self.num_features)
36 |
--------------------------------------------------------------------------------
/main_bayesian.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import os
4 | import argparse
5 |
6 | import torch
7 | import numpy as np
8 | from torch.optim import Adam, lr_scheduler
9 | from torch.nn import functional as F
10 |
11 | import data
12 | import utils
13 | import metrics
14 | import config_bayesian as cfg
15 | from models.BayesianModels.Bayesian3Conv3FC import BBB3Conv3FC
16 | from models.BayesianModels.BayesianAlexNet import BBBAlexNet
17 | from models.BayesianModels.BayesianLeNet import BBBLeNet
18 |
19 | # CUDA settings
20 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
21 |
22 | def getModel(net_type, inputs, outputs, priors, layer_type, activation_type):
23 | if (net_type == 'lenet'):
24 | return BBBLeNet(outputs, inputs, priors, layer_type, activation_type)
25 | elif (net_type == 'alexnet'):
26 | return BBBAlexNet(outputs, inputs, priors, layer_type, activation_type)
27 | elif (net_type == '3conv3fc'):
28 | return BBB3Conv3FC(outputs, inputs, priors, layer_type, activation_type)
29 | else:
30 | raise ValueError('Network should be either [LeNet / AlexNet / 3Conv3FC')
31 |
32 |
33 | def train_model(net, optimizer, criterion, trainloader, num_ens=1, beta_type=0.1, epoch=None, num_epochs=None):
34 | net.train()
35 | training_loss = 0.0
36 | accs = []
37 | kl_list = []
38 | for i, (inputs, labels) in enumerate(trainloader, 1):
39 |
40 | optimizer.zero_grad()
41 |
42 | inputs, labels = inputs.to(device), labels.to(device)
43 | outputs = torch.zeros(inputs.shape[0], net.num_classes, num_ens).to(device)
44 |
45 | kl = 0.0
46 | for j in range(num_ens):
47 | net_out, _kl = net(inputs)
48 | kl += _kl
49 | outputs[:, :, j] = F.log_softmax(net_out, dim=1)
50 |
51 | kl = kl / num_ens
52 | kl_list.append(kl.item())
53 | log_outputs = utils.logmeanexp(outputs, dim=2)
54 |
55 | beta = metrics.get_beta(i-1, len(trainloader), beta_type, epoch, num_epochs)
56 | loss = criterion(log_outputs, labels, kl, beta)
57 | loss.backward()
58 | optimizer.step()
59 |
60 | accs.append(metrics.acc(log_outputs.data, labels))
61 | training_loss += loss.cpu().data.numpy()
62 | return training_loss/len(trainloader), np.mean(accs), np.mean(kl_list)
63 |
64 |
65 | def validate_model(net, criterion, validloader, num_ens=1, beta_type=0.1, epoch=None, num_epochs=None):
66 | """Calculate ensemble accuracy and NLL Loss"""
67 | net.train()
68 | valid_loss = 0.0
69 | accs = []
70 |
71 | for i, (inputs, labels) in enumerate(validloader):
72 | inputs, labels = inputs.to(device), labels.to(device)
73 | outputs = torch.zeros(inputs.shape[0], net.num_classes, num_ens).to(device)
74 | kl = 0.0
75 | for j in range(num_ens):
76 | net_out, _kl = net(inputs)
77 | kl += _kl
78 | outputs[:, :, j] = F.log_softmax(net_out, dim=1).data
79 |
80 | log_outputs = utils.logmeanexp(outputs, dim=2)
81 |
82 | beta = metrics.get_beta(i-1, len(validloader), beta_type, epoch, num_epochs)
83 | valid_loss += criterion(log_outputs, labels, kl, beta).item()
84 | accs.append(metrics.acc(log_outputs, labels))
85 |
86 | return valid_loss/len(validloader), np.mean(accs)
87 |
88 |
89 | def run(dataset, net_type):
90 |
91 | # Hyper Parameter settings
92 | layer_type = cfg.layer_type
93 | activation_type = cfg.activation_type
94 | priors = cfg.priors
95 |
96 | train_ens = cfg.train_ens
97 | valid_ens = cfg.valid_ens
98 | n_epochs = cfg.n_epochs
99 | lr_start = cfg.lr_start
100 | num_workers = cfg.num_workers
101 | valid_size = cfg.valid_size
102 | batch_size = cfg.batch_size
103 | beta_type = cfg.beta_type
104 |
105 | trainset, testset, inputs, outputs = data.getDataset(dataset)
106 | train_loader, valid_loader, test_loader = data.getDataloader(
107 | trainset, testset, valid_size, batch_size, num_workers)
108 | net = getModel(net_type, inputs, outputs, priors, layer_type, activation_type).to(device)
109 |
110 | ckpt_dir = f'checkpoints/{dataset}/bayesian'
111 | ckpt_name = f'checkpoints/{dataset}/bayesian/model_{net_type}_{layer_type}_{activation_type}.pt'
112 |
113 | if not os.path.exists(ckpt_dir):
114 | os.makedirs(ckpt_dir, exist_ok=True)
115 |
116 | criterion = metrics.ELBO(len(trainset)).to(device)
117 | optimizer = Adam(net.parameters(), lr=lr_start)
118 | lr_sched = lr_scheduler.ReduceLROnPlateau(optimizer, patience=6, verbose=True)
119 | valid_loss_max = np.Inf
120 | for epoch in range(n_epochs): # loop over the dataset multiple times
121 |
122 | train_loss, train_acc, train_kl = train_model(net, optimizer, criterion, train_loader, num_ens=train_ens, beta_type=beta_type, epoch=epoch, num_epochs=n_epochs)
123 | valid_loss, valid_acc = validate_model(net, criterion, valid_loader, num_ens=valid_ens, beta_type=beta_type, epoch=epoch, num_epochs=n_epochs)
124 | lr_sched.step(valid_loss)
125 |
126 | print('Epoch: {} \tTraining Loss: {:.4f} \tTraining Accuracy: {:.4f} \tValidation Loss: {:.4f} \tValidation Accuracy: {:.4f} \ttrain_kl_div: {:.4f}'.format(
127 | epoch, train_loss, train_acc, valid_loss, valid_acc, train_kl))
128 |
129 | # save model if validation accuracy has increased
130 | if valid_loss <= valid_loss_max:
131 | print('Validation loss decreased ({:.6f} --> {:.6f}). Saving model ...'.format(
132 | valid_loss_max, valid_loss))
133 | torch.save(net.state_dict(), ckpt_name)
134 | valid_loss_max = valid_loss
135 |
136 | if __name__ == '__main__':
137 | parser = argparse.ArgumentParser(description = "PyTorch Bayesian Model Training")
138 | parser.add_argument('--net_type', default='lenet', type=str, help='model')
139 | parser.add_argument('--dataset', default='MNIST', type=str, help='dataset = [MNIST/CIFAR10/CIFAR100]')
140 | args = parser.parse_args()
141 |
142 | run(args.dataset, args.net_type)
143 |
--------------------------------------------------------------------------------
/main_frequentist.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import os
4 | import argparse
5 |
6 | import torch
7 | import numpy as np
8 | import torch.nn as nn
9 | from torch.optim import Adam, lr_scheduler
10 |
11 | import data
12 | import utils
13 | import metrics
14 | import config_frequentist as cfg
15 | from models.NonBayesianModels.AlexNet import AlexNet
16 | from models.NonBayesianModels.LeNet import LeNet
17 | from models.NonBayesianModels.ThreeConvThreeFC import ThreeConvThreeFC
18 |
19 | # CUDA settings
20 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
21 |
22 |
23 | def getModel(net_type, inputs, outputs):
24 | if (net_type == 'lenet'):
25 | return LeNet(outputs, inputs)
26 | elif (net_type == 'alexnet'):
27 | return AlexNet(outputs, inputs)
28 | elif (net_type == '3conv3fc'):
29 | return ThreeConvThreeFC(outputs, inputs)
30 | else:
31 | raise ValueError('Network should be either [LeNet / AlexNet / 3Conv3FC')
32 |
33 |
34 | def train_model(net, optimizer, criterion, train_loader):
35 | train_loss = 0.0
36 | net.train()
37 | accs = []
38 | for data, target in train_loader:
39 | data, target = data.to(device), target.to(device)
40 | optimizer.zero_grad()
41 | output = net(data)
42 | loss = criterion(output, target)
43 | loss.backward()
44 | optimizer.step()
45 | train_loss += loss.item()*data.size(0)
46 | accs.append(metrics.acc(output.detach(), target))
47 | return train_loss, np.mean(accs)
48 |
49 |
50 | def validate_model(net, criterion, valid_loader):
51 | valid_loss = 0.0
52 | net.eval()
53 | accs = []
54 | for data, target in valid_loader:
55 | data, target = data.to(device), target.to(device)
56 | output = net(data)
57 | loss = criterion(output, target)
58 | valid_loss += loss.item()*data.size(0)
59 | accs.append(metrics.acc(output.detach(), target))
60 | return valid_loss, np.mean(accs)
61 |
62 |
63 | def run(dataset, net_type):
64 |
65 | # Hyper Parameter settings
66 | n_epochs = cfg.n_epochs
67 | lr = cfg.lr
68 | num_workers = cfg.num_workers
69 | valid_size = cfg.valid_size
70 | batch_size = cfg.batch_size
71 |
72 | trainset, testset, inputs, outputs = data.getDataset(dataset)
73 | train_loader, valid_loader, test_loader = data.getDataloader(
74 | trainset, testset, valid_size, batch_size, num_workers)
75 | net = getModel(net_type, inputs, outputs).to(device)
76 |
77 | ckpt_dir = f'checkpoints/{dataset}/frequentist'
78 | ckpt_name = f'checkpoints/{dataset}/frequentist/model_{net_type}.pt'
79 |
80 | if not os.path.exists(ckpt_dir):
81 | os.makedirs(ckpt_dir, exist_ok=True)
82 |
83 | criterion = nn.CrossEntropyLoss()
84 | optimizer = Adam(net.parameters(), lr=lr)
85 | lr_sched = lr_scheduler.ReduceLROnPlateau(optimizer, patience=6, verbose=True)
86 | valid_loss_min = np.Inf
87 | for epoch in range(1, n_epochs+1):
88 |
89 | train_loss, train_acc = train_model(net, optimizer, criterion, train_loader)
90 | valid_loss, valid_acc = validate_model(net, criterion, valid_loader)
91 | lr_sched.step(valid_loss)
92 |
93 | train_loss = train_loss/len(train_loader.dataset)
94 | valid_loss = valid_loss/len(valid_loader.dataset)
95 |
96 | print('Epoch: {} \tTraining Loss: {:.4f} \tTraining Accuracy: {:.4f} \tValidation Loss: {:.4f} \tValidation Accuracy: {:.4f}'.format(
97 | epoch, train_loss, train_acc, valid_loss, valid_acc))
98 |
99 | # save model if validation loss has decreased
100 | if valid_loss <= valid_loss_min:
101 | print('Validation loss decreased ({:.6f} --> {:.6f}). Saving model ...'.format(
102 | valid_loss_min, valid_loss))
103 | torch.save(net.state_dict(), ckpt_name)
104 | valid_loss_min = valid_loss
105 |
106 |
107 | if __name__ == '__main__':
108 | parser = argparse.ArgumentParser(description = "PyTorch Frequentist Model Training")
109 | parser.add_argument('--net_type', default='lenet', type=str, help='model')
110 | parser.add_argument('--dataset', default='MNIST', type=str, help='dataset = [MNIST/CIFAR10/CIFAR100]')
111 | args = parser.parse_args()
112 |
113 | run(args.dataset, args.net_type)
114 |
--------------------------------------------------------------------------------
/metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch.nn.functional as F
3 | from torch import nn
4 | import torch
5 |
6 |
7 | class ELBO(nn.Module):
8 | def __init__(self, train_size):
9 | super(ELBO, self).__init__()
10 | self.train_size = train_size
11 |
12 | def forward(self, input, target, kl, beta):
13 | assert not target.requires_grad
14 | return F.nll_loss(input, target, reduction='mean') * self.train_size + beta * kl
15 |
16 |
17 | # def lr_linear(epoch_num, decay_start, total_epochs, start_value):
18 | # if epoch_num < decay_start:
19 | # return start_value
20 | # return start_value*float(total_epochs-epoch_num)/float(total_epochs-decay_start)
21 |
22 |
23 | def acc(outputs, targets):
24 | return np.mean(outputs.cpu().numpy().argmax(axis=1) == targets.data.cpu().numpy())
25 |
26 |
27 | def calculate_kl(mu_q, sig_q, mu_p, sig_p):
28 | kl = 0.5 * (2 * torch.log(sig_p / sig_q) - 1 + (sig_q / sig_p).pow(2) + ((mu_p - mu_q) / sig_p).pow(2)).sum()
29 | return kl
30 |
31 |
32 | def get_beta(batch_idx, m, beta_type, epoch, num_epochs):
33 | if type(beta_type) is float:
34 | return beta_type
35 |
36 | if beta_type == "Blundell":
37 | beta = 2 ** (m - (batch_idx + 1)) / (2 ** m - 1)
38 | elif beta_type == "Soenderby":
39 | if epoch is None or num_epochs is None:
40 | raise ValueError('Soenderby method requires both epoch and num_epochs to be passed.')
41 | beta = min(epoch / (num_epochs // 4), 1)
42 | elif beta_type == "Standard":
43 | beta = 1 / m
44 | else:
45 | beta = 0
46 | return beta
47 |
--------------------------------------------------------------------------------
/models/BayesianModels/Bayesian3Conv3FC.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch.nn as nn
3 | from layers import BBB_Linear, BBB_Conv2d
4 | from layers import BBB_LRT_Linear, BBB_LRT_Conv2d
5 | from layers import FlattenLayer, ModuleWrapper
6 |
7 | class BBB3Conv3FC(ModuleWrapper):
8 | """
9 |
10 | Simple Neural Network having 3 Convolution
11 | and 3 FC layers with Bayesian layers.
12 | """
13 | def __init__(self, outputs, inputs, priors, layer_type='lrt', activation_type='softplus'):
14 | super(BBB3Conv3FC, self).__init__()
15 |
16 | self.num_classes = outputs
17 | self.layer_type = layer_type
18 | self.priors = priors
19 |
20 | if layer_type=='lrt':
21 | BBBLinear = BBB_LRT_Linear
22 | BBBConv2d = BBB_LRT_Conv2d
23 | elif layer_type=='bbb':
24 | BBBLinear = BBB_Linear
25 | BBBConv2d = BBB_Conv2d
26 | else:
27 | raise ValueError("Undefined layer_type")
28 |
29 | if activation_type=='softplus':
30 | self.act = nn.Softplus
31 | elif activation_type=='relu':
32 | self.act = nn.ReLU
33 | else:
34 | raise ValueError("Only softplus or relu supported")
35 |
36 | self.conv1 = BBBConv2d(inputs, 32, 5, padding=2, bias=True, priors=self.priors)
37 | self.act1 = self.act()
38 | self.pool1 = nn.MaxPool2d(kernel_size=3, stride=2)
39 |
40 | self.conv2 = BBBConv2d(32, 64, 5, padding=2, bias=True, priors=self.priors)
41 | self.act2 = self.act()
42 | self.pool2 = nn.MaxPool2d(kernel_size=3, stride=2)
43 |
44 | self.conv3 = BBBConv2d(64, 128, 5, padding=1, bias=True, priors=self.priors)
45 | self.act3 = self.act()
46 | self.pool3 = nn.MaxPool2d(kernel_size=3, stride=2)
47 |
48 | self.flatten = FlattenLayer(2 * 2 * 128)
49 | self.fc1 = BBBLinear(2 * 2 * 128, 1000, bias=True, priors=self.priors)
50 | self.act4 = self.act()
51 |
52 | self.fc2 = BBBLinear(1000, 1000, bias=True, priors=self.priors)
53 | self.act5 = self.act()
54 |
55 | self.fc3 = BBBLinear(1000, outputs, bias=True, priors=self.priors)
56 |
--------------------------------------------------------------------------------
/models/BayesianModels/BayesianAlexNet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import math
3 | from layers import BBB_Linear, BBB_Conv2d
4 | from layers import BBB_LRT_Linear, BBB_LRT_Conv2d
5 | from layers import FlattenLayer, ModuleWrapper
6 |
7 |
8 | class BBBAlexNet(ModuleWrapper):
9 | '''The architecture of AlexNet with Bayesian Layers'''
10 |
11 | def __init__(self, outputs, inputs, priors, layer_type='lrt', activation_type='softplus'):
12 | super(BBBAlexNet, self).__init__()
13 |
14 | self.num_classes = outputs
15 | self.layer_type = layer_type
16 | self.priors = priors
17 |
18 | if layer_type=='lrt':
19 | BBBLinear = BBB_LRT_Linear
20 | BBBConv2d = BBB_LRT_Conv2d
21 | elif layer_type=='bbb':
22 | BBBLinear = BBB_Linear
23 | BBBConv2d = BBB_Conv2d
24 | else:
25 | raise ValueError("Undefined layer_type")
26 |
27 | if activation_type=='softplus':
28 | self.act = nn.Softplus
29 | elif activation_type=='relu':
30 | self.act = nn.ReLU
31 | else:
32 | raise ValueError("Only softplus or relu supported")
33 |
34 | self.conv1 = BBBConv2d(inputs, 64, 11, stride=4, padding=5, bias=True, priors=self.priors)
35 | self.act1 = self.act()
36 | self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
37 |
38 | self.conv2 = BBBConv2d(64, 192, 5, padding=2, bias=True, priors=self.priors)
39 | self.act2 = self.act()
40 | self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
41 |
42 | self.conv3 = BBBConv2d(192, 384, 3, padding=1, bias=True, priors=self.priors)
43 | self.act3 = self.act()
44 |
45 | self.conv4 = BBBConv2d(384, 256, 3, padding=1, bias=True, priors=self.priors)
46 | self.act4 = self.act()
47 |
48 | self.conv5 = BBBConv2d(256, 128, 3, padding=1, bias=True, priors=self.priors)
49 | self.act5 = self.act()
50 | self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
51 |
52 | self.flatten = FlattenLayer(1 * 1 * 128)
53 | self.classifier = BBBLinear(1 * 1 * 128, outputs, bias=True, priors=self.priors)
54 |
--------------------------------------------------------------------------------
/models/BayesianModels/BayesianLeNet.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch.nn as nn
3 | from layers import BBB_Linear, BBB_Conv2d
4 | from layers import BBB_LRT_Linear, BBB_LRT_Conv2d
5 | from layers import FlattenLayer, ModuleWrapper
6 |
7 |
8 | class BBBLeNet(ModuleWrapper):
9 | '''The architecture of LeNet with Bayesian Layers'''
10 |
11 | def __init__(self, outputs, inputs, priors, layer_type='lrt', activation_type='softplus'):
12 | super(BBBLeNet, self).__init__()
13 |
14 | self.num_classes = outputs
15 | self.layer_type = layer_type
16 | self.priors = priors
17 |
18 | if layer_type=='lrt':
19 | BBBLinear = BBB_LRT_Linear
20 | BBBConv2d = BBB_LRT_Conv2d
21 | elif layer_type=='bbb':
22 | BBBLinear = BBB_Linear
23 | BBBConv2d = BBB_Conv2d
24 | else:
25 | raise ValueError("Undefined layer_type")
26 |
27 | if activation_type=='softplus':
28 | self.act = nn.Softplus
29 | elif activation_type=='relu':
30 | self.act = nn.ReLU
31 | else:
32 | raise ValueError("Only softplus or relu supported")
33 |
34 | self.conv1 = BBBConv2d(inputs, 6, 5, padding=0, bias=True, priors=self.priors)
35 | self.act1 = self.act()
36 | self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
37 |
38 | self.conv2 = BBBConv2d(6, 16, 5, padding=0, bias=True, priors=self.priors)
39 | self.act2 = self.act()
40 | self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
41 |
42 | self.flatten = FlattenLayer(5 * 5 * 16)
43 | self.fc1 = BBBLinear(5 * 5 * 16, 120, bias=True, priors=self.priors)
44 | self.act3 = self.act()
45 |
46 | self.fc2 = BBBLinear(120, 84, bias=True, priors=self.priors)
47 | self.act4 = self.act()
48 |
49 | self.fc3 = BBBLinear(84, outputs, bias=True, priors=self.priors)
50 |
--------------------------------------------------------------------------------
/models/BayesianModels/__pycache__/Bayesian3Conv3FC.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kumar-shridhar/PyTorch-BayesianCNN/d93bad543c3226cd0fe05c0cb0ba033c41b3caa6/models/BayesianModels/__pycache__/Bayesian3Conv3FC.cpython-37.pyc
--------------------------------------------------------------------------------
/models/BayesianModels/__pycache__/BayesianAlexNet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kumar-shridhar/PyTorch-BayesianCNN/d93bad543c3226cd0fe05c0cb0ba033c41b3caa6/models/BayesianModels/__pycache__/BayesianAlexNet.cpython-37.pyc
--------------------------------------------------------------------------------
/models/BayesianModels/__pycache__/BayesianLeNet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kumar-shridhar/PyTorch-BayesianCNN/d93bad543c3226cd0fe05c0cb0ba033c41b3caa6/models/BayesianModels/__pycache__/BayesianLeNet.cpython-37.pyc
--------------------------------------------------------------------------------
/models/NonBayesianModels/AlexNet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | import numpy as np
4 |
5 | def conv_init(m):
6 | classname = m.__class__.__name__
7 | if classname.find('Conv') != -1:
8 | #nn.init.xavier_uniform(m.weight, gain=np.sqrt(2))
9 | nn.init.normal_(m.weight, mean=0, std=1)
10 | nn.init.constant(m.bias, 0)
11 |
12 | class AlexNet(nn.Module):
13 |
14 | def __init__(self, num_classes, inputs=3):
15 | super(AlexNet, self).__init__()
16 | self.features = nn.Sequential(
17 | nn.Conv2d(inputs, 64, kernel_size=11, stride=4, padding=5),
18 | nn.ReLU(inplace=True),
19 | nn.Dropout(p=0.5),
20 | nn.MaxPool2d(kernel_size=2, stride=2),
21 | nn.Conv2d(64, 192, kernel_size=5, padding=2),
22 | nn.ReLU(inplace=True),
23 | nn.MaxPool2d(kernel_size=2, stride=2),
24 | nn.Conv2d(192, 384, kernel_size=3, padding=1),
25 | nn.ReLU(inplace=True),
26 | nn.Dropout(p=0.5),
27 | nn.Conv2d(384, 256, kernel_size=3, padding=1),
28 | nn.ReLU(inplace=True),
29 | nn.Conv2d(256, 256, kernel_size=3, padding=1),
30 | nn.ReLU(inplace=True),
31 | nn.Dropout(p=0.5),
32 | nn.MaxPool2d(kernel_size=2, stride=2),
33 | )
34 | self.classifier = nn.Linear(256, num_classes)
35 |
36 | def forward(self, x):
37 | x = self.features(x)
38 | x = x.view(x.size(0), -1)
39 | x = self.classifier(x)
40 | return x
41 |
--------------------------------------------------------------------------------
/models/NonBayesianModels/LeNet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | import numpy as np
4 |
5 | def conv_init(m):
6 | classname = m.__class__.__name__
7 | if classname.find('Conv') != -1:
8 | #nn.init.xavier_uniform(m.weight, gain=np.sqrt(2))
9 | nn.init.normal_(m.weight, mean=0, std=1)
10 | nn.init.constant(m.bias, 0)
11 |
12 | class LeNet(nn.Module):
13 | def __init__(self, num_classes, inputs=3):
14 | super(LeNet, self).__init__()
15 | self.conv1 = nn.Conv2d(inputs, 6, 5)
16 | self.conv2 = nn.Conv2d(6, 16, 5)
17 | self.fc1 = nn.Linear(16*5*5, 120)
18 | self.fc2 = nn.Linear(120, 84)
19 | self.fc3 = nn.Linear(84, num_classes)
20 |
21 | def forward(self, x):
22 | out = F.relu(self.conv1(x))
23 | out = F.max_pool2d(out, 2)
24 | out = F.relu(self.conv2(out))
25 | out = F.max_pool2d(out, 2)
26 | out = out.view(out.size(0), -1)
27 | out = F.relu(self.fc1(out))
28 | out = F.relu(self.fc2(out))
29 | out = self.fc3(out)
30 |
31 | return(out)
32 |
--------------------------------------------------------------------------------
/models/NonBayesianModels/ThreeConvThreeFC.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from layers.misc import FlattenLayer
3 |
4 |
5 | def conv_init(m):
6 | classname = m.__class__.__name__
7 | if classname.find('Conv') != -1:
8 | #nn.init.xavier_uniform(m.weight, gain=np.sqrt(2))
9 | nn.init.normal_(m.weight, mean=0, std=1)
10 | nn.init.constant(m.bias, 0)
11 |
12 | class ThreeConvThreeFC(nn.Module):
13 | """
14 | To train on CIFAR-10:
15 | https://arxiv.org/pdf/1207.0580.pdf
16 | """
17 | def __init__(self, outputs, inputs):
18 | super(ThreeConvThreeFC, self).__init__()
19 | self.features = nn.Sequential(
20 | nn.Conv2d(inputs, 32, 5, stride=1, padding=2),
21 | nn.Softplus(),
22 | nn.MaxPool2d(kernel_size=3, stride=2),
23 | nn.Conv2d(32, 64, 5, stride=1, padding=2),
24 | nn.Softplus(),
25 | nn.MaxPool2d(kernel_size=3, stride=2),
26 | nn.Conv2d(64, 128, 5, stride=1, padding=1),
27 | nn.Softplus(),
28 | nn.MaxPool2d(kernel_size=3, stride=2),
29 | )
30 | self.classifier = nn.Sequential(
31 | FlattenLayer(2 * 2 * 128),
32 | nn.Linear(2 * 2 * 128, 1000),
33 | nn.Softplus(),
34 | nn.Linear(1000, 1000),
35 | nn.Softplus(),
36 | nn.Linear(1000, outputs)
37 | )
38 |
39 | def forward(self, x):
40 | x = self.features(x)
41 | x = self.classifier(x)
42 | return x
43 |
--------------------------------------------------------------------------------
/models/NonBayesianModels/__pycache__/AlexNet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kumar-shridhar/PyTorch-BayesianCNN/d93bad543c3226cd0fe05c0cb0ba033c41b3caa6/models/NonBayesianModels/__pycache__/AlexNet.cpython-37.pyc
--------------------------------------------------------------------------------
/models/NonBayesianModels/__pycache__/LeNet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kumar-shridhar/PyTorch-BayesianCNN/d93bad543c3226cd0fe05c0cb0ba033c41b3caa6/models/NonBayesianModels/__pycache__/LeNet.cpython-37.pyc
--------------------------------------------------------------------------------
/models/NonBayesianModels/__pycache__/ThreeConvThreeFC.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kumar-shridhar/PyTorch-BayesianCNN/d93bad543c3226cd0fe05c0cb0ba033c41b3caa6/models/NonBayesianModels/__pycache__/ThreeConvThreeFC.cpython-37.pyc
--------------------------------------------------------------------------------
/tests/test_models.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import math
4 | import pytest
5 | import unittest
6 | import numpy as np
7 | import torch
8 | import torch.nn as nn
9 |
10 | import utils
11 | from models.BayesianModels.Bayesian3Conv3FC import BBB3Conv3FC
12 | from models.BayesianModels.BayesianAlexNet import BBBAlexNet
13 | from models.BayesianModels.BayesianLeNet import BBBLeNet
14 | from models.NonBayesianModels.AlexNet import AlexNet
15 | from models.NonBayesianModels.LeNet import LeNet
16 | from models.NonBayesianModels.ThreeConvThreeFC import ThreeConvThreeFC
17 | from layers.BBBConv import BBBConv2d
18 | from layers.BBBLinear import BBBLinear
19 | from layers.misc import FlattenLayer
20 |
21 | cuda_available = torch.cuda.is_available()
22 | bayesian_models = [BBBLeNet, BBBAlexNet, BBB3Conv3FC]
23 | non_bayesian_models = [LeNet, AlexNet, ThreeConvThreeFC]
24 |
25 | class TestModelForwardpass:
26 |
27 | @pytest.mark.parametrize("model", bayesian_models)
28 | def test_cpu_bayesian(self, model):
29 | batch_size = np.random.randint(1, 256)
30 | batch = torch.randn((batch_size, 3, 32, 32))
31 | net = model(10, 3)
32 | out = net(batch)
33 | assert out[0].shape[0]==batch_size
34 |
35 | @pytest.mark.parametrize("model", non_bayesian_models)
36 | def test_cpu_frequentist(self, model):
37 | batch_size = np.random.randint(1, 256)
38 | batch = torch.randn((batch_size, 3, 32, 32))
39 | net = model(10, 3)
40 | out = net(batch)
41 | assert out.shape[0]==batch_size
42 |
43 | @pytest.mark.skipif(not cuda_available, reason="CUDA not available")
44 | @pytest.mark.parametrize("model", bayesian_models)
45 | def test_gpu_bayesian(self, model):
46 | batch_size = np.random.randint(1, 256)
47 | batch = torch.randn((batch_size, 3, 32, 32))
48 | net = model(10, 3)
49 | if cuda_available:
50 | net = net.cuda()
51 | batch = batch.cuda()
52 | out = net(batch)
53 | assert out[0].shape[0]==batch_size
54 |
55 | @pytest.mark.skipif(not cuda_available, reason="CUDA not available")
56 | @pytest.mark.parametrize("model", non_bayesian_models)
57 | def test_gpu_frequentist(self, model):
58 | batch_size = np.random.randint(1, 256)
59 | batch = torch.randn((batch_size, 3, 32, 32))
60 | net = model(10, 3)
61 | if cuda_available:
62 | net = net.cuda()
63 | batch = batch.cuda()
64 | out = net(batch)
65 | assert out.shape[0]==batch_size
66 |
67 |
68 | class TestBayesianLayers:
69 |
70 | def test_flatten(self):
71 | batch_size = np.random.randint(1, 256)
72 | batch = torch.randn((batch_size, 64, 4, 4))
73 |
74 | layer = FlattenLayer(4 * 4 * 64)
75 | batch = layer(batch)
76 |
77 | assert batch.shape[0]==batch_size
78 | assert batch.shape[1]==(4 * 4 *64)
79 |
80 | def test_conv(self):
81 | batch_size = np.random.randint(1, 256)
82 | batch = torch.randn((batch_size, 16, 24, 24))
83 |
84 | layer = BBBConv2d(16, 6, 4, alpha_shape=(1,1), padding=0, bias=False)
85 | batch = layer(batch)
86 |
87 | assert batch.shape[0]==batch_size
88 | assert batch.shape[1]==6
89 |
90 | def test_linear(self):
91 | batch_size = np.random.randint(1, 256)
92 | batch = torch.randn((batch_size, 128))
93 |
94 | layer = BBBLinear(128, 64, alpha_shape=(1,1), bias=False)
95 | batch = layer(batch)
96 |
97 | assert batch.shape[0]==batch_size
98 | assert batch.shape[1]==64
99 |
--------------------------------------------------------------------------------
/uncertainty_estimation.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import numpy as np
4 | import pandas as pd
5 | import seaborn as sns
6 | from PIL import Image
7 | import torchvision
8 | from torch.nn import functional as F
9 | import torchvision.transforms as transforms
10 | import matplotlib.pyplot as plt
11 |
12 | import data
13 | from main_bayesian import getModel
14 | import config_bayesian as cfg
15 |
16 |
17 | # CUDA settings
18 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
19 |
20 | mnist_set = None
21 | notmnist_set = None
22 |
23 | transform = transforms.Compose([
24 | transforms.ToPILImage(),
25 | transforms.Resize((32, 32)),
26 | transforms.ToTensor(),
27 | ])
28 |
29 |
30 | def init_dataset(notmnist_dir):
31 | global mnist_set
32 | global notmnist_set
33 | mnist_set, _, _, _ = data.getDataset('MNIST')
34 | notmnist_set = torchvision.datasets.ImageFolder(root=notmnist_dir)
35 |
36 |
37 | def get_uncertainty_per_image(model, input_image, T=15, normalized=False):
38 | input_image = input_image.unsqueeze(0)
39 | input_images = input_image.repeat(T, 1, 1, 1)
40 |
41 | net_out, _ = model(input_images)
42 | pred = torch.mean(net_out, dim=0).cpu().detach().numpy()
43 | if normalized:
44 | prediction = F.softplus(net_out)
45 | p_hat = prediction / torch.sum(prediction, dim=1).unsqueeze(1)
46 | else:
47 | p_hat = F.softmax(net_out, dim=1)
48 | p_hat = p_hat.detach().cpu().numpy()
49 | p_bar = np.mean(p_hat, axis=0)
50 |
51 | temp = p_hat - np.expand_dims(p_bar, 0)
52 | epistemic = np.dot(temp.T, temp) / T
53 | epistemic = np.diag(epistemic)
54 |
55 | aleatoric = np.diag(p_bar) - (np.dot(p_hat.T, p_hat) / T)
56 | aleatoric = np.diag(aleatoric)
57 |
58 | return pred, epistemic, aleatoric
59 |
60 |
61 | def get_uncertainty_per_batch(model, batch, T=15, normalized=False):
62 | batch_predictions = []
63 | net_outs = []
64 | batches = batch.unsqueeze(0).repeat(T, 1, 1, 1, 1)
65 |
66 | preds = []
67 | epistemics = []
68 | aleatorics = []
69 |
70 | for i in range(T): # for T batches
71 | net_out, _ = model(batches[i].cuda())
72 | net_outs.append(net_out)
73 | if normalized:
74 | prediction = F.softplus(net_out)
75 | prediction = prediction / torch.sum(prediction, dim=1).unsqueeze(1)
76 | else:
77 | prediction = F.softmax(net_out, dim=1)
78 | batch_predictions.append(prediction)
79 |
80 | for sample in range(batch.shape[0]):
81 | # for each sample in a batch
82 | pred = torch.cat([a_batch[sample].unsqueeze(0) for a_batch in net_outs], dim=0)
83 | pred = torch.mean(pred, dim=0)
84 | preds.append(pred)
85 |
86 | p_hat = torch.cat([a_batch[sample].unsqueeze(0) for a_batch in batch_predictions], dim=0).detach().cpu().numpy()
87 | p_bar = np.mean(p_hat, axis=0)
88 |
89 | temp = p_hat - np.expand_dims(p_bar, 0)
90 | epistemic = np.dot(temp.T, temp) / T
91 | epistemic = np.diag(epistemic)
92 | epistemics.append(epistemic)
93 |
94 | aleatoric = np.diag(p_bar) - (np.dot(p_hat.T, p_hat) / T)
95 | aleatoric = np.diag(aleatoric)
96 | aleatorics.append(aleatoric)
97 |
98 | epistemic = np.vstack(epistemics) # (batch_size, categories)
99 | aleatoric = np.vstack(aleatorics) # (batch_size, categories)
100 | preds = torch.cat([i.unsqueeze(0) for i in preds]).cpu().detach().numpy() # (batch_size, categories)
101 |
102 | return preds, epistemic, aleatoric
103 |
104 |
105 | def get_sample(dataset, sample_type='mnist'):
106 | idx = np.random.randint(len(dataset.targets))
107 | if sample_type=='mnist':
108 | sample = dataset.data[idx]
109 | truth = dataset.targets[idx]
110 | else:
111 | path, truth = dataset.samples[idx]
112 | sample = torch.from_numpy(np.array(Image.open(path)))
113 |
114 | sample = sample.unsqueeze(0)
115 | sample = transform(sample)
116 | return sample.to(device), truth
117 |
118 |
119 | def run(net_type, weight_path, notmnist_dir):
120 | init_dataset(notmnist_dir)
121 |
122 | layer_type = cfg.layer_type
123 | activation_type = cfg.activation_type
124 |
125 | net = getModel(net_type, 1, 10, priors=None, layer_type=layer_type, activation_type=activation_type)
126 | net.load_state_dict(torch.load(weight_path))
127 | net.train()
128 | net.to(device)
129 |
130 | fig = plt.figure()
131 | fig.suptitle('Uncertainty Estimation', fontsize='x-large')
132 | mnist_img = fig.add_subplot(321)
133 | notmnist_img = fig.add_subplot(322)
134 | epi_stats_norm = fig.add_subplot(323)
135 | ale_stats_norm = fig.add_subplot(324)
136 | epi_stats_soft = fig.add_subplot(325)
137 | ale_stats_soft = fig.add_subplot(326)
138 |
139 | sample_mnist, truth_mnist = get_sample(mnist_set)
140 | pred_mnist, epi_mnist_norm, ale_mnist_norm = get_uncertainty_per_image(net, sample_mnist, T=25, normalized=True)
141 | pred_mnist, epi_mnist_soft, ale_mnist_soft = get_uncertainty_per_image(net, sample_mnist, T=25, normalized=False)
142 | mnist_img.imshow(sample_mnist.squeeze().cpu(), cmap='gray')
143 | mnist_img.axis('off')
144 | mnist_img.set_title('MNIST Truth: {} Prediction: {}'.format(int(truth_mnist), int(np.argmax(pred_mnist))))
145 |
146 | sample_notmnist, truth_notmnist = get_sample(notmnist_set, sample_type='notmnist')
147 | pred_notmnist, epi_notmnist_norm, ale_notmnist_norm = get_uncertainty_per_image(net, sample_notmnist, T=25, normalized=True)
148 | pred_notmnist, epi_notmnist_soft, ale_notmnist_soft = get_uncertainty_per_image(net, sample_notmnist, T=25, normalized=False)
149 | notmnist_img.imshow(sample_notmnist.squeeze().cpu(), cmap='gray')
150 | notmnist_img.axis('off')
151 | notmnist_img.set_title('notMNIST Truth: {}({}) Prediction: {}({})'.format(
152 | int(truth_notmnist), chr(65 + truth_notmnist), int(np.argmax(pred_notmnist)), chr(65 + np.argmax(pred_notmnist))))
153 |
154 | x = list(range(10))
155 | data = pd.DataFrame({
156 | 'epistemic_norm': np.hstack([epi_mnist_norm, epi_notmnist_norm]),
157 | 'aleatoric_norm': np.hstack([ale_mnist_norm, ale_notmnist_norm]),
158 | 'epistemic_soft': np.hstack([epi_mnist_soft, epi_notmnist_soft]),
159 | 'aleatoric_soft': np.hstack([ale_mnist_soft, ale_notmnist_soft]),
160 | 'category': np.hstack([x, x]),
161 | 'dataset': np.hstack([['MNIST']*10, ['notMNIST']*10])
162 | })
163 | print(data)
164 | sns.barplot(x='category', y='epistemic_norm', hue='dataset', data=data, ax=epi_stats_norm)
165 | sns.barplot(x='category', y='aleatoric_norm', hue='dataset', data=data, ax=ale_stats_norm)
166 | epi_stats_norm.set_title('Epistemic Uncertainty (Normalized)')
167 | ale_stats_norm.set_title('Aleatoric Uncertainty (Normalized)')
168 |
169 | sns.barplot(x='category', y='epistemic_soft', hue='dataset', data=data, ax=epi_stats_soft)
170 | sns.barplot(x='category', y='aleatoric_soft', hue='dataset', data=data, ax=ale_stats_soft)
171 | epi_stats_soft.set_title('Epistemic Uncertainty (Softmax)')
172 | ale_stats_soft.set_title('Aleatoric Uncertainty (Softmax)')
173 |
174 | plt.show()
175 |
176 |
177 | if __name__ == '__main__':
178 | parser = argparse.ArgumentParser(description = "PyTorch Uncertainty Estimation b/w MNIST and notMNIST")
179 | parser.add_argument('--net_type', default='lenet', type=str, help='model')
180 | parser.add_argument('--weights_path', default='checkpoints/MNIST/bayesian/model_lenet.pt', type=str, help='weights for model')
181 | parser.add_argument('--notmnist_dir', default='data/notMNIST_small/', type=str, help='weights for model')
182 | args = parser.parse_args()
183 |
184 | run(args.net_type, args.weights_path, args.notmnist_dir)
185 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | from torch.nn import functional as F
5 |
6 | import config_bayesian as cfg
7 |
8 |
9 | # cifar10 classes
10 | cifar10_classes = ['airplane', 'automobile', 'bird', 'cat', 'deer',
11 | 'dog', 'frog', 'horse', 'ship', 'truck']
12 |
13 |
14 | def logmeanexp(x, dim=None, keepdim=False):
15 | """Stable computation of log(mean(exp(x))"""
16 |
17 |
18 | if dim is None:
19 | x, dim = x.view(-1), 0
20 | x_max, _ = torch.max(x, dim, keepdim=True)
21 | x = x_max + torch.log(torch.mean(torch.exp(x - x_max), dim, keepdim=True))
22 | return x if keepdim else x.squeeze(dim)
23 |
24 | # check if dimension is correct
25 |
26 | # def dimension_check(x, dim=None, keepdim=False):
27 | # if dim is None:
28 | # x, dim = x.view(-1), 0
29 |
30 | # return x if keepdim else x.squeeze(dim)
31 |
32 |
33 | def adjust_learning_rate(optimizer, lr):
34 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
35 | for param_group in optimizer.param_groups:
36 | param_group['lr'] = lr
37 |
38 |
39 | def save_array_to_file(numpy_array, filename):
40 | file = open(filename, 'a')
41 | shape = " ".join(map(str, numpy_array.shape))
42 | np.savetxt(file, numpy_array.flatten(), newline=" ", fmt="%.3f")
43 | file.write("\n")
44 | file.close()
45 |
--------------------------------------------------------------------------------