├── figures ├── NPN.jpg ├── nn-vs-npn.png ├── nn-vs-npn-op.png └── acc-var-MNIST.jpg ├── boston_housing_nor_val.pkl ├── boston_housing_nor_train.pkl ├── utils.py ├── .gitignore ├── datasets_boston_housing.py ├── regress_mlp.sh ├── regress_npn.sh ├── mlp.sh ├── npn.sh ├── cnn_mlp.sh ├── cnn_npn.sh ├── mlp-att.sh ├── npnlite.sh ├── README.md ├── npn.py └── main_mlp.py /figures/NPN.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/js05212/PyTorch-for-NPN/HEAD/figures/NPN.jpg -------------------------------------------------------------------------------- /figures/nn-vs-npn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/js05212/PyTorch-for-NPN/HEAD/figures/nn-vs-npn.png -------------------------------------------------------------------------------- /figures/nn-vs-npn-op.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/js05212/PyTorch-for-NPN/HEAD/figures/nn-vs-npn-op.png -------------------------------------------------------------------------------- /boston_housing_nor_val.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/js05212/PyTorch-for-NPN/HEAD/boston_housing_nor_val.pkl -------------------------------------------------------------------------------- /figures/acc-var-MNIST.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/js05212/PyTorch-for-NPN/HEAD/figures/acc-var-MNIST.jpg -------------------------------------------------------------------------------- /boston_housing_nor_train.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/js05212/PyTorch-for-NPN/HEAD/boston_housing_nor_train.pkl -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | def plain_log(filename, text): 2 | fp = open(filename,'a') 3 | fp.write(text) 4 | fp.close() 5 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | tmp* 2 | *~ 3 | __pycache__ 4 | *.model 5 | *.pyc 6 | *.py_* 7 | *.sh_* 8 | README 9 | hao-att.sh 10 | cp.sh 11 | weight_info* 12 | main_npn.py 13 | main.py 14 | npn_example.py 15 | -------------------------------------------------------------------------------- /datasets_boston_housing.py: -------------------------------------------------------------------------------- 1 | from torch.utils import data 2 | import pickle 3 | import numpy as np 4 | 5 | class Dataset_boston_housing(data.Dataset): 6 | # boston housing dataset 7 | def __init__(self, path): 8 | # read data pickle as a dict, where keys are X, X_labels 9 | with open(path, 'rb') as f: 10 | self.data = pickle.load(f) 11 | 12 | def __len__(self): 13 | return self.data['X'].shape[0] 14 | 15 | def __getitem__(self, index): 16 | index = index % self.data['X'].shape[0] 17 | return self.data['X'][index].astype('float32'), np.asarray([self.data['X_labels'][index]]).astype('float32') 18 | 19 | -------------------------------------------------------------------------------- /regress_mlp.sh: -------------------------------------------------------------------------------- 1 | gpuid='0' 2 | dropout='0.0' 3 | lr='1e0' 4 | output_s='0e1' 5 | type='regress_mlp' 6 | loss='mse' 7 | evaluate='' 8 | checkpoint='none' 9 | save_head='tmp' 10 | epo='1000' 11 | save_interval='100' 12 | batch_size='16' 13 | log_file='tmp_mlp' 14 | seed='1112' 15 | CUDA_VISIBLE_DEVICES=$gpuid python3.5 main_mlp.py \ 16 | --lr $lr \ 17 | --epochs $epo \ 18 | --batch-size $batch_size \ 19 | --log_file $log_file \ 20 | --dropout $dropout \ 21 | --save_interval $save_interval \ 22 | --loss $loss \ 23 | --save_head $save_head \ 24 | $evaluate \ 25 | --checkpoint $checkpoint \ 26 | --type $type \ 27 | --output_s $output_s \ 28 | --seed $seed 29 | -------------------------------------------------------------------------------- /regress_npn.sh: -------------------------------------------------------------------------------- 1 | gpuid='0' 2 | dropout='0.0' 3 | lr='1e0' 4 | output_s='0e1' 5 | type='regress_npn' 6 | loss='gaussian' 7 | evaluate='' 8 | checkpoint='none' 9 | save_head='tmp' 10 | epo='1000' 11 | save_interval='100' 12 | batch_size='16' 13 | log_file='tmp_mlp' 14 | seed='1112' 15 | CUDA_VISIBLE_DEVICES=$gpuid python3.5 main_mlp.py \ 16 | --lr $lr \ 17 | --epochs $epo \ 18 | --batch-size $batch_size \ 19 | --log_file $log_file \ 20 | --dropout $dropout \ 21 | --save_interval $save_interval \ 22 | --loss $loss \ 23 | --save_head $save_head \ 24 | $evaluate \ 25 | --checkpoint $checkpoint \ 26 | --type $type \ 27 | --output_s $output_s \ 28 | --seed $seed 29 | -------------------------------------------------------------------------------- /mlp.sh: -------------------------------------------------------------------------------- 1 | gpuid='0' 2 | dropout='0.2' 3 | lr='5e0' # 1e-4 4 | lambda='1e-1' 5 | type='mlp' 6 | num_train='100' 7 | loss='default' # default 8 | evaluate='' 9 | checkpoint='none' 10 | save_head='tmp' 11 | epo='500' 12 | save_interval='100' 13 | batch_size='128' 14 | log_file='tmp_mlp' 15 | seed='2' 16 | CUDA_VISIBLE_DEVICES=$gpuid python main_mlp.py \ 17 | --lr $lr \ 18 | --epochs $epo \ 19 | --num_train $num_train \ 20 | --batch-size $batch_size \ 21 | --log_file $log_file \ 22 | --dropout $dropout \ 23 | --save_interval $save_interval \ 24 | --loss $loss \ 25 | --save_head $save_head \ 26 | $evaluate \ 27 | --checkpoint $checkpoint \ 28 | --type $type \ 29 | --output_s $lambda \ 30 | --seed $seed 31 | -------------------------------------------------------------------------------- /npn.sh: -------------------------------------------------------------------------------- 1 | gpuid='0' 2 | dropout='0.2' 3 | lr='5e0' # 1e-4 4 | lambda='1e-1' 5 | type='npn' 6 | num_train='100' 7 | loss='default' # default 8 | evaluate='' 9 | checkpoint='none' 10 | save_head='tmp' 11 | epo='500' 12 | save_interval='100' 13 | batch_size='128' 14 | log_file='tmp_mlp' 15 | seed='2' 16 | CUDA_VISIBLE_DEVICES=$gpuid python main_mlp.py \ 17 | --lr $lr \ 18 | --epochs $epo \ 19 | --num_train $num_train \ 20 | --batch-size $batch_size \ 21 | --log_file $log_file \ 22 | --dropout $dropout \ 23 | --save_interval $save_interval \ 24 | --loss $loss \ 25 | --save_head $save_head \ 26 | $evaluate \ 27 | --checkpoint $checkpoint \ 28 | --type $type \ 29 | --output_s $lambda \ 30 | --seed $seed 31 | -------------------------------------------------------------------------------- /cnn_mlp.sh: -------------------------------------------------------------------------------- 1 | gpuid='0' 2 | dropout='0.2' 3 | lr='5e0' # 1e-4 4 | lambda='1e-1' 5 | type='cnn' 6 | num_train='100' 7 | loss='default' # default 8 | evaluate='' 9 | checkpoint='none' 10 | save_head='tmp' 11 | epo='500' 12 | save_interval='100' 13 | batch_size='128' 14 | log_file='tmp_mlp' 15 | seed='2' 16 | CUDA_VISIBLE_DEVICES=$gpuid python main_mlp.py \ 17 | --lr $lr \ 18 | --epochs $epo \ 19 | --num_train $num_train \ 20 | --batch-size $batch_size \ 21 | --log_file $log_file \ 22 | --dropout $dropout \ 23 | --save_interval $save_interval \ 24 | --loss $loss \ 25 | --save_head $save_head \ 26 | $evaluate \ 27 | --checkpoint $checkpoint \ 28 | --type $type \ 29 | --output_s $lambda \ 30 | --seed $seed 31 | -------------------------------------------------------------------------------- /cnn_npn.sh: -------------------------------------------------------------------------------- 1 | gpuid='0' 2 | dropout='0.2' 3 | lr='5e0' # 1e-4 4 | lambda='1e-1' 5 | type='npncnn' 6 | num_train='100' 7 | loss='default' # default 8 | evaluate='' 9 | checkpoint='none' 10 | save_head='tmp' 11 | epo='500' 12 | save_interval='100' 13 | batch_size='128' 14 | log_file='tmp_mlp' 15 | seed='2' 16 | CUDA_VISIBLE_DEVICES=$gpuid python main_mlp.py \ 17 | --lr $lr \ 18 | --epochs $epo \ 19 | --num_train $num_train \ 20 | --batch-size $batch_size \ 21 | --log_file $log_file \ 22 | --dropout $dropout \ 23 | --save_interval $save_interval \ 24 | --loss $loss \ 25 | --save_head $save_head \ 26 | $evaluate \ 27 | --checkpoint $checkpoint \ 28 | --type $type \ 29 | --output_s $lambda \ 30 | --seed $seed 31 | -------------------------------------------------------------------------------- /mlp-att.sh: -------------------------------------------------------------------------------- 1 | gpuid='0' 2 | dropout='0.2' 3 | lr='5e0' # 1e-4 4 | lambda='1e-1' 5 | type='npn' 6 | num_train='100' 7 | loss='default' # default 8 | evaluate='' 9 | checkpoint='none' 10 | save_head='tmp' 11 | epo='500' 12 | save_interval='100' 13 | batch_size='128' 14 | log_file='tmp_mlp' 15 | seed='1112' 16 | CUDA_VISIBLE_DEVICES=$gpuid python3.5 main_mlp.py \ 17 | --lr $lr \ 18 | --epochs $epo \ 19 | --num_train $num_train \ 20 | --batch-size $batch_size \ 21 | --log_file $log_file \ 22 | --dropout $dropout \ 23 | --save_interval $save_interval \ 24 | --loss $loss \ 25 | --save_head $save_head \ 26 | $evaluate \ 27 | --checkpoint $checkpoint \ 28 | --type $type \ 29 | --output_s $lambda \ 30 | --seed $seed 31 | -------------------------------------------------------------------------------- /npnlite.sh: -------------------------------------------------------------------------------- 1 | gpuid='0' 2 | dropout='0.2' 3 | lr='5e0' # 1e-4 4 | lambda='1e-1' 5 | type='npn_lite' 6 | num_train='100' 7 | loss='default' # default 8 | evaluate='' 9 | checkpoint='none' 10 | save_head='tmp' 11 | epo='500' 12 | save_interval='100' 13 | batch_size='128' 14 | log_file='tmp_mlp' 15 | seed='2' 16 | CUDA_VISIBLE_DEVICES=$gpuid python main_mlp.py \ 17 | --lr $lr \ 18 | --epochs $epo \ 19 | --num_train $num_train \ 20 | --batch-size $batch_size \ 21 | --log_file $log_file \ 22 | --dropout $dropout \ 23 | --save_interval $save_interval \ 24 | --loss $loss \ 25 | --save_head $save_head \ 26 | $evaluate \ 27 | --checkpoint $checkpoint \ 28 | --type $type \ 29 | --output_s $lambda \ 30 | --seed $seed 31 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Natural-Parameter Networks (NPN) in PyTorch 2 | ============ 3 | 4 | This is the PyTorch code for the NIPS paper ['Natural-Parameter Networks: A Class of Probabilistic Neural Networks'](http://wanghao.in/paper/NIPS16_NPN.pdf). 5 | 6 | It is a class of probabilistic neural networks that treat both weights and neurons as distributions rather than just points in high-dimensional space. Distributions are first-citizens in the networks. The design allows distributions to feedforward and backprop across the network. Given an input data point, NPN will output a predicted distribution with information on both the prediction and uncertainty. 7 | 8 | NPN can be used either independently or as a building block for [Bayesian Deep Learning](http://wanghao.in/paper/TKDE16_BDL.pdf) (BDL). 9 | 10 | Note that this is the code for Gaussian NPN to run on the MNIST and Boston 11 | Housing datasets. For Gamma NPN or Poisson NPN please go to the other repo. 12 | 13 | ## Neural networks v.s. natural-parameter-networks in two figures: 14 | 15 | ### Distributions as first-class citizens: 16 | 17 |

18 | 19 |

20 | 21 | ### Closed-form operations to handle uncertainty: 22 | 23 |

24 | 25 |

26 | 27 | ## Example results on uncertainty-aware prediction: 28 | ### Output both prediction and uncertainty for regression: 29 |

30 | 31 |

32 | Above is the predictive distribution for NPN. The shaded regions correspond 33 | to 3 standard deviations. The black curve is the data-generating function and blue curves 34 | show the mean of the predictive distributions. Red stars are the training data. 35 | 36 | ### Accuracy versus uncertainty (variance): 37 |

38 | 39 |

40 | Above is the classification accuracy for different variance (uncertainty). Note that ‘1’ in the x-axis means the variance is in the range [0, 0.04), ‘2’ means the variance is in the range [0.04, 0.08), etc. 41 | 42 | ## Accuracy: 43 | 44 | Using only 100 training samples in the training set of MNIST: 45 | 46 | 47 | | Method | Accuracy | 48 | | -------|----------| 49 | | NPN (ours) | 74.58% | 50 | | MLP | 69.02% | 51 | | CNN+NPN (ours) | 86.87% | 52 | | CNN+MLP | 82.90% | 53 | 54 | ## RMSE: 55 | 56 | Regression task on Boston Housing: 57 | 58 | | Method | RMSE | 59 | | -------|----------| 60 | | NPN (ours) | 3.2197 | 61 | | MLP | 3.5748 | 62 | 63 | ## How to run the code: 64 | 65 | * In general, to train the model, run the command: 'sh mlp-att.sh' 66 | * To train NPN (fully connected), run the command: 'sh npn.sh' 67 | * To train MLP (fully connected), run the command: 'sh mlp.sh' 68 | * To train CNN+NPN, run the command: 'sh cnn_npn.sh' 69 | * To train CNN+MLP, run the command: 'sh cnn_mlp.sh' 70 | * For regression tasks (Boston Housing) using NPN, run the command: 'sh regress_npn.sh' 71 | * For regression tasks (Boston Housing) using MLP, run the command: 'sh regress_mlp.sh' 72 | 73 | ## Short code example: 74 | This is *everything* to implement a three-layer NPN on PyTorch (essentially only need to replace nn.Linear with NPNLinear): 75 | ```python 76 | from npn import NPNLinear 77 | from npn import NPNSigmoid 78 | class NPNNet(nn.Module): 79 | def __init__(self): 80 | super(NPNNet, self).__init__() 81 | 82 | # Last parameter of NPNLinear 83 | # True: input contains both the mean and variance 84 | # False: input contains only the mean 85 | self.fc1 = NPNLinear(784, 800, False) 86 | self.sigmoid1 = NPNSigmoid() 87 | self.fc2 = NPNLinear(800, 800) 88 | self.sigmoid2 = NPNSigmoid() 89 | self.fc3 = NPNLinear(800, 10) 90 | self.sigmoid3 = NPNSigmoid() 91 | 92 | def forward(self, x): 93 | x = self.sigmoid1(self.fc1(x)) 94 | x = self.sigmoid2(self.fc2(x)) 95 | # output mean (x) and variance (s) of Gaussian NPN 96 | x, s = self.sigmoid3(self.fc3(x)) 97 | return x, s 98 | ``` 99 | 100 | ## Install: 101 | 102 | The code is tested under PyTorch 0.2.03 and Python 3.5.2. 103 | 104 | ## Official Matlab implementation: 105 | 106 | The official Matlab version (with GPU support) can be found [here](https://github.com/js05212/NPN) 107 | 108 | ## Other implementations (third-party): 109 | 110 | Another version of Pytorch/Python code (with extension to GRU) by [sohamghosh121](https://github.com/sohamghosh121/natural-parameter-networks). 111 | 112 | ## Reference: 113 | [Natural-Parameter Networks: A Class of Probabilistic Neural Networks](http://wanghao.in/paper/NIPS16_NPN.pdf) 114 | ``` 115 | @inproceedings{DBLP:conf/nips/WangSY16, 116 | author = {Hao Wang and 117 | Xingjian Shi and 118 | Dit{-}Yan Yeung}, 119 | title = {Natural-Parameter Networks: {A} Class of Probabilistic Neural Networks}, 120 | booktitle = {Advances in Neural Information Processing Systems 29: Annual Conference 121 | on Neural Information Processing Systems 2016, December 5-10, 2016, 122 | Barcelona, Spain}, 123 | pages = {118--126}, 124 | year = {2016} 125 | } 126 | ``` 127 | 128 | -------------------------------------------------------------------------------- /npn.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.autograd import Variable 3 | import torch.optim as optim 4 | import torch 5 | import torch.nn.functional as F 6 | import math 7 | import random 8 | 9 | class NPNLinear(nn.Module): 10 | def positive_s(self, x, use_sigmoid = 0): 11 | if use_sigmoid == 0: 12 | y = torch.log(torch.exp(x) + 1) 13 | else: 14 | y = F.sigmoid(x) 15 | return y 16 | 17 | def positive_s_inv(self, x, use_sigmoid = 0): 18 | if use_sigmoid == 0: 19 | y = torch.log(torch.exp(x) - 1) 20 | else: 21 | y = - torch.log(1 / x - 1) 22 | return y 23 | 24 | def __init__(self, in_channels, out_channels, dual_input = True, init_type = 0): 25 | # init_type 0: normal, 1: mixture of delta distr' 26 | super(NPNLinear, self).__init__() 27 | self.in_channels = in_channels 28 | self.out_channels = out_channels 29 | self.dual_input = dual_input 30 | 31 | self.W_m = nn.Parameter(2 * math.sqrt(6) / math.sqrt(in_channels + out_channels) * (torch.rand(in_channels, out_channels) - 0.5)) 32 | if init_type == 0: 33 | #W_s_init = 1 * math.sqrt(6) / math.sqrt(in_channels + out_channels) * torch.rand(in_channels, out_channels) 34 | W_s_init = 0.01 * math.sqrt(6) / math.sqrt(in_channels + out_channels) * torch.rand(in_channels, out_channels) 35 | else: 36 | bern = torch.bernoulli(torch.ones(in_channels, out_channels) * 0.5) 37 | W_s_init = bern * math.exp(-2) + (1 - bern) * math.exp(-14) 38 | print(W_s_init[:4,:4]) 39 | #self.W_s_ = nn.Parameter(torch.log(torch.exp(W_s_init) - 1)) 40 | self.W_s_ = nn.Parameter(self.positive_s_inv(W_s_init, 0)) 41 | #self.W_m = nn.Parameter(torch.zeros(in_channels, out_channels)) 42 | #self.W_s_ = nn.Parameter(torch.zeros(in_channels, out_channels)) 43 | 44 | self.bias_m = nn.Parameter(torch.zeros(out_channels)) 45 | #self.bias_s_ = nn.Parameter(torch.zeros(out_channels)) 46 | if init_type == 0: 47 | self.bias_s_ = nn.Parameter(torch.ones(out_channels) * (-10)) 48 | #self.bias_s_ = nn.Parameter(torch.ones(out_channels) * (-1)) 49 | else: 50 | bern = torch.bernoulli(torch.ones(out_channels) * 0.5) 51 | bias_s_init = bern * math.exp(-2) + (1 - bern) * math.exp(-14) 52 | self.bias_s_ = nn.Parameter(self.positive_s_inv(bias_s_init, 0)) 53 | 54 | def forward(self, x): 55 | if self.dual_input: 56 | x_m, x_s = x 57 | else: 58 | x_m = x 59 | #x_s = Variable(torch.zeros(x_m.size())) 60 | x_s = x.clone() 61 | x_s = 0 * x_s 62 | 63 | o_m = torch.mm(x_m, self.W_m) 64 | o_m = o_m + self.bias_m.expand_as(o_m) 65 | 66 | #W_s = torch.log(torch.exp(self.W_s_) + 1) 67 | #bias_s = torch.log(torch.exp(self.bias_s_) + 1) 68 | W_s = self.positive_s(self.W_s_, 0) 69 | bias_s = self.positive_s(self.bias_s_, 0) 70 | 71 | o_s = torch.mm(x_s, W_s) + torch.mm(x_s, self.W_m * self.W_m) + torch.mm(x_m * x_m, W_s) 72 | o_s = o_s + bias_s.expand_as(o_s) 73 | 74 | #print('bingo om os') 75 | #print(o_m.data) 76 | #print(o_s.data) 77 | 78 | return o_m, o_s 79 | 80 | 81 | class NPNLinearLite(nn.Module): 82 | def __init__(self, in_channels, out_channels, dual_input = True, init_type = 0): 83 | super(NPNLinearLite, self).__init__() 84 | self.in_channels = in_channels 85 | self.out_channels = out_channels 86 | self.dual_input = dual_input 87 | 88 | self.W_m = nn.Parameter(2 * math.sqrt(6) / math.sqrt(in_channels + out_channels) * (torch.rand(in_channels, out_channels) - 0.5)) 89 | self.bias_m = nn.Parameter(torch.zeros(out_channels)) 90 | 91 | def forward(self, x): 92 | if self.dual_input: 93 | x_m, x_s = x 94 | else: 95 | x_m = x 96 | x_s = x.clone() 97 | x_s = 0 * x_s 98 | 99 | o_m = torch.mm(x_m, self.W_m) 100 | o_m = o_m + self.bias_m.expand_as(o_m) 101 | 102 | o_s = torch.mm(x_s, self.W_m * self.W_m) 103 | 104 | return o_m, o_s 105 | 106 | class Linear2Branch(nn.Module): 107 | def __init__(self, in_channels, out_channels, dual_input = True, init_type = 0): 108 | super(Linear2Branch, self).__init__() 109 | self.in_channels = in_channels 110 | self.out_channels = out_channels 111 | self.dual_input = dual_input 112 | 113 | self.W_m = nn.Parameter(2 * math.sqrt(6) / math.sqrt(in_channels + out_channels) * (torch.rand(in_channels, out_channels) - 0.5)) 114 | self.W_s = nn.Parameter(2 * math.sqrt(6) / math.sqrt(in_channels + out_channels) * (torch.rand(in_channels, out_channels) - 0.5)) 115 | 116 | self.bias_m = nn.Parameter(torch.zeros(out_channels)) 117 | self.bias_s = nn.Parameter(torch.zeros(out_channels)) 118 | 119 | def forward(self, x): 120 | if self.dual_input: 121 | x_m, x_s = x 122 | else: 123 | x_m = x 124 | x_s = x 125 | 126 | o_m = torch.mm(x_m, self.W_m) 127 | o_m = o_m + self.bias_m.expand_as(o_m) 128 | 129 | o_s = torch.mm(x_s, self.W_s) 130 | o_s = o_s + self.bias_s.expand_as(o_s) 131 | o_s = o_s.exp() 132 | 133 | return o_m, o_s 134 | 135 | class NPNRelu(nn.Module): 136 | def __init__(self): 137 | super(NPNRelu, self).__init__() 138 | self.scale = math.sqrt(8/math.pi) # sqrt(8/pi) 139 | 140 | def forward(self, x): 141 | assert(len(x) == 2) 142 | o_m, o_s = x 143 | a_m = F.sigmoid(self.scale * o_m * (o_s ** (-0.5))) * o_m + torch.sqrt(o_s) / math.sqrt(2 * math.pi) * torch.exp(-o_m ** 2 / (2 * o_s)) 144 | a_s = F.sigmoid(self.scale * o_m * (o_s ** (-0.5))) * (o_m ** 2 + o_s) + o_m * torch.sqrt(o_s) / math.sqrt(2 * math.pi) * torch.exp(-o_m ** 2 / (2 * o_s)) - a_m ** 2 # mbr 145 | return a_m, a_s 146 | 147 | class NPNSigmoid(nn.Module): 148 | def __init__(self): 149 | super(NPNSigmoid, self).__init__() 150 | self.xi_sq = math.pi / 8 151 | self.alpha = 4 - 2 * math.sqrt(2) 152 | self.beta = - math.log(math.sqrt(2) + 1) 153 | 154 | def forward(self, x): 155 | assert(len(x) == 2) 156 | o_m, o_s = x 157 | a_m = F.sigmoid(o_m / (1 + self.xi_sq * o_s) ** 0.5) 158 | a_s = F.sigmoid(self.alpha * (o_m + self.beta) / (1 + self.xi_sq * self.alpha ** 2 * o_s) ** 0.5) - a_m ** 2 159 | return a_m, a_s 160 | 161 | #class NPNDropout(nn.Module): 162 | # def __init__(self, rate): 163 | # super(NPNDropout, self).__init__() 164 | # self.dropout = nn.Dropout(p = rate) 165 | # def forward(self, x): 166 | # assert(len(x) == 2) 167 | # if self.training: 168 | # self.dropout.train() 169 | # else: 170 | # self.dropout.eval() 171 | # x_m, x_s = x 172 | # y_m = self.dropout(x_m) 173 | # y_s = self.dropout(x_s) 174 | # return y_m, y_s 175 | 176 | class NPNDropout(nn.Module): 177 | def __init__(self, rate): 178 | super(NPNDropout, self).__init__() 179 | self.dropout = nn.Dropout2d(p = rate) 180 | def forward(self, x): 181 | assert(len(x) == 2) 182 | if self.training: 183 | self.dropout.train() 184 | else: 185 | self.dropout.eval() 186 | x_m, x_s = x 187 | x_m = x_m.unsqueeze(2) 188 | x_s = x_s.unsqueeze(2) 189 | x_com = torch.cat((x_m, x_s), dim = 2) 190 | x_com = x_com.unsqueeze(3) 191 | x_com = self.dropout(x_com) 192 | y_m = x_com[:,:,0,0] 193 | y_s = x_com[:,:,1,0] 194 | return y_m, y_s 195 | 196 | def NPNBCELoss(pred_m, pred_s, label): 197 | loss = -torch.sum((torch.log(pred_m + 1e-10) * label + torch.log(1 - pred_m + 1e-10) * (1 - label))/ (pred_s + 1e-10) - torch.log(pred_s+ 1e-10)) 198 | #loss = -torch.sum((torch.log(pred_m) * label + torch.log(1 - pred_m) * (1 - label))/ (pred_s + 1e-10) - torch.log(pred_s+ 1e-10)) / pred_m.size()[0] 199 | #loss = -torch.sum((torch.log(torch.max(pred_m, 1e-10)) * label + torch.log(torch.max(1 - pred_m, 1e-10)) * (1 - label))/ (torch.max(pred_s, 1e-10)) - torch.log(torch.max(pred_s, 1e-10))) 200 | return loss 201 | 202 | def KL_BG(pred_m, pred_s, label): 203 | #loss = 0.5 * torch.sum((1 - label) * (pred_m ** 2 / pred_s + torch.log(math.pi * 2 * pred_s)) + label * ((pred_m - 1) ** 2 / pred_s + torch.log(math.pi * 2 * pred_s))) / pred_m.size()[0] 204 | loss = 0.5 * torch.sum((1 - label) * (pred_m ** 2 / pred_s + torch.log(torch.clamp(math.pi * 2 * pred_s, min=1e-6))) + label * ((pred_m - 1) ** 2 / pred_s + torch.log(torch.clamp(math.pi * 2 * pred_s, min=1e-6)))) / pred_m.size()[0] # min = 1e-6 205 | #loss = 0.5 * torch.sum((1 - label) * (pred_m ** 2 / (pred_s) + torch.log(math.pi * 2 * pred_s)) + label * ((pred_s - 1) ** 2 / pred_s + torch.log(math.pi * 2 * pred_s))) / pred_m.size()[0] 206 | return loss 207 | 208 | def L2_loss(pred, label): 209 | loss = torch.sum((pred - label) ** 2) 210 | return loss 211 | 212 | def KL_loss(pred, label): 213 | assert(len(pred) == 2) 214 | pred_m, pred_s = pred 215 | #loss = 0.5 * torch.sum(((pred_m - label) ** 2) / (pred_s + 1e-10) + torch.log(pred_s)) # may need epsilon 216 | loss = 0.5 * torch.sum(((pred_m - label) ** 2) / (pred_s) + torch.log(pred_s)) # may need epsilon 217 | return loss 218 | 219 | def multi_logistic_loss(pred, label): 220 | assert(len(label.size()) == 1) 221 | print('bingo type\n', label.data.type()) 222 | print('bingo label\n', pred[:, label]) 223 | log_prob = torch.sum(torch.log(1 - pred)) + torch.sum(log(pred[:, label.data]) - log(1 - pred[:, label.data])) 224 | return -log_prob 225 | 226 | def RMSE(pred, label): 227 | loss = torch.mean(torch.sum((pred - label) ** 2, 1), 0) ** 0.5 228 | return loss 229 | -------------------------------------------------------------------------------- /main_mlp.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from torch.utils.data.sampler import SubsetRandomSampler 3 | import numpy as np 4 | import random 5 | import argparse 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.optim as optim 10 | from torchvision import datasets, transforms 11 | from torch.autograd import Variable 12 | from utils import plain_log 13 | from npn import NPNLinear 14 | from npn import Linear2Branch 15 | from npn import NPNLinearLite 16 | from npn import NPNSigmoid 17 | from npn import NPNRelu 18 | from npn import NPNDropout 19 | from npn import multi_logistic_loss 20 | from npn import NPNBCELoss 21 | from npn import KL_BG 22 | from npn import KL_loss 23 | from npn import RMSE 24 | from datasets_boston_housing import Dataset_boston_housing 25 | from torch.utils import data 26 | 27 | # Training settings 28 | parser = argparse.ArgumentParser(description='PyTorch MNIST Example') 29 | parser.add_argument('--batch-size', type=int, default=64, metavar='N', 30 | help='input batch size for training (default: 64)') 31 | parser.add_argument('--num_workers', type=int, default=2, 32 | help='number of workers') 33 | parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', 34 | help='input batch size for testing (default: 1000)') 35 | parser.add_argument('--epochs', type=int, default=10, metavar='N', 36 | help='number of epochs to train (default: 10)') 37 | parser.add_argument('--lr', type=float, default=0.01, metavar='LR', 38 | help='learning rate (default: 0.01)') 39 | parser.add_argument('--momentum', type=float, default=0.5, metavar='M', 40 | help='SGD momentum (default: 0.5)') 41 | parser.add_argument('--output_s', type=float, default=1.0, 42 | help='lambda of output_s') 43 | parser.add_argument('--dropout', type=float, default=0.0, 44 | help='dropout rate') 45 | parser.add_argument('--no-cuda', action='store_true', default=False, 46 | help='disables CUDA training') 47 | parser.add_argument('--evaluate', action='store_true', default=False, 48 | help='evaluate only') 49 | parser.add_argument('--seed', type=int, default=1, metavar='S', 50 | help='random seed (default: 1)') 51 | parser.add_argument('--log-interval', type=int, default=1000, metavar='N', 52 | help='how many batches to wait before logging training status') 53 | parser.add_argument('--checkpoint', type=str, default='none', 54 | help='file name of checkpoint model') 55 | parser.add_argument('--save_interval', type=int, default=100, metavar='N', 56 | help='how many epochs to wait before saving model') 57 | parser.add_argument('--log_file', type=str, default='tmp', 58 | help='log file name') 59 | parser.add_argument('--save_head', type=str, default='tmp', 60 | help='file name head for saving') 61 | parser.add_argument('--type', type=str, default='mlp', 62 | help='mlp/npn') 63 | parser.add_argument('--loss', type=str, default='default', 64 | help='default/npnbce/kl') 65 | parser.add_argument('--num_train', type=int, default=60000, 66 | help='num train') 67 | args = parser.parse_args() 68 | args.cuda = not args.no_cuda and torch.cuda.is_available() 69 | 70 | torch.manual_seed(args.seed) 71 | if args.cuda: 72 | torch.cuda.manual_seed(args.seed) 73 | random.seed(args.seed) 74 | torch.manual_seed(int(args.seed)) 75 | 76 | if args.type.startswith('regress_'): 77 | bh_train_dataset = Dataset_boston_housing('./boston_housing_nor_train.pkl') 78 | bh_val_dataset = Dataset_boston_housing('./boston_housing_nor_val.pkl') 79 | 80 | train_loader = data.DataLoader( 81 | dataset = bh_train_dataset, 82 | batch_size = args.batch_size, 83 | shuffle = False, 84 | num_workers = args.num_workers, 85 | pin_memory = False 86 | ) 87 | 88 | test_loader = data.DataLoader( 89 | dataset = bh_val_dataset, 90 | batch_size = args.test_batch_size, 91 | shuffle = False, 92 | num_workers = args.num_workers, 93 | pin_memory = False 94 | ) 95 | else: 96 | mnist_train = datasets.MNIST('../data', train=True, download=True, 97 | transform=transforms.Compose([ 98 | transforms.ToTensor(), 99 | transforms.Normalize((0.1307,), (0.3081,)) 100 | ])) 101 | size_train = len(mnist_train) 102 | indices = list(range(size_train)) 103 | np.random.shuffle(indices) 104 | num_train = args.num_train 105 | train_ind = indices[:num_train] 106 | train_sampler = SubsetRandomSampler(train_ind) 107 | 108 | kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {} 109 | train_loader = torch.utils.data.DataLoader( 110 | datasets.MNIST('../data', train=True, download=True, 111 | transform=transforms.Compose([ 112 | transforms.ToTensor() 113 | ])), 114 | batch_size=args.batch_size, sampler=train_sampler, **kwargs) 115 | #batch_size=args.batch_size, shuffle=False, **kwargs) 116 | test_loader = torch.utils.data.DataLoader( 117 | datasets.MNIST('../data', train=False, transform=transforms.Compose([ 118 | transforms.ToTensor() 119 | ])), 120 | batch_size=args.test_batch_size, shuffle=False, **kwargs) 121 | 122 | 123 | class Net(nn.Module): 124 | def __init__(self): 125 | super(Net, self).__init__() 126 | self.fc1 = nn.Linear(784, 800) 127 | self.fc2 = nn.Linear(800, 800) 128 | self.fc3 = nn.Linear(800, 10) 129 | self.dropout = args.dropout 130 | 131 | self.drop1 = nn.Dropout(self.dropout) 132 | self.drop2 = nn.Dropout(self.dropout) 133 | 134 | def forward(self, x): 135 | x = x.view(-1, 784) 136 | x = F.sigmoid(self.fc1(x)) 137 | x = self.drop1(x) 138 | x = F.sigmoid(self.fc2(x)) 139 | x = self.drop2(x) 140 | #x = F.relu(self.fc1(x)) 141 | #x = F.relu(self.fc2(x)) 142 | x = self.fc3(x) 143 | return F.log_softmax(x) 144 | ##x = torch.log(F.sigmoid(x)) 145 | #x = torch.log(F.softmax(F.sigmoid(x))) 146 | #return x 147 | 148 | class NPNNet(nn.Module): 149 | def __init__(self): 150 | super(NPNNet, self).__init__() 151 | self.dropout = args.dropout 152 | 153 | self.fc1 = NPNLinear(784, 800, False) 154 | self.sigmoid1 = NPNSigmoid() 155 | #self.sigmoid1 = NPNRelu() 156 | self.fc2 = NPNLinear(800, 800) 157 | self.sigmoid2 = NPNSigmoid() 158 | #self.sigmoid2 = NPNRelu() 159 | self.fc3 = NPNLinear(800, 10) 160 | self.sigmoid3 = NPNSigmoid() 161 | 162 | self.drop1 = NPNDropout(self.dropout) 163 | self.drop2 = NPNDropout(self.dropout) 164 | self.drop3 = NPNDropout(self.dropout) 165 | 166 | def forward(self, x): 167 | x = x.view(-1, 784) 168 | x = self.fc1(x) 169 | x = self.sigmoid1(x) 170 | x = self.drop1(x) 171 | x = self.fc2(x) 172 | x = self.sigmoid2(x) 173 | x = self.drop2(x) 174 | x, s = self.sigmoid3(self.fc3(x)) 175 | return x, s 176 | 177 | class NPNNetLite(nn.Module): 178 | def __init__(self): 179 | super(NPNNetLite, self).__init__() 180 | self.dropout = args.dropout 181 | 182 | self.fc1 = Linear2Branch(784, 800, False) 183 | self.sigmoid1 = NPNSigmoid() 184 | #self.sigmoid1 = NPNRelu() 185 | self.fc2 = NPNLinearLite(800, 800) 186 | self.sigmoid2 = NPNSigmoid() 187 | #self.sigmoid2 = NPNRelu() 188 | self.fc3 = NPNLinearLite(800, 10) 189 | self.sigmoid3 = NPNSigmoid() 190 | 191 | self.drop1 = NPNDropout(self.dropout) 192 | self.drop2 = NPNDropout(self.dropout) 193 | self.drop3 = NPNDropout(self.dropout) 194 | 195 | def forward(self, x): 196 | x = x.view(-1, 784) 197 | x = self.fc1(x) 198 | x = self.sigmoid1(x) 199 | x = self.drop1(x) 200 | x = self.fc2(x) 201 | x = self.sigmoid2(x) 202 | x = self.drop2(x) 203 | x, s = self.sigmoid3(self.fc3(x)) 204 | return x, s 205 | 206 | class CNN(nn.Module): 207 | def __init__(self): 208 | super(CNN, self).__init__() 209 | self.dropout = args.dropout 210 | 211 | self.conv1 = nn.Conv2d(1, 10, kernel_size=5) 212 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5) 213 | self.conv2_drop = nn.Dropout2d() 214 | self.fc1 = nn.Linear(320, 50) 215 | self.fc2 = nn.Linear(50, 10) 216 | 217 | self.drop1 = nn.Dropout(self.dropout) 218 | 219 | def forward(self, x): 220 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 221 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) 222 | x = x.view(-1, 320) 223 | x = F.relu(self.fc1(x)) 224 | #x = F.dropout(x, training=self.training) 225 | x = self.drop1(x) 226 | x = self.fc2(x) 227 | return F.log_softmax(x) 228 | 229 | class NPNCNN(nn.Module): 230 | def __init__(self): 231 | super(NPNCNN, self).__init__() 232 | self.dropout = args.dropout 233 | self.conv1 = nn.Conv2d(1, 10, kernel_size=5) 234 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5) 235 | self.conv2_drop = nn.Dropout2d() 236 | self.fc1 = NPNLinear(320, 50, dual_input=False) 237 | self.relu1 = NPNRelu() 238 | self.drop1 = NPNDropout(self.dropout) 239 | self.fc2 = NPNLinear(50, 10) 240 | self.sigmoid1 = NPNSigmoid() 241 | 242 | def forward(self, x): 243 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 244 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) 245 | x = x.view(-1, 320) 246 | x = self.relu1(self.fc1(x)) 247 | x = self.drop1(x) 248 | x = self.fc2(x) 249 | if args.loss == 'nll': 250 | x, _ = x 251 | return F.log_softmax(x), x 252 | else: 253 | x, s = self.sigmoid1(x) 254 | return x, s 255 | 256 | class ReNPN(nn.Module): 257 | def __init__(self): 258 | super(ReNPN, self).__init__() 259 | self.dropout = args.dropout 260 | 261 | self.fc1 = NPNLinear(13, 50, False) 262 | self.relu1 = NPNRelu() 263 | self.fc2 = NPNLinear(50, 1) 264 | 265 | def forward(self, x): 266 | x = self.fc1(x) 267 | x = self.relu1(x) 268 | x, s = self.fc2(x) 269 | return x, s 270 | 271 | class ReMLP(nn.Module): 272 | def __init__(self): 273 | super(ReMLP, self).__init__() 274 | self.dropout = args.dropout 275 | 276 | self.fc1 = nn.Linear(13, 50) 277 | self.fc2 = nn.Linear(50, 1) 278 | 279 | def forward(self, x): 280 | x = self.fc1(x) 281 | x = F.relu(x) 282 | x = self.fc2(x) 283 | return x 284 | 285 | if args.type == 'mlp': 286 | model = Net() 287 | elif args.type == 'npn': 288 | model = NPNNet() 289 | elif args.type == 'npn_lite': 290 | model = NPNNetLite() 291 | elif args.type == 'cnn': 292 | model = CNN() 293 | elif args.type == 'npncnn': 294 | model = NPNCNN() 295 | elif args.type == 'regress_npn': 296 | model = ReNPN() 297 | elif args.type == 'regress_mlp': 298 | model = ReMLP() 299 | if args.cuda: 300 | model.cuda() 301 | 302 | #optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum) 303 | optimizer = optim.Adadelta(model.parameters(), lr = args.lr, eps = 1e-7) # lr default 0.02 304 | #optimizer = optim.Adam(model.parameters(), lr = args.lr) # lr 305 | ind = list(range(args.batch_size)) 306 | ind_test = list(range(1000)) 307 | bce = nn.BCELoss() 308 | mse = nn.MSELoss() 309 | 310 | def train(epoch): 311 | model.train() 312 | sum_loss = 0 313 | for batch_idx, (data, target) in enumerate(train_loader): 314 | # TODO: expand label here 315 | if not args.type.startswith('regress_'): 316 | target_ex = torch.zeros(target.size()[0], 10) 317 | target_ex[ind[:min(args.batch_size, target.size()[0])], target] = 1 318 | 319 | if args.type == 'npn' or args.type == 'npncnn' or args.type == 'npn_lite': 320 | target = target_ex 321 | if args.cuda: 322 | data, target = data.cuda(), target.cuda() 323 | data, target = Variable(data), Variable(target) 324 | optimizer.zero_grad() 325 | output = model(data) 326 | if args.type == 'mlp' or args.type == 'cnn': 327 | loss = F.nll_loss(output, target) 328 | else: 329 | if args.type != 'regress_mlp': 330 | x, s = output 331 | #loss = F.nll_loss(torch.log(x+1e-10), target) + args.output_s * torch.sum(s) 332 | #loss = multi_logistic_loss(x, target) + args.output_s * torch.sum(s) 333 | if args.loss == 'default': 334 | loss = bce(x, target) + args.output_s * torch.sum(s ** 2) 335 | elif args.loss == 'npnbce': 336 | loss = NPNBCELoss(x, s, target) + args.output_s * torch.sum(s ** 2) 337 | elif args.loss == 'kl': 338 | loss = KL_BG(x, s, target) + args.output_s * torch.sum(s ** 2) 339 | elif args.loss == 'nll': 340 | loss = F.nll_loss(x, target) 341 | elif args.loss == 'gaussian': 342 | loss = KL_loss(output, target) + 0.5 * args.output_s * torch.sum(s ** 2) 343 | elif args.loss == 'mse': 344 | loss = mse(output, target) 345 | # TODO: use BCELoss 346 | #sum_loss += loss.data[0] 347 | sum_loss += loss.item() 348 | loss.backward() 349 | optimizer.step() 350 | if batch_idx % args.log_interval == 0 and batch_idx != 0: 351 | log_txt = 'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.7f}'.format( 352 | epoch, batch_idx * len(data), len(train_loader.dataset), 353 | 100. * batch_idx / len(train_loader), loss.item()) 354 | #100. * batch_idx / len(train_loader), loss.data[0]) 355 | print(log_txt) 356 | plain_log(args.log_file,log_txt+'\n') 357 | avg_loss = sum_loss / len(train_loader.dataset) * args.batch_size 358 | log_txt = 'Train Epoch {}: Average Loss = {:.7f}'.format(epoch, avg_loss) 359 | print(log_txt) 360 | plain_log(args.log_file,log_txt+'\n') 361 | if epoch % args.save_interval == 0 and epoch != 0: 362 | torch.save(model, '%s.model' % args.save_head) 363 | 364 | def test(): 365 | model.eval() 366 | test_loss = 0 367 | rmse_loss = 0 368 | correct = 0 369 | for data, target in test_loader: 370 | if not args.type.startswith('regress_'): 371 | target_ex = torch.zeros(target.size()[0], 10) 372 | target_ex[ind_test[:min(1000, target.size()[0])], target] = 1 373 | 374 | if args.cuda: 375 | if not args.type.startswith('regress_'): 376 | data, target_ex, target = data.cuda(), target_ex.cuda(), target.cuda() 377 | else: 378 | data, target = data.cuda(), target.cuda() 379 | if not args.type.startswith('regress_'): 380 | data, target_ex = Variable(data, volatile=True), Variable(target_ex) 381 | else: 382 | data, target = Variable(data, volatile=True), Variable(target) 383 | output = model(data) 384 | if args.type == 'npn' or args.type == 'npncnn' or args.type == 'npn_lite' or args.type.startswith('regress_'): 385 | if args.type != 'regress_mlp': 386 | output, s = output 387 | #test_loss += F.nll_loss(torch.log(output+1e-10), target, size_average=False).data[0] # sum up batch loss 388 | if args.loss == 'default': 389 | test_loss += (bce(output, target_ex) + args.output_s * torch.sum(s ** 2)).item() 390 | elif args.loss == 'npnbce': 391 | test_loss += (NPNBCELoss(output, s, target_ex) + args.output_s * torch.sum(s ** 2)).item() 392 | elif args.loss == 'kl': 393 | test_loss += (KL_BG(output, s, target_ex) + args.output_s * torch.sum(s ** 2)).item() 394 | elif args.loss == 'gaussian': 395 | test_loss += KL_loss((output, s), target).item() 396 | rmse_loss += RMSE(output, target).item() 397 | elif args.loss == 'mse': 398 | test_loss += mse(output, target).item() 399 | rmse_loss += RMSE(output, target).item() 400 | else: 401 | test_loss += F.nll_loss(output, target, size_average=False).item() # sum up batch loss 402 | if not args.type.startswith('regress_'): 403 | pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability 404 | correct += pred.eq(target.view_as(pred)).cpu().sum() 405 | 406 | test_loss /= len(test_loader.dataset) 407 | if not args.type.startswith('regress_'): 408 | log_txt = 'Test set: Average loss: {:.7f}, Accuracy: {}/{} ({:.0f}%)'.format( 409 | test_loss, correct, len(test_loader.dataset), 410 | 100. * correct / len(test_loader.dataset)) 411 | else: 412 | log_txt = 'Test set: Average loss: {:.7f}, RMSE: {:.6f}'.format(test_loss, rmse_loss) 413 | print(log_txt) 414 | plain_log(args.log_file,log_txt+'\n') 415 | 416 | if args.checkpoint != 'none': 417 | model = torch.load(args.checkpoint) 418 | print(str(model)) 419 | for key, module in model._modules.items(): 420 | print('key', key) 421 | print('module', module) 422 | if module.__class__.__name__ == 'NPNLinear': 423 | print('para\n', torch.log(torch.exp(module.W_s_[:8,:8])+1)) 424 | 425 | if not args.evaluate: 426 | for epoch in range(1, args.epochs + 1): 427 | train(epoch) 428 | if epoch % 1 == 0: 429 | test() 430 | test() 431 | --------------------------------------------------------------------------------