├── .gitignore
├── LICENSE
├── README.md
├── Update Records.md
├── demos
├── Bayesian Neural Network Classification.ipynb
├── Bayesian Neural Network Regression.ipynb
├── Convert to Bayesian Neural Network.ipynb
├── Custom KL loss with Iris Data.ipynb
└── Freeze Bayesian Neural Network.ipynb
├── docs
├── Makefile
├── conf.py
├── functional.rst
├── index.rst
├── make.bat
├── modules.rst
└── utils.rst
├── requirements.txt
└── torchbnn
├── __init__.py
├── functional.py
├── modules
├── __init__.py
├── batchnorm.py
├── conv.py
├── linear.py
├── loss.py
└── module.py
└── utils
├── __init__.py
└── freeze_model.py
/.gitignore:
--------------------------------------------------------------------------------
1 | build/*
2 | _*
3 | _*/
4 | dist/*
5 | torchbnn.egg-info/*
6 | demos/data/MNIST/*
7 | demos/data/cifar*
8 | */.*
9 | MENIFEST.in
10 | setup.cfg
11 | setup.py
12 | .gitattributes
13 | .*/
14 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Harry Kim
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Bayesian-Neural-Network-Pytorch
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 | This is a lightweight repository of bayesian neural network for PyTorch.
10 |
11 | ## Usage
12 |
13 | ### :clipboard: Dependencies
14 |
15 | - torch 1.2.0
16 | - python 3.6
17 |
18 |
19 |
20 | ### :hammer: Installation
21 |
22 | - `pip install torchbnn` or
23 | - `git clone https://github.com/Harry24k/bayesian-neural-network-pytorch`
24 |
25 | ```python
26 | import torchbnn
27 | ```
28 |
29 |
30 |
31 | ### :rocket: Demos
32 |
33 | * **Bayesian Neural Network Regression** ([code](https://github.com/Harry24k/bayesian-neural-network-pytorch/blob/master/demos/Bayesian%20Neural%20Network%20Regression.ipynb)):
34 | In this demo, two-layer bayesian neural network is constructed and trained on simple custom data. It shows how bayesian-neural-network works and randomness of the model.
35 | * **Bayesian Neural Network Classification** ([code](https://github.com/Harry24k/bayesian-neural-network-pytorch/blob/master/demos/Bayesian%20Neural%20Network%20Classification.ipynb)):
36 | To classify Iris data, in this demo, two-layer bayesian neural network is constructed and trained on the Iris data. It shows how bayesian-neural-network works and randomness of the model.
37 | * **Convert to Bayesian Neural Network** ([code](https://github.com/Harry24k/bayesian-neural-network-pytorch/blob/master/demos/Convert%20to%20Bayesian%20Neural%20Network.ipynb)):
38 | To convert a basic neural network to a bayesian neural network, this demo shows how `nonbayes_to_bayes` and `bayes_to_nonbayes` work.
39 | * **Freeze Bayesian Neural Network** ([code](https://github.com/Harry24k/bayesian-neural-network-pytorch/blob/master/demos/Freeze%20Bayesian%20Neural%20Network.ipynb)):
40 | To freeze a bayesian neural network, which means force a bayesian neural network to output same result for same input, this demo shows the effect of `freeze` and `unfreeze`.
41 |
42 |
43 | ## Citation
44 | If you use this package, please cite the following BibTex (SemanticScholar, GoogleScholar):
45 |
46 | ```
47 | @article{lee2022graddiv,
48 | title={Graddiv: Adversarial robustness of randomized neural networks via gradient diversity regularization},
49 | author={Lee, Sungyoon and Kim, Hoki and Lee, Jaewook},
50 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
51 | year={2022},
52 | publisher={IEEE}
53 | }
54 | ```
55 |
56 | ## :mag_right: Update Records
57 |
58 | Here is [update records](Update%20Records.md) of this package.
59 |
60 |
61 | ## Thanks to
62 |
63 | * @kumar-shridhar [github:PyTorch-BayesianCNN](https://github.com/kumar-shridhar/PyTorch-BayesianCNN)
64 | * @xuanqing94 [github:BayesianDefense](https://github.com/xuanqing94/BayesianDefense)
65 |
--------------------------------------------------------------------------------
/Update Records.md:
--------------------------------------------------------------------------------
1 | ### ~~v0.1~~
2 |
3 | * ~~**modules** : BayesLinear, BayesConv2d, BayesBatchNorm2d are added.~~
4 | * ~~**utils** : convert_model(nonbayes_to_bayes, bayes_to_nonbayes) is added.~~
5 | * ~~**functional.py** : bayesian_kl_loss is added.~~
6 |
7 |
8 |
9 | ### ~~v0.2~~
10 |
11 | * ~~**prior_sigma** is used when initialize modules and functions instead of **prior_log_sigma**.~~
12 | * ~~**modules** are re-defined with prior_sigma instead of prior_log_sigma.~~
13 | * ~~**utils/convert_model.py** is also changed with prior_sigma instead of prior_log_sigma.~~
14 | * ~~**modules** : Base initialization method is changed to the method of Adv-BNN from the original torch method.~~
15 | * ~~**functional.py** : **bayesian_kl_loss** is changed similar to ones in **torch.functional**.~~
16 | * ~~**modules/loss.py** : **BKLLoss** is added based on bayesian_kl_loss similar to ones in **torch.loss**.~~
17 |
18 |
19 |
20 | ### ~~v0.3~~
21 |
22 | * ~~**functional.py** :~~
23 | * ~~**'bayesian_kl_loss' returns tensor.Tensor([0]) as default** : In the previous version, bayesian_kl_loss returns 0 of int type if there is no Bayesian layers. However, considering all torch loss returns tensor and .item() is used to make them to int type, they are changed to return tensor.Tensor([0]) if there is no Bayesian layers.~~
24 |
25 |
26 | ### ~~v0.4~~
27 |
28 | * ~~**functional.py** :~~
29 | * ~~**'bayesian_kl_loss' is modified** : In some cases, the device(cuda/cpu) error has occurred. Thus, losses are initialized with tensor.Tensor([0]) on the device on which the model is.~~
30 |
31 |
32 |
33 | ### ~~v0.5~~
34 |
35 | * ~~**utils/convert_model.py** :~~
36 | * ~~**'nonbayes_to_bayes', 'bayes_to_nonbayes' methods are modified** : Before this version, they always replace the original model. From now on, we can handle it with the 'inplace' argument. Set 'inplace=True' for replace the input model and 'inplace=False' for getting a new model. 'inplace=True' is recommended cause it shortens memories and there is no future problems with deepcopy.~~
37 |
38 |
39 |
40 | ### ~~v0.6~~
41 |
42 | * ~~**utils/freeze_model.py** :~~
43 | * ~~**'freeze', 'unfreeze' methods are added** : Bayesian modules always returns different outputs even if inputs are same. It is because of their randomized forward propagation. Sometimes, however, we need to freeze this randomized process for analyzing the model deeply. Then you can use this freeze method for changing the bayesian model into non-bayesian model with same parameters.~~
44 | * ~~**modules** : For supporting **freeze** method, freeze, weight_eps and bias_eps is added to each modules. If freeze is False (Defalt), weight_eps and bias_eps will be initialized with normal noise at every forward. If freeze is True, weight_eps and bias_eps won't be changed.~~
45 |
46 |
47 |
48 | ### ~~v0.8~~
49 |
50 | * ~~**modules** : To support **freeze** method, weight_eps and bias_eps is changed to buffer with register_buffer method. Thorugh this change, it provides save and load even if bayesian neural network is freezed.~~
51 | * ~~**BayesModule is added** : Bayesian version of torch.nn.Module. Not being used currently.~~
52 | * ~~**utils/freeze_model.py** :~~
53 | * ~~**'freeze', 'unfreeze' methods are modified** : Previous methods didn't work on single layer network.~~
54 | * ~~**Demos are uploaded** : "Bayesian Neural Network with Iris Data".~~
55 |
56 |
57 |
58 |
59 | ### ~~v0.9~~
60 |
61 | * ~~**modules** :~~
62 | * ~~**Variable 'freeze' is deleted** : The status, which indicates whether this bayesian module is freezed, is deleted. Instead of 'freeze', we can determine by checking 'eps' is set to None. For example, if 'weight_eps' is None, the BayesLinear module is freezed now. The reason of this update is to solve backpropagation error occured by inplacing eps.~~
63 | * ~~**Method 'freeze' and 'unfreeze' are added** : These methods will change 'eps' to None or random normal values.~~
64 | * ~~**utils/freeze_model.py** :~~
65 | * ~~**freeze, unfreeze methods are modified** : These methods in utils are changed due to the above.~~
66 | * ~~**Demos are uploaded** : "Convert to Bayesian Neural Network", "Freeze Bayesian Neural Network".~~
67 |
68 |
69 |
70 |
71 | ### ~~v1.0~~
72 |
73 | * ~~**modules** : BayesLinear, BayesConv2d are modified.~~
74 | * ~~**BayesLinear** : Bias will set to False if the bias in args is None or Flase. Otherwise, it set to True.~~
75 | * ~~**BayesConv2d** : Bias will set to False if the bias in args is None or Flase. Otherwise, it set to True. In addition, re-defined with prior_sigma instead of prior_log_sigma.~~
76 |
77 | * ~~**utils/convert_model.py** :~~
78 | * ~~Depreciated. Please refer the [modified demo](https://github.com/Harry24k/bayesian-neural-network-pytorch/blob/master/demos/Convert%20to%20Bayesian%20Neural%20Network.ipynb).~~
79 |
80 |
81 |
82 |
83 | ### v1.1
84 |
85 | * **Pip Package Re-uploaded**
86 |
87 |
88 |
89 |
90 | ### v1.2
91 |
92 | * **[Bug fixed](https://github.com/Harry24k/bayesian-neural-network-pytorch/issues/4)**
93 |
--------------------------------------------------------------------------------
/demos/Bayesian Neural Network Classification.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Demo - Bayesian Neural Network Classification"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": 1,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "import numpy as np\n",
17 | "from sklearn import datasets\n",
18 | "\n",
19 | "import torch\n",
20 | "import torch.nn as nn\n",
21 | "import torch.optim as optim\n",
22 | "\n",
23 | "import torchbnn as bnn"
24 | ]
25 | },
26 | {
27 | "cell_type": "code",
28 | "execution_count": 2,
29 | "metadata": {},
30 | "outputs": [],
31 | "source": [
32 | "import matplotlib.pyplot as plt\n",
33 | "%matplotlib inline"
34 | ]
35 | },
36 | {
37 | "cell_type": "markdown",
38 | "metadata": {},
39 | "source": [
40 | "## 1. Load Iris Data"
41 | ]
42 | },
43 | {
44 | "cell_type": "code",
45 | "execution_count": 3,
46 | "metadata": {},
47 | "outputs": [],
48 | "source": [
49 | "iris = datasets.load_iris()"
50 | ]
51 | },
52 | {
53 | "cell_type": "code",
54 | "execution_count": 4,
55 | "metadata": {},
56 | "outputs": [],
57 | "source": [
58 | "X = iris.data\n",
59 | "Y = iris.target "
60 | ]
61 | },
62 | {
63 | "cell_type": "code",
64 | "execution_count": 5,
65 | "metadata": {},
66 | "outputs": [
67 | {
68 | "data": {
69 | "text/plain": [
70 | "(torch.Size([150, 4]), torch.Size([150]))"
71 | ]
72 | },
73 | "execution_count": 5,
74 | "metadata": {},
75 | "output_type": "execute_result"
76 | }
77 | ],
78 | "source": [
79 | "x, y = torch.from_numpy(X).float(), torch.from_numpy(Y).long()\n",
80 | "x.shape, y.shape"
81 | ]
82 | },
83 | {
84 | "cell_type": "markdown",
85 | "metadata": {},
86 | "source": [
87 | "## 2. Define Model"
88 | ]
89 | },
90 | {
91 | "cell_type": "code",
92 | "execution_count": 6,
93 | "metadata": {},
94 | "outputs": [],
95 | "source": [
96 | "model = nn.Sequential(\n",
97 | " bnn.BayesLinear(prior_mu=0, prior_sigma=0.1, in_features=4, out_features=100),\n",
98 | " nn.ReLU(),\n",
99 | " bnn.BayesLinear(prior_mu=0, prior_sigma=0.1, in_features=100, out_features=3),\n",
100 | ")"
101 | ]
102 | },
103 | {
104 | "cell_type": "code",
105 | "execution_count": 7,
106 | "metadata": {},
107 | "outputs": [],
108 | "source": [
109 | "ce_loss = nn.CrossEntropyLoss()\n",
110 | "kl_loss = bnn.BKLLoss(reduction='mean', last_layer_only=False)\n",
111 | "kl_weight = 0.01\n",
112 | "\n",
113 | "optimizer = optim.Adam(model.parameters(), lr=0.01)"
114 | ]
115 | },
116 | {
117 | "cell_type": "markdown",
118 | "metadata": {},
119 | "source": [
120 | "## 3. Train Model"
121 | ]
122 | },
123 | {
124 | "cell_type": "code",
125 | "execution_count": 8,
126 | "metadata": {},
127 | "outputs": [],
128 | "source": [
129 | "kl_weight = 0.1"
130 | ]
131 | },
132 | {
133 | "cell_type": "code",
134 | "execution_count": 9,
135 | "metadata": {},
136 | "outputs": [
137 | {
138 | "name": "stdout",
139 | "output_type": "stream",
140 | "text": [
141 | "- Accuracy: 96.666667 %\n",
142 | "- CE : 0.11, KL : 1.31\n"
143 | ]
144 | }
145 | ],
146 | "source": [
147 | "for step in range(3000):\n",
148 | " pre = model(x)\n",
149 | " ce = ce_loss(pre, y)\n",
150 | " kl = kl_loss(model)\n",
151 | " cost = ce + kl_weight*kl\n",
152 | " \n",
153 | " optimizer.zero_grad()\n",
154 | " cost.backward()\n",
155 | " optimizer.step()\n",
156 | " \n",
157 | "_, predicted = torch.max(pre.data, 1)\n",
158 | "total = y.size(0)\n",
159 | "correct = (predicted == y).sum()\n",
160 | "print('- Accuracy: %f %%' % (100 * float(correct) / total))\n",
161 | "print('- CE : %2.2f, KL : %2.2f' % (ce.item(), kl.item()))"
162 | ]
163 | },
164 | {
165 | "cell_type": "markdown",
166 | "metadata": {},
167 | "source": [
168 | "## 4. Test Model"
169 | ]
170 | },
171 | {
172 | "cell_type": "code",
173 | "execution_count": 10,
174 | "metadata": {},
175 | "outputs": [],
176 | "source": [
177 | "def draw_plot(predicted) :\n",
178 | " fig = plt.figure(figsize = (16, 5))\n",
179 | "\n",
180 | " ax1 = fig.add_subplot(1, 2, 1)\n",
181 | " ax2 = fig.add_subplot(1, 2, 2)\n",
182 | "\n",
183 | " z1_plot = ax1.scatter(X[:, 0], X[:, 1], c = Y)\n",
184 | " z2_plot = ax2.scatter(X[:, 0], X[:, 1], c = predicted)\n",
185 | "\n",
186 | " plt.colorbar(z1_plot,ax=ax1)\n",
187 | " plt.colorbar(z2_plot,ax=ax2)\n",
188 | "\n",
189 | " ax1.set_title(\"REAL\")\n",
190 | " ax2.set_title(\"PREDICT\")\n",
191 | "\n",
192 | " plt.show()"
193 | ]
194 | },
195 | {
196 | "cell_type": "code",
197 | "execution_count": 11,
198 | "metadata": {},
199 | "outputs": [
200 | {
201 | "data": {
202 | "image/png": "\n",
203 | "text/plain": [
204 | ""
205 | ]
206 | },
207 | "metadata": {},
208 | "output_type": "display_data"
209 | }
210 | ],
211 | "source": [
212 | "pre = model(x)\n",
213 | "_, predicted = torch.max(pre.data, 1)\n",
214 | "draw_plot(predicted)"
215 | ]
216 | },
217 | {
218 | "cell_type": "code",
219 | "execution_count": 12,
220 | "metadata": {},
221 | "outputs": [
222 | {
223 | "data": {
224 | "image/png": "\n",
225 | "text/plain": [
226 | ""
227 | ]
228 | },
229 | "metadata": {},
230 | "output_type": "display_data"
231 | }
232 | ],
233 | "source": [
234 | "# Bayesian Neural Network will return different outputs even if inputs are same.\n",
235 | "# In other words, different plots will be shown every time forward method is called.\n",
236 | "pre = model(x)\n",
237 | "_, predicted = torch.max(pre.data, 1)\n",
238 | "draw_plot(predicted)"
239 | ]
240 | }
241 | ],
242 | "metadata": {
243 | "kernelspec": {
244 | "display_name": "Python 3",
245 | "language": "python",
246 | "name": "python3"
247 | },
248 | "language_info": {
249 | "codemirror_mode": {
250 | "name": "ipython",
251 | "version": 3
252 | },
253 | "file_extension": ".py",
254 | "mimetype": "text/x-python",
255 | "name": "python",
256 | "nbconvert_exporter": "python",
257 | "pygments_lexer": "ipython3",
258 | "version": "3.6.5"
259 | }
260 | },
261 | "nbformat": 4,
262 | "nbformat_minor": 2
263 | }
264 |
--------------------------------------------------------------------------------
/demos/Bayesian Neural Network Regression.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Demo - Bayesian Neural Network Regression"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": 1,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "import numpy as np\n",
17 | "from sklearn import datasets\n",
18 | "\n",
19 | "import torch\n",
20 | "import torch.nn as nn\n",
21 | "import torch.optim as optim\n",
22 | "\n",
23 | "import torchbnn as bnn"
24 | ]
25 | },
26 | {
27 | "cell_type": "code",
28 | "execution_count": 2,
29 | "metadata": {},
30 | "outputs": [],
31 | "source": [
32 | "import matplotlib.pyplot as plt\n",
33 | "%matplotlib inline"
34 | ]
35 | },
36 | {
37 | "cell_type": "markdown",
38 | "metadata": {},
39 | "source": [
40 | "## 1. Generate Sample Data"
41 | ]
42 | },
43 | {
44 | "cell_type": "code",
45 | "execution_count": 3,
46 | "metadata": {},
47 | "outputs": [
48 | {
49 | "data": {
50 | "image/png": "\n",
51 | "text/plain": [
52 | ""
53 | ]
54 | },
55 | "metadata": {},
56 | "output_type": "display_data"
57 | }
58 | ],
59 | "source": [
60 | "x = torch.linspace(-2, 2, 500)\n",
61 | "y = x.pow(3) - x.pow(2) + 3*torch.rand(x.size())\n",
62 | "x = torch.unsqueeze(x, dim=1)\n",
63 | "y = torch.unsqueeze(y, dim=1)\n",
64 | "\n",
65 | "plt.scatter(x.data.numpy(), y.data.numpy())\n",
66 | "plt.show()"
67 | ]
68 | },
69 | {
70 | "cell_type": "markdown",
71 | "metadata": {},
72 | "source": [
73 | "## 2. Define Model"
74 | ]
75 | },
76 | {
77 | "cell_type": "code",
78 | "execution_count": 4,
79 | "metadata": {},
80 | "outputs": [],
81 | "source": [
82 | "model = nn.Sequential(\n",
83 | " bnn.BayesLinear(prior_mu=0, prior_sigma=0.1, in_features=1, out_features=100),\n",
84 | " nn.ReLU(),\n",
85 | " bnn.BayesLinear(prior_mu=0, prior_sigma=0.1, in_features=100, out_features=1),\n",
86 | ")"
87 | ]
88 | },
89 | {
90 | "cell_type": "code",
91 | "execution_count": 5,
92 | "metadata": {},
93 | "outputs": [],
94 | "source": [
95 | "mse_loss = nn.MSELoss()\n",
96 | "kl_loss = bnn.BKLLoss(reduction='mean', last_layer_only=False)\n",
97 | "kl_weight = 0.01\n",
98 | "\n",
99 | "optimizer = optim.Adam(model.parameters(), lr=0.01)"
100 | ]
101 | },
102 | {
103 | "cell_type": "markdown",
104 | "metadata": {},
105 | "source": [
106 | "## 3. Train Model"
107 | ]
108 | },
109 | {
110 | "cell_type": "code",
111 | "execution_count": 6,
112 | "metadata": {},
113 | "outputs": [],
114 | "source": [
115 | "kl_weight = 0.1"
116 | ]
117 | },
118 | {
119 | "cell_type": "code",
120 | "execution_count": 7,
121 | "metadata": {},
122 | "outputs": [
123 | {
124 | "name": "stdout",
125 | "output_type": "stream",
126 | "text": [
127 | "- MSE : 0.89, KL : 5.19\n"
128 | ]
129 | }
130 | ],
131 | "source": [
132 | "for step in range(3000):\n",
133 | " pre = model(x)\n",
134 | " mse = mse_loss(pre, y)\n",
135 | " kl = kl_loss(model)\n",
136 | " cost = mse + kl_weight*kl\n",
137 | " \n",
138 | " optimizer.zero_grad()\n",
139 | " cost.backward()\n",
140 | " optimizer.step()\n",
141 | " \n",
142 | "print('- MSE : %2.2f, KL : %2.2f' % (mse.item(), kl.item()))"
143 | ]
144 | },
145 | {
146 | "cell_type": "markdown",
147 | "metadata": {},
148 | "source": [
149 | "## 4. Test Model"
150 | ]
151 | },
152 | {
153 | "cell_type": "code",
154 | "execution_count": 8,
155 | "metadata": {},
156 | "outputs": [],
157 | "source": [
158 | "x_test = torch.linspace(-2, 2, 500)\n",
159 | "y_test = x_test.pow(3) - x_test.pow(2) + 3*torch.rand(x_test.size())\n",
160 | "\n",
161 | "x_test = torch.unsqueeze(x_test, dim=1)\n",
162 | "y_test = torch.unsqueeze(y_test, dim=1)"
163 | ]
164 | },
165 | {
166 | "cell_type": "code",
167 | "execution_count": 9,
168 | "metadata": {},
169 | "outputs": [
170 | {
171 | "data": {
172 | "image/png": "\n",
173 | "text/plain": [
174 | ""
175 | ]
176 | },
177 | "metadata": {},
178 | "output_type": "display_data"
179 | }
180 | ],
181 | "source": [
182 | "plt.xlabel(r'$x$')\n",
183 | "plt.ylabel(r'$y$')\n",
184 | "\n",
185 | "plt.scatter(x_test.data.numpy(), y_test.data.numpy(), color='k', s=2) \n",
186 | "\n",
187 | "y_predict = model(x_test)\n",
188 | "plt.plot(x_test.data.numpy(), y_predict.data.numpy(), 'r-', linewidth=5, label='First Prediction')\n",
189 | "\n",
190 | "y_predict = model(x_test)\n",
191 | "plt.plot(x_test.data.numpy(), y_predict.data.numpy(), 'b-', linewidth=5, label='Second Prediction')\n",
192 | "\n",
193 | "y_predict = model(x_test)\n",
194 | "plt.plot(x_test.data.numpy(), y_predict.data.numpy(), 'g-', linewidth=5, label='Third Prediction')\n",
195 | "\n",
196 | "plt.legend()\n",
197 | "\n",
198 | "plt.show()"
199 | ]
200 | }
201 | ],
202 | "metadata": {
203 | "kernelspec": {
204 | "display_name": "Python 3",
205 | "language": "python",
206 | "name": "python3"
207 | },
208 | "language_info": {
209 | "codemirror_mode": {
210 | "name": "ipython",
211 | "version": 3
212 | },
213 | "file_extension": ".py",
214 | "mimetype": "text/x-python",
215 | "name": "python",
216 | "nbconvert_exporter": "python",
217 | "pygments_lexer": "ipython3",
218 | "version": "3.6.5"
219 | }
220 | },
221 | "nbformat": 4,
222 | "nbformat_minor": 2
223 | }
224 |
--------------------------------------------------------------------------------
/demos/Convert to Bayesian Neural Network.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Demo - Convert to Bayesian Neural Network"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": 1,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "import numpy as np\n",
17 | "\n",
18 | "import torch\n",
19 | "import torch.nn as nn\n",
20 | "import torch.optim as optim\n",
21 | "\n",
22 | "import torchbnn as bnn\n",
23 | "from torchhk import transform_model"
24 | ]
25 | },
26 | {
27 | "cell_type": "code",
28 | "execution_count": 2,
29 | "metadata": {},
30 | "outputs": [],
31 | "source": [
32 | "import matplotlib.pyplot as plt\n",
33 | "%matplotlib inline"
34 | ]
35 | },
36 | {
37 | "cell_type": "markdown",
38 | "metadata": {},
39 | "source": [
40 | "## 2. Define Model"
41 | ]
42 | },
43 | {
44 | "cell_type": "code",
45 | "execution_count": 3,
46 | "metadata": {},
47 | "outputs": [],
48 | "source": [
49 | "class CNN(nn.Module):\n",
50 | " def __init__(self):\n",
51 | " super(CNN, self).__init__()\n",
52 | " \n",
53 | " self.conv_layer = nn.Sequential(\n",
54 | " nn.Conv2d(1,3,3),\n",
55 | " nn.ReLU(),\n",
56 | " nn.MaxPool2d(2,2)\n",
57 | " )\n",
58 | " \n",
59 | " self.fc_layer = nn.Sequential(\n",
60 | " nn.Linear(3*2*2,3*2),\n",
61 | " nn.ReLU(),\n",
62 | " nn.Linear(3*2,2)\n",
63 | " ) \n",
64 | " \n",
65 | " def forward(self,x):\n",
66 | " out = self.conv_layer(x)\n",
67 | " out = out.view(-1,3*2*2)\n",
68 | " out = self.fc_layer(out)\n",
69 | "\n",
70 | " return out"
71 | ]
72 | },
73 | {
74 | "cell_type": "code",
75 | "execution_count": 4,
76 | "metadata": {},
77 | "outputs": [],
78 | "source": [
79 | "model = CNN()"
80 | ]
81 | },
82 | {
83 | "cell_type": "code",
84 | "execution_count": 5,
85 | "metadata": {},
86 | "outputs": [
87 | {
88 | "data": {
89 | "text/plain": [
90 | "CNN(\n",
91 | " (conv_layer): Sequential(\n",
92 | " (0): Conv2d(1, 3, kernel_size=(3, 3), stride=(1, 1))\n",
93 | " (1): ReLU()\n",
94 | " (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
95 | " )\n",
96 | " (fc_layer): Sequential(\n",
97 | " (0): Linear(in_features=12, out_features=6, bias=True)\n",
98 | " (1): ReLU()\n",
99 | " (2): Linear(in_features=6, out_features=2, bias=True)\n",
100 | " )\n",
101 | ")"
102 | ]
103 | },
104 | "execution_count": 5,
105 | "metadata": {},
106 | "output_type": "execute_result"
107 | }
108 | ],
109 | "source": [
110 | "model"
111 | ]
112 | },
113 | {
114 | "cell_type": "markdown",
115 | "metadata": {},
116 | "source": [
117 | "## 3. Convert Model"
118 | ]
119 | },
120 | {
121 | "cell_type": "markdown",
122 | "metadata": {},
123 | "source": [
124 | "### 3.1. Nonbayes to Bayes"
125 | ]
126 | },
127 | {
128 | "cell_type": "code",
129 | "execution_count": 6,
130 | "metadata": {},
131 | "outputs": [
132 | {
133 | "name": "stderr",
134 | "output_type": "stream",
135 | "text": [
136 | "C:\\Users\\slcf\\Anaconda3\\lib\\site-packages\\torchhk\\transform.py:31: Warning: \n",
137 | " * Caution : The Input Model is CHANGED because inplace=True.\n",
138 | " warnings.warn(\"\\n * Caution : The Input Model is CHANGED because inplace=True.\", Warning)\n"
139 | ]
140 | },
141 | {
142 | "data": {
143 | "text/plain": [
144 | "CNN(\n",
145 | " (conv_layer): Sequential(\n",
146 | " (0): BayesConv2d(0, 0.1, 1, 3, kernel_size=(3, 3), stride=(1, 1))\n",
147 | " (1): ReLU()\n",
148 | " (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
149 | " )\n",
150 | " (fc_layer): Sequential(\n",
151 | " (0): Linear(in_features=12, out_features=6, bias=True)\n",
152 | " (1): ReLU()\n",
153 | " (2): Linear(in_features=6, out_features=2, bias=True)\n",
154 | " )\n",
155 | ")"
156 | ]
157 | },
158 | "execution_count": 6,
159 | "metadata": {},
160 | "output_type": "execute_result"
161 | }
162 | ],
163 | "source": [
164 | "# Convert Conv2d -> BayesConv2d\n",
165 | "transform_model(model, nn.Conv2d, bnn.BayesConv2d, \n",
166 | " args={\"prior_mu\":0, \"prior_sigma\":0.1, \"in_channels\" : \".in_channels\",\n",
167 | " \"out_channels\" : \".out_channels\", \"kernel_size\" : \".kernel_size\",\n",
168 | " \"stride\" : \".stride\", \"padding\" : \".padding\", \"bias\":\".bias\"\n",
169 | " }, \n",
170 | " attrs={\"weight_mu\" : \".weight\"})"
171 | ]
172 | },
173 | {
174 | "cell_type": "code",
175 | "execution_count": 7,
176 | "metadata": {},
177 | "outputs": [
178 | {
179 | "name": "stderr",
180 | "output_type": "stream",
181 | "text": [
182 | "C:\\Users\\slcf\\Anaconda3\\lib\\site-packages\\torchhk\\transform.py:31: Warning: \n",
183 | " * Caution : The Input Model is CHANGED because inplace=True.\n",
184 | " warnings.warn(\"\\n * Caution : The Input Model is CHANGED because inplace=True.\", Warning)\n"
185 | ]
186 | },
187 | {
188 | "data": {
189 | "text/plain": [
190 | "CNN(\n",
191 | " (conv_layer): Sequential(\n",
192 | " (0): BayesConv2d(0, 0.1, 1, 3, kernel_size=(3, 3), stride=(1, 1))\n",
193 | " (1): ReLU()\n",
194 | " (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
195 | " )\n",
196 | " (fc_layer): Sequential(\n",
197 | " (0): BayesLinear(prior_mu=0, prior_sigma=0.1, in_features=12, out_features=6, bias=True)\n",
198 | " (1): ReLU()\n",
199 | " (2): BayesLinear(prior_mu=0, prior_sigma=0.1, in_features=6, out_features=2, bias=True)\n",
200 | " )\n",
201 | ")"
202 | ]
203 | },
204 | "execution_count": 7,
205 | "metadata": {},
206 | "output_type": "execute_result"
207 | }
208 | ],
209 | "source": [
210 | "# Convert Linear -> BayesLinear\n",
211 | "transform_model(model, nn.Linear, bnn.BayesLinear, \n",
212 | " args={\"prior_mu\":0, \"prior_sigma\":0.1, \"in_features\" : \".in_features\",\n",
213 | " \"out_features\" : \".out_features\", \"bias\":\".bias\"\n",
214 | " }, \n",
215 | " attrs={\"weight_mu\" : \".weight\"})"
216 | ]
217 | },
218 | {
219 | "cell_type": "code",
220 | "execution_count": 8,
221 | "metadata": {},
222 | "outputs": [
223 | {
224 | "data": {
225 | "text/plain": [
226 | "CNN(\n",
227 | " (conv_layer): Sequential(\n",
228 | " (0): BayesConv2d(0, 0.1, 1, 3, kernel_size=(3, 3), stride=(1, 1))\n",
229 | " (1): ReLU()\n",
230 | " (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
231 | " )\n",
232 | " (fc_layer): Sequential(\n",
233 | " (0): BayesLinear(prior_mu=0, prior_sigma=0.1, in_features=12, out_features=6, bias=True)\n",
234 | " (1): ReLU()\n",
235 | " (2): BayesLinear(prior_mu=0, prior_sigma=0.1, in_features=6, out_features=2, bias=True)\n",
236 | " )\n",
237 | ")"
238 | ]
239 | },
240 | "execution_count": 8,
241 | "metadata": {},
242 | "output_type": "execute_result"
243 | }
244 | ],
245 | "source": [
246 | "model"
247 | ]
248 | },
249 | {
250 | "cell_type": "markdown",
251 | "metadata": {},
252 | "source": [
253 | "### 3.2. Bayes to Nonbayes"
254 | ]
255 | },
256 | {
257 | "cell_type": "code",
258 | "execution_count": 9,
259 | "metadata": {},
260 | "outputs": [
261 | {
262 | "name": "stderr",
263 | "output_type": "stream",
264 | "text": [
265 | "C:\\Users\\slcf\\Anaconda3\\lib\\site-packages\\torchhk\\transform.py:31: Warning: \n",
266 | " * Caution : The Input Model is CHANGED because inplace=True.\n",
267 | " warnings.warn(\"\\n * Caution : The Input Model is CHANGED because inplace=True.\", Warning)\n"
268 | ]
269 | },
270 | {
271 | "data": {
272 | "text/plain": [
273 | "CNN(\n",
274 | " (conv_layer): Sequential(\n",
275 | " (0): Conv2d(1, 3, kernel_size=(3, 3), stride=(1, 1))\n",
276 | " (1): ReLU()\n",
277 | " (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
278 | " )\n",
279 | " (fc_layer): Sequential(\n",
280 | " (0): BayesLinear(prior_mu=0, prior_sigma=0.1, in_features=12, out_features=6, bias=True)\n",
281 | " (1): ReLU()\n",
282 | " (2): BayesLinear(prior_mu=0, prior_sigma=0.1, in_features=6, out_features=2, bias=True)\n",
283 | " )\n",
284 | ")"
285 | ]
286 | },
287 | "execution_count": 9,
288 | "metadata": {},
289 | "output_type": "execute_result"
290 | }
291 | ],
292 | "source": [
293 | "# Convert BayesConv2d -> Conv2d\n",
294 | "transform_model(model, bnn.BayesConv2d, nn.Conv2d,\n",
295 | " args={\"in_channels\" : \".in_channels\", \"out_channels\" : \".out_channels\",\n",
296 | " \"kernel_size\" : \".kernel_size\",\n",
297 | " \"padding\" : \".padding\", \"bias\":\".bias\"\n",
298 | " }, \n",
299 | " attrs={\"weight\" : \".weight_mu\"})"
300 | ]
301 | },
302 | {
303 | "cell_type": "code",
304 | "execution_count": 10,
305 | "metadata": {},
306 | "outputs": [
307 | {
308 | "name": "stderr",
309 | "output_type": "stream",
310 | "text": [
311 | "C:\\Users\\slcf\\Anaconda3\\lib\\site-packages\\torchhk\\transform.py:31: Warning: \n",
312 | " * Caution : The Input Model is CHANGED because inplace=True.\n",
313 | " warnings.warn(\"\\n * Caution : The Input Model is CHANGED because inplace=True.\", Warning)\n"
314 | ]
315 | },
316 | {
317 | "data": {
318 | "text/plain": [
319 | "CNN(\n",
320 | " (conv_layer): Sequential(\n",
321 | " (0): Conv2d(1, 3, kernel_size=(3, 3), stride=(1, 1))\n",
322 | " (1): ReLU()\n",
323 | " (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
324 | " )\n",
325 | " (fc_layer): Sequential(\n",
326 | " (0): Linear(in_features=12, out_features=6, bias=True)\n",
327 | " (1): ReLU()\n",
328 | " (2): Linear(in_features=6, out_features=2, bias=True)\n",
329 | " )\n",
330 | ")"
331 | ]
332 | },
333 | "execution_count": 10,
334 | "metadata": {},
335 | "output_type": "execute_result"
336 | }
337 | ],
338 | "source": [
339 | "# Convert BayesLinear -> Linear\n",
340 | "transform_model(model, bnn.BayesLinear, nn.Linear, \n",
341 | " args={\"in_features\" : \".in_features\", \"out_features\" : \".out_features\",\n",
342 | " \"bias\":\".bias\"\n",
343 | " }, \n",
344 | " attrs={\"weight\" : \".weight_mu\"})"
345 | ]
346 | },
347 | {
348 | "cell_type": "code",
349 | "execution_count": 11,
350 | "metadata": {},
351 | "outputs": [
352 | {
353 | "data": {
354 | "text/plain": [
355 | "CNN(\n",
356 | " (conv_layer): Sequential(\n",
357 | " (0): Conv2d(1, 3, kernel_size=(3, 3), stride=(1, 1))\n",
358 | " (1): ReLU()\n",
359 | " (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
360 | " )\n",
361 | " (fc_layer): Sequential(\n",
362 | " (0): Linear(in_features=12, out_features=6, bias=True)\n",
363 | " (1): ReLU()\n",
364 | " (2): Linear(in_features=6, out_features=2, bias=True)\n",
365 | " )\n",
366 | ")"
367 | ]
368 | },
369 | "execution_count": 11,
370 | "metadata": {},
371 | "output_type": "execute_result"
372 | }
373 | ],
374 | "source": [
375 | "model"
376 | ]
377 | }
378 | ],
379 | "metadata": {
380 | "kernelspec": {
381 | "display_name": "Python 3",
382 | "language": "python",
383 | "name": "python3"
384 | },
385 | "language_info": {
386 | "codemirror_mode": {
387 | "name": "ipython",
388 | "version": 3
389 | },
390 | "file_extension": ".py",
391 | "mimetype": "text/x-python",
392 | "name": "python",
393 | "nbconvert_exporter": "python",
394 | "pygments_lexer": "ipython3",
395 | "version": "3.6.5"
396 | }
397 | },
398 | "nbformat": 4,
399 | "nbformat_minor": 2
400 | }
401 |
--------------------------------------------------------------------------------
/demos/Freeze Bayesian Neural Network.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Demo - Freeze Bayesian Neural Network"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": 1,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "import numpy as np\n",
17 | "\n",
18 | "import torch\n",
19 | "import torch.nn as nn\n",
20 | "import torch.optim as optim\n",
21 | "\n",
22 | "import torchbnn as bnn\n",
23 | "from torchbnn.utils import freeze, unfreeze"
24 | ]
25 | },
26 | {
27 | "cell_type": "code",
28 | "execution_count": 2,
29 | "metadata": {},
30 | "outputs": [],
31 | "source": [
32 | "import matplotlib.pyplot as plt\n",
33 | "%matplotlib inline"
34 | ]
35 | },
36 | {
37 | "cell_type": "markdown",
38 | "metadata": {},
39 | "source": [
40 | "## 2. Define Model"
41 | ]
42 | },
43 | {
44 | "cell_type": "code",
45 | "execution_count": 3,
46 | "metadata": {},
47 | "outputs": [],
48 | "source": [
49 | "model = nn.Sequential(\n",
50 | " bnn.BayesLinear(prior_mu=0, prior_sigma=0.05, in_features=2, out_features=2),\n",
51 | " nn.ReLU(),\n",
52 | " bnn.BayesLinear(prior_mu=0, prior_sigma=0.05, in_features=2, out_features=1),\n",
53 | ")"
54 | ]
55 | },
56 | {
57 | "cell_type": "markdown",
58 | "metadata": {},
59 | "source": [
60 | "## 3. Forward Model"
61 | ]
62 | },
63 | {
64 | "cell_type": "code",
65 | "execution_count": 4,
66 | "metadata": {},
67 | "outputs": [
68 | {
69 | "data": {
70 | "text/plain": [
71 | "tensor([[-0.4672]], grad_fn=)"
72 | ]
73 | },
74 | "execution_count": 4,
75 | "metadata": {},
76 | "output_type": "execute_result"
77 | }
78 | ],
79 | "source": [
80 | "model(torch.ones(1, 2))"
81 | ]
82 | },
83 | {
84 | "cell_type": "code",
85 | "execution_count": 5,
86 | "metadata": {},
87 | "outputs": [
88 | {
89 | "data": {
90 | "text/plain": [
91 | "tensor([[-0.3220]], grad_fn=)"
92 | ]
93 | },
94 | "execution_count": 5,
95 | "metadata": {},
96 | "output_type": "execute_result"
97 | }
98 | ],
99 | "source": [
100 | "model(torch.ones(1, 2))"
101 | ]
102 | },
103 | {
104 | "cell_type": "markdown",
105 | "metadata": {},
106 | "source": [
107 | "## 3. Freeze Model"
108 | ]
109 | },
110 | {
111 | "cell_type": "code",
112 | "execution_count": 6,
113 | "metadata": {},
114 | "outputs": [],
115 | "source": [
116 | "freeze(model)"
117 | ]
118 | },
119 | {
120 | "cell_type": "code",
121 | "execution_count": 7,
122 | "metadata": {},
123 | "outputs": [
124 | {
125 | "data": {
126 | "text/plain": [
127 | "tensor([[-0.4340]], grad_fn=)"
128 | ]
129 | },
130 | "execution_count": 7,
131 | "metadata": {},
132 | "output_type": "execute_result"
133 | }
134 | ],
135 | "source": [
136 | "model(torch.ones(1, 2))"
137 | ]
138 | },
139 | {
140 | "cell_type": "code",
141 | "execution_count": 8,
142 | "metadata": {},
143 | "outputs": [
144 | {
145 | "data": {
146 | "text/plain": [
147 | "tensor([[-0.4340]], grad_fn=)"
148 | ]
149 | },
150 | "execution_count": 8,
151 | "metadata": {},
152 | "output_type": "execute_result"
153 | }
154 | ],
155 | "source": [
156 | "model(torch.ones(1, 2))"
157 | ]
158 | },
159 | {
160 | "cell_type": "code",
161 | "execution_count": 9,
162 | "metadata": {},
163 | "outputs": [],
164 | "source": [
165 | "freeze(model)"
166 | ]
167 | },
168 | {
169 | "cell_type": "code",
170 | "execution_count": 10,
171 | "metadata": {},
172 | "outputs": [
173 | {
174 | "data": {
175 | "text/plain": [
176 | "tensor([[-0.2875]], grad_fn=)"
177 | ]
178 | },
179 | "execution_count": 10,
180 | "metadata": {},
181 | "output_type": "execute_result"
182 | }
183 | ],
184 | "source": [
185 | "model(torch.ones(1, 2))"
186 | ]
187 | },
188 | {
189 | "cell_type": "code",
190 | "execution_count": 11,
191 | "metadata": {},
192 | "outputs": [
193 | {
194 | "data": {
195 | "text/plain": [
196 | "tensor([[-0.2875]], grad_fn=)"
197 | ]
198 | },
199 | "execution_count": 11,
200 | "metadata": {},
201 | "output_type": "execute_result"
202 | }
203 | ],
204 | "source": [
205 | "model(torch.ones(1, 2))"
206 | ]
207 | },
208 | {
209 | "cell_type": "markdown",
210 | "metadata": {},
211 | "source": [
212 | "## 4. Unfreeze Model"
213 | ]
214 | },
215 | {
216 | "cell_type": "code",
217 | "execution_count": 12,
218 | "metadata": {},
219 | "outputs": [],
220 | "source": [
221 | "unfreeze(model)"
222 | ]
223 | },
224 | {
225 | "cell_type": "code",
226 | "execution_count": 13,
227 | "metadata": {},
228 | "outputs": [
229 | {
230 | "data": {
231 | "text/plain": [
232 | "tensor([[-0.4530]], grad_fn=)"
233 | ]
234 | },
235 | "execution_count": 13,
236 | "metadata": {},
237 | "output_type": "execute_result"
238 | }
239 | ],
240 | "source": [
241 | "model(torch.ones(1, 2))"
242 | ]
243 | },
244 | {
245 | "cell_type": "code",
246 | "execution_count": 14,
247 | "metadata": {},
248 | "outputs": [
249 | {
250 | "data": {
251 | "text/plain": [
252 | "tensor([[-0.4920]], grad_fn=)"
253 | ]
254 | },
255 | "execution_count": 14,
256 | "metadata": {},
257 | "output_type": "execute_result"
258 | }
259 | ],
260 | "source": [
261 | "model(torch.ones(1, 2))"
262 | ]
263 | }
264 | ],
265 | "metadata": {
266 | "kernelspec": {
267 | "display_name": "Python 3",
268 | "language": "python",
269 | "name": "python3"
270 | },
271 | "language_info": {
272 | "codemirror_mode": {
273 | "name": "ipython",
274 | "version": 3
275 | },
276 | "file_extension": ".py",
277 | "mimetype": "text/x-python",
278 | "name": "python",
279 | "nbconvert_exporter": "python",
280 | "pygments_lexer": "ipython3",
281 | "version": "3.6.5"
282 | }
283 | },
284 | "nbformat": 4,
285 | "nbformat_minor": 2
286 | }
287 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line.
5 | SPHINXOPTS =
6 | SPHINXBUILD = sphinx-build
7 | SPHINXPROJ = torchbnn
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # Configuration file for the Sphinx documentation builder.
4 | #
5 | # This file does only contain a selection of the most common options. For a
6 | # full list see the documentation:
7 | # http://www.sphinx-doc.org/en/master/config
8 |
9 | # -- Path setup --------------------------------------------------------------
10 |
11 | # If extensions (or modules to document with autodoc) are in another directory,
12 | # add these directories to sys.path here. If the directory is relative to the
13 | # documentation root, use os.path.abspath to make it absolute, like shown here.
14 | #
15 | import os
16 | import sys
17 | sys.path.insert(0, os.path.abspath('..'))
18 |
19 |
20 | # -- Project information -----------------------------------------------------
21 |
22 | project = 'torchbnn'
23 | copyright = '2020, harrykim'
24 | author = 'harrykim'
25 |
26 | # The short X.Y version
27 | version = 'v1.1'
28 | # The full version, including alpha/beta/rc tags
29 | release = 'v1.1'
30 |
31 |
32 | # -- General configuration ---------------------------------------------------
33 |
34 | # If your documentation needs a minimal Sphinx version, state it here.
35 | #
36 | # needs_sphinx = '1.0'
37 |
38 | # Add any Sphinx extension module names here, as strings. They can be
39 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
40 | # ones.
41 | extensions = [
42 | 'sphinx.ext.napoleon',
43 | 'sphinx.ext.autodoc',
44 | 'sphinx.ext.todo',
45 | 'sphinx.ext.mathjax',
46 | 'sphinx.ext.viewcode',
47 | 'sphinx.ext.githubpages',
48 | ]
49 |
50 | # Add any paths that contain templates here, relative to this directory.
51 | templates_path = ['_templates']
52 |
53 | # The suffix(es) of source filenames.
54 | # You can specify multiple suffix as a list of string:
55 | #
56 | # source_suffix = ['.rst', '.md']
57 | source_suffix = '.rst'
58 |
59 | # The master toctree document.
60 | master_doc = 'index'
61 |
62 | # The language for content autogenerated by Sphinx. Refer to documentation
63 | # for a list of supported languages.
64 | #
65 | # This is also used if you do content translation via gettext catalogs.
66 | # Usually you set "language" from the command line for these cases.
67 | language = None
68 |
69 | # List of patterns, relative to source directory, that match files and
70 | # directories to ignore when looking for source files.
71 | # This pattern also affects html_static_path and html_extra_path .
72 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
73 |
74 | # The name of the Pygments (syntax highlighting) style to use.
75 | pygments_style = 'sphinx'
76 |
77 |
78 | # -- Options for HTML output -------------------------------------------------
79 |
80 | # The theme to use for HTML and HTML Help pages. See the documentation for
81 | # a list of builtin themes.
82 | #
83 | html_theme = 'sphinx_rtd_theme'
84 |
85 | # Theme options are theme-specific and customize the look and feel of a theme
86 | # further. For a list of options available for each theme, see the
87 | # documentation.
88 | #
89 | # html_theme_options = {}
90 |
91 | # Add any paths that contain custom static files (such as style sheets) here,
92 | # relative to this directory. They are copied after the builtin static files,
93 | # so a file named "default.css" will overwrite the builtin "default.css".
94 | html_static_path = ['_static']
95 |
96 | # Custom sidebar templates, must be a dictionary that maps document names
97 | # to template names.
98 | #
99 | # The default sidebars (for documents that don't match any pattern) are
100 | # defined by theme itself. Builtin themes are using these templates by
101 | # default: ``['localtoc.html', 'relations.html', 'sourcelink.html',
102 | # 'searchbox.html']``.
103 | #
104 | # html_sidebars = {}
105 |
106 |
107 | # -- Options for HTMLHelp output ---------------------------------------------
108 |
109 | # Output file base name for HTML help builder.
110 | htmlhelp_basename = 'torchbnndoc'
111 |
112 |
113 | # -- Options for LaTeX output ------------------------------------------------
114 |
115 | latex_elements = {
116 | # The paper size ('letterpaper' or 'a4paper').
117 | #
118 | # 'papersize': 'letterpaper',
119 |
120 | # The font size ('10pt', '11pt' or '12pt').
121 | #
122 | # 'pointsize': '10pt',
123 |
124 | # Additional stuff for the LaTeX preamble.
125 | #
126 | # 'preamble': '',
127 |
128 | # Latex figure (float) alignment
129 | #
130 | # 'figure_align': 'htbp',
131 | }
132 |
133 | # Grouping the document tree into LaTeX files. List of tuples
134 | # (source start file, target name, title,
135 | # author, documentclass [howto, manual, or own class]).
136 | latex_documents = [
137 | (master_doc, 'torchbnn.tex', 'torchbnn Documentation',
138 | 'harrykim', 'manual'),
139 | ]
140 |
141 |
142 | # -- Options for manual page output ------------------------------------------
143 |
144 | # One entry per manual page. List of tuples
145 | # (source start file, name, description, authors, manual section).
146 | man_pages = [
147 | (master_doc, 'torchbnn', 'torchbnn Documentation',
148 | [author], 1)
149 | ]
150 |
151 |
152 | # -- Options for Texinfo output ----------------------------------------------
153 |
154 | # Grouping the document tree into Texinfo files. List of tuples
155 | # (source start file, target name, title, author,
156 | # dir menu entry, description, category)
157 | texinfo_documents = [
158 | (master_doc, 'torchbnn', 'torchbnn Documentation',
159 | author, 'torchbnn', 'One line description of project.',
160 | 'Miscellaneous'),
161 | ]
162 |
163 |
164 | # -- Extension configuration -------------------------------------------------
165 |
166 | # -- Options for todo extension ----------------------------------------------
167 |
168 | # If true, `todo` and `todoList` produce output, else they produce nothing.
169 | todo_include_todos = True
--------------------------------------------------------------------------------
/docs/functional.rst:
--------------------------------------------------------------------------------
1 | .. torchattacks documentation master file, created by
2 | sphinx-quickstart on Sun Apr 12 23:38:13 2020.
3 | You can adapt this file completely to your liking, but it should at least
4 | contain the root `toctree` directive.
5 |
6 | Functional
7 | ========================================
8 |
9 | Bayesian KL Loss
10 | ~~~~~~~~~~~~~~~~~~~~~
11 | .. automodule:: torchbnn.functional
12 | :members:
13 | :undoc-members:
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | .. torchbnn documentation master file, created by
2 | sphinx-quickstart on Tue Apr 21 13:12:54 2020.
3 | You can adapt this file completely to your liking, but it should at least
4 | contain the root `toctree` directive.
5 |
6 | torchbnn v1.1
7 | ====================================
8 |
9 | .. toctree::
10 | :maxdepth: 2
11 | :caption: Contents:
12 |
13 | modules
14 | utils
15 | functional
16 |
17 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=.
11 | set BUILDDIR=_build
12 | set SPHINXPROJ=torchbnn
13 |
14 | if "%1" == "" goto help
15 |
16 | %SPHINXBUILD% >NUL 2>NUL
17 | if errorlevel 9009 (
18 | echo.
19 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
20 | echo.installed, then set the SPHINXBUILD environment variable to point
21 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
22 | echo.may add the Sphinx directory to PATH.
23 | echo.
24 | echo.If you don't have Sphinx installed, grab it from
25 | echo.http://sphinx-doc.org/
26 | exit /b 1
27 | )
28 |
29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
30 | goto end
31 |
32 | :help
33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
34 |
35 | :end
36 | popd
37 |
--------------------------------------------------------------------------------
/docs/modules.rst:
--------------------------------------------------------------------------------
1 | .. torchattacks documentation master file, created by
2 | sphinx-quickstart on Sun Apr 12 23:38:13 2020.
3 | You can adapt this file completely to your liking, but it should at least
4 | contain the root `toctree` directive.
5 |
6 | Modules
7 | ========================================
8 |
9 | Bayes Module
10 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
11 | .. automodule:: torchbnn.modules.module
12 | :members:
13 |
14 | Bayes Linear
15 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
16 | .. automodule:: torchbnn.modules.linear
17 | :members:
18 |
19 | Bayes Conv
20 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
21 | .. automodule:: torchbnn.modules.conv
22 | :members:
23 |
24 | Bayes Batchnorm
25 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
26 | .. automodule:: torchbnn.modules.batchnorm
27 | :members:
28 |
29 | BKLLoss
30 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
31 | .. automodule:: torchbnn.modules.loss
32 | :members:
--------------------------------------------------------------------------------
/docs/utils.rst:
--------------------------------------------------------------------------------
1 | .. torchattacks documentation master file, created by
2 | sphinx-quickstart on Sun Apr 12 23:38:13 2020.
3 | You can adapt this file completely to your liking, but it should at least
4 | contain the root `toctree` directive.
5 |
6 | Utils
7 | ========================================
8 |
9 | Freeze Model
10 | ~~~~~~~~~~~~~~~~~~~~~
11 | .. automodule:: torchbnn.utils.freeze_model
12 | :members:
13 | :undoc-members:
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.2.0
--------------------------------------------------------------------------------
/torchbnn/__init__.py:
--------------------------------------------------------------------------------
1 | from .modules import *
2 | from . import utils
3 |
4 | __version__ = 1.2
--------------------------------------------------------------------------------
/torchbnn/functional.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 |
4 | from .modules import *
5 |
6 | def _kl_loss(mu_0, log_sigma_0, mu_1, log_sigma_1) :
7 | """
8 | An method for calculating KL divergence between two Normal distribtuion.
9 |
10 | Arguments:
11 | mu_0 (Float) : mean of normal distribution.
12 | log_sigma_0 (Float): log(standard deviation of normal distribution).
13 | mu_1 (Float): mean of normal distribution.
14 | log_sigma_1 (Float): log(standard deviation of normal distribution).
15 |
16 | """
17 | kl = log_sigma_1 - log_sigma_0 + \
18 | (torch.exp(log_sigma_0)**2 + (mu_0-mu_1)**2)/(2*math.exp(log_sigma_1)**2) - 0.5
19 | return kl.sum()
20 |
21 | def bayesian_kl_loss(model, reduction='mean', last_layer_only=False) :
22 | """
23 | An method for calculating KL divergence of whole layers in the model.
24 |
25 |
26 | Arguments:
27 | model (nn.Module): a model to be calculated for KL-divergence.
28 | reduction (string, optional): Specifies the reduction to apply to the output:
29 | ``'mean'``: the sum of the output will be divided by the number of
30 | elements of the output.
31 | ``'sum'``: the output will be summed.
32 | last_layer_only (Bool): True for return only the last layer's KL divergence.
33 |
34 | """
35 | device = torch.device("cuda" if next(model.parameters()).is_cuda else "cpu")
36 | kl = torch.Tensor([0]).to(device)
37 | kl_sum = torch.Tensor([0]).to(device)
38 | n = torch.Tensor([0]).to(device)
39 |
40 | for m in model.modules() :
41 | if isinstance(m, (BayesLinear, BayesConv2d)):
42 | kl = _kl_loss(m.weight_mu, m.weight_log_sigma, m.prior_mu, m.prior_log_sigma)
43 | kl_sum += kl
44 | n += len(m.weight_mu.view(-1))
45 |
46 | if m.bias :
47 | kl = _kl_loss(m.bias_mu, m.bias_log_sigma, m.prior_mu, m.prior_log_sigma)
48 | kl_sum += kl
49 | n += len(m.bias_mu.view(-1))
50 |
51 | if isinstance(m, BayesBatchNorm2d):
52 | if m.affine :
53 | kl = _kl_loss(m.weight_mu, m.weight_log_sigma, m.prior_mu, m.prior_log_sigma)
54 | kl_sum += kl
55 | n += len(m.weight_mu.view(-1))
56 |
57 | kl = _kl_loss(m.bias_mu, m.bias_log_sigma, m.prior_mu, m.prior_log_sigma)
58 | kl_sum += kl
59 | n += len(m.bias_mu.view(-1))
60 |
61 | if last_layer_only or n == 0 :
62 | return kl
63 |
64 | if reduction == 'mean' :
65 | return kl_sum/n
66 | elif reduction == 'sum' :
67 | return kl_sum
68 | else :
69 | raise ValueError(reduction + " is not valid")
70 |
71 |
72 |
--------------------------------------------------------------------------------
/torchbnn/modules/__init__.py:
--------------------------------------------------------------------------------
1 | from .linear import BayesLinear
2 | from .conv import BayesConv2d
3 | from .batchnorm import BayesBatchNorm2d
4 | from .loss import BKLLoss
--------------------------------------------------------------------------------
/torchbnn/modules/batchnorm.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | from torch.nn import Module, Parameter
5 | import torch.nn.init as init
6 | import torch.nn.functional as F
7 |
8 | class _BayesBatchNorm(Module):
9 | r"""
10 | Applies Bayesian Batch Normalization over a 2D or 3D input
11 |
12 | Arguments:
13 | prior_mu (Float): mean of prior normal distribution.
14 | prior_sigma (Float): sigma of prior normal distribution.
15 |
16 | .. note:: other arguments are following batchnorm of pytorch 1.2.0.
17 | https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py
18 |
19 | """
20 |
21 | _version = 2
22 | __constants__ = ['prior_mu', 'prior_sigma', 'track_running_stats',
23 | 'momentum', 'eps', 'weight', 'bias',
24 | 'running_mean', 'running_var', 'num_batches_tracked',
25 | 'num_features', 'affine']
26 |
27 | def __init__(self, prior_mu, prior_sigma, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True):
28 | super(_BayesBatchNorm, self).__init__()
29 | self.num_features = num_features
30 | self.eps = eps
31 | self.momentum = momentum
32 | self.affine = affine
33 | self.track_running_stats = track_running_stats
34 | if self.affine:
35 | self.prior_mu = prior_mu
36 | self.prior_sigma = prior_sigma
37 | self.prior_log_sigma = math.log(prior_sigma)
38 |
39 | self.weight_mu = Parameter(torch.Tensor(num_features))
40 | self.weight_log_sigma = Parameter(torch.Tensor(num_features))
41 | self.register_buffer('weight_eps', None)
42 |
43 | self.bias_mu = Parameter(torch.Tensor(num_features))
44 | self.bias_log_sigma = Parameter(torch.Tensor(num_features))
45 | self.register_buffer('bias_eps', None)
46 | else:
47 | self.register_parameter('weight_mu', None)
48 | self.register_parameter('weight_log_sigma', None)
49 | self.register_buffer('weight_eps', None)
50 | self.register_parameter('bias_mu', None)
51 | self.register_parameter('bias_log_sigma', None)
52 | self.register_buffer('bias_eps', None)
53 | if self.track_running_stats:
54 | self.register_buffer('running_mean', torch.zeros(num_features))
55 | self.register_buffer('running_var', torch.ones(num_features))
56 | self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
57 | else:
58 | self.register_parameter('running_mean', None)
59 | self.register_parameter('running_var', None)
60 | self.register_parameter('num_batches_tracked', None)
61 | self.reset_parameters()
62 |
63 | def reset_running_stats(self):
64 | if self.track_running_stats:
65 | self.running_mean.zero_()
66 | self.running_var.fill_(1)
67 | self.num_batches_tracked.zero_()
68 |
69 | def reset_parameters(self):
70 | self.reset_running_stats()
71 | if self.affine:
72 | # Initialization method of Adv-BNN.
73 | self.weight_mu.data.uniform_()
74 | self.weight_log_sigma.data.fill_(self.prior_log_sigma)
75 | self.bias_mu.data.zero_()
76 | self.bias_log_sigma.data.fill_(self.prior_log_sigma)
77 |
78 | # Initilization method of the original torch nn.batchnorm.
79 | # init.ones_(self.weight_mu)
80 | # self.weight_log_sigma.data.fill_(self.prior_log_sigma)
81 | # init.zeros_(self.bias_mu)
82 | # self.bias_log_sigma.data.fill_(self.prior_log_sigma)
83 |
84 | def freeze(self) :
85 | if self.affine :
86 | self.weight_eps = torch.randn_like(self.weight_log_sigma)
87 | self.bias_eps = torch.randn_like(self.bias_log_sigma)
88 |
89 | def unfreeze(self) :
90 | if self.affine :
91 | self.weight_eps = None
92 | self.bias_eps = None
93 |
94 | def _check_input_dim(self, input):
95 | raise NotImplementedError
96 |
97 | def forward(self, input):
98 | self._check_input_dim(input)
99 |
100 | if self.momentum is None:
101 | exponential_average_factor = 0.0
102 | else:
103 | exponential_average_factor = self.momentum
104 |
105 | if self.training and self.track_running_stats:
106 | if self.num_batches_tracked is not None:
107 | self.num_batches_tracked += 1
108 | if self.momentum is None:
109 | exponential_average_factor = 1.0 / float(self.num_batches_tracked)
110 | else:
111 | exponential_average_factor = self.momentum
112 |
113 | if self.affine :
114 | if self.weight_eps is None :
115 | weight = self.weight_mu + torch.exp(self.weight_log_sigma) * torch.randn_like(self.weight_log_sigma)
116 | bias = self.bias_mu + torch.exp(self.bias_log_sigma) * torch.randn_like(self.bias_log_sigma)
117 | else :
118 | weight = self.weight_mu + torch.exp(self.weight_log_sigma) * self.weight_eps
119 | bias = self.bias_mu + torch.exp(self.bias_log_sigma) * self.bias_eps
120 | else :
121 | weight = None
122 | bias = None
123 |
124 | return F.batch_norm(
125 | input, self.running_mean, self.running_var, weight, bias,
126 | self.training or not self.track_running_stats,
127 | exponential_average_factor, self.eps)
128 |
129 | def extra_repr(self):
130 | return '{prior_mu}, {prior_sigma}, {num_features}, ' \
131 | 'eps={eps}, momentum={momentum}, affine={affine}, ' \
132 | 'track_running_stats={track_running_stats}'.format(**self.__dict__)
133 |
134 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
135 | missing_keys, unexpected_keys, error_msgs):
136 | version = local_metadata.get('version', None)
137 |
138 | if (version is None or version < 2) and self.track_running_stats:
139 | num_batches_tracked_key = prefix + 'num_batches_tracked'
140 | if num_batches_tracked_key not in state_dict:
141 | state_dict[num_batches_tracked_key] = torch.tensor(0, dtype=torch.long)
142 |
143 | super(_BayesBatchNorm, self)._load_from_state_dict(
144 | state_dict, prefix, local_metadata, strict,
145 | missing_keys, unexpected_keys, error_msgs)
146 |
147 | class BayesBatchNorm2d(_BayesBatchNorm):
148 | r"""
149 | Applies Bayesian Batch Normalization over a 2D input
150 |
151 | Arguments:
152 | prior_mu (Float): mean of prior normal distribution.
153 | prior_sigma (Float): sigma of prior normal distribution.
154 |
155 | .. note:: other arguments are following batchnorm of pytorch 1.2.0.
156 | https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py
157 |
158 | """
159 |
160 | def _check_input_dim(self, input):
161 | if input.dim() != 4:
162 | raise ValueError('expected 4D input (got {}D input)'
163 | .format(input.dim()))
--------------------------------------------------------------------------------
/torchbnn/modules/conv.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn.init as init
5 | from torch.nn import Module, Parameter
6 | import torch.nn.functional as F
7 |
8 | from torch.nn.modules.utils import _single, _pair, _triple
9 |
10 |
11 | class _BayesConvNd(Module):
12 | r"""
13 | Applies Bayesian Convolution
14 |
15 | Arguments:
16 | prior_mu (Float): mean of prior normal distribution.
17 | prior_sigma (Float): sigma of prior normal distribution.
18 |
19 | .. note:: other arguments are following conv of pytorch 1.2.0.
20 | https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/conv.py
21 | """
22 | __constants__ = ['prior_mu', 'prior_sigma', 'stride', 'padding', 'dilation',
23 | 'groups', 'bias', 'padding_mode', 'output_padding', 'in_channels',
24 | 'out_channels', 'kernel_size']
25 |
26 | def __init__(self, prior_mu, prior_sigma, in_channels, out_channels, kernel_size, stride,
27 | padding, dilation, transposed, output_padding,
28 | groups, bias, padding_mode):
29 | super(_BayesConvNd, self).__init__()
30 | if in_channels % groups != 0:
31 | raise ValueError('in_channels must be divisible by groups')
32 | if out_channels % groups != 0:
33 | raise ValueError('out_channels must be divisible by groups')
34 | self.in_channels = in_channels
35 | self.out_channels = out_channels
36 | self.kernel_size = kernel_size
37 | self.stride = stride
38 | self.padding = padding
39 | self.dilation = dilation
40 | self.transposed = transposed
41 | self.output_padding = output_padding
42 | self.groups = groups
43 | self.padding_mode = padding_mode
44 |
45 | self.prior_mu = prior_mu
46 | self.prior_sigma = prior_sigma
47 | self.prior_log_sigma = math.log(prior_sigma)
48 |
49 | if transposed:
50 | self.weight_mu = Parameter(torch.Tensor(
51 | in_channels, out_channels // groups, *kernel_size))
52 | self.weight_log_sigma = Parameter(torch.Tensor(
53 | in_channels, out_channels // groups, *kernel_size))
54 | self.register_buffer('weight_eps', None)
55 | else:
56 | self.weight_mu = Parameter(torch.Tensor(
57 | out_channels, in_channels // groups, *kernel_size))
58 | self.weight_log_sigma = Parameter(torch.Tensor(
59 | out_channels, in_channels // groups, *kernel_size))
60 | self.register_buffer('weight_eps', None)
61 |
62 | if bias is None or bias is False :
63 | self.bias = False
64 | else :
65 | self.bias = True
66 |
67 | if self.bias:
68 | self.bias_mu = Parameter(torch.Tensor(out_channels))
69 | self.bias_log_sigma = Parameter(torch.Tensor(out_channels))
70 | self.register_buffer('bias_eps', None)
71 | else:
72 | self.register_parameter('bias_mu', None)
73 | self.register_parameter('bias_log_sigma', None)
74 | self.register_buffer('bias_eps', None)
75 |
76 | self.reset_parameters()
77 |
78 | def reset_parameters(self):
79 | # Initialization method of Adv-BNN.
80 | n = self.in_channels
81 | n *= self.kernel_size[0] ** 2
82 | stdv = 1.0 / math.sqrt(n)
83 | self.weight_mu.data.uniform_(-stdv, stdv)
84 | self.weight_log_sigma.data.fill_(self.prior_log_sigma)
85 |
86 | if self.bias :
87 | self.bias_mu.data.uniform_(-stdv, stdv)
88 | self.bias_log_sigma.data.fill_(self.prior_log_sigma)
89 |
90 | # Initialization method of the original torch nn.conv.
91 | # init.kaiming_uniform_(self.weight_mu, a=math.sqrt(5))
92 | # self.weight_log_sigma.data.fill_(self.prior_log_sigma)
93 |
94 | # if self.bias :
95 | # fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight_mu)
96 | # bound = 1 / math.sqrt(fan_in)
97 | # init.uniform_(self.bias_mu, -bound, bound)
98 |
99 | # self.bias_log_sigma.data.fill_(self.prior_log_sigma)
100 |
101 | def freeze(self) :
102 | self.weight_eps = torch.randn_like(self.weight_log_sigma)
103 | if self.bias :
104 | self.bias_eps = torch.randn_like(self.bias_log_sigma)
105 |
106 | def unfreeze(self) :
107 | self.weight_eps = None
108 | if self.bias :
109 | self.bias_eps = None
110 |
111 | def extra_repr(self):
112 | s = ('{prior_mu}, {prior_sigma}'
113 | ', {in_channels}, {out_channels}, kernel_size={kernel_size}'
114 | ', stride={stride}')
115 | if self.padding != (0,) * len(self.padding):
116 | s += ', padding={padding}'
117 | if self.dilation != (1,) * len(self.dilation):
118 | s += ', dilation={dilation}'
119 | if self.output_padding != (0,) * len(self.output_padding):
120 | s += ', output_padding={output_padding}'
121 | if self.groups != 1:
122 | s += ', groups={groups}'
123 | if self.bias is False:
124 | s += ', bias=False'
125 | return s.format(**self.__dict__)
126 |
127 | def __setstate__(self, state):
128 | super(_BayesConvNd, self).__setstate__(state)
129 | if not hasattr(self, 'padding_mode'):
130 | self.padding_mode = 'zeros'
131 |
132 | class BayesConv2d(_BayesConvNd):
133 | r"""
134 | Applies Bayesian Convolution for 2D inputs
135 |
136 | Arguments:
137 | prior_mu (Float): mean of prior normal distribution.
138 | prior_sigma (Float): sigma of prior normal distribution.
139 |
140 | .. note:: other arguments are following conv of pytorch 1.2.0.
141 | https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/conv.py
142 |
143 | """
144 | def __init__(self, prior_mu, prior_sigma, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'):
145 | kernel_size = _pair(kernel_size)
146 | stride = _pair(stride)
147 | padding = _pair(padding)
148 | dilation = _pair(dilation)
149 | super(BayesConv2d, self).__init__(
150 | prior_mu, prior_sigma, in_channels, out_channels, kernel_size, stride,
151 | padding, dilation, False, _pair(0), groups, bias, padding_mode)
152 |
153 | def conv2d_forward(self, input, weight):
154 |
155 | if self.bias:
156 | if self.bias_eps is None :
157 | bias = self.bias_mu + torch.exp(self.bias_log_sigma) * torch.randn_like(self.bias_log_sigma)
158 | else :
159 | bias = self.bias_mu + torch.exp(self.bias_log_sigma) * self.bias_eps
160 | else :
161 | bias = None
162 |
163 | if self.padding_mode == 'circular':
164 | expanded_padding = ((self.padding[1] + 1) // 2, self.padding[1] // 2,
165 | (self.padding[0] + 1) // 2, self.padding[0] // 2)
166 | return F.conv2d(F.pad(input, expanded_padding, mode='circular'),
167 | weight, bias, self.stride,
168 | _pair(0), self.dilation, self.groups)
169 | return F.conv2d(input, weight, bias, self.stride,
170 | self.padding, self.dilation, self.groups)
171 |
172 | def forward(self, input):
173 | r"""
174 | Overriden.
175 | """
176 | if self.weight_eps is None :
177 | weight = self.weight_mu + torch.exp(self.weight_log_sigma) * torch.randn_like(self.weight_log_sigma)
178 | else :
179 | weight = self.weight_mu + torch.exp(self.weight_log_sigma) * self.weight_eps
180 |
181 | return self.conv2d_forward(input, weight)
--------------------------------------------------------------------------------
/torchbnn/modules/linear.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | from torch.nn import Module, Parameter
5 | import torch.nn.init as init
6 | import torch.nn.functional as F
7 |
8 | class BayesLinear(Module):
9 | r"""
10 | Applies Bayesian Linear
11 |
12 | Arguments:
13 | prior_mu (Float): mean of prior normal distribution.
14 | prior_sigma (Float): sigma of prior normal distribution.
15 |
16 | .. note:: other arguments are following linear of pytorch 1.2.0.
17 | https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/linear.py
18 |
19 | """
20 | __constants__ = ['prior_mu', 'prior_sigma', 'bias', 'in_features', 'out_features']
21 |
22 | def __init__(self, prior_mu, prior_sigma, in_features, out_features, bias=True):
23 | super(BayesLinear, self).__init__()
24 | self.in_features = in_features
25 | self.out_features = out_features
26 |
27 | self.prior_mu = prior_mu
28 | self.prior_sigma = prior_sigma
29 | self.prior_log_sigma = math.log(prior_sigma)
30 |
31 | self.weight_mu = Parameter(torch.Tensor(out_features, in_features))
32 | self.weight_log_sigma = Parameter(torch.Tensor(out_features, in_features))
33 | self.register_buffer('weight_eps', None)
34 |
35 | if bias is None or bias is False :
36 | self.bias = False
37 | else :
38 | self.bias = True
39 |
40 | if self.bias:
41 | self.bias_mu = Parameter(torch.Tensor(out_features))
42 | self.bias_log_sigma = Parameter(torch.Tensor(out_features))
43 | self.register_buffer('bias_eps', None)
44 | else:
45 | self.register_parameter('bias_mu', None)
46 | self.register_parameter('bias_log_sigma', None)
47 | self.register_buffer('bias_eps', None)
48 |
49 | self.reset_parameters()
50 |
51 | def reset_parameters(self):
52 | # Initialization method of Adv-BNN
53 | stdv = 1. / math.sqrt(self.weight_mu.size(1))
54 | self.weight_mu.data.uniform_(-stdv, stdv)
55 | self.weight_log_sigma.data.fill_(self.prior_log_sigma)
56 | if self.bias :
57 | self.bias_mu.data.uniform_(-stdv, stdv)
58 | self.bias_log_sigma.data.fill_(self.prior_log_sigma)
59 |
60 | # Initialization method of the original torch nn.linear.
61 | # init.kaiming_uniform_(self.weight_mu, a=math.sqrt(5))
62 | # self.weight_log_sigma.data.fill_(self.prior_log_sigma)
63 |
64 | # if self.bias :
65 | # fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight_mu)
66 | # bound = 1 / math.sqrt(fan_in)
67 | # init.uniform_(self.bias_mu, -bound, bound)
68 |
69 | # self.bias_log_sigma.data.fill_(self.prior_log_sigma)
70 |
71 | def freeze(self) :
72 | self.weight_eps = torch.randn_like(self.weight_log_sigma)
73 | if self.bias :
74 | self.bias_eps = torch.randn_like(self.bias_log_sigma)
75 |
76 | def unfreeze(self) :
77 | self.weight_eps = None
78 | if self.bias :
79 | self.bias_eps = None
80 |
81 | def forward(self, input):
82 | r"""
83 | Overriden.
84 | """
85 | if self.weight_eps is None :
86 | weight = self.weight_mu + torch.exp(self.weight_log_sigma) * torch.randn_like(self.weight_log_sigma)
87 | else :
88 | weight = self.weight_mu + torch.exp(self.weight_log_sigma) * self.weight_eps
89 |
90 | if self.bias:
91 | if self.bias_eps is None :
92 | bias = self.bias_mu + torch.exp(self.bias_log_sigma) * torch.randn_like(self.bias_log_sigma)
93 | else :
94 | bias = self.bias_mu + torch.exp(self.bias_log_sigma) * self.bias_eps
95 | else :
96 | bias = None
97 |
98 | return F.linear(input, weight, bias)
99 |
100 | def extra_repr(self):
101 | r"""
102 | Overriden.
103 | """
104 | return 'prior_mu={}, prior_sigma={}, in_features={}, out_features={}, bias={}'.format(self.prior_mu, self.prior_sigma, self.in_features, self.out_features, self.bias is not None)
--------------------------------------------------------------------------------
/torchbnn/modules/loss.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | from torch.nn import Module
4 | from torch.nn import functional as F
5 | from torch.nn import _reduction as _Reduction
6 |
7 | from .. import functional as BF
8 |
9 | class _Loss(Module):
10 | def __init__(self, reduction='mean'):
11 | super(_Loss, self).__init__()
12 | self.reduction = reduction
13 |
14 | class BKLLoss(_Loss):
15 | """
16 | Loss for calculating KL divergence of baysian neural network model.
17 |
18 | Arguments:
19 | reduction (string, optional): Specifies the reduction to apply to the output:
20 | ``'mean'``: the sum of the output will be divided by the number of
21 | elements of the output.
22 | ``'sum'``: the output will be summed.
23 | last_layer_only (Bool): True for return only the last layer's KL divergence.
24 | """
25 | __constants__ = ['reduction']
26 |
27 | def __init__(self, reduction='mean', last_layer_only=False):
28 | super(BKLLoss, self).__init__(reduction)
29 | self.last_layer_only = last_layer_only
30 |
31 | def forward(self, model):
32 | """
33 | Arguments:
34 | model (nn.Module): a model to be calculated for KL-divergence.
35 | """
36 | return BF.bayesian_kl_loss(model, reduction=self.reduction, last_layer_only=self.last_layer_only)
--------------------------------------------------------------------------------
/torchbnn/modules/module.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import Module
3 |
4 |
5 | class BayesModule(Module) :
6 | r"""
7 | Applies Bayesian Module
8 | Currently this module is not being used as base of bayesian modules because it has not many utilies yet,
9 | However, it can be used in the near future for convenience.
10 | """
11 |
12 | def freeze(self):
13 | r"""Sets the module in freezed mode.
14 | This has effect on bayesian modules. It will fix epsilons, e.g. weight_eps, bias_eps.
15 | Thus, bayesian neural networks will return same results with same inputs.
16 | """
17 | self.freeze = True
18 | for module in self.children():
19 | module.freeze(mode)
20 | return self
21 |
22 | def unfreeze(self):
23 | r"""Sets the module in unfreezed mode.
24 | This has effect on bayesian modules. It will unfix epsilons, e.g. weight_eps, bias_eps.
25 | Thus, bayesian neural networks will return different results even if same inputs are given.
26 | """
27 | self.freeze = False
28 | for module in self.children():
29 | module.unfreeze(mode)
30 | return self
--------------------------------------------------------------------------------
/torchbnn/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .freeze_model import freeze, unfreeze
--------------------------------------------------------------------------------
/torchbnn/utils/freeze_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from ..modules import *
4 |
5 | bayes_layer = (BayesLinear, BayesConv2d, BayesBatchNorm2d)
6 |
7 | def freeze(module):
8 | """
9 | Methods for freezing bayesian-model.
10 |
11 | Arguments:
12 | model (nn.Module): a model to be freezed.
13 |
14 | """
15 |
16 | if isinstance(module, bayes_layer) :
17 | module.freeze()
18 | for submodule in module.children() :
19 | freeze(submodule)
20 |
21 |
22 | def unfreeze(module):
23 | """
24 | Methods for unfreezing bayesian-model.
25 |
26 | Arguments:
27 | model (nn.Module): a model to be unfreezed.
28 |
29 | """
30 | if isinstance(module, bayes_layer) :
31 | module.unfreeze()
32 | for submodule in module.children() :
33 | unfreeze(submodule)
--------------------------------------------------------------------------------