├── .gitignore ├── README.md ├── functions.py ├── generate.py ├── mnist_utils.py ├── network.py ├── plot.py ├── q_network.py ├── q_network_v2.py ├── q_network_v3.py ├── q_script.py ├── q_script_v2.py ├── q_script_v3.py ├── q_script_v4.py └── script.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.py[cod] 3 | *$py.class 4 | .env 5 | .venv 6 | .ipynb_checkpoints 7 | MNIST_train/ 8 | MNIST_test/ 9 | imgs/ 10 | data/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ### Predictive Coding in Python 2 | 3 | ## logo 4 | 5 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 6 | 7 | A `Python` implementation of _An Approximation of the Error Backpropagation Algorithm in a Predictive Coding Network with Local Hebbian Synaptic Plasticity_ 8 | 9 | [[Paper](https://www.mrcbndu.ox.ac.uk/sites/default/files/pdf_files/Whittington%20Bogacz%202017_Neural%20Comput.pdf)] 10 | 11 | Based on the `MATLAB` [implementation](https://github.com/djcrw/Supervised-Predictive-Coding) from [`@djcrw`] 12 | 13 | ## Requirements 14 | - `numpy` 15 | - `torch` 16 | - `torchvision` 17 | 18 | 19 | ## Tasks 20 | - Include model from _A tutorial on the free-energy framework for modelling perception and learning_ 21 | - Add additional optimisers 22 | - Measure number of iterations 23 | - The initial space of mu needs to be sufficently large - ensembles of amortised weights or slow learning rate? 24 | - Test pure PC accuracy 25 | - Errors go down, but amortised asymptotes 26 | - Infinite iterations - amortised learning provides mechanism for setting number, remove free parameter (replaces with threshold) 27 | - Generative overtakes and this is inconsistent with discriminative - need some way to promote discrimination in network 28 | - Both trying to predict each other? Generative predicting discrimantive? 29 | - Add gradient clamping -------------------------------------------------------------------------------- /functions.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=not-callable 2 | # pylint: disable=no-member 3 | 4 | import torch 5 | import numpy as np 6 | 7 | LINEAR = "LINEAR" 8 | TANH = "TANH" 9 | LOGSIG = "LOGSIG" 10 | 11 | 12 | def f(x, act_fn): 13 | """ (activation_size, batch_size) """ 14 | if act_fn is LINEAR: 15 | m = x 16 | elif act_fn is TANH: 17 | m = torch.tanh(x) 18 | elif act_fn is LOGSIG: 19 | return 1. / (torch.ones_like(x) + torch.exp(-x)) 20 | else: 21 | raise ValueError(f"{act_fn} not supported") 22 | return m 23 | 24 | 25 | def f_deriv(x, act_fn): 26 | """ (activation_size, batch_size) """ 27 | if act_fn is LINEAR: 28 | deriv = np.ones(x.shape) 29 | elif act_fn is TANH: 30 | deriv = torch.ones_like(x) - torch.tanh(x) ** 2 31 | elif act_fn is LOGSIG: 32 | """ TODO """ 33 | f = 1. / (torch.ones_like(x) + torch.exp(-x)) 34 | deriv = torch.mul(f, (torch.ones_like(x) - f)) 35 | else: 36 | raise ValueError(f"{act_fn} not supported") 37 | return deriv 38 | 39 | 40 | def f_inv(x, act_fn): 41 | """ (activation_size, batch_size) """ 42 | if act_fn is LINEAR: 43 | m = x 44 | elif act_fn is TANH: 45 | num = np.ones(x.shape) + x 46 | div = (np.ones(x.shape) - x) + 1e-7 47 | m = 0.5 * np.log(np.divide(num, div)) 48 | elif act_fn is LOGSIG: 49 | """ TODO """ 50 | div = (np.ones(x.shape) - x) + 1e-7 51 | m = np.log(np.divide(x, div) + 1e-7) 52 | else: 53 | raise ValueError(f"{act_fn} not supported") 54 | return m 55 | 56 | -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=not-callable 2 | # pylint: disable=no-member 3 | 4 | import numpy as np 5 | import torch 6 | 7 | import mnist_utils 8 | import functions as F 9 | from network import PredictiveCodingNetwork 10 | 11 | 12 | class AttrDict(dict): 13 | __setattr__ = dict.__setitem__ 14 | __getattr__ = dict.__getitem__ 15 | 16 | 17 | def main(cf): 18 | print(f"device [{cf.device}]") 19 | print("loading MNIST data...") 20 | train_set = mnist_utils.get_mnist_train_set() 21 | test_set = mnist_utils.get_mnist_test_set() 22 | 23 | img_train = mnist_utils.get_imgs(train_set) 24 | img_test = mnist_utils.get_imgs(test_set) 25 | label_train = mnist_utils.get_labels(train_set) 26 | label_test = mnist_utils.get_labels(test_set) 27 | 28 | if cf.data_size is not None: 29 | test_size = cf.data_size // 5 30 | img_train = img_train[:, 0 : cf.data_size] 31 | label_train = label_train[:, 0 : cf.data_size] 32 | img_test = img_test[:, 0:test_size] 33 | label_test = label_test[:, 0:test_size] 34 | 35 | msg = "img_train {} img_test {} label_train {} label_test {}" 36 | print(msg.format(img_train.shape, img_test.shape, label_train.shape, label_test.shape)) 37 | 38 | print("performing preprocessing...") 39 | if cf.apply_scaling: 40 | img_train = mnist_utils.scale_imgs(img_train, cf.img_scale) 41 | img_test = mnist_utils.scale_imgs(img_test, cf.img_scale) 42 | label_train = mnist_utils.scale_labels(label_train, cf.label_scale) 43 | label_test = mnist_utils.scale_labels(label_test, cf.label_scale) 44 | 45 | if cf.apply_inv: 46 | img_train = F.f_inv(img_train, cf.act_fn) 47 | img_test = F.f_inv(img_test, cf.act_fn) 48 | 49 | model = PredictiveCodingNetwork(cf) 50 | 51 | with torch.no_grad(): 52 | for epoch in range(cf.n_epochs): 53 | print(f"\nepoch {epoch}") 54 | 55 | img_batches, label_batches = mnist_utils.get_batches(img_train, label_train, cf.batch_size) 56 | print(f"training on {len(img_batches)} batches of size {cf.batch_size}") 57 | model.train_epoch(label_batches, img_batches, epoch_num=epoch) 58 | 59 | img_batches, label_batches = mnist_utils.get_batches(img_test, label_test, cf.batch_size) 60 | print("generating images...") 61 | pred_imgs = model.generate_data(label_batches[0]) 62 | mnist_utils.plot_imgs(pred_imgs, cf.img_path.format(epoch)) 63 | 64 | perm = np.random.permutation(img_train.shape[1]) 65 | img_train = img_train[:, perm] 66 | label_train = label_train[:, perm] 67 | 68 | 69 | if __name__ == "__main__": 70 | cf = AttrDict() 71 | 72 | cf.img_path = "imgs/{}.png" 73 | 74 | cf.n_epochs = 100 75 | cf.data_size = None 76 | cf.batch_size = 128 77 | 78 | cf.apply_inv = True 79 | cf.apply_scaling = True 80 | cf.label_scale = 0.94 81 | cf.img_scale = 1.0 82 | 83 | cf.neurons = [10, 500, 500, 784] 84 | cf.n_layers = len(cf.neurons) 85 | cf.act_fn = F.TANH 86 | cf.var_out = 1 87 | cf.vars = torch.ones(cf.n_layers) 88 | 89 | cf.itr_max = 50 90 | cf.beta = 0.1 91 | cf.div = 2 92 | cf.condition = 1e-6 93 | cf.d_rate = 0 94 | 95 | # optim parameters 96 | cf.l_rate = 1e-3 97 | cf.optim = "ADAM" 98 | cf.eps = 1e-8 99 | cf.decay_r = 0.9 100 | cf.beta_1 = 0.9 101 | cf.beta_2 = 0.999 102 | 103 | cf.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 104 | main(cf) 105 | 106 | -------------------------------------------------------------------------------- /mnist_utils.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=not-callable 2 | # pylint: disable=no-member 3 | 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | import torch 7 | import torchvision 8 | 9 | 10 | def get_mnist_train_set(): 11 | return torchvision.datasets.MNIST("MNIST_train", download=True, train=True) 12 | 13 | 14 | def get_mnist_test_set(): 15 | return torchvision.datasets.MNIST("MNIST_test", download=True, train=False) 16 | 17 | 18 | def onehot(label, n_classes=10): 19 | arr = np.zeros([10]) 20 | arr[int(label)] = 1.0 21 | return arr 22 | 23 | 24 | def img_to_np(img): 25 | return np.array(img).reshape([784]) / 255.0 26 | 27 | 28 | def get_imgs(dataset): 29 | imgs = np.array([img_to_np(dataset[i][0]) for i in range(len(dataset))]) 30 | return np.swapaxes(imgs, 0, 1) 31 | 32 | 33 | def get_labels(dataset): 34 | labels = np.array([onehot(dataset[i][1]) for i in range(len(dataset))]) 35 | return np.swapaxes(labels, 0, 1) 36 | 37 | 38 | def scale_imgs(imgs, scale_factor): 39 | return imgs * scale_factor + 0.5 * (1 - scale_factor) * np.ones(imgs.shape) 40 | 41 | 42 | def scale_labels(labels, scale_factor): 43 | return labels * scale_factor + 0.5 * (1 - scale_factor) * np.ones(labels.shape) 44 | 45 | 46 | def mnist_accuracy(pred_labels, labels): 47 | correct = 0 48 | batch_size = pred_labels.size(1) 49 | for b in range(batch_size): 50 | if torch.argmax(pred_labels[:, b]) == torch.argmax(labels[:, b]): 51 | correct += 1 52 | return correct / batch_size 53 | 54 | 55 | def get_batches(imgs, labels, batch_size): 56 | n_data = imgs.shape[1] 57 | n_batches = int(np.ceil(n_data / batch_size)) 58 | 59 | img_batches = [[] for _ in range(n_batches)] 60 | label_batches = [[] for _ in range(n_batches)] 61 | 62 | for batch in range(n_batches): 63 | if batch == n_batches - 1: 64 | start = batch * batch_size 65 | img_batches[batch] = imgs[:, start:] 66 | label_batches[batch] = labels[:, start:] 67 | else: 68 | start = batch * batch_size 69 | end = (batch + 1) * batch_size 70 | img_batches[batch] = imgs[:, start:end] 71 | label_batches[batch] = labels[:, start:end] 72 | 73 | return img_batches, label_batches 74 | 75 | 76 | def plot_imgs(img_batch, save_path): 77 | img_batch = img_batch.detach().cpu().numpy() 78 | batch_size = img_batch.shape[1] 79 | dim = nearest_square(batch_size) 80 | 81 | imgs = [np.reshape(img_batch[:, i], [28, 28]) for i in range(dim ** 2)] 82 | _, axes = plt.subplots(dim, dim) 83 | axes = axes.flatten() 84 | for i, img in enumerate(imgs): 85 | axes[i].imshow(img) 86 | axes[i].set_axis_off() 87 | plt.savefig(save_path) 88 | plt.close('all') 89 | 90 | 91 | def nearest_square(limit): 92 | answer = 0 93 | while (answer + 1) ** 2 < limit: 94 | answer += 1 95 | return answer 96 | 97 | -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=not-callable 2 | # pylint: disable=no-member 3 | 4 | import numpy as np 5 | import torch 6 | 7 | import mnist_utils 8 | import functions as F 9 | 10 | 11 | def set_tensor(arr, device): 12 | return torch.from_numpy(arr).float().to(device) 13 | 14 | 15 | class PredictiveCodingNetwork(object): 16 | def __init__(self, cf): 17 | self.device = cf.device 18 | self.n_layers = cf.n_layers 19 | self.act_fn = cf.act_fn 20 | self.neurons = cf.neurons 21 | self.vars = cf.vars.float().to(self.device) 22 | self.itr_max = cf.itr_max 23 | self.batch_size = cf.batch_size 24 | 25 | self.beta_1 = cf.beta_1 26 | self.beta_2 = cf.beta_2 27 | self.beta = cf.beta 28 | self.div = cf.div 29 | self.d_rate = cf.d_rate 30 | self.l_rate = cf.l_rate 31 | self.condition = cf.condition / (sum(cf.neurons) - cf.neurons[0]) 32 | 33 | self.optim = cf.optim 34 | self.eps = cf.eps 35 | self.decay_r = cf.decay_r 36 | self.c_b = [[] for _ in range(self.n_layers)] 37 | self.c_w = [[] for _ in range(self.n_layers)] 38 | self.v_b = [[] for _ in range(self.n_layers)] 39 | self.v_w = [[] for _ in range(self.n_layers)] 40 | 41 | self.W = None 42 | self.b = None 43 | self._init_params() 44 | 45 | def train_epoch(self, x_batches, y_batches, epoch_num=None): 46 | n_batches = len(x_batches) 47 | for batch_id, (x_batch, y_batch) in enumerate(zip(x_batches, y_batches)): 48 | 49 | if batch_id % 500 == 0 and batch_id > 0: 50 | print(f"batch {batch_id}") 51 | 52 | x_batch = set_tensor(x_batch, self.device) 53 | y_batch = set_tensor(y_batch, self.device) 54 | batch_size = x_batch.size(1) 55 | 56 | x = [[] for _ in range(self.n_layers)] 57 | x[0] = x_batch 58 | for l in range(1, self.n_layers): 59 | b = self.b[l - 1].repeat(1, batch_size) 60 | x[l] = self.W[l - 1] @ F.f(x[l - 1], self.act_fn) + b 61 | x[self.n_layers - 1] = y_batch 62 | 63 | x, errors, _ = self.infer(x, batch_size) 64 | self.update_params( 65 | x, errors, batch_size, epoch_num=epoch_num, n_batches=n_batches, curr_batch=batch_id 66 | ) 67 | 68 | def test_epoch(self, x_batches, y_batches): 69 | accs = [] 70 | for x_batch, y_batch in zip(x_batches, y_batches): 71 | x_batch = set_tensor(x_batch, self.device) 72 | y_batch = set_tensor(y_batch, self.device) 73 | batch_size = x_batch.size(1) 74 | 75 | x = [[] for _ in range(self.n_layers)] 76 | x[0] = x_batch 77 | for l in range(1, self.n_layers): 78 | b = self.b[l - 1].repeat(1, batch_size) 79 | x[l] = self.W[l - 1] @ F.f(x[l - 1], self.act_fn) + b 80 | pred_y = x[-1] 81 | 82 | acc = mnist_utils.mnist_accuracy(pred_y, y_batch) 83 | accs.append(acc) 84 | return accs 85 | 86 | def generate_data(self, x_batch): 87 | x_batch = set_tensor(x_batch, self.device) 88 | batch_size = x_batch.size(1) 89 | 90 | x = [[] for _ in range(self.n_layers)] 91 | x[0] = x_batch 92 | for l in range(1, self.n_layers): 93 | b = self.b[l - 1].repeat(1, batch_size) 94 | x[l] = self.W[l - 1] @ F.f(x[l - 1], self.act_fn) + b 95 | pred_y = x[-1] 96 | return pred_y 97 | 98 | def infer(self, x, batch_size): 99 | errors = [[] for _ in range(self.n_layers)] 100 | f_x_arr = [[] for _ in range(self.n_layers)] 101 | f_x_deriv_arr = [[] for _ in range(self.n_layers)] 102 | f_0 = 0 103 | its = 0 104 | beta = self.beta 105 | 106 | for l in range(1, self.n_layers): 107 | f_x = F.f(x[l - 1], self.act_fn) 108 | f_x_deriv = F.f_deriv(x[l - 1], self.act_fn) 109 | f_x_arr[l - 1] = f_x 110 | f_x_deriv_arr[l - 1] = f_x_deriv 111 | 112 | # eq. 2.17 113 | b = self.b[l - 1].repeat(1, batch_size) 114 | errors[l] = (x[l] - self.W[l - 1] @ f_x - b) / self.vars[l] 115 | f_0 = f_0 - self.vars[l] * torch.sum(torch.mul(errors[l], errors[l]), dim=0) 116 | 117 | for itr in range(self.itr_max): 118 | # update node activity 119 | for l in range(1, self.n_layers - 1): 120 | # eq. 2.18 121 | g = torch.mul(self.W[l].T @ errors[l + 1], f_x_deriv_arr[l]) 122 | x[l] = x[l] + beta * (-errors[l] + g) 123 | 124 | # update errors 125 | f = 0 126 | for l in range(1, self.n_layers): 127 | f_x = F.f(x[l - 1], self.act_fn) 128 | f_x_deriv = F.f_deriv(x[l - 1], self.act_fn) 129 | f_x_arr[l - 1] = f_x 130 | f_x_deriv_arr[l - 1] = f_x_deriv 131 | 132 | # eq. 2.17 133 | errors[l] = (x[l] - self.W[l - 1] @ f_x - self.b[l - 1]) / self.vars[l] 134 | f = f - self.vars[l] * torch.sum(torch.mul(errors[l], errors[l]), dim=0) 135 | 136 | diff = f - f_0 137 | threshold = self.condition * self.beta / self.vars[self.n_layers - 1] 138 | if torch.any(diff < 0): 139 | beta = beta / self.div 140 | elif torch.mean(diff) < threshold: 141 | break 142 | 143 | f_0 = f 144 | its = itr 145 | 146 | return x, errors, its 147 | 148 | def update_params(self, x, errors, batch_size, epoch_num=None, n_batches=None, curr_batch=None): 149 | grad_w = [[] for _ in range(self.n_layers - 1)] 150 | grad_b = [[] for _ in range(self.n_layers - 1)] 151 | 152 | for l in range(self.n_layers - 1): 153 | # eq. 2.19 (with weight decay) 154 | grad_w[l] = ( 155 | self.vars[-1] * (1 / batch_size) * errors[l + 1] @ F.f(x[l], self.act_fn).T 156 | - self.d_rate * self.W[l] 157 | ) 158 | grad_b[l] = self.vars[-1] * (1 / batch_size) * torch.sum(errors[l + 1], axis=1) 159 | 160 | self._apply_gradients(grad_w, grad_b, epoch_num=epoch_num, n_batches=n_batches, curr_batch=curr_batch) 161 | 162 | def _init_params(self): 163 | weights = [[] for _ in range(self.n_layers)] 164 | bias = [[] for _ in range(self.n_layers)] 165 | 166 | for l in range(self.n_layers - 1): 167 | norm_b = 0 168 | if self.act_fn is F.LINEAR: 169 | norm_w = np.sqrt(1 / (self.neurons[l + 1] + self.neurons[l])) 170 | elif self.act_fn is F.TANH: 171 | norm_w = np.sqrt(6 / (self.neurons[l + 1] + self.neurons[l])) 172 | elif self.act_fn is F.LOGSIG: 173 | norm_w = 4 * np.sqrt(6 / (self.neurons[l + 1] + self.neurons[l])) 174 | else: 175 | raise ValueError(f"{self.act_fn} not supported") 176 | 177 | layer_w = np.random.uniform(-1, 1, size=(self.neurons[l + 1], self.neurons[l])) * norm_w 178 | layer_b = np.zeros((self.neurons[l + 1], 1)) + norm_b * np.ones((self.neurons[l + 1], 1)) 179 | weights[l] = set_tensor(layer_w, self.device) 180 | bias[l] = set_tensor(layer_b, self.device) 181 | 182 | self.W = weights 183 | self.b = bias 184 | 185 | for l in range(self.n_layers - 1): 186 | self.c_b[l] = torch.zeros_like(self.b[l]) 187 | self.c_w[l] = torch.zeros_like(self.W[l]) 188 | self.v_b[l] = torch.zeros_like(self.b[l]) 189 | self.v_w[l] = torch.zeros_like(self.W[l]) 190 | 191 | def _apply_gradients(self, grad_w, grad_b, epoch_num=None, n_batches=None, curr_batch=None): 192 | 193 | if self.optim is "RMSPROP": 194 | for l in range(self.n_layers - 1): 195 | grad_b[l] = grad_b[l].unsqueeze(dim=1) 196 | self.c_w[l] = self.decay_r * self.c_w[l] + (1 - self.decay_r) * grad_w[l] ** 2 197 | self.c_b[l] = self.decay_r * self.c_b[l] + (1 - self.decay_r) * grad_b[l] ** 2 198 | 199 | self.W[l] = self.W[l] + self.l_rate * (grad_w[l] / (torch.sqrt(self.c_w[l]) + self.eps)) 200 | self.b[l] = self.b[l] + self.l_rate * (grad_b[l] / (torch.sqrt(self.c_b[l]) + self.eps)) 201 | 202 | elif self.optim is "ADAM": 203 | for l in range(self.n_layers - 1): 204 | grad_b[l] = grad_b[l].unsqueeze(dim=1) 205 | self.c_b[l] = self.beta_1 * self.c_b[l] + (1 - self.beta_1) * grad_b[l] 206 | self.c_w[l] = self.beta_1 * self.c_w[l] + (1 - self.beta_1) * grad_w[l] 207 | 208 | self.v_b[l] = self.beta_2 * self.v_b[l] + (1 - self.beta_2) * grad_b[l] ** 2 209 | self.v_w[l] = self.beta_2 * self.v_w[l] + (1 - self.beta_2) * grad_w[l] ** 2 210 | 211 | t = (epoch_num) * n_batches + curr_batch 212 | self.W[l] = self.W[l] + self.l_rate * np.sqrt(1 - self.beta_2 ** t) * self.c_w[l] / ( 213 | torch.sqrt(self.v_w[l]) + self.eps 214 | ) 215 | self.b[l] = self.b[l] + self.l_rate * np.sqrt(1 - self.beta_2 ** t) * self.c_b[l] / ( 216 | torch.sqrt(self.v_b[l]) + self.eps 217 | ) 218 | 219 | elif self.optim is "SGD" or self.optim is None: 220 | for l in range(self.n_layers - 1): 221 | self.W[l] = self.W[l] + self.l_rate * grad_w[l] 222 | self.b[l] = self.b[l] + self.l_rate * grad_b[l].unsqueeze(dim=1) 223 | 224 | else: 225 | raise ValueError(f"{self.optim} not supported") 226 | 227 | -------------------------------------------------------------------------------- /plot.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib 3 | import matplotlib.pyplot as plt 4 | 5 | matplotlib.rcParams["axes.linewidth"] = 1.6 6 | 7 | colors = ["#9403fc", "#fcba03", "#db2727"] 8 | 9 | h_accs_path = "data/h_accs_4.npy" 10 | q_accs_path = "data/q_accs_4.npy" 11 | pc_accs_path = "data/pc_accs_4.npy" 12 | fig_path = "imgs/accuracy_hybrid.png" 13 | n_steps = 50 14 | 15 | 16 | if __name__ == "__main__": 17 | h_accs = np.load(h_accs_path)[0:n_steps] 18 | q_accs = np.load(q_accs_path)[0:n_steps] 19 | pc_accs = np.load(pc_accs_path)[0:n_steps] 20 | 21 | plt.plot(h_accs, label="Hybrid", color=colors[0], lw=3) 22 | plt.plot(q_accs, label="Amortised", color=colors[1], lw=3) 23 | plt.plot(pc_accs, label="Predictive Coding", color=colors[2], lw=3) 24 | plt.xlabel("Number of epochs") 25 | plt.ylabel("MNIST accuracy (%)") 26 | plt.title("MNIST learning") 27 | plt.legend() 28 | plt.savefig(fig_path) 29 | plt.close() 30 | 31 | -------------------------------------------------------------------------------- /q_network.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=not-callable 2 | # pylint: disable=no-member 3 | 4 | import numpy as np 5 | import torch 6 | 7 | import mnist_utils 8 | import functions as F 9 | 10 | 11 | def set_tensor(arr, device): 12 | return torch.from_numpy(arr).float().to(device) 13 | 14 | 15 | class QCodingNetwork(object): 16 | def __init__(self, cf): 17 | self.device = cf.device 18 | self.amortised = cf.amortised 19 | self.n_layers = cf.n_layers 20 | self.act_fn = cf.act_fn 21 | self.neurons = cf.neurons 22 | self.vars = cf.vars.float().to(self.device) 23 | self.itr_max = cf.itr_max 24 | self.batch_size = cf.batch_size 25 | 26 | self.beta_1 = cf.beta_1 27 | self.beta_2 = cf.beta_2 28 | self.beta = cf.beta 29 | self.div = cf.div 30 | self.d_rate = cf.d_rate 31 | self.l_rate = cf.l_rate 32 | self.q_l_rate = cf.q_l_rate 33 | self.condition = cf.condition / (sum(cf.neurons) - cf.neurons[0]) 34 | 35 | self.optim = cf.optim 36 | self.eps = cf.eps 37 | self.decay_r = cf.decay_r 38 | self.c_b = [[] for _ in range(self.n_layers)] 39 | self.c_w = [[] for _ in range(self.n_layers)] 40 | self.v_b = [[] for _ in range(self.n_layers)] 41 | self.v_w = [[] for _ in range(self.n_layers)] 42 | self.c_b_q = [[] for _ in range(self.n_layers)] 43 | self.c_w_q = [[] for _ in range(self.n_layers)] 44 | self.v_b_q = [[] for _ in range(self.n_layers)] 45 | self.v_w_q = [[] for _ in range(self.n_layers)] 46 | 47 | self.W = None 48 | self.b = None 49 | self.Wq = None 50 | self.bq = None 51 | self._init_params() 52 | 53 | def train_epoch(self, x_batches, y_batches, epoch_num=None): 54 | """ x_batch are images, y_batch are labels 55 | TODO 0 is highest and lowest layer, fix this 56 | """ 57 | init_err = 0 58 | end_err = 0 59 | avg_itr = 0 60 | n_batches = len(x_batches) 61 | 62 | for batch_id, (x_batch, y_batch) in enumerate(zip(x_batches, y_batches)): 63 | if batch_id % 500 == 0 and batch_id > 0: 64 | print(f"batch {batch_id}") 65 | 66 | x_batch = set_tensor(x_batch, self.device) 67 | y_batch = set_tensor(y_batch, self.device) 68 | batch_size = x_batch.size(1) 69 | 70 | x = [[] for _ in range(self.n_layers)] 71 | q = [[] for _ in range(self.n_layers)] 72 | 73 | if self.amortised is True: 74 | q[0] = x_batch 75 | for l in range(1, self.n_layers): 76 | b_q = self.b_q[l - 1].repeat(1, batch_size) 77 | q[l] = self.W_q[l - 1] @ F.f(q[l - 1], self.act_fn) + b_q 78 | 79 | x = q[::-1] 80 | x[0] = y_batch 81 | x[self.n_layers - 1] = x_batch 82 | 83 | else: 84 | x[0] = y_batch 85 | for l in range(1, self.n_layers): 86 | b = self.b[l - 1].repeat(1, batch_size) 87 | x[l] = self.W[l - 1] @ F.f(x[l - 1], self.act_fn) + b 88 | x[self.n_layers - 1] = x_batch 89 | 90 | init_err += self.get_errors(x, batch_size) 91 | x, errors, its = self.infer(x, batch_size) 92 | self.update_params( 93 | x, 94 | q, 95 | errors, 96 | batch_size, 97 | x_batch, 98 | y_batch, 99 | epoch_num=epoch_num, 100 | n_batches=n_batches, 101 | curr_batch=batch_id, 102 | ) 103 | end_err += self.get_errors(x, batch_size) 104 | avg_itr += its 105 | 106 | return end_err / n_batches, init_err / n_batches, avg_itr / n_batches 107 | 108 | def test_epoch(self, x_batches, y_batches, itr_max=None): 109 | accs = [] 110 | n_batches = len(x_batches) 111 | avg_itr = 0 112 | for x_batch, y_batch in zip(x_batches, y_batches): 113 | x_batch = set_tensor(x_batch, self.device) 114 | y_batch = set_tensor(y_batch, self.device) 115 | batch_size = x_batch.size(1) 116 | 117 | x = [[] for _ in range(self.n_layers)] 118 | q = [[] for _ in range(self.n_layers)] 119 | 120 | if self.amortised is True: 121 | q[0] = x_batch 122 | for l in range(1, self.n_layers): 123 | b_q = self.b_q[l - 1].repeat(1, batch_size) 124 | q[l] = self.W_q[l - 1] @ F.f(q[l - 1], self.act_fn) + b_q 125 | x = q[::-1] 126 | x[self.n_layers - 1] = x_batch 127 | else: 128 | x[0] = torch.empty_like(y_batch).normal_(mean=0.0, std=0.1) 129 | for l in range(1, self.n_layers): 130 | b = self.b[l - 1].repeat(1, batch_size) 131 | x[l] = self.W[l - 1] @ F.f(x[l - 1], self.act_fn) + b 132 | x[self.n_layers - 1] = x_batch 133 | 134 | x, errors, its = self.infer_v2(x, batch_size, x_batch, itr_max=itr_max) 135 | pred_y = x[0] 136 | acc = mnist_utils.mnist_accuracy(pred_y, y_batch) 137 | accs.append(acc) 138 | avg_itr += its 139 | return accs, avg_itr / n_batches 140 | 141 | def test_pc_epoch(self, x_batches, y_batches, itr_max=None): 142 | accs = [] 143 | n_batches = len(x_batches) 144 | avg_itr = 0 145 | for x_batch, y_batch in zip(x_batches, y_batches): 146 | x_batch = set_tensor(x_batch, self.device) 147 | y_batch = set_tensor(y_batch, self.device) 148 | batch_size = x_batch.size(1) 149 | 150 | x = [[] for _ in range(self.n_layers)] 151 | q = [[] for _ in range(self.n_layers)] 152 | 153 | x[0] = torch.empty_like(y_batch).normal_(mean=0.0, std=0.1) 154 | for l in range(1, self.n_layers): 155 | b = self.b[l - 1].repeat(1, batch_size) 156 | x[l] = self.W[l - 1] @ F.f(x[l - 1], self.act_fn) + b 157 | x[self.n_layers - 1] = x_batch 158 | 159 | x, errors, its = self.infer_v2(x, batch_size, x_batch, itr_max=itr_max) 160 | pred_y = x[0] 161 | acc = mnist_utils.mnist_accuracy(pred_y, y_batch) 162 | accs.append(acc) 163 | avg_itr += its 164 | return accs, avg_itr / n_batches 165 | 166 | def test_amortised_epoch(self, x_batches, y_batches): 167 | accs = [] 168 | for x_batch, y_batch in zip(x_batches, y_batches): 169 | x_batch = set_tensor(x_batch, self.device) 170 | y_batch = set_tensor(y_batch, self.device) 171 | batch_size = x_batch.size(1) 172 | 173 | q = [[] for _ in range(self.n_layers)] 174 | q[0] = x_batch 175 | for l in range(1, self.n_layers): 176 | b_q = self.b_q[l - 1].repeat(1, batch_size) 177 | q[l] = self.W_q[l - 1] @ F.f(q[l - 1], self.act_fn) + b_q 178 | pred_y = q[-1] 179 | acc = mnist_utils.mnist_accuracy(pred_y, y_batch) 180 | accs.append(acc) 181 | return accs 182 | 183 | def generate_data(self, x_batch): 184 | x_batch = set_tensor(x_batch, self.device) 185 | batch_size = x_batch.size(1) 186 | 187 | x = [[] for _ in range(self.n_layers)] 188 | x[0] = x_batch 189 | for l in range(1, self.n_layers): 190 | b = self.b[l - 1].repeat(1, batch_size) 191 | x[l] = self.W[l - 1] @ F.f(x[l - 1], self.act_fn) + b 192 | pred_y = x[-1] 193 | return pred_y 194 | 195 | def infer(self, x, batch_size, itr_max=None): 196 | itr_max = self.itr_max if itr_max is None else itr_max 197 | errors = [[] for _ in range(self.n_layers)] 198 | f_x_arr = [[] for _ in range(self.n_layers)] 199 | f_x_deriv_arr = [[] for _ in range(self.n_layers)] 200 | f_0 = 0 201 | its = 0 202 | beta = self.beta 203 | 204 | for l in range(1, self.n_layers): 205 | f_x = F.f(x[l - 1], self.act_fn) 206 | f_x_deriv = F.f_deriv(x[l - 1], self.act_fn) 207 | f_x_arr[l - 1] = f_x 208 | f_x_deriv_arr[l - 1] = f_x_deriv 209 | 210 | # eq. 2.17 211 | b = self.b[l - 1].repeat(1, batch_size) 212 | errors[l] = (x[l] - self.W[l - 1] @ f_x - b) / self.vars[l] 213 | f_0 = f_0 - self.vars[l] * torch.sum(torch.mul(errors[l], errors[l]), dim=0) 214 | 215 | for itr in range(itr_max): 216 | # update node activity 217 | for l in range(1, self.n_layers - 1): 218 | # eq. 2.18 219 | g = torch.mul(self.W[l].T @ errors[l + 1], f_x_deriv_arr[l]) 220 | x[l] = x[l] + beta * (-errors[l] + g) 221 | 222 | # update errors 223 | f = 0 224 | for l in range(1, self.n_layers): 225 | f_x = F.f(x[l - 1], self.act_fn) 226 | f_x_deriv = F.f_deriv(x[l - 1], self.act_fn) 227 | f_x_arr[l - 1] = f_x 228 | f_x_deriv_arr[l - 1] = f_x_deriv 229 | 230 | # eq. 2.17 231 | errors[l] = (x[l] - self.W[l - 1] @ f_x - self.b[l - 1]) / self.vars[l] 232 | f = f - self.vars[l] * torch.sum(torch.mul(errors[l], errors[l]), dim=0) 233 | 234 | diff = f - f_0 235 | threshold = self.condition * self.beta / self.vars[self.n_layers - 1] 236 | if torch.any(diff < 0): 237 | beta = beta / self.div 238 | elif torch.mean(diff) < threshold: 239 | # print(f"broke @ {its} its") 240 | break 241 | 242 | f_0 = f 243 | its = itr 244 | 245 | return x, errors, its 246 | 247 | def infer_v2(self, x, batch_size, x_batch, itr_max=None): 248 | """ this version infers top layer, rather than keeping it fixed """ 249 | itr_max = self.itr_max if itr_max is None else itr_max 250 | errors = [[] for _ in range(self.n_layers)] 251 | f_x_arr = [[] for _ in range(self.n_layers)] 252 | f_x_deriv_arr = [[] for _ in range(self.n_layers)] 253 | f_0 = 0 254 | its = 0 255 | beta = self.beta 256 | 257 | x[self.n_layers - 1] = x_batch 258 | 259 | for l in range(1, self.n_layers): 260 | f_x = F.f(x[l - 1], self.act_fn) 261 | f_x_deriv = F.f_deriv(x[l - 1], self.act_fn) 262 | f_x_arr[l - 1] = f_x 263 | f_x_deriv_arr[l - 1] = f_x_deriv 264 | 265 | # eq. 2.17 266 | b = self.b[l - 1].repeat(1, batch_size) 267 | errors[l] = (x[l] - self.W[l - 1] @ f_x - b) / self.vars[l] 268 | f_0 = f_0 - self.vars[l] * torch.sum(torch.mul(errors[l], errors[l]), dim=0) 269 | 270 | for itr in range(itr_max): 271 | # TODO (updating top layer) 272 | g = torch.mul(self.W[0].T @ errors[1], f_x_deriv_arr[0]) 273 | x[0] = x[0] + beta * g 274 | 275 | # update node activity 276 | for l in range(1, self.n_layers - 1): 277 | # eq. 2.18 278 | g = torch.mul(self.W[l].T @ errors[l + 1], f_x_deriv_arr[l]) 279 | x[l] = x[l] + beta * (-errors[l] + g) 280 | 281 | # update errors 282 | f = 0 283 | for l in range(1, self.n_layers): 284 | f_x = F.f(x[l - 1], self.act_fn) 285 | f_x_deriv = F.f_deriv(x[l - 1], self.act_fn) 286 | f_x_arr[l - 1] = f_x 287 | f_x_deriv_arr[l - 1] = f_x_deriv 288 | 289 | # eq. 2.17 290 | errors[l] = (x[l] - self.W[l - 1] @ f_x - self.b[l - 1]) / self.vars[l] 291 | f = f - self.vars[l] * torch.sum(torch.mul(errors[l], errors[l]), dim=0) 292 | 293 | diff = f - f_0 294 | threshold = self.condition * self.beta / self.vars[self.n_layers - 1] 295 | if torch.any(diff < 0): 296 | beta = beta / self.div 297 | elif torch.mean(diff) < threshold: 298 | break 299 | 300 | f_0 = f 301 | its = itr 302 | 303 | return x, errors, its 304 | 305 | def get_errors(self, x, batch_size): 306 | total_err = 0 307 | for l in range(1, self.n_layers - 1): 308 | b = self.b[l - 1].repeat(1, batch_size) 309 | err = (x[l] - self.W[l - 1] @ F.f(x[l - 1], self.act_fn) - b) / self.vars[l] 310 | total_err += torch.sum(torch.mul(err, err), dim=0) 311 | return torch.sum(total_err) 312 | 313 | 314 | def update_params( 315 | self, x, q, errors, batch_size, x_batch, y_batch, epoch_num=None, n_batches=None, curr_batch=None 316 | ): 317 | 318 | grad_w = [[] for _ in range(self.n_layers - 1)] 319 | grad_b = [[] for _ in range(self.n_layers - 1)] 320 | grad_w_q = [[] for _ in range(self.n_layers - 1)] 321 | grad_b_q = [[] for _ in range(self.n_layers - 1)] 322 | 323 | for l in range(self.n_layers - 1): 324 | # eq. 2.19 (with weight decay) 325 | grad_w[l] = ( 326 | self.vars[-1] * (1 / batch_size) * errors[l + 1] @ F.f(x[l], self.act_fn).T 327 | - self.d_rate * self.W[l] 328 | ) 329 | grad_b[l] = self.vars[-1] * (1 / batch_size) * torch.sum(errors[l + 1], axis=1) 330 | 331 | if self.amortised: 332 | q = q[::-1] 333 | 334 | q_errs = [[] for _ in range(self.n_layers - 1)] 335 | q_errs[0] = x[2] - q[2] 336 | fn_deriv = F.f_deriv(torch.matmul(x_batch.T, self.W_q[0].T), self.act_fn) 337 | grad_w_q[0] = torch.matmul(x[3], q_errs[0].T * fn_deriv) 338 | grad_b_q[0] = self.vars[-1] * (1 / batch_size) * torch.sum(q_errs[0], axis=1) 339 | 340 | q_errs[1] = x[1] - q[1] 341 | fn_deriv = F.f_deriv(torch.matmul(x[2].T, self.W_q[1].T), self.act_fn) 342 | grad_w_q[1] = torch.matmul(x[2], q_errs[1].T * fn_deriv) 343 | grad_b_q[1] = self.vars[-1] * (1 / batch_size) * torch.sum(q_errs[1], axis=1) 344 | 345 | # q_errs[2] = x[0] - q[0] 346 | q_errs[2] = y_batch - q[0] 347 | fn_deriv = F.f_deriv(torch.matmul(x[1].T, self.W_q[2].T), self.act_fn) 348 | grad_w_q[2] = torch.matmul(x[1], q_errs[2].T * fn_deriv) 349 | grad_b_q[2] = self.vars[-1] * (1 / batch_size) * torch.sum(q_errs[2], axis=1) 350 | 351 | self._apply_gradients( 352 | grad_w, 353 | grad_b, 354 | grad_w_q, 355 | grad_b_q, 356 | epoch_num=epoch_num, 357 | n_batches=n_batches, 358 | curr_batch=curr_batch, 359 | ) 360 | 361 | def _apply_gradients( 362 | self, grad_w, grad_b, grad_w_q, grad_b_q, epoch_num=None, n_batches=None, curr_batch=None 363 | ): 364 | 365 | if self.optim is "RMSPROP": 366 | for l in range(self.n_layers - 1): 367 | grad_b[l] = grad_b[l].unsqueeze(dim=1) 368 | 369 | self.c_w[l] = self.decay_r * self.c_w[l] + (1 - self.decay_r) * grad_w[l] ** 2 370 | self.c_b[l] = self.decay_r * self.c_b[l] + (1 - self.decay_r) * grad_b[l] ** 2 371 | 372 | self.W[l] = self.W[l] + self.l_rate * (grad_w[l] / (torch.sqrt(self.c_w[l]) + self.eps)) 373 | self.b[l] = self.b[l] + self.l_rate * (grad_b[l] / (torch.sqrt(self.c_b[l]) + self.eps)) 374 | 375 | if self.amortised: 376 | grad_b_q[l] = grad_b_q[l].unsqueeze(dim=1) 377 | self.c_w_q[l] = self.decay_r * self.c_w_q[l] + (1 - self.decay_r) * grad_w_q[l].T ** 2 378 | self.c_b_q[l] = self.decay_r * self.c_b_q[l] + (1 - self.decay_r) * grad_b_q[l] ** 2 379 | 380 | self.W_q[l] = self.W_q[l] + self.q_l_rate * ( 381 | grad_w_q[l].T / (torch.sqrt(self.c_w_q[l]) + self.eps) 382 | ) 383 | self.b_q[l] = self.b_q[l] + self.q_l_rate * ( 384 | grad_b_q[l] / (torch.sqrt(self.c_b_q[l]) + self.eps) 385 | ) 386 | 387 | elif self.optim is "ADAM": 388 | for l in range(self.n_layers - 1): 389 | grad_b[l] = grad_b[l].unsqueeze(dim=1) 390 | self.c_b[l] = self.beta_1 * self.c_b[l] + (1 - self.beta_1) * grad_b[l] 391 | self.c_w[l] = self.beta_1 * self.c_w[l] + (1 - self.beta_1) * grad_w[l] 392 | 393 | self.v_b[l] = self.beta_2 * self.v_b[l] + (1 - self.beta_2) * grad_b[l] ** 2 394 | self.v_w[l] = self.beta_2 * self.v_w[l] + (1 - self.beta_2) * grad_w[l] ** 2 395 | 396 | t = (epoch_num) * n_batches + curr_batch 397 | self.W[l] = self.W[l] + self.l_rate * np.sqrt(1 - self.beta_2 ** t) * self.c_w[l] / ( 398 | torch.sqrt(self.v_w[l]) + self.eps 399 | ) 400 | self.b[l] = self.b[l] + self.l_rate * np.sqrt(1 - self.beta_2 ** t) * self.c_b[l] / ( 401 | torch.sqrt(self.v_b[l]) + self.eps 402 | ) 403 | 404 | if self.amortised: 405 | grad_b_q[l] = grad_b_q[l].unsqueeze(dim=1) 406 | 407 | self.c_b_q[l] = self.beta_1 * self.c_b_q[l] + (1 - self.beta_1) * grad_b_q[l] 408 | self.c_w_q[l] = self.beta_1 * self.c_w_q[l] + (1 - self.beta_1) * grad_w_q[l].T 409 | 410 | self.v_b_q[l] = self.beta_2 * self.v_b_q[l] + (1 - self.beta_2) * grad_b_q[l] ** 2 411 | self.v_w_q[l] = self.beta_2 * self.v_w_q[l] + (1 - self.beta_2) * grad_w_q[l].T ** 2 412 | 413 | t = (epoch_num) * n_batches + curr_batch 414 | self.W_q[l] = self.W_q[l] + self.q_l_rate * np.sqrt(1 - self.beta_2 ** t) * self.c_w_q[ 415 | l 416 | ] / (torch.sqrt(self.v_w_q[l]) + self.eps) 417 | self.b_q[l] = self.b_q[l] + self.q_l_rate * np.sqrt(1 - self.beta_2 ** t) * self.c_b_q[ 418 | l 419 | ] / (torch.sqrt(self.v_b_q[l]) + self.eps) 420 | 421 | elif self.optim is "SGD" or self.optim is None: 422 | for l in range(self.n_layers - 1): 423 | self.W[l] = self.W[l] + self.l_rate * grad_w[l] 424 | self.b[l] = self.b[l] + self.l_rate * grad_b[l].unsqueeze(dim=1) 425 | 426 | if self.amortised: 427 | self.W_q[l] = self.W_q[l] + self.q_l_rate * grad_w_q[l].T 428 | self.b_q[l] = self.b_q[l] + self.q_l_rate * grad_b_q[l].unsqueeze(dim=1) 429 | 430 | else: 431 | raise ValueError(f"{self.optim} not supported") 432 | 433 | 434 | def _init_params(self): 435 | weights = [[] for _ in range(self.n_layers - 1)] 436 | bias = [[] for _ in range(self.n_layers - 1)] 437 | 438 | for l in range(self.n_layers - 1): 439 | norm_b = 0 440 | if self.act_fn is F.LINEAR: 441 | norm_w = np.sqrt(1 / (self.neurons[l + 1] + self.neurons[l])) 442 | elif self.act_fn is F.TANH: 443 | norm_w = np.sqrt(6 / (self.neurons[l + 1] + self.neurons[l])) 444 | elif self.act_fn is F.LOGSIG: 445 | norm_w = 4 * np.sqrt(6 / (self.neurons[l + 1] + self.neurons[l])) 446 | else: 447 | raise ValueError(f"{self.act_fn} not supported") 448 | 449 | layer_w = np.random.uniform(-1, 1, size=(self.neurons[l + 1], self.neurons[l])) * norm_w 450 | layer_b = np.zeros((self.neurons[l + 1], 1)) + norm_b * np.ones((self.neurons[l + 1], 1)) 451 | weights[l] = set_tensor(layer_w, self.device) 452 | bias[l] = set_tensor(layer_b, self.device) 453 | 454 | self.W = weights 455 | self.b = bias 456 | 457 | for l in range(self.n_layers - 1): 458 | self.c_b[l] = torch.zeros_like(self.b[l]) 459 | self.c_w[l] = torch.zeros_like(self.W[l]) 460 | self.v_b[l] = torch.zeros_like(self.b[l]) 461 | self.v_w[l] = torch.zeros_like(self.W[l]) 462 | 463 | if self.amortised: 464 | q_weights = [[] for _ in range(self.n_layers - 1)] 465 | q_bias = [[] for _ in range(self.n_layers - 1)] 466 | q_neurons = self.neurons[::-1] 467 | 468 | for l in range(self.n_layers - 1): 469 | norm_b = 0 470 | if self.act_fn is F.LINEAR: 471 | norm_w = np.sqrt(1 / (q_neurons[l + 1] + q_neurons[l])) 472 | elif self.act_fn is F.TANH: 473 | norm_w = np.sqrt(6 / (q_neurons[l + 1] + q_neurons[l])) 474 | elif self.act_fn is F.LOGSIG: 475 | norm_w = 4 * np.sqrt(6 / (q_neurons[l + 1] + q_neurons[l])) 476 | else: 477 | raise ValueError(f"{self.act_fn} not supported") 478 | 479 | q_layer_w = np.random.uniform(-1, 1, size=(q_neurons[l + 1], q_neurons[l])) * norm_w 480 | q_layer_b = np.zeros((q_neurons[l + 1], 1)) + norm_b * np.ones((q_neurons[l + 1], 1)) 481 | q_weights[l] = set_tensor(q_layer_w, self.device) 482 | q_bias[l] = set_tensor(q_layer_b, self.device) 483 | 484 | self.W_q = q_weights 485 | self.b_q = q_bias 486 | 487 | for l in range(self.n_layers - 1): 488 | self.c_b_q[l] = torch.zeros_like(self.b_q[l]) 489 | self.c_w_q[l] = torch.zeros_like(self.W_q[l]) 490 | self.v_b_q[l] = torch.zeros_like(self.b_q[l]) 491 | self.v_w_q[l] = torch.zeros_like(self.W_q[l]) 492 | 493 | -------------------------------------------------------------------------------- /q_network_v2.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=not-callable 2 | # pylint: disable=no-member 3 | 4 | import numpy as np 5 | import torch 6 | 7 | import mnist_utils 8 | import functions as F 9 | 10 | 11 | def set_tensor(arr, device): 12 | return torch.from_numpy(arr).float().to(device) 13 | 14 | 15 | class QCodingNetwork(object): 16 | def __init__(self, cf): 17 | self.device = cf.device 18 | self.n_layers = cf.n_layers 19 | self.act_fn = cf.act_fn 20 | self.td_neurons = cf.td_neurons 21 | self.bu_neurons = cf.bu_neurons 22 | self.vars = cf.vars.float().to(self.device) 23 | self.itr_max = cf.itr_max 24 | self.batch_size = cf.batch_size 25 | 26 | self.beta_1 = cf.beta_1 27 | self.beta_2 = cf.beta_2 28 | self.beta = cf.beta 29 | self.div = cf.div 30 | self.d_rate = cf.d_rate 31 | self.l_rate = cf.l_rate 32 | self.q_l_rate = cf.q_l_rate 33 | self.condition = cf.condition / (sum(cf.td_neurons) - cf.td_neurons[0]) 34 | 35 | self.optim = cf.optim 36 | self.eps = cf.eps 37 | self.decay_r = cf.decay_r 38 | self.c_b = [[] for _ in range(self.n_layers)] 39 | self.c_w = [[] for _ in range(self.n_layers)] 40 | self.v_b = [[] for _ in range(self.n_layers)] 41 | self.v_w = [[] for _ in range(self.n_layers)] 42 | self.c_b_q = [[] for _ in range(self.n_layers)] 43 | self.c_w_q = [[] for _ in range(self.n_layers)] 44 | self.v_b_q = [[] for _ in range(self.n_layers)] 45 | self.v_w_q = [[] for _ in range(self.n_layers)] 46 | 47 | self.W = None 48 | self.b = None 49 | self.Wq = None 50 | self.bq = None 51 | self._init_params() 52 | 53 | def train_epoch(self, img_batches, label_batches, epoch_num=None): 54 | init_err = 0 55 | end_err = 0 56 | avg_itr = 0 57 | n_batches = len(img_batches) 58 | 59 | for batch_id, (img_batch, label_batch) in enumerate(zip(img_batches, label_batches)): 60 | img_batch = set_tensor(img_batch, self.device) 61 | label_batch = set_tensor(label_batch, self.device) 62 | batch_size = img_batch.size(1) 63 | 64 | x = [[] for _ in range(self.n_layers)] 65 | q = [[] for _ in range(self.n_layers)] 66 | 67 | # bottom-up inference 68 | # self.W_q.reverse() 69 | # self.b_q.reverse() 70 | q[0] = img_batch 71 | for l in range(1, self.n_layers): 72 | b_q = self.b_q[l - 1].repeat(1, batch_size) 73 | q[l] = self.W_q[l - 1] @ F.f(q[l - 1], self.act_fn) + b_q 74 | 75 | 76 | x = q[::-1] 77 | x[0] = label_batch 78 | x[self.n_layers - 1] = img_batch 79 | init_err += self.get_errors(x, batch_size) 80 | 81 | # top-down inference 82 | x, errors, its = self.infer(x, batch_size) 83 | end_err += self.get_errors(x, batch_size) 84 | avg_itr += its 85 | 86 | # bottom-up inference 87 | # self.W_q.reverse() 88 | # self.b_q.reverse() 89 | q = x[::-1] 90 | q[0] = img_batch 91 | q[self.n_layers-1] = label_batch 92 | q, q_errors, its = self.amortised_infer(q, batch_size) 93 | # self.W_q.reverse() 94 | # self.b_q.reverse() 95 | 96 | # update top-down parameters 97 | self.update_params( 98 | x, 99 | q, 100 | errors, 101 | q_errors, 102 | batch_size, 103 | img_batch, 104 | label_batch, 105 | epoch_num=epoch_num, 106 | n_batches=n_batches, 107 | curr_batch=batch_id, 108 | ) 109 | 110 | return end_err / n_batches, init_err / n_batches, avg_itr / n_batches 111 | 112 | def test_epoch(self, x_batches, y_batches, itr_max=None): 113 | accs = [] 114 | n_batches = len(x_batches) 115 | avg_itr = 0 116 | for x_batch, y_batch in zip(x_batches, y_batches): 117 | x_batch = set_tensor(x_batch, self.device) 118 | y_batch = set_tensor(y_batch, self.device) 119 | batch_size = x_batch.size(1) 120 | 121 | x = [[] for _ in range(self.n_layers)] 122 | q = [[] for _ in range(self.n_layers)] 123 | 124 | q[0] = x_batch 125 | for l in range(1, self.n_layers): 126 | b_q = self.b_q[l - 1].repeat(1, batch_size) 127 | q[l] = self.W_q[l - 1] @ F.f(q[l - 1], self.act_fn) + b_q 128 | x = q[::-1] 129 | x[self.n_layers - 1] = x_batch 130 | 131 | x, errors, its = self.infer_and_classify(x, batch_size, x_batch, itr_max=itr_max) 132 | pred_y = x[0] 133 | acc = mnist_utils.mnist_accuracy(pred_y, y_batch) 134 | accs.append(acc) 135 | avg_itr += its 136 | return accs, avg_itr / n_batches 137 | 138 | def test_pc_epoch(self, x_batches, y_batches, itr_max=None): 139 | accs = [] 140 | n_batches = len(x_batches) 141 | avg_itr = 0 142 | for x_batch, y_batch in zip(x_batches, y_batches): 143 | x_batch = set_tensor(x_batch, self.device) 144 | y_batch = set_tensor(y_batch, self.device) 145 | batch_size = x_batch.size(1) 146 | 147 | x = [[] for _ in range(self.n_layers)] 148 | q = [[] for _ in range(self.n_layers)] 149 | 150 | x[0] = torch.empty_like(y_batch).normal_(mean=0.0, std=0.1) 151 | for l in range(1, self.n_layers): 152 | b = self.b[l - 1].repeat(1, batch_size) 153 | x[l] = self.W[l - 1] @ F.f(x[l - 1], self.act_fn) + b 154 | x[self.n_layers - 1] = x_batch 155 | 156 | x, errors, its = self.infer_and_classify(x, batch_size, x_batch, itr_max=itr_max) 157 | pred_y = x[0] 158 | acc = mnist_utils.mnist_accuracy(pred_y, y_batch) 159 | accs.append(acc) 160 | avg_itr += its 161 | return accs, avg_itr / n_batches 162 | 163 | def test_amortised_epoch(self, x_batches, y_batches): 164 | # self.W_q.reverse() 165 | # self.b_q.reverse() 166 | 167 | accs = [] 168 | for x_batch, y_batch in zip(x_batches, y_batches): 169 | x_batch = set_tensor(x_batch, self.device) 170 | y_batch = set_tensor(y_batch, self.device) 171 | batch_size = x_batch.size(1) 172 | 173 | q = [[] for _ in range(self.n_layers)] 174 | q[0] = x_batch 175 | for l in range(1, self.n_layers): 176 | b_q = self.b_q[l - 1].repeat(1, batch_size) 177 | q[l] = self.W_q[l - 1] @ F.f(q[l - 1], self.act_fn) + b_q 178 | pred_y = q[-1] 179 | acc = mnist_utils.mnist_accuracy(pred_y, y_batch) 180 | accs.append(acc) 181 | 182 | # self.W_q.reverse() 183 | # self.b_q.reverse() 184 | return accs 185 | 186 | def generate_data(self, x_batch): 187 | x_batch = set_tensor(x_batch, self.device) 188 | batch_size = x_batch.size(1) 189 | 190 | x = [[] for _ in range(self.n_layers)] 191 | x[0] = x_batch 192 | for l in range(1, self.n_layers): 193 | b = self.b[l - 1].repeat(1, batch_size) 194 | x[l] = self.W[l - 1] @ F.f(x[l - 1], self.act_fn) + b 195 | pred_y = x[-1] 196 | return pred_y 197 | 198 | def amortised_infer(self, x, batch_size, itr_max=None): 199 | itr_max = self.itr_max if itr_max is None else itr_max 200 | errors = [[] for _ in range(self.n_layers)] 201 | f_x_arr = [[] for _ in range(self.n_layers)] 202 | f_x_deriv_arr = [[] for _ in range(self.n_layers)] 203 | f_0 = 0 204 | its = 0 205 | beta = self.beta 206 | 207 | for l in range(1, self.n_layers): 208 | f_x = F.f(x[l - 1], self.act_fn) 209 | f_x_deriv = F.f_deriv(x[l - 1], self.act_fn) 210 | f_x_arr[l - 1] = f_x 211 | f_x_deriv_arr[l - 1] = f_x_deriv 212 | 213 | # eq. 2.17 214 | b = self.b_q[l - 1].repeat(1, batch_size) 215 | errors[l] = (x[l] - self.W_q[l - 1] @ f_x - b) / self.vars[l] 216 | f_0 = f_0 - self.vars[l] * torch.sum(torch.mul(errors[l], errors[l]), dim=0) 217 | 218 | for itr in range(itr_max): 219 | # update node activity 220 | for l in range(1, self.n_layers - 1): 221 | # eq. 2.18 222 | g = torch.mul(self.W_q[l].T @ errors[l + 1], f_x_deriv_arr[l]) 223 | x[l] = x[l] + beta * (-errors[l] + g) 224 | 225 | # update errors 226 | f = 0 227 | for l in range(1, self.n_layers): 228 | f_x = F.f(x[l - 1], self.act_fn) 229 | f_x_deriv = F.f_deriv(x[l - 1], self.act_fn) 230 | f_x_arr[l - 1] = f_x 231 | f_x_deriv_arr[l - 1] = f_x_deriv 232 | 233 | # eq. 2.17 234 | errors[l] = (x[l] - self.W_q[l - 1] @ f_x - self.b_q[l - 1]) / self.vars[l] 235 | f = f - self.vars[l] * torch.sum(torch.mul(errors[l], errors[l]), dim=0) 236 | 237 | diff = f - f_0 238 | threshold = self.condition * self.beta / self.vars[self.n_layers - 1] 239 | if torch.any(diff < 0): 240 | beta = beta / self.div 241 | elif torch.mean(diff) < threshold: 242 | # print(f"broke @ {its} its") 243 | break 244 | 245 | f_0 = f 246 | its = itr 247 | 248 | return x, errors, its 249 | 250 | def infer(self, x, batch_size, itr_max=None): 251 | itr_max = self.itr_max if itr_max is None else itr_max 252 | errors = [[] for _ in range(self.n_layers)] 253 | f_x_arr = [[] for _ in range(self.n_layers)] 254 | f_x_deriv_arr = [[] for _ in range(self.n_layers)] 255 | f_0 = 0 256 | its = 0 257 | beta = self.beta 258 | 259 | for l in range(1, self.n_layers): 260 | f_x = F.f(x[l - 1], self.act_fn) 261 | f_x_deriv = F.f_deriv(x[l - 1], self.act_fn) 262 | f_x_arr[l - 1] = f_x 263 | f_x_deriv_arr[l - 1] = f_x_deriv 264 | 265 | # eq. 2.17 266 | b = self.b[l - 1].repeat(1, batch_size) 267 | errors[l] = (x[l] - self.W[l - 1] @ f_x - b) / self.vars[l] 268 | f_0 = f_0 - self.vars[l] * torch.sum(torch.mul(errors[l], errors[l]), dim=0) 269 | 270 | for itr in range(itr_max): 271 | # update node activity 272 | for l in range(1, self.n_layers - 1): 273 | # eq. 2.18 274 | g = torch.mul(self.W[l].T @ errors[l + 1], f_x_deriv_arr[l]) 275 | x[l] = x[l] + beta * (-errors[l] + g) 276 | 277 | # update errors 278 | f = 0 279 | for l in range(1, self.n_layers): 280 | f_x = F.f(x[l - 1], self.act_fn) 281 | f_x_deriv = F.f_deriv(x[l - 1], self.act_fn) 282 | f_x_arr[l - 1] = f_x 283 | f_x_deriv_arr[l - 1] = f_x_deriv 284 | 285 | # eq. 2.17 286 | errors[l] = (x[l] - self.W[l - 1] @ f_x - self.b[l - 1]) / self.vars[l] 287 | f = f - self.vars[l] * torch.sum(torch.mul(errors[l], errors[l]), dim=0) 288 | 289 | diff = f - f_0 290 | threshold = self.condition * self.beta / self.vars[self.n_layers - 1] 291 | if torch.any(diff < 0): 292 | beta = beta / self.div 293 | elif torch.mean(diff) < threshold: 294 | # print(f"broke @ {its} its") 295 | break 296 | 297 | f_0 = f 298 | its = itr 299 | 300 | return x, errors, its 301 | 302 | def infer_and_classify(self, x, batch_size, x_batch, itr_max=None): 303 | """ this version infers top layer, rather than keeping it fixed """ 304 | itr_max = self.itr_max if itr_max is None else itr_max 305 | errors = [[] for _ in range(self.n_layers)] 306 | f_x_arr = [[] for _ in range(self.n_layers)] 307 | f_x_deriv_arr = [[] for _ in range(self.n_layers)] 308 | f_0 = 0 309 | its = 0 310 | beta = self.beta 311 | 312 | x[self.n_layers - 1] = x_batch 313 | 314 | for l in range(1, self.n_layers): 315 | f_x = F.f(x[l - 1], self.act_fn) 316 | f_x_deriv = F.f_deriv(x[l - 1], self.act_fn) 317 | f_x_arr[l - 1] = f_x 318 | f_x_deriv_arr[l - 1] = f_x_deriv 319 | 320 | # eq. 2.17 321 | b = self.b[l - 1].repeat(1, batch_size) 322 | errors[l] = (x[l] - self.W[l - 1] @ f_x - b) / self.vars[l] 323 | f_0 = f_0 - self.vars[l] * torch.sum(torch.mul(errors[l], errors[l]), dim=0) 324 | 325 | for itr in range(itr_max): 326 | # TODO (updating top layer) 327 | g = torch.mul(self.W[0].T @ errors[1], f_x_deriv_arr[0]) 328 | x[0] = x[0] + beta * g 329 | 330 | # update node activity 331 | for l in range(1, self.n_layers - 1): 332 | # eq. 2.18 333 | g = torch.mul(self.W[l].T @ errors[l + 1], f_x_deriv_arr[l]) 334 | x[l] = x[l] + beta * (-errors[l] + g) 335 | 336 | # update errors 337 | f = 0 338 | for l in range(1, self.n_layers): 339 | f_x = F.f(x[l - 1], self.act_fn) 340 | f_x_deriv = F.f_deriv(x[l - 1], self.act_fn) 341 | f_x_arr[l - 1] = f_x 342 | f_x_deriv_arr[l - 1] = f_x_deriv 343 | 344 | # eq. 2.17 345 | errors[l] = (x[l] - self.W[l - 1] @ f_x - self.b[l - 1]) / self.vars[l] 346 | f = f - self.vars[l] * torch.sum(torch.mul(errors[l], errors[l]), dim=0) 347 | 348 | diff = f - f_0 349 | threshold = self.condition * self.beta / self.vars[self.n_layers - 1] 350 | if torch.any(diff < 0): 351 | beta = beta / self.div 352 | elif torch.mean(diff) < threshold: 353 | break 354 | 355 | f_0 = f 356 | its = itr 357 | 358 | return x, errors, its 359 | 360 | def get_errors(self, x, batch_size): 361 | total_err = 0 362 | for l in range(1, self.n_layers - 1): 363 | b = self.b[l - 1].repeat(1, batch_size) 364 | err = (x[l] - self.W[l - 1] @ F.f(x[l - 1], self.act_fn) - b) / self.vars[l] 365 | total_err += torch.sum(torch.mul(err, err), dim=0) 366 | return torch.sum(total_err) 367 | 368 | def update_params( 369 | self, x, q, errors, q_errors, batch_size, x_batch, y_batch, epoch_num=None, n_batches=None, curr_batch=None 370 | ): 371 | 372 | grad_w = [[] for _ in range(self.n_layers - 1)] 373 | grad_b = [[] for _ in range(self.n_layers - 1)] 374 | grad_w_q = [[] for _ in range(self.n_layers - 1)] 375 | grad_b_q = [[] for _ in range(self.n_layers - 1)] 376 | 377 | for l in range(self.n_layers - 1): 378 | grad_w[l] = ( 379 | self.vars[-1] * (1 / batch_size) * errors[l + 1] @ F.f(x[l], self.act_fn).T 380 | - self.d_rate * self.W[l] 381 | ) 382 | grad_b[l] = self.vars[-1] * (1 / batch_size) * torch.sum(errors[l + 1], axis=1) 383 | 384 | for l in range(self.n_layers - 1): 385 | grad_w_q[l] = ( 386 | self.vars[-1] * (1 / batch_size) * q_errors[l + 1] @ F.f(q[l], self.act_fn).T 387 | - self.d_rate * self.W_q[l] 388 | ) 389 | grad_b_q[l] = self.vars[-1] * (1 / batch_size) * torch.sum(q_errors[l + 1], axis=1) 390 | 391 | self._apply_gradients( 392 | grad_w, 393 | grad_b, 394 | grad_w_q, 395 | grad_b_q, 396 | epoch_num=epoch_num, 397 | n_batches=n_batches, 398 | curr_batch=curr_batch, 399 | ) 400 | 401 | def _apply_gradients( 402 | self, grad_w, grad_b, grad_w_q, grad_b_q, epoch_num=None, n_batches=None, curr_batch=None 403 | ): 404 | 405 | if self.optim is "RMSPROP": 406 | for l in range(self.n_layers - 1): 407 | grad_b[l] = grad_b[l].unsqueeze(dim=1) 408 | 409 | self.c_w[l] = self.decay_r * self.c_w[l] + (1 - self.decay_r) * grad_w[l] ** 2 410 | self.c_b[l] = self.decay_r * self.c_b[l] + (1 - self.decay_r) * grad_b[l] ** 2 411 | 412 | self.W[l] = self.W[l] + self.l_rate * (grad_w[l] / (torch.sqrt(self.c_w[l]) + self.eps)) 413 | self.b[l] = self.b[l] + self.l_rate * (grad_b[l] / (torch.sqrt(self.c_b[l]) + self.eps)) 414 | 415 | self.c_w_q[l] = self.decay_r * self.c_w[l] + (1 - self.decay_r) * grad_w[l] ** 2 416 | self.c_b_q[l] = self.decay_r * self.c_b[l] + (1 - self.decay_r) * grad_b[l] ** 2 417 | 418 | self.W_q[l] = self.W_q[l] + self.q_l_rate * (grad_w_1[l] / (torch.sqrt(self.c_w_q[l]) + self.eps)) 419 | self.b_q[l] = self.b_q[l] + self.q_l_rate * (grad_b_q[l] / (torch.sqrt(self.c_b_q[l]) + self.eps)) 420 | 421 | 422 | elif self.optim is "ADAM": 423 | for l in range(self.n_layers - 1): 424 | grad_b[l] = grad_b[l].unsqueeze(dim=1) 425 | self.c_b[l] = self.beta_1 * self.c_b[l] + (1 - self.beta_1) * grad_b[l] 426 | self.c_w[l] = self.beta_1 * self.c_w[l] + (1 - self.beta_1) * grad_w[l] 427 | 428 | self.v_b[l] = self.beta_2 * self.v_b[l] + (1 - self.beta_2) * grad_b[l] ** 2 429 | self.v_w[l] = self.beta_2 * self.v_w[l] + (1 - self.beta_2) * grad_w[l] ** 2 430 | 431 | t = (epoch_num) * n_batches + curr_batch 432 | self.W[l] = self.W[l] + self.l_rate * np.sqrt(1 - self.beta_2 ** t) * self.c_w[l] / ( 433 | torch.sqrt(self.v_w[l]) + self.eps 434 | ) 435 | self.b[l] = self.b[l] + self.l_rate * np.sqrt(1 - self.beta_2 ** t) * self.c_b[l] / ( 436 | torch.sqrt(self.v_b[l]) + self.eps 437 | ) 438 | 439 | grad_b_q[l] = grad_b_q[l].unsqueeze(dim=1) 440 | self.c_b_q[l] = self.beta_1 * self.c_b_q[l] + (1 - self.beta_1) * grad_b_q[l] 441 | self.c_w_q[l] = self.beta_1 * self.c_w_q[l] + (1 - self.beta_1) * grad_w_q[l] 442 | 443 | self.v_b_q[l] = self.beta_2 * self.v_b_q[l] + (1 - self.beta_2) * grad_b_q[l] ** 2 444 | self.v_w_q[l] = self.beta_2 * self.v_w_q[l] + (1 - self.beta_2) * grad_w_q[l] ** 2 445 | 446 | t = (epoch_num) * n_batches + curr_batch 447 | self.W_q[l] = self.W_q[l] + self.q_l_rate * np.sqrt(1 - self.beta_2 ** t) * self.c_w_q[l] / ( 448 | torch.sqrt(self.v_w_q[l]) + self.eps 449 | ) 450 | self.b_q[l] = self.b_q[l] + self.q_l_rate * np.sqrt(1 - self.beta_2 ** t) * self.c_b_q[l] / ( 451 | torch.sqrt(self.v_b_q[l]) + self.eps 452 | ) 453 | 454 | elif self.optim is "SGD" or self.optim is None: 455 | for l in range(self.n_layers - 1): 456 | self.W[l] = self.W[l] + self.l_rate * grad_w[l] 457 | self.b[l] = self.b[l] + self.l_rate * grad_b[l].unsqueeze(dim=1) 458 | 459 | self.W_q[l] = self.W_q[l] + self.q_l_rate * grad_w_q[l] 460 | self.b_q[l] = self.b_q[l] + self.q_l_rate * grad_b_q[l].unsqueeze(dim=1) 461 | 462 | else: 463 | raise ValueError(f"{self.optim} not supported") 464 | 465 | 466 | def _init_params(self): 467 | """ 468 | Top down 469 | """ 470 | 471 | weights = [[] for _ in range(self.n_layers - 1)] 472 | bias = [[] for _ in range(self.n_layers - 1)] 473 | 474 | for l in range(self.n_layers - 1): 475 | norm_b = 0 476 | if self.act_fn is F.LINEAR: 477 | norm_w = np.sqrt(1 / (self.td_neurons[l + 1] + self.td_neurons[l])) 478 | elif self.act_fn is F.TANH: 479 | norm_w = np.sqrt(6 / (self.td_neurons[l + 1] + self.td_neurons[l])) 480 | elif self.act_fn is F.LOGSIG: 481 | norm_w = 4 * np.sqrt(6 / (self.td_neurons[l + 1] + self.td_neurons[l])) 482 | else: 483 | raise ValueError(f"{self.act_fn} not supported") 484 | 485 | layer_w = np.random.uniform(-1, 1, size=(self.td_neurons[l + 1], self.td_neurons[l])) * norm_w 486 | layer_b = np.zeros((self.td_neurons[l + 1], 1)) + norm_b * np.ones((self.td_neurons[l + 1], 1)) 487 | weights[l] = set_tensor(layer_w, self.device) 488 | bias[l] = set_tensor(layer_b, self.device) 489 | 490 | self.W = weights 491 | self.b = bias 492 | 493 | for l in range(self.n_layers - 1): 494 | self.c_b[l] = torch.zeros_like(self.b[l]) 495 | self.c_w[l] = torch.zeros_like(self.W[l]) 496 | self.v_b[l] = torch.zeros_like(self.b[l]) 497 | self.v_w[l] = torch.zeros_like(self.W[l]) 498 | 499 | """ 500 | Bottom up 501 | """ 502 | 503 | q_weights = [[] for _ in range(self.n_layers - 1)] 504 | q_bias = [[] for _ in range(self.n_layers - 1)] 505 | 506 | for l in range(self.n_layers - 1): 507 | norm_b = 0 508 | if self.act_fn is F.LINEAR: 509 | norm_w = np.sqrt(1 / (self.bu_neurons[l + 1] + self.bu_neurons[l])) 510 | elif self.act_fn is F.TANH: 511 | norm_w = np.sqrt(6 / (self.bu_neurons[l + 1] + self.bu_neurons[l])) 512 | elif self.act_fn is F.LOGSIG: 513 | norm_w = 4 * np.sqrt(6 / (self.bu_neurons[l + 1] + self.bu_neurons[l])) 514 | else: 515 | raise ValueError(f"{self.act_fn} not supported") 516 | 517 | layer_w = np.random.uniform(-1, 1, size=(self.bu_neurons[l + 1], self.bu_neurons[l])) * norm_w 518 | layer_b = np.zeros((self.bu_neurons[l + 1], 1)) + norm_b * np.ones((self.bu_neurons[l + 1], 1)) 519 | q_weights[l] = set_tensor(layer_w, self.device) 520 | q_bias[l] = set_tensor(layer_b, self.device) 521 | 522 | self.W_q = q_weights 523 | self.b_q = q_bias 524 | 525 | for l in range(self.n_layers - 1): 526 | self.c_b_q[l] = torch.zeros_like(self.b_q[l]) 527 | self.c_w_q[l] = torch.zeros_like(self.W_q[l]) 528 | self.v_b_q[l] = torch.zeros_like(self.b_q[l]) 529 | self.v_w_q[l] = torch.zeros_like(self.W_q[l]) 530 | -------------------------------------------------------------------------------- /q_network_v3.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=not-callable 2 | # pylint: disable=no-member 3 | 4 | import numpy as np 5 | import torch 6 | 7 | import mnist_utils 8 | import functions as F 9 | 10 | 11 | def set_tensor(arr, device): 12 | return torch.from_numpy(arr).float().to(device) 13 | 14 | 15 | class QCodingNetwork(object): 16 | def __init__(self, cf): 17 | self.device = cf.device 18 | self.n_layers = cf.n_layers 19 | self.act_fn = cf.act_fn 20 | self.td_neurons = cf.td_neurons 21 | self.bu_neurons = cf.bu_neurons 22 | self.vars = cf.vars.float().to(self.device) 23 | self.itr_max = cf.itr_max 24 | self.batch_size = cf.batch_size 25 | 26 | self.amortised_prec = cf.amortised_prec 27 | self.generative_prec = cf.generative_prec 28 | 29 | self.beta_1 = cf.beta_1 30 | self.beta_2 = cf.beta_2 31 | self.beta = cf.beta 32 | self.div = cf.div 33 | self.d_rate = cf.d_rate 34 | self.l_rate = cf.l_rate 35 | self.q_l_rate = cf.q_l_rate 36 | self.condition = cf.condition / (sum(cf.td_neurons) - cf.td_neurons[0]) 37 | 38 | self.optim = cf.optim 39 | self.eps = cf.eps 40 | self.decay_r = cf.decay_r 41 | self.c_b = [[] for _ in range(self.n_layers)] 42 | self.c_w = [[] for _ in range(self.n_layers)] 43 | self.v_b = [[] for _ in range(self.n_layers)] 44 | self.v_w = [[] for _ in range(self.n_layers)] 45 | self.c_b_q = [[] for _ in range(self.n_layers)] 46 | self.c_w_q = [[] for _ in range(self.n_layers)] 47 | self.v_b_q = [[] for _ in range(self.n_layers)] 48 | self.v_w_q = [[] for _ in range(self.n_layers)] 49 | 50 | self.W = None 51 | self.b = None 52 | self.Wq = None 53 | self.bq = None 54 | self._init_params() 55 | 56 | def train_epoch(self, img_batches, label_batches, epoch_num=None): 57 | init_err = 0 58 | end_err = 0 59 | avg_itr = 0 60 | n_batches = len(img_batches) 61 | 62 | for batch_id, (img_batch, label_batch) in enumerate(zip(img_batches, label_batches)): 63 | img_batch = set_tensor(img_batch, self.device) 64 | label_batch = set_tensor(label_batch, self.device) 65 | batch_size = img_batch.size(1) 66 | 67 | # activations / mu 68 | x = [[] for _ in range(self.n_layers)] 69 | 70 | # amortised forward 71 | x[0] = img_batch 72 | for l in range(1, self.n_layers): 73 | b_q = self.b_q[l - 1].repeat(1, batch_size) 74 | x[l] = self.W_q[l - 1] @ F.f(x[l - 1], self.act_fn) + b_q 75 | 76 | # reverse order 77 | x = x[::-1] 78 | x[0] = label_batch 79 | x[self.n_layers - 1] = img_batch 80 | init_err += self.get_errors(x, batch_size) 81 | 82 | # inference 83 | x, errors, q_errors, its = self.hybrid_infer(x, batch_size) 84 | 85 | end_err += self.get_errors(x, batch_size) 86 | avg_itr += its 87 | 88 | self.update_params( 89 | x, 90 | errors, 91 | q_errors, 92 | batch_size, 93 | img_batch, 94 | label_batch, 95 | epoch_num=epoch_num, 96 | n_batches=n_batches, 97 | curr_batch=batch_id, 98 | ) 99 | 100 | return end_err / n_batches, init_err / n_batches, avg_itr / n_batches 101 | 102 | def test_epoch(self, img_batches, label_batches, itr_max=None): 103 | accs = [] 104 | n_batches = len(img_batches) 105 | avg_itr = 0 106 | for img_batch, label_batch in zip(img_batches, label_batches): 107 | img_batch = set_tensor(img_batch, self.device) 108 | label_batch = set_tensor(label_batch, self.device) 109 | batch_size = img_batch.size(1) 110 | 111 | x = [[] for _ in range(self.n_layers)] 112 | x[0] = img_batch 113 | for l in range(1, self.n_layers): 114 | b_q = self.b_q[l - 1].repeat(1, batch_size) 115 | x[l] = self.W_q[l - 1] @ F.f(x[l - 1], self.act_fn) + b_q 116 | 117 | x = x[::-1] 118 | x[self.n_layers - 1] = img_batch 119 | 120 | x, errors, q_errors, its = self.hybrid_infer(x, batch_size, itr_max=itr_max, test=True) 121 | pred_labels = x[0] 122 | acc = mnist_utils.mnist_accuracy(pred_labels, label_batch) 123 | accs.append(acc) 124 | avg_itr += its 125 | return accs, avg_itr / n_batches 126 | 127 | def test_pc_epoch(self, x_batches, y_batches, itr_max=None): 128 | accs = [] 129 | n_batches = len(x_batches) 130 | avg_itr = 0 131 | for x_batch, y_batch in zip(x_batches, y_batches): 132 | x_batch = set_tensor(x_batch, self.device) 133 | y_batch = set_tensor(y_batch, self.device) 134 | batch_size = x_batch.size(1) 135 | 136 | x = [[] for _ in range(self.n_layers)] 137 | 138 | x[0] = torch.empty_like(y_batch).normal_(mean=0.0, std=0.1) 139 | for l in range(1, self.n_layers): 140 | b = self.b[l - 1].repeat(1, batch_size) 141 | x[l] = self.W[l - 1] @ F.f(x[l - 1], self.act_fn) + b 142 | x[self.n_layers - 1] = x_batch 143 | 144 | x, errors, its = self.infer(x, batch_size, itr_max=itr_max, test=True) 145 | pred_y = x[0] 146 | acc = mnist_utils.mnist_accuracy(pred_y, y_batch) 147 | accs.append(acc) 148 | avg_itr += its 149 | return accs, avg_itr / n_batches 150 | 151 | def test_amortised_epoch(self, x_batches, y_batches): 152 | accs = [] 153 | for x_batch, y_batch in zip(x_batches, y_batches): 154 | x_batch = set_tensor(x_batch, self.device) 155 | y_batch = set_tensor(y_batch, self.device) 156 | batch_size = x_batch.size(1) 157 | 158 | q = [[] for _ in range(self.n_layers)] 159 | q[0] = x_batch 160 | for l in range(1, self.n_layers): 161 | b_q = self.b_q[l - 1].repeat(1, batch_size) 162 | q[l] = self.W_q[l - 1] @ F.f(q[l - 1], self.act_fn) + b_q 163 | pred_y = q[-1] 164 | acc = mnist_utils.mnist_accuracy(pred_y, y_batch) 165 | accs.append(acc) 166 | return accs 167 | 168 | def generate_data(self, x_batch): 169 | x_batch = set_tensor(x_batch, self.device) 170 | batch_size = x_batch.size(1) 171 | 172 | x = [[] for _ in range(self.n_layers)] 173 | x[0] = x_batch 174 | for l in range(1, self.n_layers): 175 | b = self.b[l - 1].repeat(1, batch_size) 176 | x[l] = self.W[l - 1] @ F.f(x[l - 1], self.act_fn) + b 177 | pred_y = x[-1] 178 | return pred_y 179 | 180 | def hybrid_infer(self, x, batch_size, itr_max=None, test=False): 181 | itr_max = self.itr_max if itr_max is None else itr_max 182 | 183 | errors = [[] for _ in range(self.n_layers)] 184 | f_x_arr = [[] for _ in range(self.n_layers)] 185 | f_x_deriv_arr = [[] for _ in range(self.n_layers)] 186 | 187 | q_errors = [[] for _ in range(self.n_layers)] 188 | q_f_x_arr = [[] for _ in range(self.n_layers)] 189 | q_f_x_deriv_arr = [[] for _ in range(self.n_layers)] 190 | 191 | f_0 = 0 192 | q_f_0 = 0 193 | its = 0 194 | beta = self.beta 195 | 196 | for l in range(1, self.n_layers): 197 | # bottom up 198 | x = x[::-1] 199 | q_f_x = F.f(x[l - 1], self.act_fn) 200 | q_f_x_deriv = F.f_deriv(x[l - 1], self.act_fn) 201 | q_f_x_arr[l - 1] = q_f_x 202 | q_f_x_deriv_arr[l - 1] = q_f_x_deriv 203 | 204 | q_b = self.b_q[l - 1].repeat(1, batch_size) 205 | q_errors[l] = (x[l] - self.W_q[l - 1] @ q_f_x - q_b) / self.vars[l] 206 | q_f_0 = q_f_0 - self.vars[l] * torch.sum(torch.mul(q_errors[l], q_errors[l]), dim=0) 207 | 208 | # top down 209 | x = x[::-1] 210 | f_x = F.f(x[l - 1], self.act_fn) 211 | f_x_deriv = F.f_deriv(x[l - 1], self.act_fn) 212 | f_x_arr[l - 1] = f_x 213 | f_x_deriv_arr[l - 1] = f_x_deriv 214 | 215 | b = self.b[l - 1].repeat(1, batch_size) 216 | errors[l] = (x[l] - self.W[l - 1] @ f_x - b) / self.vars[l] 217 | f_0 = f_0 - self.vars[l] * torch.sum(torch.mul(errors[l], errors[l]), dim=0) 218 | 219 | for itr in range(itr_max): 220 | # bottom up 221 | x = x[::-1] 222 | if test: 223 | g = torch.mul(self.W_q[0].T @ q_errors[1], q_f_x_deriv_arr[0]) 224 | x[0] = x[0] + self.amortised_prec * g 225 | 226 | for l in range(1, self.n_layers - 1): 227 | g = torch.mul(self.W_q[l].T @ q_errors[l + 1], q_f_x_deriv_arr[l]) 228 | x[l] = x[l] + self.amortised_prec * (-q_errors[l] + g) 229 | 230 | # top down 231 | x = x[::-1] 232 | if test: 233 | g = torch.mul(self.W[0].T @ errors[1], f_x_deriv_arr[0]) 234 | x[0] = x[0] + self.generative_prec * g 235 | for l in range(1, self.n_layers - 1): 236 | g = torch.mul(self.W[l].T @ errors[l + 1], f_x_deriv_arr[l]) 237 | x[l] = x[l] + beta * (-errors[l] + g) 238 | 239 | # update errors 240 | f = 0 241 | q_f = 0 242 | for l in range(1, self.n_layers): 243 | # bottom up 244 | x = x[::-1] 245 | q_f_x = F.f(x[l - 1], self.act_fn) 246 | q_f_x_deriv = F.f_deriv(x[l - 1], self.act_fn) 247 | q_f_x_arr[l - 1] = q_f_x 248 | q_f_x_deriv_arr[l - 1] = q_f_x_deriv 249 | 250 | q_errors[l] = (x[l] - self.W_q[l - 1] @ q_f_x - self.b_q[l - 1]) / self.vars[l] 251 | q_f = q_f - self.vars[l] * torch.sum(torch.mul(q_errors[l], q_errors[l]), dim=0) 252 | 253 | # top down 254 | x = x[::-1] 255 | f_x = F.f(x[l - 1], self.act_fn) 256 | f_x_deriv = F.f_deriv(x[l - 1], self.act_fn) 257 | f_x_arr[l - 1] = f_x 258 | f_x_deriv_arr[l - 1] = f_x_deriv 259 | 260 | errors[l] = (x[l] - self.W[l - 1] @ f_x - self.b[l - 1]) / self.vars[l] 261 | f = f - self.vars[l] * torch.sum(torch.mul(errors[l], errors[l]), dim=0) 262 | 263 | # TODO 264 | diff = f - f_0 265 | q_diff = q_f - q_f_0 266 | threshold = self.condition * self.beta / self.vars[self.n_layers - 1] 267 | if torch.any(diff < 0) or torch.any(q_diff < 0): 268 | beta = beta / self.div 269 | # TODO update relative betas 270 | elif torch.mean(diff) < threshold or torch.mean(q_diff) < threshold: 271 | # print(f"broke @ {its} its") 272 | break 273 | 274 | f_0 = f 275 | q_f_0 = q_f 276 | its = itr 277 | 278 | return x, errors, q_errors, its 279 | 280 | def infer(self, x, batch_size, itr_max=None, test=False): 281 | itr_max = self.itr_max if itr_max is None else itr_max 282 | errors = [[] for _ in range(self.n_layers)] 283 | f_x_arr = [[] for _ in range(self.n_layers)] 284 | f_x_deriv_arr = [[] for _ in range(self.n_layers)] 285 | f_0 = 0 286 | its = 0 287 | beta = self.beta 288 | 289 | 290 | for l in range(1, self.n_layers): 291 | f_x = F.f(x[l - 1], self.act_fn) 292 | f_x_deriv = F.f_deriv(x[l - 1], self.act_fn) 293 | f_x_arr[l - 1] = f_x 294 | f_x_deriv_arr[l - 1] = f_x_deriv 295 | 296 | # eq. 2.17 297 | b = self.b[l - 1].repeat(1, batch_size) 298 | errors[l] = (x[l] - self.W[l - 1] @ f_x - b) / self.vars[l] 299 | f_0 = f_0 - self.vars[l] * torch.sum(torch.mul(errors[l], errors[l]), dim=0) 300 | 301 | for itr in range(itr_max): 302 | 303 | if test: 304 | g = torch.mul(self.W[0].T @ errors[1], f_x_deriv_arr[0]) 305 | x[0] = x[0] + beta * g 306 | 307 | # update node activity 308 | for l in range(1, self.n_layers - 1): 309 | # eq. 2.18 310 | g = torch.mul(self.W[l].T @ errors[l + 1], f_x_deriv_arr[l]) 311 | x[l] = x[l] + beta * (-errors[l] + g) 312 | 313 | # update errors 314 | f = 0 315 | for l in range(1, self.n_layers): 316 | f_x = F.f(x[l - 1], self.act_fn) 317 | f_x_deriv = F.f_deriv(x[l - 1], self.act_fn) 318 | f_x_arr[l - 1] = f_x 319 | f_x_deriv_arr[l - 1] = f_x_deriv 320 | 321 | # eq. 2.17 322 | errors[l] = (x[l] - self.W[l - 1] @ f_x - self.b[l - 1]) / self.vars[l] 323 | f = f - self.vars[l] * torch.sum(torch.mul(errors[l], errors[l]), dim=0) 324 | 325 | diff = f - f_0 326 | threshold = self.condition * self.beta / self.vars[self.n_layers - 1] 327 | if torch.any(diff < 0): 328 | beta = beta / self.div 329 | elif torch.mean(diff) < threshold: 330 | print(f"broke @ {its} its") 331 | break 332 | 333 | f_0 = f 334 | its = itr 335 | 336 | return x, errors, its 337 | 338 | def get_errors(self, x, batch_size): 339 | total_err = 0 340 | for l in range(1, self.n_layers - 1): 341 | b = self.b[l - 1].repeat(1, batch_size) 342 | err = (x[l] - self.W[l - 1] @ F.f(x[l - 1], self.act_fn) - b) / self.vars[l] 343 | total_err += torch.sum(torch.mul(err, err), dim=0) 344 | return torch.sum(total_err) 345 | 346 | def update_params( 347 | self, x, errors, q_errors, batch_size, x_batch, y_batch, epoch_num=None, n_batches=None, curr_batch=None 348 | ): 349 | 350 | grad_w = [[] for _ in range(self.n_layers - 1)] 351 | grad_b = [[] for _ in range(self.n_layers - 1)] 352 | grad_w_q = [[] for _ in range(self.n_layers - 1)] 353 | grad_b_q = [[] for _ in range(self.n_layers - 1)] 354 | 355 | for l in range(self.n_layers - 1): 356 | grad_w[l] = ( 357 | self.vars[-1] * (1 / batch_size) * errors[l + 1] @ F.f(x[l], self.act_fn).T 358 | - self.d_rate * self.W[l] 359 | ) 360 | grad_b[l] = self.vars[-1] * (1 / batch_size) * torch.sum(errors[l + 1], axis=1) 361 | 362 | x = x[::-1] 363 | for l in range(self.n_layers - 1): 364 | grad_w_q[l] = ( 365 | self.vars[-1] * (1 / batch_size) * q_errors[l + 1] @ F.f(x[l], self.act_fn).T 366 | - self.d_rate * self.W_q[l] 367 | ) 368 | grad_b_q[l] = self.vars[-1] * (1 / batch_size) * torch.sum(q_errors[l + 1], axis=1) 369 | 370 | self._apply_gradients( 371 | grad_w, 372 | grad_b, 373 | grad_w_q, 374 | grad_b_q, 375 | epoch_num=epoch_num, 376 | n_batches=n_batches, 377 | curr_batch=curr_batch, 378 | ) 379 | 380 | def _apply_gradients( 381 | self, grad_w, grad_b, grad_w_q, grad_b_q, epoch_num=None, n_batches=None, curr_batch=None 382 | ): 383 | 384 | if self.optim is "RMSPROP": 385 | for l in range(self.n_layers - 1): 386 | grad_b[l] = grad_b[l].unsqueeze(dim=1) 387 | 388 | self.c_w[l] = self.decay_r * self.c_w[l] + (1 - self.decay_r) * grad_w[l] ** 2 389 | self.c_b[l] = self.decay_r * self.c_b[l] + (1 - self.decay_r) * grad_b[l] ** 2 390 | 391 | self.W[l] = self.W[l] + self.l_rate * (grad_w[l] / (torch.sqrt(self.c_w[l]) + self.eps)) 392 | self.b[l] = self.b[l] + self.l_rate * (grad_b[l] / (torch.sqrt(self.c_b[l]) + self.eps)) 393 | 394 | self.c_w_q[l] = self.decay_r * self.c_w[l] + (1 - self.decay_r) * grad_w[l] ** 2 395 | self.c_b_q[l] = self.decay_r * self.c_b[l] + (1 - self.decay_r) * grad_b[l] ** 2 396 | 397 | self.W_q[l] = self.W_q[l] + self.q_l_rate * (grad_w_1[l] / (torch.sqrt(self.c_w_q[l]) + self.eps)) 398 | self.b_q[l] = self.b_q[l] + self.q_l_rate * (grad_b_q[l] / (torch.sqrt(self.c_b_q[l]) + self.eps)) 399 | 400 | 401 | elif self.optim is "ADAM": 402 | for l in range(self.n_layers - 1): 403 | grad_b[l] = grad_b[l].unsqueeze(dim=1) 404 | self.c_b[l] = self.beta_1 * self.c_b[l] + (1 - self.beta_1) * grad_b[l] 405 | self.c_w[l] = self.beta_1 * self.c_w[l] + (1 - self.beta_1) * grad_w[l] 406 | 407 | self.v_b[l] = self.beta_2 * self.v_b[l] + (1 - self.beta_2) * grad_b[l] ** 2 408 | self.v_w[l] = self.beta_2 * self.v_w[l] + (1 - self.beta_2) * grad_w[l] ** 2 409 | 410 | t = (epoch_num) * n_batches + curr_batch 411 | self.W[l] = self.W[l] + self.l_rate * np.sqrt(1 - self.beta_2 ** t) * self.c_w[l] / ( 412 | torch.sqrt(self.v_w[l]) + self.eps 413 | ) 414 | self.b[l] = self.b[l] + self.l_rate * np.sqrt(1 - self.beta_2 ** t) * self.c_b[l] / ( 415 | torch.sqrt(self.v_b[l]) + self.eps 416 | ) 417 | 418 | grad_b_q[l] = grad_b_q[l].unsqueeze(dim=1) 419 | self.c_b_q[l] = self.beta_1 * self.c_b_q[l] + (1 - self.beta_1) * grad_b_q[l] 420 | self.c_w_q[l] = self.beta_1 * self.c_w_q[l] + (1 - self.beta_1) * grad_w_q[l] 421 | 422 | self.v_b_q[l] = self.beta_2 * self.v_b_q[l] + (1 - self.beta_2) * grad_b_q[l] ** 2 423 | self.v_w_q[l] = self.beta_2 * self.v_w_q[l] + (1 - self.beta_2) * grad_w_q[l] ** 2 424 | 425 | t = (epoch_num) * n_batches + curr_batch 426 | self.W_q[l] = self.W_q[l] + self.q_l_rate * np.sqrt(1 - self.beta_2 ** t) * self.c_w_q[l] / ( 427 | torch.sqrt(self.v_w_q[l]) + self.eps 428 | ) 429 | self.b_q[l] = self.b_q[l] + self.q_l_rate * np.sqrt(1 - self.beta_2 ** t) * self.c_b_q[l] / ( 430 | torch.sqrt(self.v_b_q[l]) + self.eps 431 | ) 432 | 433 | elif self.optim is "SGD" or self.optim is None: 434 | for l in range(self.n_layers - 1): 435 | self.W[l] = self.W[l] + self.l_rate * grad_w[l] 436 | self.b[l] = self.b[l] + self.l_rate * grad_b[l].unsqueeze(dim=1) 437 | 438 | self.W_q[l] = self.W_q[l] + self.q_l_rate * grad_w_q[l] 439 | self.b_q[l] = self.b_q[l] + self.q_l_rate * grad_b_q[l].unsqueeze(dim=1) 440 | 441 | else: 442 | raise ValueError(f"{self.optim} not supported") 443 | 444 | 445 | def _init_params(self): 446 | """ 447 | Top down 448 | """ 449 | 450 | weights = [[] for _ in range(self.n_layers - 1)] 451 | bias = [[] for _ in range(self.n_layers - 1)] 452 | 453 | for l in range(self.n_layers - 1): 454 | norm_b = 0 455 | if self.act_fn is F.LINEAR: 456 | norm_w = np.sqrt(1 / (self.td_neurons[l + 1] + self.td_neurons[l])) 457 | elif self.act_fn is F.TANH: 458 | norm_w = np.sqrt(6 / (self.td_neurons[l + 1] + self.td_neurons[l])) 459 | elif self.act_fn is F.LOGSIG: 460 | norm_w = 4 * np.sqrt(6 / (self.td_neurons[l + 1] + self.td_neurons[l])) 461 | else: 462 | raise ValueError(f"{self.act_fn} not supported") 463 | 464 | layer_w = np.random.uniform(-1, 1, size=(self.td_neurons[l + 1], self.td_neurons[l])) * norm_w 465 | layer_b = np.zeros((self.td_neurons[l + 1], 1)) + norm_b * np.ones((self.td_neurons[l + 1], 1)) 466 | weights[l] = set_tensor(layer_w, self.device) 467 | bias[l] = set_tensor(layer_b, self.device) 468 | 469 | self.W = weights 470 | self.b = bias 471 | 472 | for l in range(self.n_layers - 1): 473 | self.c_b[l] = torch.zeros_like(self.b[l]) 474 | self.c_w[l] = torch.zeros_like(self.W[l]) 475 | self.v_b[l] = torch.zeros_like(self.b[l]) 476 | self.v_w[l] = torch.zeros_like(self.W[l]) 477 | 478 | """ 479 | Bottom up 480 | """ 481 | 482 | q_weights = [[] for _ in range(self.n_layers - 1)] 483 | q_bias = [[] for _ in range(self.n_layers - 1)] 484 | 485 | for l in range(self.n_layers - 1): 486 | norm_b = 0 487 | if self.act_fn is F.LINEAR: 488 | norm_w = np.sqrt(1 / (self.bu_neurons[l + 1] + self.bu_neurons[l])) 489 | elif self.act_fn is F.TANH: 490 | norm_w = np.sqrt(6 / (self.bu_neurons[l + 1] + self.bu_neurons[l])) 491 | elif self.act_fn is F.LOGSIG: 492 | norm_w = 4 * np.sqrt(6 / (self.bu_neurons[l + 1] + self.bu_neurons[l])) 493 | else: 494 | raise ValueError(f"{self.act_fn} not supported") 495 | 496 | layer_w = np.random.uniform(-1, 1, size=(self.bu_neurons[l + 1], self.bu_neurons[l])) * norm_w 497 | layer_b = np.zeros((self.bu_neurons[l + 1], 1)) + norm_b * np.ones((self.bu_neurons[l + 1], 1)) 498 | q_weights[l] = set_tensor(layer_w, self.device) 499 | q_bias[l] = set_tensor(layer_b, self.device) 500 | 501 | self.W_q = q_weights 502 | self.b_q = q_bias 503 | 504 | for l in range(self.n_layers - 1): 505 | self.c_b_q[l] = torch.zeros_like(self.b_q[l]) 506 | self.c_w_q[l] = torch.zeros_like(self.W_q[l]) 507 | self.v_b_q[l] = torch.zeros_like(self.b_q[l]) 508 | self.v_w_q[l] = torch.zeros_like(self.W_q[l]) 509 | -------------------------------------------------------------------------------- /q_script.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=not-callable 2 | # pylint: disable=no-member 3 | 4 | import numpy as np 5 | import torch 6 | 7 | import mnist_utils 8 | import functions as F 9 | from q_network import QCodingNetwork 10 | 11 | """ Update amortised in iterations? """ 12 | 13 | 14 | class AttrDict(dict): 15 | __setattr__ = dict.__setitem__ 16 | __getattr__ = dict.__getitem__ 17 | 18 | 19 | def main(cf): 20 | print(f"device [{cf.device}]") 21 | print("loading MNIST data...") 22 | train_set = mnist_utils.get_mnist_train_set() 23 | test_set = mnist_utils.get_mnist_test_set() 24 | 25 | img_train = mnist_utils.get_imgs(train_set) 26 | img_test = mnist_utils.get_imgs(test_set) 27 | label_train = mnist_utils.get_labels(train_set) 28 | label_test = mnist_utils.get_labels(test_set) 29 | 30 | if cf.data_size is not None: 31 | test_size = cf.data_size // 5 32 | img_train = img_train[:, 0 : cf.data_size] 33 | label_train = label_train[:, 0 : cf.data_size] 34 | img_test = img_test[:, 0:test_size] 35 | label_test = label_test[:, 0:test_size] 36 | 37 | msg = "img_train {} img_test {} label_train {} label_test {}" 38 | print(msg.format(img_train.shape, img_test.shape, label_train.shape, label_test.shape)) 39 | 40 | print("performing preprocessing...") 41 | if cf.apply_scaling: 42 | img_train = mnist_utils.scale_imgs(img_train, cf.img_scale) 43 | img_test = mnist_utils.scale_imgs(img_test, cf.img_scale) 44 | label_train = mnist_utils.scale_labels(label_train, cf.label_scale) 45 | label_test = mnist_utils.scale_labels(label_test, cf.label_scale) 46 | 47 | if cf.apply_inv: 48 | img_train = F.f_inv(img_train, cf.act_fn) 49 | img_test = F.f_inv(img_test, cf.act_fn) 50 | 51 | model = QCodingNetwork(cf) 52 | 53 | q_accs = [] 54 | h_accs = [] 55 | p_accs = [] 56 | 57 | with torch.no_grad(): 58 | for epoch in range(cf.n_epochs): 59 | print(f"\nepoch {epoch}") 60 | 61 | img_batches, label_batches = mnist_utils.get_batches(img_train, label_train, cf.batch_size) 62 | print(f"> training on {len(img_batches)} batches of size {cf.batch_size}") 63 | end_err, init_err, its = model.train_epoch(img_batches, label_batches, epoch_num=epoch) 64 | print("end_err {} / init_err {} / its {}".format(end_err, init_err, its)) 65 | 66 | if epoch % cf.test_every == 0: 67 | img_batches, label_batches = mnist_utils.get_batches(img_test, label_test, cf.batch_size) 68 | print("> generating images...") 69 | pred_imgs = model.generate_data(label_batches[0]) 70 | mnist_utils.plot_imgs(pred_imgs, cf.img_path.format(epoch)) 71 | 72 | if cf.amortised: 73 | img_batches, label_batches = mnist_utils.get_batches(img_test, label_test, cf.batch_size) 74 | print(f"> testing amortised acc {len(img_batches)} batches of size {cf.batch_size}") 75 | accs = model.test_amortised_epoch(img_batches, label_batches) 76 | mean_q_acc = np.mean(np.array(accs)) 77 | q_accs.append(mean_q_acc) 78 | print(f"average amortised accuracy {mean_q_acc}") 79 | 80 | img_batches, label_batches = mnist_utils.get_batches(img_test, label_test, cf.batch_size) 81 | print(f"> testing hybrid acc on {len(img_batches)} batches of size {cf.batch_size}") 82 | accs, its = model.test_epoch(img_batches, label_batches, itr_max=cf.test_itr_max) 83 | mean_h_acc = np.mean(np.array(accs)) 84 | h_accs.append(mean_h_acc) 85 | print(f"average hybrid accuracy {mean_h_acc} / its {its}") 86 | 87 | 88 | img_batches, label_batches = mnist_utils.get_batches(img_test, label_test, cf.batch_size) 89 | print(f"> testing PC acc on {len(img_batches)} batches of size {cf.batch_size}") 90 | accs, its = model.test_pc_epoch(img_batches, label_batches, itr_max=cf.test_itr_max) 91 | mean_p_acc = np.mean(np.array(accs)) 92 | p_accs.append(mean_p_acc) 93 | print(f"average PC accuracy {mean_p_acc} / its {its}") 94 | 95 | np.save(cf.hybird_path, h_accs) 96 | np.save(cf.amortised_path, q_accs) 97 | np.save(cf.pc_path, p_accs) 98 | 99 | perm = np.random.permutation(img_train.shape[1]) 100 | img_train = img_train[:, perm] 101 | label_train = label_train[:, perm] 102 | 103 | 104 | 105 | 106 | if __name__ == "__main__": 107 | cf = AttrDict() 108 | 109 | cf.amortised = True 110 | 111 | cf.img_path = "imgs/epoch_{}.png" 112 | cf.hybird_path = "data/h_accs_3" 113 | cf.amortised_path = "data/q_accs_3" 114 | cf.pc_path = "data/pc_accs_2" 115 | cf.test_every = 1 116 | 117 | cf.n_epochs = 100 118 | cf.data_size = None 119 | cf.batch_size = 128 120 | 121 | cf.apply_inv = True 122 | cf.apply_scaling = True 123 | cf.label_scale = 0.94 124 | cf.img_scale = 1.0 125 | 126 | cf.neurons = [10, 500, 500, 784] 127 | cf.n_layers = len(cf.neurons) 128 | cf.act_fn = F.TANH 129 | cf.var_out = 1 130 | cf.vars = torch.ones(cf.n_layers) 131 | 132 | # TODO 133 | cf.itr_max = 50 134 | cf.test_itr_max = 50 135 | # TODO 136 | cf.beta = 0.1 137 | cf.div = 2 138 | cf.condition = 1e-6 139 | cf.d_rate = 0 140 | 141 | # optim parameters 142 | cf.l_rate = 1e-5 143 | cf.q_l_rate = 1e-5 144 | # TODO 145 | cf.optim = "ADAM" 146 | cf.eps = 1e-8 147 | cf.decay_r = 0.9 148 | cf.beta_1 = 0.9 149 | cf.beta_2 = 0.999 150 | 151 | cf.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 152 | main(cf) 153 | 154 | -------------------------------------------------------------------------------- /q_script_v2.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=not-callable 2 | # pylint: disable=no-member 3 | 4 | import numpy as np 5 | import torch 6 | 7 | import mnist_utils 8 | import functions as F 9 | from q_network_v2 import QCodingNetwork 10 | 11 | """ Update amortised in iterations? """ 12 | 13 | 14 | class AttrDict(dict): 15 | __setattr__ = dict.__setitem__ 16 | __getattr__ = dict.__getitem__ 17 | 18 | 19 | def main(cf): 20 | print(f"device [{cf.device}]") 21 | print("loading MNIST data...") 22 | train_set = mnist_utils.get_mnist_train_set() 23 | test_set = mnist_utils.get_mnist_test_set() 24 | 25 | img_train = mnist_utils.get_imgs(train_set) 26 | img_test = mnist_utils.get_imgs(test_set) 27 | label_train = mnist_utils.get_labels(train_set) 28 | label_test = mnist_utils.get_labels(test_set) 29 | 30 | if cf.data_size is not None: 31 | test_size = cf.data_size // 5 32 | img_train = img_train[:, 0 : cf.data_size] 33 | label_train = label_train[:, 0 : cf.data_size] 34 | img_test = img_test[:, 0:test_size] 35 | label_test = label_test[:, 0:test_size] 36 | 37 | msg = "img_train {} img_test {} label_train {} label_test {}" 38 | print(msg.format(img_train.shape, img_test.shape, label_train.shape, label_test.shape)) 39 | 40 | print("performing preprocessing...") 41 | if cf.apply_scaling: 42 | img_train = mnist_utils.scale_imgs(img_train, cf.img_scale) 43 | img_test = mnist_utils.scale_imgs(img_test, cf.img_scale) 44 | label_train = mnist_utils.scale_labels(label_train, cf.label_scale) 45 | label_test = mnist_utils.scale_labels(label_test, cf.label_scale) 46 | 47 | if cf.apply_inv: 48 | img_train = F.f_inv(img_train, cf.act_fn) 49 | img_test = F.f_inv(img_test, cf.act_fn) 50 | 51 | model = QCodingNetwork(cf) 52 | 53 | q_accs = [] 54 | h_accs = [] 55 | p_accs = [] 56 | 57 | with torch.no_grad(): 58 | for epoch in range(cf.n_epochs): 59 | print(f"\nepoch {epoch}") 60 | 61 | img_batches, label_batches = mnist_utils.get_batches(img_train, label_train, cf.batch_size) 62 | print(f"> training on {len(img_batches)} batches of size {cf.batch_size}") 63 | end_err, init_err, its = model.train_epoch(img_batches, label_batches, epoch_num=epoch) 64 | print("end_err {} / init_err {} / its {}".format(end_err, init_err, its)) 65 | 66 | if epoch % cf.test_every == 0: 67 | img_batches, label_batches = mnist_utils.get_batches(img_test, label_test, cf.batch_size) 68 | print("> generating images...") 69 | pred_imgs = model.generate_data(label_batches[0]) 70 | mnist_utils.plot_imgs(pred_imgs, cf.img_path.format(epoch)) 71 | 72 | img_batches, label_batches = mnist_utils.get_batches(img_test, label_test, cf.batch_size) 73 | print(f"> testing amortised acc {len(img_batches)} batches of size {cf.batch_size}") 74 | accs = model.test_amortised_epoch(img_batches, label_batches) 75 | mean_q_acc = np.mean(np.array(accs)) 76 | q_accs.append(mean_q_acc) 77 | print(f"average amortised accuracy {mean_q_acc}") 78 | 79 | img_batches, label_batches = mnist_utils.get_batches(img_test, label_test, cf.batch_size) 80 | print(f"> testing hybrid acc on {len(img_batches)} batches of size {cf.batch_size}") 81 | accs, its = model.test_epoch(img_batches, label_batches, itr_max=cf.test_itr_max) 82 | mean_h_acc = np.mean(np.array(accs)) 83 | h_accs.append(mean_h_acc) 84 | print(f"average hybrid accuracy {mean_h_acc} / its {its}") 85 | 86 | 87 | img_batches, label_batches = mnist_utils.get_batches(img_test, label_test, cf.batch_size) 88 | print(f"> testing PC acc on {len(img_batches)} batches of size {cf.batch_size}") 89 | accs, its = model.test_pc_epoch(img_batches, label_batches, itr_max=cf.test_itr_max) 90 | mean_p_acc = np.mean(np.array(accs)) 91 | p_accs.append(mean_p_acc) 92 | print(f"average PC accuracy {mean_p_acc} / its {its}") 93 | 94 | np.save(cf.hybird_path, h_accs) 95 | np.save(cf.amortised_path, q_accs) 96 | np.save(cf.pc_path, p_accs) 97 | 98 | perm = np.random.permutation(img_train.shape[1]) 99 | img_train = img_train[:, perm] 100 | label_train = label_train[:, perm] 101 | 102 | 103 | 104 | 105 | if __name__ == "__main__": 106 | cf = AttrDict() 107 | 108 | cf.img_path = "imgs/epoch_{}.png" 109 | cf.hybird_path = "data/h_accs_3" 110 | cf.amortised_path = "data/q_accs_3" 111 | cf.pc_path = "data/pc_accs_2" 112 | cf.test_every = 1 113 | 114 | cf.n_epochs = 100 115 | cf.data_size = None 116 | cf.batch_size = 128 117 | 118 | cf.apply_inv = True 119 | cf.apply_scaling = True 120 | cf.label_scale = 0.94 121 | cf.img_scale = 1.0 122 | 123 | cf.td_neurons = [10, 500, 500, 784] 124 | cf.bu_neurons = [784, 500, 500, 10] 125 | cf.n_layers = len(cf.td_neurons) 126 | cf.act_fn = F.TANH 127 | cf.var_out = 1 128 | cf.vars = torch.ones(cf.n_layers) 129 | 130 | # TODO 131 | cf.itr_max = 50 132 | cf.test_itr_max = 50 133 | # TODO 134 | cf.beta = 0.1 135 | cf.div = 2 136 | cf.condition = 1e-6 137 | cf.d_rate = 0 138 | 139 | # optim parameters 140 | cf.l_rate = 1e-4 141 | cf.q_l_rate = 1e-4 142 | # TODO 143 | cf.optim = "ADAM" 144 | cf.eps = 1e-8 145 | cf.decay_r = 0.9 146 | cf.beta_1 = 0.9 147 | cf.beta_2 = 0.999 148 | 149 | cf.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 150 | main(cf) 151 | 152 | -------------------------------------------------------------------------------- /q_script_v3.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=not-callable 2 | # pylint: disable=no-member 3 | 4 | import numpy as np 5 | import torch 6 | 7 | import mnist_utils 8 | import functions as F 9 | from q_network_v3 import QCodingNetwork 10 | 11 | """ 12 | Precision term which weights influence - in theory a learnable parameter? 13 | """ 14 | 15 | class AttrDict(dict): 16 | __setattr__ = dict.__setitem__ 17 | __getattr__ = dict.__getitem__ 18 | 19 | 20 | def main(cf): 21 | print(f"device [{cf.device}]") 22 | print("loading MNIST data...") 23 | train_set = mnist_utils.get_mnist_train_set() 24 | test_set = mnist_utils.get_mnist_test_set() 25 | 26 | img_train = mnist_utils.get_imgs(train_set) 27 | img_test = mnist_utils.get_imgs(test_set) 28 | label_train = mnist_utils.get_labels(train_set) 29 | label_test = mnist_utils.get_labels(test_set) 30 | 31 | if cf.data_size is not None: 32 | test_size = cf.data_size // 5 33 | img_train = img_train[:, 0 : cf.data_size] 34 | label_train = label_train[:, 0 : cf.data_size] 35 | img_test = img_test[:, 0:test_size] 36 | label_test = label_test[:, 0:test_size] 37 | 38 | msg = "img_train {} img_test {} label_train {} label_test {}" 39 | print(msg.format(img_train.shape, img_test.shape, label_train.shape, label_test.shape)) 40 | 41 | print("performing preprocessing...") 42 | if cf.apply_scaling: 43 | img_train = mnist_utils.scale_imgs(img_train, cf.img_scale) 44 | img_test = mnist_utils.scale_imgs(img_test, cf.img_scale) 45 | label_train = mnist_utils.scale_labels(label_train, cf.label_scale) 46 | label_test = mnist_utils.scale_labels(label_test, cf.label_scale) 47 | 48 | if cf.apply_inv: 49 | img_train = F.f_inv(img_train, cf.act_fn) 50 | img_test = F.f_inv(img_test, cf.act_fn) 51 | 52 | model = QCodingNetwork(cf) 53 | 54 | q_accs = [] 55 | h_accs = [] 56 | p_accs = [] 57 | init_errs = [] 58 | end_errs = [] 59 | 60 | with torch.no_grad(): 61 | for epoch in range(cf.n_epochs): 62 | print(f"\nepoch {epoch}") 63 | 64 | img_batches, label_batches = mnist_utils.get_batches(img_train, label_train, cf.batch_size) 65 | print(f"> training on {len(img_batches)} batches of size {cf.batch_size}") 66 | end_err, init_err, its = model.train_epoch(img_batches, label_batches, epoch_num=epoch) 67 | print("end_err {} / init_err {} / its {}".format(end_err, init_err, its)) 68 | init_errs.append(init_err) 69 | end_errs.append(end_err) 70 | 71 | if epoch % cf.test_every == 0: 72 | img_batches, label_batches = mnist_utils.get_batches(img_test, label_test, cf.batch_size) 73 | print("> generating images...") 74 | pred_imgs = model.generate_data(label_batches[0]) 75 | mnist_utils.plot_imgs(pred_imgs, cf.img_path.format(epoch)) 76 | 77 | img_batches, label_batches = mnist_utils.get_batches(img_test, label_test, cf.batch_size) 78 | print(f"> testing hybrid acc on {len(img_batches)} batches of size {cf.batch_size}") 79 | accs, its = model.test_epoch(img_batches, label_batches, itr_max=cf.test_itr_max) 80 | mean_h_acc = np.mean(np.array(accs)) 81 | h_accs.append(mean_h_acc) 82 | print(f"average hybrid accuracy {mean_h_acc} / its {its}") 83 | 84 | img_batches, label_batches = mnist_utils.get_batches(img_test, label_test, cf.batch_size) 85 | print(f"> testing amortised acc {len(img_batches)} batches of size {cf.batch_size}") 86 | accs = model.test_amortised_epoch(img_batches, label_batches) 87 | mean_q_acc = np.mean(np.array(accs)) 88 | q_accs.append(mean_q_acc) 89 | print(f"average amortised accuracy {mean_q_acc}") 90 | 91 | img_batches, label_batches = mnist_utils.get_batches(img_test, label_test, cf.batch_size) 92 | print(f"> testing PC acc on {len(img_batches)} batches of size {cf.batch_size}") 93 | accs, its = model.test_pc_epoch(img_batches, label_batches, itr_max=cf.test_itr_max) 94 | mean_p_acc = np.mean(np.array(accs)) 95 | p_accs.append(mean_p_acc) 96 | print(f"average PC accuracy {mean_p_acc} / its {its}") 97 | 98 | np.save(cf.hybird_path, h_accs) 99 | np.save(cf.amortised_path, q_accs) 100 | np.save(cf.pc_path, p_accs) 101 | np.save(cf.init_errs_path, init_errs) 102 | np.save(cf.end_errs_path, end_errs) 103 | 104 | perm = np.random.permutation(img_train.shape[1]) 105 | img_train = img_train[:, perm] 106 | label_train = label_train[:, perm] 107 | 108 | 109 | 110 | 111 | if __name__ == "__main__": 112 | cf = AttrDict() 113 | 114 | cf.img_path = "imgs/epoch_{}.png" 115 | cf.hybird_path = "data/h_accs_5" 116 | cf.amortised_path = "data/q_accs_5" 117 | cf.pc_path = "data/pc_accs_5" 118 | cf.init_errs_path = "data/init_errs_5" 119 | cf.end_errs_path = "data/end_errs_5" 120 | cf.test_every = 1 121 | 122 | cf.n_epochs = 100 123 | cf.data_size = None 124 | cf.batch_size = 128 125 | 126 | cf.apply_inv = True 127 | cf.apply_scaling = True 128 | cf.label_scale = 0.94 129 | cf.img_scale = 1.0 130 | 131 | cf.td_neurons = [10, 500, 500, 784] 132 | cf.bu_neurons = [784, 500, 500, 10] 133 | cf.n_layers = len(cf.td_neurons) 134 | cf.act_fn = F.TANH 135 | cf.var_out = 1 136 | cf.vars = torch.ones(cf.n_layers) 137 | 138 | # TODO 139 | cf.itr_max = 50 140 | cf.test_itr_max = 50 141 | # TODO change stuff here 142 | cf.amortised_prec = 0.1 143 | cf.generative_prec = 0.1 144 | cf.beta = 0.1 145 | cf.div = 2 146 | # TODO TODO 147 | cf.condition = 1e-6 148 | cf.d_rate = 0 149 | 150 | # optim parameters 151 | cf.l_rate = 1e-5 152 | # TODO q_l_rate low? 153 | cf.q_l_rate = 1e-5 154 | # TODO 155 | cf.optim = "ADAM" 156 | cf.eps = 1e-8 157 | cf.decay_r = 0.9 158 | cf.beta_1 = 0.9 159 | cf.beta_2 = 0.999 160 | 161 | cf.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 162 | main(cf) 163 | 164 | -------------------------------------------------------------------------------- /q_script_v4.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=not-callable 2 | # pylint: disable=no-member 3 | 4 | import numpy as np 5 | import torch 6 | 7 | import mnist_utils 8 | import functions as F 9 | from q_network_v3 import QCodingNetwork 10 | 11 | """ 12 | Precision term which weights influence - in theory a learnable parameter? 13 | """ 14 | 15 | class AttrDict(dict): 16 | __setattr__ = dict.__setitem__ 17 | __getattr__ = dict.__getitem__ 18 | 19 | 20 | def main(cf): 21 | print(f"device [{cf.device}]") 22 | print("loading MNIST data...") 23 | train_set = mnist_utils.get_mnist_train_set() 24 | test_set = mnist_utils.get_mnist_test_set() 25 | 26 | img_train = mnist_utils.get_imgs(train_set) 27 | img_test = mnist_utils.get_imgs(test_set) 28 | label_train = mnist_utils.get_labels(train_set) 29 | label_test = mnist_utils.get_labels(test_set) 30 | 31 | if cf.data_size is not None: 32 | test_size = cf.data_size // 5 33 | img_train = img_train[:, 0 : cf.data_size] 34 | label_train = label_train[:, 0 : cf.data_size] 35 | img_test = img_test[:, 0:test_size] 36 | label_test = label_test[:, 0:test_size] 37 | 38 | msg = "img_train {} img_test {} label_train {} label_test {}" 39 | print(msg.format(img_train.shape, img_test.shape, label_train.shape, label_test.shape)) 40 | 41 | print("performing preprocessing...") 42 | if cf.apply_scaling: 43 | img_train = mnist_utils.scale_imgs(img_train, cf.img_scale) 44 | img_test = mnist_utils.scale_imgs(img_test, cf.img_scale) 45 | label_train = mnist_utils.scale_labels(label_train, cf.label_scale) 46 | label_test = mnist_utils.scale_labels(label_test, cf.label_scale) 47 | 48 | if cf.apply_inv: 49 | img_train = F.f_inv(img_train, cf.act_fn) 50 | img_test = F.f_inv(img_test, cf.act_fn) 51 | 52 | model = QCodingNetwork(cf) 53 | 54 | q_accs = [] 55 | h_accs = [] 56 | p_accs = [] 57 | init_errs = [] 58 | end_errs = [] 59 | 60 | with torch.no_grad(): 61 | for epoch in range(cf.n_epochs): 62 | print(f"\nepoch {epoch}") 63 | 64 | img_batches, label_batches = mnist_utils.get_batches(img_train, label_train, cf.batch_size) 65 | print(f"> training on {len(img_batches)} batches of size {cf.batch_size}") 66 | end_err, init_err, its = model.train_epoch(img_batches, label_batches, epoch_num=epoch) 67 | print("end_err {} / init_err {} / its {}".format(end_err, init_err, its)) 68 | init_errs.append(init_err) 69 | end_errs.append(end_err) 70 | 71 | if epoch % cf.test_every == 0: 72 | img_batches, label_batches = mnist_utils.get_batches(img_test, label_test, cf.batch_size) 73 | print("> generating images...") 74 | pred_imgs = model.generate_data(label_batches[0]) 75 | mnist_utils.plot_imgs(pred_imgs, cf.img_path.format(epoch)) 76 | 77 | img_batches, label_batches = mnist_utils.get_batches(img_test, label_test, cf.batch_size) 78 | print(f"> testing hybrid acc on {len(img_batches)} batches of size {cf.batch_size}") 79 | accs, its = model.test_epoch(img_batches, label_batches, itr_max=cf.test_itr_max) 80 | mean_h_acc = np.mean(np.array(accs)) 81 | h_accs.append(mean_h_acc) 82 | print(f"average hybrid accuracy {mean_h_acc} / its {its}") 83 | 84 | img_batches, label_batches = mnist_utils.get_batches(img_test, label_test, cf.batch_size) 85 | print(f"> testing amortised acc {len(img_batches)} batches of size {cf.batch_size}") 86 | accs = model.test_amortised_epoch(img_batches, label_batches) 87 | mean_q_acc = np.mean(np.array(accs)) 88 | q_accs.append(mean_q_acc) 89 | print(f"average amortised accuracy {mean_q_acc}") 90 | 91 | img_batches, label_batches = mnist_utils.get_batches(img_test, label_test, cf.batch_size) 92 | print(f"> testing PC acc on {len(img_batches)} batches of size {cf.batch_size}") 93 | accs, its = model.test_pc_epoch(img_batches, label_batches, itr_max=cf.test_itr_max) 94 | mean_p_acc = np.mean(np.array(accs)) 95 | p_accs.append(mean_p_acc) 96 | print(f"average PC accuracy {mean_p_acc} / its {its}") 97 | 98 | np.save(cf.hybird_path, h_accs) 99 | np.save(cf.amortised_path, q_accs) 100 | np.save(cf.pc_path, p_accs) 101 | np.save(cf.init_errs_path, init_errs) 102 | np.save(cf.end_errs_path, end_errs) 103 | 104 | perm = np.random.permutation(img_train.shape[1]) 105 | img_train = img_train[:, perm] 106 | label_train = label_train[:, perm] 107 | 108 | 109 | 110 | 111 | if __name__ == "__main__": 112 | cf = AttrDict() 113 | 114 | cf.img_path = "imgs/epoch_{}.png" 115 | cf.hybird_path = "data/h_accs_6" 116 | cf.amortised_path = "data/q_accs_6" 117 | cf.pc_path = "data/pc_accs_6" 118 | cf.init_errs_path = "data/init_errs_6" 119 | cf.end_errs_path = "data/end_errs_6" 120 | cf.test_every = 1 121 | 122 | cf.n_epochs = 100 123 | cf.data_size = None 124 | cf.batch_size = 128 125 | 126 | cf.apply_inv = True 127 | cf.apply_scaling = True 128 | cf.label_scale = 0.94 129 | cf.img_scale = 1.0 130 | 131 | cf.td_neurons = [10, 500, 500, 784] 132 | cf.bu_neurons = [784, 500, 500, 10] 133 | cf.n_layers = len(cf.td_neurons) 134 | cf.act_fn = F.TANH 135 | cf.var_out = 1 136 | cf.vars = torch.ones(cf.n_layers) 137 | 138 | # TODO 139 | cf.itr_max = 1000 140 | cf.test_itr_max = 1000 141 | # TODO change stuff here 142 | cf.amortised_prec = 0.1 143 | cf.generative_prec = 0.1 144 | cf.beta = 0.1 145 | cf.div = 2 146 | # TODO TODO 147 | cf.condition = 1e-6 148 | cf.d_rate = 0 149 | 150 | # optim parameters 151 | cf.l_rate = 1e-5 152 | # TODO q_l_rate low? 153 | cf.q_l_rate = 1e-5 154 | # TODO 155 | cf.optim = "ADAM" 156 | cf.eps = 1e-8 157 | cf.decay_r = 0.9 158 | cf.beta_1 = 0.9 159 | cf.beta_2 = 0.999 160 | 161 | cf.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 162 | main(cf) 163 | 164 | -------------------------------------------------------------------------------- /script.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=not-callable 2 | # pylint: disable=no-member 3 | 4 | import numpy as np 5 | import torch 6 | 7 | import mnist_utils 8 | import functions as F 9 | from network import PredictiveCodingNetwork 10 | 11 | 12 | class AttrDict(dict): 13 | __setattr__ = dict.__setitem__ 14 | __getattr__ = dict.__getitem__ 15 | 16 | 17 | def main(cf): 18 | print(f"device [{cf.device}]") 19 | print("loading MNIST data...") 20 | train_set = mnist_utils.get_mnist_train_set() 21 | test_set = mnist_utils.get_mnist_test_set() 22 | 23 | img_train = mnist_utils.get_imgs(train_set) 24 | img_test = mnist_utils.get_imgs(test_set) 25 | label_train = mnist_utils.get_labels(train_set) 26 | label_test = mnist_utils.get_labels(test_set) 27 | 28 | if cf.data_size is not None: 29 | test_size = cf.data_size // 5 30 | img_train = img_train[:, 0 : cf.data_size] 31 | label_train = label_train[:, 0 : cf.data_size] 32 | img_test = img_test[:, 0:test_size] 33 | label_test = label_test[:, 0:test_size] 34 | 35 | msg = "img_train {} img_test {} label_train {} label_test {}" 36 | print(msg.format(img_train.shape, img_test.shape, label_train.shape, label_test.shape)) 37 | 38 | print("performing preprocessing...") 39 | if cf.apply_scaling: 40 | img_train = mnist_utils.scale_imgs(img_train, cf.img_scale) 41 | img_test = mnist_utils.scale_imgs(img_test, cf.img_scale) 42 | label_train = mnist_utils.scale_labels(label_train, cf.label_scale) 43 | label_test = mnist_utils.scale_labels(label_test, cf.label_scale) 44 | 45 | if cf.apply_inv: 46 | img_train = F.f_inv(img_train, cf.act_fn) 47 | img_test = F.f_inv(img_test, cf.act_fn) 48 | 49 | model = PredictiveCodingNetwork(cf) 50 | 51 | with torch.no_grad(): 52 | for epoch in range(cf.n_epochs): 53 | print(f"\nepoch {epoch}") 54 | 55 | img_batches, label_batches = mnist_utils.get_batches(img_train, label_train, cf.batch_size) 56 | print(f"training on {len(img_batches)} batches of size {cf.batch_size}") 57 | model.train_epoch(img_batches, label_batches, epoch_num=epoch) 58 | 59 | img_batches, label_batches = mnist_utils.get_batches(img_test, label_test, cf.batch_size) 60 | print(f"testing on {len(img_batches)} batches of size {cf.batch_size}") 61 | accs = model.test_epoch(img_batches, label_batches) 62 | print(f"average accuracy {np.mean(np.array(accs))}") 63 | 64 | perm = np.random.permutation(img_train.shape[1]) 65 | img_train = img_train[:, perm] 66 | label_train = label_train[:, perm] 67 | 68 | 69 | if __name__ == "__main__": 70 | cf = AttrDict() 71 | 72 | cf.n_epochs = 100 73 | cf.data_size = None 74 | cf.batch_size = 128 75 | 76 | cf.apply_inv = True 77 | cf.apply_scaling = True 78 | cf.label_scale = 0.94 79 | cf.img_scale = 1.0 80 | 81 | cf.neurons = [784, 500, 500, 10] 82 | cf.n_layers = len(cf.neurons) 83 | cf.act_fn = F.TANH 84 | cf.var_out = 1 85 | cf.vars = torch.ones(cf.n_layers) 86 | 87 | cf.itr_max = 50 88 | cf.beta = 0.1 89 | cf.div = 2 90 | cf.condition = 1e-6 91 | cf.d_rate = 0 92 | 93 | # optim parameters 94 | cf.l_rate = 1e-3 95 | cf.optim = "ADAM" 96 | cf.eps = 1e-8 97 | cf.decay_r = 0.9 98 | cf.beta_1 = 0.9 99 | cf.beta_2 = 0.999 100 | 101 | cf.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 102 | main(cf) 103 | 104 | --------------------------------------------------------------------------------