├── LICENSE ├── README.md ├── data.tar.xz ├── exp └── pff │ ├── kmnist │ └── tmp.txt │ └── mnist │ └── tmp.txt ├── fig └── pff_config.png └── src ├── analyze.sh ├── data_utils.py ├── eval_model.py ├── fit_gmm.py ├── pff_rnn.py ├── plot_tsne.py ├── run.sh ├── sample_model.py └── sim_train.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Alex Ororbia and Ankur Mali 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | #

The Predictive Forward-Forward Algorithm

2 | ##

Bio-plausible Forward-Only Learning for Training Neural Networks

3 | Implementation of the proposed predictive forward-forward (PFF) learning algorithm for training a neurobiologically-plausible recurrent neural system. This work combines elements of predictive coding with the recently proposed forward-forward algorithm to create a novel online learning process that involves dynamically adapting two neural circuits - a representation circuit and a generative circuit. Notably, the system introduces noise injection into the latent activity updates as well as learnable lateral synapses that induce competition across neural units (emulating cross-inhibitory and self-excitation 4 | effects inherent to neural computation). 5 | 6 | # Requirements 7 | Our implementation is easy to follow and, with knowledge of basic linear algebra, one can decode the inner workings of the PFF algorithm. Please look at Algorithm 1 in our paper (in the Appendix) to better understand the overall mechanics of the inference and learning processes. In this framework, we have provided simple modules; thus hopefully making it very convenient to extend our framework. 8 | 9 | To run the code, you should only need following basic packages: 10 | 1. TensorFlow (version >= 2.0) 11 | 2. Numpy 12 | 3. Matplotlib 13 | 4. Python (version >=3.5) 14 | 5. [ngc-learn](https://github.com/ago109/ngc-learn) (Some modules responsible for generating image samples are dependent on ngc-learn -- if you do not install ngc-learn, you won't be able to use `fit_gmm.py`, as this script uses the mixture model 15 | in that package to retro-fit the latent prior for the PFF model's generative circuit, which means that `sample_model.py` 16 | and `plot_tsne.py` will have no prior distribution model to access, so simply comment out the lines that 17 | involve `sample_model.py` and `plot_tsne.py` in the `analyze.sh` script if you do not install ngc-learn). 18 | 19 | # Execution 20 | 21 | To reproduce results from our paper, simply perform the following steps (running the relevant provided Bash scripts) the following provided Bash scripts: 22 | 1. `bash src/run.sh` (This will train model for `E=60` epochs.) 23 | 2. `bash src/analyze.sh` (This will evaluate a trained model and produce plots/visuals.) 24 | After running the above two scripts, you can find the simulation outputs in the example 25 | experimental results directory tree that have been pre-created for you. 26 | `exp/pff/mnist/` contains the results for the MNIST model (over 2 trials) and 27 | `exp/pff/kmnist/` contains the results for the KMNIST model (over 2 trials). 28 | In each directory, the following is stored: 29 | * `post_train_results.txt` - contains development/training cross-trial accuracy values 30 | * `test_results.txt` - contains test cross-trial accuracy values 31 | * `trial0/` - contains model data for trial 0, as well as any visuals produced by `analyze.sh` 32 | * `trial1` - contains model data for trial 1, as well as any visuals produced by `analyze.sh` 33 | (Note that you should modify the `MODEL_DIR` in `analyze.sh` to point to a particular 34 | trial's sub-directory -- the default points to trial 0, and thus only places images 35 | inside of the `trial0/` sub-directoy. ) 36 | 37 | Model-specific hyper-parameter defaults can be set/adjusted in `pff_rnn.py`. 38 | Training-specific hyper-parameters are available in `sim_train.py` - note that one 39 | can create/edit an arguments dictionary much like the one depicted below (inside of `sim_train.py`): 40 | 41 | 42 | 43 | which the `PFF_RNN()` constructor takes in as input to construct the simulation of 44 | the dual-circuit system. 45 | 46 | Tips while using this algorithm/model on your own datasets: 47 | 1. Track your local losses, accordingly adjust the hyper-parameters for the model 48 | 2. Play with non-zero, small values for the weight decay coefficients `reg_lambda` (for 49 | the representation circuit) and `g_reg_lambda` (for the generative circuit) - for 50 | K-MNIST a small value (as indicated in the comments) for `reg_lambda` seemed to 51 | improve generalization performance slightly in our experience. 52 | 53 | # Citation 54 | 55 | If you use or adapt (portions of) this code/algorithm in any form in your project(s), or 56 | find the PFF algorithm helpful in your own work, please cite this code's source paper: 57 | 58 | ```bibtex 59 | @article{ororbia2023predictive, 60 | title={The Predictive Forward-Forward Algorithm}, 61 | author={Ororbia, Alexander and Mali, Ankur}, 62 | journal={arXiv preprint arXiv:2301.01452}, 63 | year={2023} 64 | } 65 | ``` 66 | -------------------------------------------------------------------------------- /data.tar.xz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ago109/predictive-forward-forward/adeb918941afaafb11bc9f1b0953dae2d7dd1f13/data.tar.xz -------------------------------------------------------------------------------- /exp/pff/kmnist/tmp.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ago109/predictive-forward-forward/adeb918941afaafb11bc9f1b0953dae2d7dd1f13/exp/pff/kmnist/tmp.txt -------------------------------------------------------------------------------- /exp/pff/mnist/tmp.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ago109/predictive-forward-forward/adeb918941afaafb11bc9f1b0953dae2d7dd1f13/exp/pff/mnist/tmp.txt -------------------------------------------------------------------------------- /fig/pff_config.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ago109/predictive-forward-forward/adeb918941afaafb11bc9f1b0953dae2d7dd1f13/fig/pff_config.png -------------------------------------------------------------------------------- /src/analyze.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | GPU_ID=0 3 | N_TRIALS=2 4 | 5 | ## Analyze MNIST 6 | MODEL_TOPDIR="../exp/pff/mnist/" 7 | MODEL_DIR="../exp/pff/mnist/trial0/" ## <-- choose a specific trial idx to analyze 8 | DATA_DIR="../data/mnist/" 9 | 10 | echo " ---------- Evaluating MNIST Test Performance ---------- " 11 | python eval_model.py --data_dir=$DATA_DIR --split=test --model_topdir=$MODEL_TOPDIR --gpu_id=$GPU_ID --n_trials=$N_TRIALS --out_dir=$MODEL_TOPDIR 12 | 13 | # for prior distribution fitting / sampling 14 | echo " ---------- Fitting MNIST Prior ---------- " 15 | python fit_gmm.py --data_dir=$DATA_DIR --gpu_id=$GPU_ID --split=train --model_dir=$MODEL_DIR 16 | echo " ---------- Sampling MNIST Model ---------- " 17 | python sample_model.py --gpu_id=$GPU_ID --model_dir=$MODEL_DIR 18 | 19 | # for latent code visualization 20 | echo " ---------- Extracting MNIST Test Latents ---------- " 21 | python fit_gmm.py --data_dir=$DATA_DIR --model_dir=$MODEL_DIR --gpu_id=$GPU_ID --split=test --disable_prior=1 22 | echo " ---------- Visualizing MNIST Test Latents ---------- " 23 | python plot_tsne.py --data_dir=$DATA_DIR --model_dir=$MODEL_DIR --gpu_id=$GPU_ID --split=test 24 | 25 | 26 | ## Analyze K-MNIST 27 | MODEL_TOPDIR="../exp/pff/kmnist/" 28 | MODEL_DIR="../exp/pff/kmnist/trial0/" ## <-- choose a specific trial idx to analyze 29 | DATA_DIR="../data/kmnist/" 30 | 31 | echo " ---------- Evaluating K-MNIST Test Performance ---------- " 32 | python eval_model.py --data_dir=$DATA_DIR --split=test --model_topdir=$MODEL_TOPDIR --gpu_id=$GPU_ID --n_trials=$N_TRIALS --out_dir=$MODEL_TOPDIR 33 | 34 | # for prior distribution fitting / sampling 35 | echo " ---------- Fitting K-MNIST Prior ---------- " 36 | python fit_gmm.py --data_dir=$DATA_DIR --gpu_id=$GPU_ID --split=train --model_dir=$MODEL_DIR 37 | echo " ---------- Sampling K-MNIST Model ---------- " 38 | python sample_model.py --gpu_id=$GPU_ID --model_dir=$MODEL_DIR 39 | 40 | # for latent code visualization 41 | echo " ---------- Extracting K-MNIST Test Latents ---------- " 42 | python fit_gmm.py --data_dir=$DATA_DIR --model_dir=$MODEL_DIR --gpu_id=$GPU_ID --split=test --disable_prior=1 43 | echo " ---------- Visualizing K-MNIST Test Latents ---------- " 44 | python plot_tsne.py --data_dir=$DATA_DIR --model_dir=$MODEL_DIR --gpu_id=$GPU_ID --split=test 45 | -------------------------------------------------------------------------------- /src/data_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Data functions and utilities. 3 | 4 | Code for paper "The Predictive Forward-Forward Algorithm" (Ororbia & Mali, 2022) 5 | """ 6 | import random 7 | import numpy as np 8 | import io 9 | import sys 10 | import math 11 | 12 | seed = 69 13 | # set your random seed - very important to reproduce any Machine/Deep learning work 14 | np.random.seed(seed) 15 | 16 | class DataLoader(object): 17 | """ 18 | A data loader object, meant to allow sampling w/o replacement of one or 19 | more named design matrices. Note that this object is iterable (and 20 | implements an __iter__() method). 21 | 22 | Args: 23 | design_matrices: list of named data design matrices - [("name", matrix), ...] 24 | 25 | batch_size: number of samples to place inside a mini-batch 26 | 27 | disable_shuffle: if True, turns off sample shuffling (thus no sampling w/o replacement) 28 | 29 | ensure_equal_batches: if True, ensures sampled batches are equal in size (Default = True). 30 | Note that this means the very last batch, if it's not the same size as the rest, will 31 | reuse random samples from previously seen batches (yielding a batch with a mix of 32 | vectors sampled with and without replacement). 33 | """ 34 | def __init__(self, design_matrices, batch_size, disable_shuffle=False, 35 | ensure_equal_batches=True): 36 | self.batch_size = batch_size 37 | self.ensure_equal_batches = ensure_equal_batches 38 | self.disable_shuffle = disable_shuffle 39 | self.design_matrices = design_matrices 40 | if len(design_matrices) < 1: 41 | print(" ERROR: design_matrices must contain at least one design matrix!") 42 | sys.exit(1) 43 | self.data_len = len( self.design_matrices[0][1] ) 44 | self.ptrs = np.arange(0, self.data_len, 1) 45 | if self.data_len < self.batch_size: 46 | print("ERROR: batch size {} is > total number data samples {}".format( 47 | self.batch_size, self.data_len)) 48 | sys.exit(1) 49 | 50 | def __iter__(self): 51 | """ 52 | Yields a mini-batch of the form: [("name", batch),("name",batch),...] 53 | """ 54 | if self.disable_shuffle == False: 55 | self.ptrs = np.random.permutation(self.data_len) 56 | idx = 0 57 | while idx < len(self.ptrs): # go through each sample via the sampling pointer 58 | e_idx = idx + self.batch_size 59 | if e_idx > len(self.ptrs): # prevents reaching beyond length of dataset 60 | e_idx = len(self.ptrs) 61 | # extract sampling integer pointers 62 | indices = self.ptrs[idx:e_idx] 63 | if self.ensure_equal_batches == True: 64 | if indices.shape[0] < self.batch_size: 65 | diff = self.batch_size - indices.shape[0] 66 | indices = np.concatenate((indices, self.ptrs[0:diff])) 67 | # create the actual pattern vector batch block matrices 68 | data_batch = [] 69 | for dname, dmatrix in self.design_matrices: 70 | x_batch = dmatrix[indices] 71 | data_batch.append((dname, x_batch)) 72 | yield data_batch 73 | idx = e_idx 74 | -------------------------------------------------------------------------------- /src/eval_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code for paper "The Predictive Forward-Forward Algorithm" (Ororbia & Mali, 2022) 3 | 4 | ################################################################################ 5 | Simulates the training/adaptation of a recurrent neural system composed of 6 | a representation and generative circuit, trained via the preditive forward-forward 7 | process. 8 | Note that this code focuses on datasets of gray-scale images/patterns. 9 | ################################################################################ 10 | """ 11 | 12 | import os 13 | import sys, getopt, optparse 14 | import pickle 15 | #import dill as pickle 16 | sys.path.insert(0, '../') 17 | import tensorflow as tf 18 | import numpy as np 19 | import time 20 | import copy 21 | 22 | import matplotlib 23 | matplotlib.use('Agg') 24 | import matplotlib.pyplot as plt 25 | #cmap = plt.cm.jet 26 | 27 | # import general simulation utilities 28 | from data_utils import DataLoader 29 | from pff_rnn import PFF_RNN 30 | 31 | ################################################################################ 32 | 33 | seed = 69 34 | tf.random.set_seed(seed=seed) 35 | np.random.seed(seed) 36 | 37 | model_topdir = "../exp/" 38 | out_dir = "../exp/" 39 | data_dir = "../data/mnist/" 40 | split = "test" 41 | # read in configuration file and extract necessary variables/constants 42 | options, remainder = getopt.getopt(sys.argv[1:], '', ["data_dir=","model_topdir=", 43 | "gpu_id=","n_trials=", 44 | "out_dir=","split="]) 45 | # Collect arguments from argv 46 | n_trials = 1 47 | gpu_id = -1 48 | for opt, arg in options: 49 | if opt in ("--data_dir"): 50 | data_dir = arg.strip() 51 | elif opt in ("--model_topdir"): 52 | model_topdir = arg.strip() 53 | elif opt in ("--split"): 54 | split = arg.strip() 55 | elif opt in ("--out_dir"): 56 | out_dir = arg.strip() 57 | elif opt in ("--gpu_id"): 58 | gpu_id = int(arg.strip()) 59 | elif opt in ("--n_trials"): 60 | n_trials = int(arg.strip()) 61 | print(" Exp out dir: ",out_dir) 62 | 63 | mid = gpu_id # 0 64 | if mid >= 0: 65 | print(" > Using GPU ID {0}".format(mid)) 66 | os.environ["CUDA_VISIBLE_DEVICES"]="{0}".format(mid) 67 | #gpu_tag = '/GPU:0' 68 | gpu_tag = '/GPU:0' 69 | else: 70 | os.environ["CUDA_VISIBLE_DEVICES"]="-1" 71 | gpu_tag = '/CPU:0' 72 | 73 | print(" >>> Run sim on {} w/ GPU {}".format(data_dir, mid)) 74 | 75 | dev_batch_size = 500 # dev batch size 76 | print(" > Loading dev set") 77 | xfname = "{}/{}X.npy".format(data_dir, split) 78 | yfname = "{}/{}Y.npy".format(data_dir, split) 79 | Xdev = ( tf.cast(np.load(xfname, allow_pickle=True),dtype=tf.float32) ) 80 | 81 | if len(Xdev.shape) > 2: 82 | Xdev = tf.reshape(Xdev, [Xdev.shape[0], Xdev.shape[1] * Xdev.shape[2]]) 83 | print("Xdev.shape = ",Xdev.shape) 84 | max_val = float(tf.reduce_max(Xdev)) 85 | if max_val > 1.0: 86 | Xdev = Xdev/max_val 87 | 88 | Ydev = ( tf.cast(np.load(yfname, allow_pickle=True),dtype=tf.float32) ) 89 | if len(Ydev.shape) == 1: 90 | nC = Y.shape[1] 91 | Ydev = tf.one_hot(Ydev.numpy().astype(np.int32), depth=nC) 92 | elif Ydev.shape[1] == 1: 93 | nC = 10 94 | Ydev = tf.one_hot(tf.squeeze(Ydev).numpy().astype(np.int32), depth=nC) 95 | print("Ydev.shape = ",Ydev.shape) 96 | 97 | devset = DataLoader(design_matrices=[("x",Xdev.numpy()), ("y",Ydev.numpy())], 98 | batch_size=dev_batch_size, disable_shuffle=True) 99 | 100 | def classify(agent, x): 101 | K_low = int(agent.K/2) - 1 # 3 102 | K_high = int(agent.K/2) + 1 # 5 103 | x_ = x 104 | Ey = None 105 | z_lat = agent.forward(x_) # do forward init pass 106 | for i in range(agent.y_dim): 107 | z_lat_ = [] 108 | for ii in range(len(z_lat)): 109 | z_lat_.append(z_lat[ii] + 0) 110 | 111 | yi = tf.ones([x.shape[0],agent.y_dim]) * tf.expand_dims(tf.one_hot(i,depth=agent.y_dim),axis=0) 112 | 113 | gi = 0.0 114 | for k in range(K_high): 115 | z_lat_, p_g = agent._step(x_, yi, z_lat_, thr=0.0) 116 | if k >= K_low and k <= K_high: # only keep goodness in middle iterations 117 | gi = ((p_g[0] + p_g[1])*0.5) + gi 118 | 119 | if i > 0: 120 | Ey = tf.concat([Ey,gi],axis=1) 121 | else: 122 | Ey = gi 123 | 124 | Ey = Ey / (3.0) 125 | y_hat = tf.nn.softmax(Ey) 126 | return y_hat, Ey 127 | 128 | def eval(agent, dataset, debug=False): 129 | ''' 130 | Evaluates the current state of the agent given a dataset (data-loader). 131 | ''' 132 | N = 0.0 133 | Ny = 0.0 134 | Acc = 0.0 135 | Ly = 0.0 136 | Lx = 0.0 137 | tt = 0 138 | #debug = True 139 | for batch in dataset: 140 | _, x = batch[0] 141 | #_, y_tag = batch[1] 142 | _, y = batch[1] 143 | N += x.shape[0] 144 | Ny += float(tf.reduce_sum(y)) 145 | 146 | y_hat, Ey = classify(agent, x) 147 | Ly += tf.reduce_sum(-tf.reduce_sum(y * tf.math.log(y_hat), axis=1, keepdims=True)) 148 | 149 | z_lat = agent.forward(x) 150 | x_hat = agent.sample(z=z_lat[len(z_lat)-2]) 151 | 152 | ex = x_hat - x 153 | Lx += tf.reduce_sum(tf.reduce_sum(tf.math.square(ex),axis=1,keepdims=True)) 154 | 155 | if debug == True: 156 | print("------------------------") 157 | print(Ey[0:4,:]) 158 | print(y_hat[0:4,:]) 159 | print("------------------------") 160 | 161 | #y_m = tf.squeeze(y_m) 162 | y_ind = tf.cast(tf.argmax(y,1),dtype=tf.int32) 163 | y_pred = tf.cast(tf.argmax(Ey,1),dtype=tf.int32) 164 | comp = tf.cast(tf.equal(y_pred,y_ind),dtype=tf.float32) #* y_m 165 | Acc += tf.reduce_sum( comp ) 166 | print("\r Lx = {} Ly = {} Acc = {} ({} samples)".format(Lx/Ny,Ly/Ny,Acc/Ny,Ny),end="") 167 | print() 168 | Ly = Ly/Ny 169 | Acc = Acc/Ny 170 | Lx = Lx/Ny 171 | 172 | return Ly, Acc, Lx 173 | 174 | print("----") 175 | with tf.device(gpu_tag): 176 | 177 | best_acc_list = [] 178 | acc_list = [] 179 | for tr in range(n_trials): 180 | ######################################################################## 181 | ## load model 182 | ######################################################################## 183 | agent = PFF_RNN(model_dir="{}trial{}/".format(model_topdir, tr)) 184 | K = agent.K 185 | ######################################################################## 186 | 187 | ######################################################################## 188 | Ly, Acc, Lx = eval(agent, devset) 189 | print("{}: L {} Acc = {} Lx {}".format(-1, Ly, Acc, Lx)) 190 | 191 | acc_list.append(1.0 - Acc) 192 | 193 | ############################################################################ 194 | ## calc post-trial statistics 195 | n_dec = 4 196 | mu = round(np.mean(np.asarray(acc_list)), n_dec) 197 | sd = round(np.std(np.asarray(acc_list)), n_dec) 198 | print(" Test.Error = {:.4f} \pm {:.4f}".format(mu, sd)) 199 | 200 | ## store result to disk just in case... 201 | results_fname = "{}/test_results.txt".format(out_dir) 202 | log_t = open(results_fname,"a") 203 | log_t.write("Generalization Results:\n") 204 | log_t.write(" Test.Error = {:.4f} \pm {:.4f}\n".format(mu, sd)) 205 | log_t.close() 206 | -------------------------------------------------------------------------------- /src/fit_gmm.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code for paper "The Predictive Forward-Forward Algorithm" (Ororbia & Mali, 2022) 3 | 4 | ################################################################################ 5 | Fits a multi-modal/mixture prior to the space of a trained PFF-RNN 6 | ################################################################################ 7 | """ 8 | 9 | import os 10 | import sys, getopt, optparse 11 | import pickle 12 | sys.path.insert(0, '../') 13 | import tensorflow as tf 14 | import numpy as np 15 | import time 16 | import copy 17 | 18 | import matplotlib 19 | matplotlib.use('Agg') 20 | import matplotlib.pyplot as plt 21 | #cmap = plt.cm.jet 22 | 23 | from data_utils import DataLoader 24 | from pff_rnn import PFF_RNN 25 | 26 | # import general simulation utilities 27 | from ngclearn.density.gmm import GMM 28 | 29 | ################################################################################ 30 | 31 | seed = 69 32 | tf.random.set_seed(seed=seed) 33 | np.random.seed(seed) 34 | 35 | disable_prior = 0 36 | data_dir = "../data/" 37 | model_dir = "../exp/" 38 | split = "train" 39 | # read in configuration file and extract necessary variables/constants 40 | options, remainder = getopt.getopt(sys.argv[1:], '', ["data_dir=","model_dir=","gpu_id=","split=", 41 | "disable_prior="]) 42 | # Collect arguments from argv 43 | n_trials = 1 44 | gpu_id = -1 45 | for opt, arg in options: 46 | if opt in ("--data_dir"): 47 | data_dir = arg.strip() 48 | elif opt in ("--model_dir"): 49 | model_dir = arg.strip() 50 | elif opt in ("--split"): 51 | split = arg.strip() 52 | elif opt in ("--disable_prior"): 53 | disable_prior = int(arg.strip()) 54 | elif opt in ("--gpu_id"): 55 | gpu_id = int(arg.strip()) 56 | 57 | gmm_fname = "{}/prior.gmm".format(model_dir) 58 | latent_fname = "{}/latents.npy".format(model_dir) 59 | batch_size = 400 #200 #100 #50 #1000 #500 60 | 61 | mid = gpu_id 62 | if mid >= 0: 63 | print(" > Using GPU ID {0}".format(mid)) 64 | os.environ["CUDA_VISIBLE_DEVICES"]="{0}".format(mid) 65 | #gpu_tag = '/GPU:0' 66 | gpu_tag = '/GPU:0' 67 | else: 68 | os.environ["CUDA_VISIBLE_DEVICES"]="-1" 69 | gpu_tag = '/CPU:0' 70 | 71 | 72 | xfname = "{}{}X.npy".format(data_dir,split) 73 | X = ( tf.cast(np.load(xfname),dtype=tf.float32) ) 74 | if len(X.shape) > 2: 75 | X = tf.reshape(X, [X.shape[0], X.shape[1] * X.shape[2]]) 76 | x_dim = X.shape[1] 77 | max_val = float(tf.reduce_max(X)) 78 | if max_val > 1.0: 79 | X = X/max_val 80 | print("X.shape = ",X.shape) 81 | 82 | yfname = "{}{}Y.npy".format(data_dir,split) 83 | Y = ( tf.cast(np.load(yfname),dtype=tf.float32) ) 84 | if len(Y.shape) == 1: 85 | print("y_init.shape = ",Y.shape) 86 | nC = 10 #Y.shape[1] 87 | Y = tf.one_hot(Y.numpy().astype(np.int32), depth=nC) 88 | print("y_post.shape = ",Y.shape) 89 | elif Y.shape[1] == 1: 90 | print("y_init.shape = ",Y.shape) 91 | nC = 10 #Y.shape[1] 92 | Y = tf.one_hot(tf.squeeze(Y).numpy().astype(np.int32), depth=nC) 93 | print("y_post.shape = ",Y.shape) 94 | y_dim = Y.shape[1] 95 | 96 | dataset = DataLoader(design_matrices=[("x",X.numpy()),("y",Y.numpy())], 97 | batch_size=batch_size, disable_shuffle=True) 98 | 99 | def calc_latent_map(agent, dataset, debug=False): 100 | ''' 101 | Calculates a latent "map" or matrix containing the latent encoding of 102 | a dataset/loader. 103 | ''' 104 | z = None 105 | N = 0.0 106 | L_r = 0.0 107 | for batch in dataset: 108 | _, x = batch[0] 109 | _, y = batch[1] 110 | 111 | z2 = agent.get_latent(x, y, K, use_y_hat=True) 112 | e = agent.z0_hat - x 113 | Li = tf.reduce_sum(tf.math.square(e) * 0.5) 114 | L_r = Li + L_r 115 | 116 | if z is not None: 117 | z = tf.concat([z,z2],axis=0) 118 | else: 119 | z = z2 120 | N += x.shape[0] 121 | L_r = L_r / N 122 | return z, L_r 123 | 124 | 125 | print("----") 126 | with tf.device(gpu_tag): 127 | print(" >> Loading model from: ",model_dir) 128 | agent = PFF_RNN(model_dir=model_dir) 129 | agent.z1 = None 130 | agent.z0_hat = None 131 | K = agent.K 132 | 133 | z_map, Lr = calc_latent_map(agent, dataset) 134 | print("MSE: ",float(Lr)) 135 | 136 | # save latent space map to disk 137 | print(" > Saving latents to disk: lat.shape = ",z_map.shape) 138 | np.save(latent_fname, z_map) 139 | 140 | max_w = -10000.0 141 | min_w = 10000.0 142 | max_w = max(max_w, float(tf.reduce_max(z_map))) 143 | min_w = min(min_w, float(tf.reduce_min(z_map))) 144 | print("max_z = ", max_w) 145 | print("min_z = ", min_w) 146 | 147 | if disable_prior == 0: 148 | print(" > Estimating latent density...") 149 | n_comp = 10 # number of compoments that will define the prior P(z) 150 | lat_density = GMM(k=n_comp) 151 | lat_density.fit(z_map) 152 | 153 | print(" > Saving density estimator to: {0}".format("gmm.pkl")) 154 | fd = open("{0}".format(gmm_fname), 'wb') 155 | pickle.dump(lat_density, fd) 156 | fd.close() 157 | -------------------------------------------------------------------------------- /src/pff_rnn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code for paper "The Predictive Forward-Forward Algorithm" (Ororbia & Mali, 2022) 3 | 4 | This file contains model constructor and its credit assignment code. 5 | """ 6 | 7 | import os 8 | import sys 9 | import copy 10 | import pickle 11 | #import dill as pickle 12 | import tensorflow as tf 13 | import numpy as np 14 | 15 | ### generic routines/functions 16 | 17 | def serialize(fname, object): ## object "saving" routine 18 | fd = open(fname, 'wb') 19 | pickle.dump(object, fd) 20 | fd.close() 21 | 22 | def deserialize(fname): ## object "loading" routine 23 | fd = open(fname, 'rb') 24 | object = pickle.load( fd ) 25 | fd.close() 26 | return object 27 | 28 | @tf.function 29 | def create_competiion_matrix(z_dim, n_group, beta_scale=1.0, alpha_scale=1.0): 30 | """ 31 | Competition matrix initialization function, adapted from 32 | (Ororbia & Kifer 2022; Nature Communications). 33 | """ 34 | diag = tf.eye(z_dim) 35 | V_l = None 36 | g_shift = 0 37 | while (z_dim - (n_group + g_shift)) >= 0: 38 | if g_shift > 0: 39 | left = tf.zeros([1,g_shift]) 40 | middle = tf.ones([1,n_group]) 41 | right = tf.zeros([1,z_dim - (n_group + g_shift)]) 42 | slice = tf.concat([left,middle,right],axis=1) 43 | for n in range(n_group): 44 | V_l = tf.concat([V_l,slice],axis=0) 45 | else: 46 | middle = tf.ones([1,n_group]) 47 | right = tf.zeros([1,z_dim - n_group]) 48 | slice = tf.concat([middle,right],axis=1) 49 | for n in range(n_group): 50 | if V_l is not None: 51 | V_l = tf.concat([V_l,slice],axis=0) 52 | else: 53 | V_l = slice 54 | g_shift += n_group 55 | V_l = V_l * (1.0 - diag) * beta_scale + diag * alpha_scale 56 | return V_l 57 | 58 | @tf.function 59 | def softmax(x, tau=0.0): ## temperature-controlled softmax activation 60 | if tau > 0.0: 61 | x = x / tau 62 | max_x = tf.expand_dims( tf.reduce_max(x, axis=1), axis=1) 63 | exp_x = tf.exp(tf.subtract(x, max_x)) 64 | return exp_x / tf.expand_dims( tf.reduce_sum(exp_x, axis=1), axis=1) 65 | 66 | @tf.custom_gradient 67 | def _relu(x): ## FF/PFF relu variant activation 68 | ## modified relu 69 | out = tf.nn.relu(x) 70 | def grad(upstream): 71 | #dx = tf.cast(tf.math.greater_equal(x, 0.0), dtype=tf.float32) # d_relu/d_x 72 | dx = tf.ones(x.shape) # pretend like derivatives exist for zero values 73 | return upstream * dx 74 | return out, grad 75 | @tf.function 76 | def clip_fx(x): ## hard-clip activation 77 | return tf.clip_by_value(x, 0.0, 1.0) 78 | 79 | ### begin constructor definition 80 | 81 | class PFF_RNN: 82 | """ 83 | Basic implementation of the predictive forward-forward (FF) algorithm for a recurrent 84 | neural network from (Ororbia & Mali 2022). 85 | 86 | A "model" in this framework is defined as a pair containing an arguments 87 | dictionary and a theta construct, i.e., (args, theta). 88 | """ 89 | def __init__(self, args=None, model_dir=None): 90 | theta_r = None 91 | theta_g = None 92 | if model_dir is not None: 93 | args = deserialize("{}config.args".format(model_dir)) 94 | theta_r = deserialize("{}rep_params.theta".format(model_dir)) 95 | theta_g = deserialize("{}gen_params.theta".format(model_dir)) 96 | if args is None: 97 | print("ERROR: no model arguments provided...") 98 | sys.exit(1) 99 | self.args = args 100 | ## collect hyper-parameters 101 | self.seed = 69 102 | if args.get("seed") is not None: 103 | self.seed = args["seed"] 104 | self.x_dim = args["x_dim"] 105 | self.y_dim = args["y_dim"] 106 | self.n_units = 2000 107 | if args.get("n_units") is not None: 108 | self.n_units = args["n_units"] 109 | 110 | self.g_units = 20 # number of top-most latent variables for generative circuit 111 | if args.get("g_units") is not None: 112 | self.g_units = args["g_units"] 113 | self.K = 10 114 | if args.get("K") is not None: 115 | self.K = args["K"] 116 | self.beta = 0.025 #0.05 #0.1 117 | if args.get("beta") is not None: 118 | self.beta = args["beta"] 119 | 120 | self.gen_gamma = 1.0 121 | self.rec_gamma = 1.0 122 | self.thr = 3.0 # goodness threshold 123 | if args.get("thr") is not None: 124 | self.thr = args["thr"] 125 | self.alpha = 0.3 # dampening factor 126 | if args.get("alpha") is not None: 127 | self.alpha = args["alpha"] 128 | 129 | self.y_scale = 5.0 # 1.0 130 | # stats for peer normalization (if used) 131 | self.eps_r = 0.01 # noise factor for representation circuit 132 | if args.get("eps_r") is not None: 133 | self.eps_r = args["eps_r"] 134 | self.eps_g = 0.025 # noise factor for generative circuit 135 | if args.get("eps_g") is not None: 136 | self.eps_g = args["eps_g"] 137 | 138 | ## Set up parameter construct - theta_r and theta_g 139 | if theta_r is None: ## representation circuit params 140 | initializer = tf.compat.v1.keras.initializers.Orthogonal() 141 | self.b1 = tf.Variable(initializer([1, self.n_units])) 142 | self.b2 = tf.Variable(initializer([1, self.n_units])) 143 | self.W1 = tf.Variable(initializer([self.x_dim, self.n_units])) 144 | self.V2 = tf.Variable(initializer([self.n_units, self.n_units])) # inner feedback 145 | self.W2 = tf.Variable(initializer([self.n_units, self.n_units])) 146 | self.V = tf.Variable(initializer([self.y_dim, self.n_units])) # top to inner feedback 147 | self.W = tf.Variable(initializer([self.n_units, self.y_dim])) # softmax output 148 | self.b = tf.Variable(tf.zeros([1, self.y_dim])) 149 | initializer = tf.compat.v1.keras.initializers.RandomUniform(minval=0.0, maxval=0.05) 150 | 151 | self.M1 = create_competiion_matrix(self.n_units, n_group=10) 152 | self.M2 = create_competiion_matrix(self.n_units, n_group=10) 153 | self.L1 = tf.Variable(initializer([self.n_units, self.n_units])) # lateral lyr 1 154 | self.L2 = tf.Variable(initializer([self.n_units, self.n_units])) # lateral lyr 2 155 | theta_r = [self.L1,self.L2,self.b1,self.b2,self.W1,self.W2,self.V2,self.V,self.W,self.b] ## theta_r 156 | else: 157 | self.M1 = create_competiion_matrix(self.n_units, n_group=10) 158 | self.M2 = create_competiion_matrix(self.n_units, n_group=10) 159 | self.L1 = theta_r[0] 160 | self.L2 = theta_r[1] 161 | self.b1 = theta_r[2] 162 | self.b2 = theta_r[3] 163 | self.W1 = theta_r[4] 164 | self.W2 = theta_r[5] 165 | self.V2 = theta_r[6] 166 | self.V = theta_r[7] 167 | self.W = theta_r[8] 168 | self.b = theta_r[9] 169 | self.theta = theta_r 170 | 171 | if theta_g is None: ## generative circuit params 172 | self.Gy = tf.Variable(initializer([self.g_units, self.n_units])) 173 | self.G2 = tf.Variable(initializer([self.n_units, self.n_units])) 174 | self.G1 = tf.Variable(initializer([self.n_units, self.x_dim])) 175 | theta_g = [self.Gy,self.G1,self.G2] ## theta_g 176 | else: 177 | self.Gy = theta_g[0] 178 | self.G1 = theta_g[1] 179 | self.G2 = theta_g[2] 180 | self.theta_g = theta_g 181 | 182 | ## activation functions and latent states/stat variables 183 | self.z_g = None # top-most generative latent state 184 | self.fx = _relu ## internal activation function 185 | self.ofx = clip_fx ## predictive activation output function 186 | self.gfx = _relu ## generative activation output function 187 | self.z1 = None 188 | self.z0_hat = None 189 | 190 | def save_model(self, model_dir): 191 | """ 192 | Save current model configuration and synaptic parameters (of both 193 | representation & generative circuits) to disk. 194 | 195 | Args: 196 | model_dir: directory to save model config 197 | """ 198 | if not os.path.exists(model_dir): 199 | os.makedirs(model_dir) 200 | serialize("{}config.args".format(model_dir), self.args) 201 | serialize("{}rep_params.theta".format(model_dir), self.theta) 202 | serialize("{}gen_params.theta".format(model_dir), self.theta_g) 203 | 204 | def calc_goodness(self, z, thr): 205 | """ 206 | Calculates the "goodness" of an activation vector. 207 | 208 | Args: 209 | z: activation vector/matrix 210 | thr: goodness threshold 211 | 212 | Returns: 213 | goodness scalar of z 214 | """ 215 | z_sqr = tf.math.square(z) 216 | delta = tf.reduce_sum(z_sqr, axis=1, keepdims=True) 217 | #delta = delta - thr # maximize for positive samps, minimize for negative samps 218 | delta = -delta + thr # minimize for positive samps, maximize for negative samps 219 | # gets the probability P(pos) 220 | p = tf.nn.sigmoid(delta) 221 | eps = 1e-5 222 | p = tf.clip_by_value(p, eps, 1.0 - eps) 223 | return p, delta 224 | 225 | def calc_loss(self, z, lab, thr, keep_batch=False): 226 | """ 227 | Calculates the local loss of an activation vector. 228 | 229 | Args: 230 | z: activation vector/matrix (vector/matrix) 231 | lab: data "type" binary label (1 for pos, 0 for neg) (vector/matrix) 232 | thr: goodness threshold 233 | 234 | Returns: 235 | goodness scalar of z 236 | """ 237 | p, logit = self.calc_goodness(z, thr) 238 | ## the loss below is what the original PFF paper used & adheres to Eqn 3 239 | CE = tf.math.maximum(logit, 0) - logit * lab + tf.math.log(1. + tf.math.exp(-tf.math.abs(logit))) 240 | ## the commented-out loss below, however, also works just fine 241 | #CE = tf.nn.softplus(-logit) * lab + tf.nn.softplus(logit) * (1.0 - lab) 242 | L = tf.reduce_sum(CE, axis=1, keepdims=True) 243 | if keep_batch == True: 244 | return L 245 | L = tf.reduce_mean(L) 246 | return L 247 | 248 | def forward(self, x): 249 | """ 250 | Forward propagates x thru rep circuit 251 | 252 | Args: 253 | x: sensory input (vector/matrix) 254 | 255 | Returns: 256 | list of layer-wise activities 257 | """ 258 | z1 = self.fx(tf.matmul(self.normalize(x), self.W1) + self.b1) 259 | z2 = self.fx(tf.matmul(self.normalize(z1), self.W2) + self.b2) 260 | z3 = softmax(tf.matmul(self.normalize(z2), self.W) + self.b) 261 | return [z1,z2,z3] # return all latents 262 | 263 | def classify(self, x): 264 | """ 265 | Categorizes sensory input x 266 | 267 | Args: 268 | x: sensory input (vector/matrix) 269 | 270 | Returns: 271 | y_hat, probability distribution over labels (vector/matrix) 272 | """ 273 | z = self.forward(x) 274 | y_hat = z[len(z)-1] 275 | return y_hat 276 | 277 | def infer(self, x, y, lab, z_lat, K, opt=None, g_opt=None, reg_lambda=0.0, 278 | zero_y=False, g_reg_lambda=0.0): 279 | """ 280 | Simulates the PFF intertwined inference-and-learning process for 281 | the underlying dual-circuit system. 282 | 283 | Args: 284 | x: sensory input (vector/matrix) 285 | y: sensory class label (vector/matrix) 286 | lab: data "type" binary label (1 for pos, 0 for neg) (vector/matrix) 287 | z_lat: list of initial conditions for model's representation activities 288 | (usually provided by an initial forward pass with .forward(x) ) 289 | K: number of simulation steps 290 | opt: rep circuit optimizer 291 | g_opt: gen circuit optimizer 292 | reg_lambda: regularization coefficient (for representation synapses) 293 | zero_y: "zero out" the y-vector top-down context 294 | g_reg_lambda: regularization coefficient (for generative synapses) 295 | 296 | Returns: 297 | (global energy value (goodness + regularization), label distribution matrix, 298 | generative loss, x reconstruction) 299 | """ 300 | if self.rec_gamma > 0.0: 301 | self.theta = [self.L1,self.L2,self.b1,self.b2,self.W1,self.W2,self.V2,self.V,self.W,self.b] 302 | else: 303 | self.theta = [self.b1,self.b2,self.W1,self.W2,self.V2,self.V,self.W,self.b] 304 | 305 | calc_grad = False 306 | if opt is not None: 307 | calc_grad = True 308 | # update generative model 309 | self.z_g = tf.Variable(tf.zeros([x.shape[0], self.g_units])) 310 | for k in range(K): 311 | # update representation model 312 | z_lat, L, delta = self.step(x,y,lab,z_lat,calc_grad=calc_grad, 313 | zero_y=zero_y, reg_lambda=reg_lambda) 314 | if opt is not None: ## update synapses 315 | bound = 1.0 #5.0 # 1.0 316 | for l in range(len(delta)): 317 | delta[l] = tf.clip_by_value(delta[l], -bound, bound) # clip update by projection 318 | opt.apply_gradients(zip(delta, self.theta)) 319 | if self.gen_gamma > 0.0: ## update generative model 320 | grad_f = True #False 321 | Lg, x_hat = self.update_generator(x,y,z_lat,g_opt,reg_lambda=g_reg_lambda,grad_f=grad_f) 322 | else: 323 | Lg = 0.0 324 | x_hat = x * 0 325 | 326 | y_hat = z_lat[len(z_lat)-1] 327 | return L, y_hat, Lg, x_hat 328 | 329 | def sample(self, n_s=0, z=None, y=None): # samples generative circuit 330 | """ 331 | Samples the generative circuit within this current neural system. 332 | 333 | Args: 334 | n_s: number of samples to synthesize/confabulate 335 | z: top-most externally produced input sample (from a prior); (vector/matrix) 336 | y: sensory class label (vector/matrix) 337 | 338 | Returns: 339 | samples of the bottom-most sensory layer 340 | """ 341 | if z is None: 342 | eps_sigma = 0.05 343 | #eps = tf.random.normal([y.shape[0],self.n_units], 0.0, eps_sigma) #* 0 344 | z_in = self.normalize(self.gfx(y)) 345 | z2 = self.gfx(tf.matmul(z_in,self.Gy))# + self.c) 346 | else: 347 | z2 = self.gfx(z) 348 | #z2 = self.gfx(z2) 349 | z1 = self.gfx(tf.matmul(self.normalize(z2),self.G2))# + self.c2) 350 | #z0 = tf.matmul(self.normalize(z1),self.G1) # 351 | z0 = self.ofx(tf.matmul(self.normalize(z1),self.G1))# + self.c1) 352 | return z0 353 | 354 | def update_generator(self, x, y, z_lat, opt, reg_lambda=0.0001, grad_f=True): 355 | ''' 356 | Internal routine for adjusting synapses of generative circuit 357 | ''' 358 | z0 = x 359 | z1 = z_lat[0] 360 | z2 = z_lat[1] 361 | eps_sigma = self.eps_g #0.025 #0.055 # 0.05 #0.02 #0.1 362 | eps1 = tf.random.normal([x.shape[0],self.n_units], 0.0, eps_sigma) #* 0 363 | eps2 = tf.random.normal([x.shape[0],self.n_units], 0.0, eps_sigma) #* 0 364 | with tf.GradientTape(persistent=True) as tape: 365 | z3_bar = self.gfx(self.z_g) 366 | z2_hat = tf.matmul(self.normalize(z3_bar),self.Gy) #+ self.c 367 | #z2_hat = self.fx(tf.matmul(self.z_g,self.Gy)) 368 | z2_bar = self.gfx(z2_hat) #z2 #self.fx(z2 + eps2) 369 | z1_hat = tf.matmul(tf.stop_gradient(self.normalize(self.gfx(z2 + eps2))),self.G2) #+ self.c2 370 | #z1_hat = self.fx(tf.matmul(z2_bar,self.G2)) 371 | z1_bar = self.gfx(z1_hat) #z1 #self.fx(z1 + eps1) 372 | #z0_hat = tf.matmul(z1_bar,self.G1) # 373 | z0_hat = self.ofx(tf.matmul(tf.stop_gradient(self.normalize(self.gfx(z1 + eps1))),self.G1)) #+ self.c1 374 | 375 | e2 = z2_bar - z2#_bar 376 | L2 = tf.reduce_mean(tf.reduce_sum(tf.math.square(e2),axis=1,keepdims=True)) 377 | e1 = z1_bar - z1#_bar 378 | L1 = tf.reduce_mean(tf.reduce_sum(tf.math.square(e1),axis=1,keepdims=True)) 379 | e0 = z0_hat - z0 380 | L0 = tf.reduce_mean(tf.reduce_sum(tf.math.square(e0),axis=1,keepdims=True)) 381 | 382 | reg = 0.0 383 | if reg_lambda > 0.0: # weight decay 384 | reg = (tf.norm(self.G1) + tf.norm(self.G2) + tf.norm(self.Gy)) * reg_lambda 385 | L = L2 + L1 + L0 + reg 386 | Lg = L2 + L1 + L0 387 | if grad_f == True: 388 | delta = tape.gradient(L, self.theta_g) 389 | bound = 1.0 #5.0 # 1.0 390 | for l in range(len(delta)): 391 | delta[l] = tf.clip_by_value(delta[l], -bound, bound) # clip update by projection 392 | opt.apply_gradients(zip(delta, self.theta_g)) 393 | # for l in range(len(self.theta_g)): 394 | # self.theta_g[l].assign(tf.clip_by_norm(self.theta_g[l], 5.0, axes=1)) 395 | 396 | d_z = tape.gradient(L2, self.z_g) 397 | self.z_g.assign(self.z_g - d_z * self.beta) 398 | 399 | return Lg, z0_hat # return generative loss 400 | 401 | def step(self, x, y, lab, z_lat, calc_grad=False, zero_y=False, reg_lambda=0.0): 402 | ''' 403 | Internal full simulation step routine (inference and credit assignment) 404 | ''' 405 | y_ = y * self.y_scale 406 | Npos = tf.reduce_sum(lab) 407 | 408 | eps_sigma = self.eps_r # 0.01 #0.05 # kmnist 409 | #eps_sigma = 0.025 # mnist 410 | eps1 = tf.random.normal([x.shape[0],self.n_units], 0.0, 1.0) * eps_sigma 411 | eps2 = tf.random.normal([x.shape[0],self.n_units], 0.0, 1.0) * eps_sigma 412 | 413 | with tf.GradientTape() as tape: 414 | z1_tm1 = tf.stop_gradient(z_lat[0]) 415 | z2_tm1 = tf.stop_gradient(z_lat[1]) 416 | 417 | z1 = tf.matmul(self.normalize(x),self.W1) + tf.matmul(self.normalize(z2_tm1),self.V2) + self.b1 + eps1 418 | if self.rec_gamma > 0.0: 419 | L1 = tf.nn.relu(self.L1) 420 | L1 = L1 * self.M1 * (1. - tf.eye(self.L1.shape[0])) - L1 * tf.eye(self.L1.shape[0]) 421 | z1 = z1 - tf.matmul(z1_tm1, L1) * self.rec_gamma 422 | z1 = self.fx(z1) * (1.0 - self.alpha) + z1_tm1 * self.alpha 423 | 424 | z2 = tf.matmul(self.normalize(z1_tm1),self.W2) + tf.matmul(y_,self.V) + self.b2 + eps2 425 | if self.rec_gamma > 0.0: 426 | L2 = tf.nn.relu(self.L2) 427 | L2 = L2 * self.M2 * (1. - tf.eye(self.L2.shape[0])) - L2 * tf.eye(self.L2.shape[0]) 428 | z2 = z2 - tf.matmul(z2_tm1, L2) * self.rec_gamma 429 | z2 = self.fx(z2) * (1.0 - self.alpha) + z2_tm1 * self.alpha 430 | 431 | z3 = softmax(tf.matmul(self.normalize(z2_tm1),self.W) + self.b) #tf.nn.softmax(tf.matmul(self.normalize(z2),self.W) + self.b) 432 | 433 | ## calc loss 434 | L1 = self.calc_loss(z1, lab, thr=self.thr) 435 | L2 = self.calc_loss(z2, lab, thr=self.thr) 436 | if zero_y == True: 437 | L3 = tf.reduce_sum(-tf.reduce_sum(y_ * tf.math.log(z3), axis=1, keepdims=True)) * 0 438 | else: 439 | L3 = tf.reduce_sum(-tf.reduce_sum(y_ * tf.math.log(z3) * lab, axis=1, keepdims=True) * lab) 440 | L3 = L3 / Npos 441 | reg = 0.0 442 | if reg_lambda > 0.0: # weight decay 443 | reg = (tf.norm(self.W1) + tf.norm(self.W2) + tf.norm(self.V2) + 444 | tf.norm(self.V) + tf.norm(self.W)) * reg_lambda 445 | L = L3 + L2 + L1 + reg 446 | Lg = L2 + L1 447 | delta = None 448 | if calc_grad == True: 449 | delta = tape.gradient(L, self.theta) 450 | return [z1,z2,z3], Lg, delta 451 | 452 | def _step(self, x, y, z_lat, thr=None): 453 | ''' 454 | Internal simulation step, without credit assignment 455 | ''' 456 | thr_ = thr 457 | y_ = y * self.y_scale 458 | 459 | z1_tm1 = z_lat[0] 460 | z2_tm1 = z_lat[1] 461 | 462 | z1 = tf.matmul(self.normalize(x),self.W1) + tf.matmul(self.normalize(z2_tm1),self.V2) + self.b1 463 | if self.rec_gamma > 0.0: 464 | L1 = tf.nn.relu(self.L1) 465 | L1 = L1 * self.M1 * (1. - tf.eye(self.L1.shape[0])) - L1 * tf.eye(self.L1.shape[0]) 466 | z1 = z1 - tf.matmul(z1_tm1, L1) * self.rec_gamma 467 | z1 = self.fx(z1) * (1.0 - self.alpha) + z1_tm1 * self.alpha 468 | 469 | z2 = tf.matmul(self.normalize(z1_tm1),self.W2) + tf.matmul(y_,self.V) + self.b2 470 | if self.rec_gamma > 0.0: 471 | L2 = tf.nn.relu(self.L2) 472 | L2 = L2 * self.M2 * (1. - tf.eye(self.L2.shape[0])) - L2 * tf.eye(self.L2.shape[0]) 473 | z2 = z2 - tf.matmul(z2_tm1, L2) * self.rec_gamma 474 | z2 = self.fx(z2) * (1.0 - self.alpha) + z2_tm1 * self.alpha 475 | ## calc goodness values 476 | if thr_ is None: 477 | thr_ = self.thr 478 | p1, logit1 = self.calc_goodness(z1, thr_) 479 | p2, logit2 = self.calc_goodness(z2, thr_) 480 | 481 | return [z1,z2], [logit1,logit2] 482 | 483 | def get_latent(self, x, y, K, use_y_hat=False): 484 | self.z1 = None 485 | self.z0_hat = 0.0 486 | z_lat = self.forward(x) 487 | y_hat = z_lat[len(z_lat)-1] 488 | self.z_g = tf.Variable(tf.zeros([x.shape[0], self.g_units])) 489 | y_ = y 490 | if use_y_hat == True: 491 | y_ = y_hat 492 | for k in range(K): 493 | self._infer_latent(x,y_,z_lat) 494 | #self.z0_hat = self.ofx(tf.matmul(self.normalize(self.gfx(self.z1)),self.G1)) 495 | #self.z0_hat = self.z0_hat/(K * 1.0) 496 | return self.z_g + 0 497 | 498 | def _infer_latent(self, x, y, z_lat): 499 | ''' 500 | Internal routine for inferring the generative circuit's top-most latent state activity 501 | ''' 502 | y_ = y * self.y_scale 503 | 504 | z1_tm1 = z_lat[0] 505 | z2_tm1 = z_lat[1] 506 | with tf.GradientTape() as tape: 507 | z1 = tf.matmul(self.normalize(x),self.W1) + tf.matmul(self.normalize(z2_tm1),self.V2) + self.b1 508 | if self.rec_gamma > 0.0: 509 | L1 = tf.nn.relu(self.L1) 510 | L1 = L1 * self.M1 * (1. - tf.eye(self.L1.shape[0])) - L1 * tf.eye(self.L1.shape[0]) 511 | z1 = z1 - tf.matmul(z1_tm1, L1) * self.rec_gamma 512 | z1 = self.fx(z1) * (1.0 - self.alpha) + z1_tm1 * self.alpha 513 | self.z1 = z1 + 0 514 | 515 | z2 = tf.matmul(self.normalize(z1_tm1),self.W2) + tf.matmul(y_,self.V) + self.b2 516 | if self.rec_gamma > 0.0: 517 | L2 = tf.nn.relu(self.L2) 518 | L2 = L2 * self.M2 * (1. - tf.eye(self.L2.shape[0])) - L2 * tf.eye(self.L2.shape[0]) 519 | z2 = z2 - tf.matmul(z2_tm1, L2) * self.rec_gamma 520 | z2 = self.fx(z2) * (1.0 - self.alpha) + z2_tm1 * self.alpha 521 | 522 | z3_bar = self.gfx(self.z_g) 523 | z2_hat = tf.matmul(self.normalize(z3_bar),self.Gy) 524 | z2_bar = self.gfx(z2_hat) 525 | e2 = z2_bar - z2#_bar 526 | L2 = tf.reduce_mean(tf.reduce_sum(tf.math.square(e2),axis=1,keepdims=True)) 527 | 528 | d_z = tape.gradient(L2, self.z_g) 529 | self.z_g.assign(self.z_g - d_z * self.beta) 530 | 531 | z0_hat = self.ofx(tf.matmul(self.normalize(self.gfx(z1)),self.G1)) 532 | self.z0_hat = z0_hat 533 | 534 | def normalize(self, z_state, a_scale=1.0): ## norm 535 | ''' 536 | Internal routine for normalizing state vector/matrix "z_state" 537 | ''' 538 | eps = 1e-8 539 | L2 = tf.norm(z_state, ord=2, axis=1, keepdims=True) 540 | z_state = z_state / (L2 + eps) 541 | return z_state * a_scale 542 | -------------------------------------------------------------------------------- /src/plot_tsne.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code for paper "The Predictive Forward-Forward Algorithm" (Ororbia & Mali, 2022) 3 | """ 4 | 5 | import os 6 | import sys, getopt, optparse 7 | import pickle 8 | sys.path.insert(0, '../') 9 | import tensorflow as tf 10 | import numpy as np 11 | 12 | import matplotlib #.pyplot as plt 13 | matplotlib.use('Agg') 14 | import matplotlib.pyplot as plt 15 | cmap = plt.cm.jet 16 | #cmap = plt.cm.cividis # red-green color blind friendly palette 17 | 18 | ################################################################################ 19 | 20 | # GPU arguments 21 | # read in configuration file and extract necessary variables/constants 22 | options, remainder = getopt.getopt(sys.argv[1:], '', ["data_dir=","model_dir=","gpu_id=","split="]) 23 | 24 | # Collect arguments from argv 25 | model_dir = "../exp/" 26 | data_dir = "../data/" 27 | use_tsne = True # (args.getArg("use_tsne").lower() == 'true') 28 | split = "test" 29 | 30 | use_gpu = False 31 | gpu_id = -1 32 | for opt, arg in options: 33 | if opt in ("--data_dir"): 34 | data_dir = arg.strip() 35 | elif opt in ("--model_dir"): 36 | model_dir = arg.strip() 37 | elif opt in ("--split"): 38 | split = arg.strip() 39 | elif opt in ("--gpu_id"): 40 | gpu_id = int(arg.strip()) 41 | use_gpu = True 42 | 43 | mid = gpu_id 44 | if mid >= 0: 45 | print(" > Using GPU ID {0}".format(mid)) 46 | os.environ["CUDA_VISIBLE_DEVICES"]="{0}".format(mid) 47 | gpu_tag = '/GPU:0' 48 | else: 49 | os.environ["CUDA_VISIBLE_DEVICES"]="-1" 50 | gpu_tag = '/CPU:0' 51 | 52 | plot_fname = "{}/lat_viz.jpg".format(model_dir) 53 | latents_fname = "{}/latents.npy".format(model_dir) 54 | 55 | #batch_size = int(args.getArg("batch_size")) #128 56 | delimiter = "\t" 57 | # xfname = args.getArg("xfname") #"../data/mnist/trainX.tsv" 58 | yfname = "{}/{}Y.npy".format(data_dir,split) 59 | Y = tf.cast(np.load(yfname),dtype=tf.float32).numpy() 60 | 61 | with tf.device(gpu_tag): 62 | 63 | z_lat = tf.cast(np.load(latents_fname),dtype=tf.float32) 64 | print("Lat.shape = {}".format(z_lat.shape)) 65 | y_sample = Y 66 | if len(y_sample.shape) == 1: 67 | print("y_init.shape = ",y_sample.shape) 68 | nC = 10 #Y.shape[1] 69 | y_sample = tf.one_hot(tf.cast(y_sample,dtype=tf.float32).numpy().astype(np.int32), depth=nC) 70 | print("y_post.shape = ",y_sample.shape) 71 | elif y_sample.shape[1] == 1: 72 | print("y_init.shape = ",y_sample.shape) 73 | nC = 10 #Y.shape[1] 74 | y_sample = tf.one_hot(tf.squeeze(y_sample).numpy().astype(np.int32), depth=nC) 75 | print("y_post.shape = ",y_sample.shape) 76 | 77 | max_w = -10000.0 78 | min_w = 10000.0 79 | max_w = max(max_w, float(tf.reduce_max(z_lat))) 80 | min_w = min(min_w, float(tf.reduce_min(z_lat))) 81 | print("max_z = ", max_w) 82 | print("min_z = ", min_w) 83 | print("Y.shape = ",y_sample.shape) 84 | 85 | z_top_dim = z_lat.shape[1] 86 | z_2D = None 87 | if z_top_dim != 2: 88 | from sklearn.decomposition import IncrementalPCA 89 | print(" > Projecting latents via iPCA...") 90 | if use_tsne is True: 91 | n_comp = 32 #10 #16 #50 92 | if z_lat.shape[1] < n_comp: 93 | n_comp = z_lat.shape[1] - 2 #z_top.shape[1]-2 94 | n_comp = max(2, n_comp) 95 | ipca = IncrementalPCA(n_components=n_comp, batch_size=50) 96 | ipca.fit(z_lat.numpy()) 97 | z_2D = ipca.transform(z_lat.numpy()) 98 | print("PCA.lat.shape = ",z_2D.shape) 99 | print(" > Finishing projection via t-SNE...") 100 | from sklearn.manifold import TSNE 101 | z_2D = TSNE(n_components=2,perplexity=30).fit_transform(z_2D) 102 | #z_2D.shape 103 | else: 104 | ipca = IncrementalPCA(n_components=2, batch_size=50) 105 | ipca.fit(z_lat.numpy()) 106 | z_2D = ipca.transform(z_lat.numpy()) 107 | else: 108 | z_2D = z_lat 109 | 110 | print(" > Plotting 2D latent encodings...") 111 | plt.figure(figsize=(8, 6)) 112 | plt.scatter(z_2D[:, 0], z_2D[:, 1], c=np.argmax(y_sample, 1), cmap=cmap) 113 | plt.colorbar() 114 | plt.grid() 115 | plt.savefig("{0}".format(plot_fname), dpi=300) # latents.jpg 116 | plt.clf() 117 | -------------------------------------------------------------------------------- /src/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | GPU_ID=0 3 | N_TRIALS=2 # number of experimental trials to run 4 | 5 | echo " >>> Running MNIST simulation!" 6 | DATA_DIR="../data/mnist/" 7 | OUT_DIR="../exp/pff/mnist/" 8 | python sim_train.py --data_dir=$DATA_DIR --gpu_id=$GPU_ID --n_trials=$N_TRIALS --out_dir=$OUT_DIR 9 | 10 | echo " >>> Running K-MNIST simulation!" 11 | DATA_DIR="../data/kmnist/" 12 | OUT_DIR="../exp/pff/kmnist/" 13 | python sim_train.py --data_dir=$DATA_DIR --gpu_id=$GPU_ID --n_trials=$N_TRIALS --out_dir=$OUT_DIR 14 | -------------------------------------------------------------------------------- /src/sample_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code for paper "The Predictive Forward-Forward Algorithm" (Ororbia & Mali, 2022) 3 | 4 | ################################################################################ 5 | Samples a trained PFF-RNN system. 6 | ################################################################################ 7 | """ 8 | 9 | import os 10 | import sys, getopt, optparse 11 | import pickle 12 | sys.path.insert(0, '../') 13 | import tensorflow as tf 14 | import numpy as np 15 | import time 16 | import copy 17 | 18 | import matplotlib 19 | matplotlib.use('Agg') 20 | import matplotlib.pyplot as plt 21 | #cmap = plt.cm.jet 22 | 23 | from pff_rnn import PFF_RNN 24 | 25 | ################################################################################ 26 | 27 | def deserialize(fname): ## object "loading" routine 28 | fd = open(fname, 'rb') 29 | object = pickle.load( fd ) 30 | fd.close() 31 | return object 32 | 33 | def plot_img_grid(samples, fname, nx, ny, px, py, plt, rotNeg90=False): # rows, cols,... 34 | px_dim = px 35 | py_dim = py 36 | canvas = np.empty((px_dim*nx, py_dim*ny)) 37 | ptr = 0 38 | for i in range(0,nx,1): 39 | for j in range(0,ny,1): 40 | #xs = tf.expand_dims(tf.cast(samples[ptr,:],dtype=tf.float32),axis=0) 41 | xs = np.expand_dims(samples[ptr,:],axis=0) 42 | #xs = xs.numpy() #tf.make_ndarray(x_mean) 43 | xs = xs[0].reshape(px_dim, py_dim) 44 | if rotNeg90 is True: 45 | xs = np.rot90(xs, -1) 46 | canvas[(nx-i-1)*px_dim:(nx-i)*px_dim, j*py_dim:(j+1)*py_dim] = xs 47 | ptr += 1 48 | plt.figure(figsize=(12, 14)) 49 | plt.imshow(canvas, origin="upper", cmap="gray") 50 | plt.tight_layout() 51 | plt.axis('off') 52 | plt.savefig("{0}".format(fname), bbox_inches='tight', pad_inches=0) 53 | plt.clf() 54 | 55 | seed = 69 56 | tf.random.set_seed(seed=seed) 57 | np.random.seed(seed) 58 | 59 | model_dir = "../exp/" 60 | # read in configuration file and extract necessary variables/constants 61 | options, remainder = getopt.getopt(sys.argv[1:], '', ["model_dir=","gpu_id="]) 62 | # Collect arguments from argv 63 | n_trials = 1 64 | gpu_id = -1 65 | for opt, arg in options: 66 | if opt in ("--model_dir"): 67 | model_dir = arg.strip() 68 | elif opt in ("--gpu_id"): 69 | gpu_id = int(arg.strip()) 70 | 71 | gmm_fname = "{}/prior.gmm".format(model_dir) 72 | 73 | mid = gpu_id #0 74 | if mid >= 0: 75 | print(" > Using GPU ID {0}".format(mid)) 76 | os.environ["CUDA_VISIBLE_DEVICES"]="{0}".format(mid) 77 | #gpu_tag = '/GPU:0' 78 | gpu_tag = '/GPU:0' 79 | else: 80 | os.environ["CUDA_VISIBLE_DEVICES"]="-1" 81 | gpu_tag = '/CPU:0' 82 | 83 | def get_n_comp(gmm, use_sklearn=True): 84 | if use_sklearn is True: 85 | return gmm.n_components 86 | else: 87 | return gmm.k 88 | 89 | def sample_gmm(gmm, n_samps, use_sklearn=False): 90 | if use_sklearn is True: 91 | np_samps, np_labs = gmm.sample(n_samps) 92 | z_samp = tf.cast(np_samps,dtype=tf.float32) 93 | else: 94 | z_samp, z_labs = gmm.sample(n_samps) 95 | np_labs = tf.squeeze(z_labs).numpy() 96 | y_s = tf.one_hot(np_labs, get_n_comp(gmm, use_sklearn=use_sklearn)) 97 | return z_samp, y_s 98 | 99 | print("----") 100 | with tf.device(gpu_tag): 101 | print(" >> Loading prior P(z): ",gmm_fname) 102 | prior = deserialize(gmm_fname) 103 | print(" >> Loading model P(x|z): ",model_dir) 104 | agent = PFF_RNN(model_dir=model_dir) 105 | print(agent.V.shape) 106 | nrow = 10 #2 107 | ncol = 10 #4 108 | n_samp = nrow * ncol # per class 109 | print(" >> Generating confabulations from P(x|z)P(z)...") 110 | for _ in range(15): ## jitter the prior 111 | z2s, _ = sample_gmm(prior, n_samp) 112 | xs = agent.sample(y=z2s) 113 | fname = "{}samples.jpg".format(model_dir) 114 | plot_img_grid(xs.numpy(), fname, nx=nrow, ny=ncol, px=28, py=28, plt=plt) 115 | plt.close() 116 | -------------------------------------------------------------------------------- /src/sim_train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code for paper "The Predictive Forward-Forward Algorithm" (Ororbia & Mali, 2022) 3 | 4 | ################################################################################ 5 | Simulates the training/adaptation of a recurrent neural system composed of 6 | a representation and generative circuit, trained via the preditive forward-forward 7 | process. 8 | Note that this code focuses on datasets of gray-scale images/patterns. 9 | ################################################################################ 10 | """ 11 | 12 | import os 13 | import sys, getopt, optparse 14 | import pickle 15 | #import dill as pickle 16 | sys.path.insert(0, '../') 17 | import tensorflow as tf 18 | import numpy as np 19 | import time 20 | import copy 21 | 22 | import matplotlib 23 | matplotlib.use('Agg') 24 | import matplotlib.pyplot as plt 25 | #cmap = plt.cm.jet 26 | 27 | # import general simulation utilities 28 | from data_utils import DataLoader 29 | 30 | def plot_img_grid(samples, fname, nx, ny, px, py, plt, rotNeg90=False): # rows, cols,... 31 | ''' 32 | Visualizes a matrix of vector patterns in the form of an image grid plot. 33 | ''' 34 | px_dim = px 35 | py_dim = py 36 | canvas = np.empty((px_dim*nx, py_dim*ny)) 37 | ptr = 0 38 | for i in range(0,nx,1): 39 | for j in range(0,ny,1): 40 | #xs = tf.expand_dims(tf.cast(samples[ptr,:],dtype=tf.float32),axis=0) 41 | xs = np.expand_dims(samples[ptr,:],axis=0) 42 | #xs = xs.numpy() #tf.make_ndarray(x_mean) 43 | xs = xs[0].reshape(px_dim, py_dim) 44 | if rotNeg90 is True: 45 | xs = np.rot90(xs, -1) 46 | canvas[(nx-i-1)*px_dim:(nx-i)*px_dim, j*py_dim:(j+1)*py_dim] = xs 47 | ptr += 1 48 | plt.figure(figsize=(12, 14)) 49 | plt.imshow(canvas, origin="upper", cmap="gray") 50 | plt.tight_layout() 51 | plt.axis('off') 52 | #print(" SAVE: {0}{1}".format(out_dir,"gmm_decoded_samples.jpg")) 53 | plt.savefig("{0}".format(fname), bbox_inches='tight', pad_inches=0) 54 | plt.clf() 55 | 56 | from pff_rnn import PFF_RNN 57 | 58 | ################################################################################ 59 | 60 | seed = 69 61 | tf.random.set_seed(seed=seed) 62 | np.random.seed(seed) 63 | 64 | out_dir = "../exp/" 65 | data_dir = "../data/mnist/" 66 | # read in configuration file and extract necessary variables/constants 67 | options, remainder = getopt.getopt(sys.argv[1:], '', ["data_dir=","gpu_id=","n_trials=","out_dir="]) 68 | # Collect arguments from argv 69 | n_trials = 1 70 | gpu_id = -1 71 | for opt, arg in options: 72 | if opt in ("--data_dir"): 73 | data_dir = arg.strip() 74 | elif opt in ("--out_dir"): 75 | out_dir = arg.strip() 76 | elif opt in ("--gpu_id"): 77 | gpu_id = int(arg.strip()) 78 | elif opt in ("--n_trials"): 79 | n_trials = int(arg.strip()) 80 | print(" Exp out dir: ",out_dir) 81 | 82 | mid = gpu_id # 0 83 | if mid >= 0: 84 | print(" > Using GPU ID {0}".format(mid)) 85 | os.environ["CUDA_VISIBLE_DEVICES"]="{0}".format(mid) 86 | #gpu_tag = '/GPU:0' 87 | gpu_tag = '/GPU:0' 88 | else: 89 | os.environ["CUDA_VISIBLE_DEVICES"]="-1" 90 | gpu_tag = '/CPU:0' 91 | 92 | print(" >>> Run sim on {} w/ GPU {}".format(data_dir, mid)) 93 | 94 | xfname = "{}/trainX.npy".format(data_dir) 95 | yfname = "{}/trainY.npy".format(data_dir) 96 | 97 | X = ( tf.cast(np.load(xfname, allow_pickle=True),dtype=tf.float32) ) 98 | if len(X.shape) > 2: 99 | X = tf.reshape(X, [X.shape[0], X.shape[1] * X.shape[2]]) 100 | print(X.shape) 101 | x_dim = X.shape[1] 102 | max_val = float(tf.reduce_max(X)) 103 | if max_val > 1.0: 104 | X = X/max_val 105 | 106 | Y = ( tf.cast(np.load(yfname, allow_pickle=True),dtype=tf.float32) ) 107 | y_dim = Y.shape[1] 108 | print("Y.shape = ",Y.shape) 109 | 110 | n_iter = 60 #100 #60 # number of training iterations 111 | batch_size = 500 # batch size 112 | dev_batch_size = 500 # dev batch size 113 | dataset = DataLoader(design_matrices=[("x",X.numpy()),("y",Y.numpy())], batch_size=batch_size) 114 | 115 | print(" > Loading dev set") 116 | xfname = "{}/validX.npy".format(data_dir) 117 | yfname = "{}/validY.npy".format(data_dir) 118 | Xdev = ( tf.cast(np.load(xfname, allow_pickle=True),dtype=tf.float32) ) 119 | 120 | if len(Xdev.shape) > 2: 121 | Xdev = tf.reshape(Xdev, [Xdev.shape[0], Xdev.shape[1] * Xdev.shape[2]]) 122 | print("Xdev.shape = ",Xdev.shape) 123 | max_val = float(tf.reduce_max(Xdev)) 124 | if max_val > 1.0: 125 | Xdev = Xdev/max_val 126 | 127 | Ydev = ( tf.cast(np.load(yfname, allow_pickle=True),dtype=tf.float32) ) 128 | if len(Ydev.shape) == 1: 129 | nC = Y.shape[1] 130 | Ydev = tf.one_hot(Ydev.numpy().astype(np.int32), depth=nC) 131 | elif Ydev.shape[1] == 1: 132 | nC = 10 133 | Ydev = tf.one_hot(tf.squeeze(Ydev).numpy().astype(np.int32), depth=nC) 134 | print("Ydev.shape = ",Ydev.shape) 135 | 136 | devset = DataLoader(design_matrices=[("x",Xdev.numpy()), ("y",Ydev.numpy())], 137 | batch_size=dev_batch_size, disable_shuffle=True) 138 | 139 | def classify(agent, x): 140 | K_low = int(agent.K/2) - 1 # 3 141 | K_high = int(agent.K/2) + 1 # 5 142 | x_ = x 143 | Ey = None 144 | z_lat = agent.forward(x_) # do forward init pass 145 | for i in range(agent.y_dim): 146 | z_lat_ = [] 147 | for ii in range(len(z_lat)): 148 | z_lat_.append(z_lat[ii] + 0) 149 | 150 | yi = tf.ones([x.shape[0],agent.y_dim]) * tf.expand_dims(tf.one_hot(i,depth=agent.y_dim),axis=0) 151 | 152 | gi = 0.0 153 | for k in range(K_high): 154 | z_lat_, p_g = agent._step(x_, yi, z_lat_, thr=0.0) 155 | if k >= K_low and k <= K_high: # only keep goodness in middle iterations 156 | gi = ((p_g[0] + p_g[1])*0.5) + gi 157 | 158 | if i > 0: 159 | Ey = tf.concat([Ey,gi],axis=1) 160 | else: 161 | Ey = gi 162 | 163 | Ey = Ey / (3.0) 164 | y_hat = tf.nn.softmax(Ey) 165 | return y_hat, Ey 166 | 167 | def eval(agent, dataset, debug=False, save_img=True, out_dir=""): 168 | ''' 169 | Evaluates the current state of the agent given a dataset (data-loader). 170 | ''' 171 | N = 0.0 172 | Ny = 0.0 173 | Acc = 0.0 174 | Ly = 0.0 175 | Lx = 0.0 176 | tt = 0 177 | #debug = True 178 | for batch in dataset: 179 | _, x = batch[0] 180 | _, y = batch[1] 181 | N += x.shape[0] 182 | Ny += float(tf.reduce_sum(y)) 183 | 184 | y_hat, Ey = classify(agent, x) 185 | Ly += tf.reduce_sum(-tf.reduce_sum(y * tf.math.log(y_hat), axis=1, keepdims=True)) 186 | 187 | z_lat = agent.forward(x) 188 | x_hat = agent.sample(z=z_lat[len(z_lat)-2]) 189 | 190 | ex = x_hat - x 191 | Lx += tf.reduce_sum(tf.reduce_sum(tf.math.square(ex),axis=1,keepdims=True)) 192 | 193 | if debug == True: 194 | print("------------------------") 195 | print(Ey[0:4,:]) 196 | print(y_hat[0:4,:]) 197 | print("------------------------") 198 | 199 | y_ind = tf.cast(tf.argmax(y,1),dtype=tf.int32) 200 | y_pred = tf.cast(tf.argmax(Ey,1),dtype=tf.int32) 201 | comp = tf.cast(tf.equal(y_pred,y_ind),dtype=tf.float32) #* y_m 202 | Acc += tf.reduce_sum( comp ) 203 | Ly = Ly/Ny 204 | Acc = Acc/Ny 205 | Lx = Lx/Ny 206 | 207 | if save_img == True: 208 | fname = "{}/x_samples.png".format(out_dir) 209 | plot_img_grid(x_hat.numpy(), fname, nx=10, ny=10, px=28, py=28, plt=plt) 210 | plt.close() 211 | 212 | fname = "{}/x_data.png".format(out_dir) 213 | plot_img_grid(x, fname, nx=10, ny=10, px=28, py=28, plt=plt) 214 | plt.close() 215 | 216 | return Ly, Acc, Lx 217 | 218 | print("----") 219 | with tf.device(gpu_tag): 220 | 221 | best_acc_list = [] 222 | acc_list = [] 223 | for tr in range(n_trials): 224 | acc_scores = [] # tracks acc during training w/in a trial 225 | ######################################################################## 226 | ## create model 227 | ######################################################################## 228 | model_dir = "{}/trial{}/".format(out_dir, tr) 229 | if not os.path.exists(model_dir): 230 | os.makedirs(model_dir) 231 | args = {"x_dim": x_dim, 232 | "y_dim": y_dim, 233 | "n_units": 2000, 234 | "K":12, 235 | "thr": 10.0, 236 | "eps_r": 0.01, 237 | "eps_g": 0.025} 238 | agent = PFF_RNN(args=args) 239 | ## set up optimization 240 | eta = 0.00025 # 0.0005 # for grnn 241 | reg_lambda = 0.0 242 | #reg_lambda = 0.0001 # works nice for kmnist 243 | opt = tf.keras.optimizers.Adam(eta) 244 | 245 | g_eta = 0.00025 # 0.0005 #0.001 #0.0005 #0.001 246 | g_reg_lambda = 0 247 | g_opt = tf.keras.optimizers.Adam(g_eta) 248 | ######################################################################## 249 | 250 | ######################################################################## 251 | ## begin simulation 252 | ######################################################################## 253 | Ly, Acc, Lx = eval(agent, devset, out_dir=model_dir) 254 | acc_scores.append(Acc) 255 | print("{}: L {} Acc = {} Lx {}".format(-1, Ly, Acc, Lx)) 256 | 257 | best_Acc = Acc 258 | best_Ly = Ly 259 | for t in range(n_iter): 260 | N = 0.0 261 | Ng = 0.0 262 | Lg = 0.0 263 | Ly = 0.0 264 | for batch in dataset: 265 | _, x = batch[0] 266 | _, y = batch[1] 267 | N += x.shape[0] 268 | 269 | # create negative data (x, y_neg) 270 | x_neg = x 271 | y_neg = tf.random.uniform(y.shape, 0.0, 1.0) * (1.0 - y) 272 | y_neg = tf.one_hot(tf.argmax(y_neg,axis=1), depth=agent.y_dim) 273 | 274 | ## create full batch 275 | x_ = tf.concat([x,x_neg],axis=0) 276 | y_ = tf.concat([y,y_neg],axis=0) 277 | lab = tf.concat([tf.ones([x.shape[0],1]),tf.zeros([x_neg.shape[0],1])],axis=0) 278 | ## update model given full batch 279 | z_lat = agent.forward(x_) 280 | Lg_t, _, Lgen_t, x_hat = agent.infer(x_, y_, lab, z_lat, agent.K, opt, g_opt, reg_lambda=reg_lambda, g_reg_lambda=g_reg_lambda) # total goodness 281 | Lg_t = Lg_t * (x.shape[0] + x_neg.shape[0]) 282 | 283 | #y_hat = y_hat[0:x.shape[0],:] 284 | y_hat = agent.classify(x) 285 | Ly_t = tf.reduce_sum(-tf.reduce_sum(y * tf.math.log(y_hat), axis=1, keepdims=True)) 286 | 287 | ## track losses 288 | Ly = Ly_t + Ly 289 | Lg = Lg_t + Lg 290 | Ng += (x.shape[0] + x_neg.shape[0]) 291 | 292 | print("\r {}: Ly = {} L = {} w/ {} samples".format(t, Ly/N, Lg/Ng, N), end="") 293 | print() 294 | print("--------------------------------------") 295 | 296 | Ly, Acc, Lx = eval(agent, devset, out_dir=model_dir) 297 | acc_scores.append(Acc) 298 | np.save("{}/dev_acc.npy".format(model_dir), np.asarray(acc_scores)) 299 | print("{}: L {} Acc = {} Lx {}".format(t, Ly, Acc, Lx)) 300 | 301 | if Acc > best_Acc: 302 | best_Acc = Acc 303 | best_Ly = Ly 304 | 305 | print(" >> Saving model to: ",model_dir) 306 | agent.save_model(model_dir) 307 | 308 | print("************") 309 | Ly, Acc, _ = eval(agent, dataset, out_dir=model_dir, save_img=False) 310 | print(" Train: Ly {} Acc = {}".format(Ly, Acc)) 311 | print("Best.Dev: Ly {} Acc = {}".format(best_Ly, best_Acc)) 312 | 313 | acc_list.append(1.0 - Acc) 314 | best_acc_list.append(1.0 - best_Acc) 315 | 316 | ############################################################################ 317 | ## calc post-trial statistics 318 | n_dec = 4 319 | mu = round(np.mean(np.asarray(best_acc_list)), n_dec) 320 | sd = round(np.std(np.asarray(best_acc_list)), n_dec) 321 | print(" Dev.Error = {:.4f} \pm {:.4f}".format(mu, sd)) 322 | 323 | ## store result to disk just in case... 324 | results_fname = "{}/post_train_results.txt".format(out_dir) 325 | log_t = open(results_fname,"a") 326 | log_t.write("Generalization Results:\n") 327 | log_t.write(" Dev.Error = {:.4f} \pm {:.4f}\n".format(mu, sd)) 328 | 329 | n_dec = 4 330 | mu = round(np.mean(np.asarray(acc_list)), n_dec) 331 | sd = round(np.std(np.asarray(acc_list)), n_dec) 332 | print(" Train.Error = {:.4f} \pm {:.4f}".format(mu, sd)) 333 | log_t.write("Training-Set/Optimization Results:\n") 334 | log_t.write(" Train.Error = {:.4f} \pm {:.4f}\n".format(mu, sd)) 335 | 336 | log_t.close() # close the log file 337 | --------------------------------------------------------------------------------