├── LICENSE
├── README.md
├── data.tar.xz
├── exp
└── pff
│ ├── kmnist
│ └── tmp.txt
│ └── mnist
│ └── tmp.txt
├── fig
└── pff_config.png
└── src
├── analyze.sh
├── data_utils.py
├── eval_model.py
├── fit_gmm.py
├── pff_rnn.py
├── plot_tsne.py
├── run.sh
├── sample_model.py
└── sim_train.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Alex Ororbia and Ankur Mali
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | #
The Predictive Forward-Forward Algorithm
2 | ## Bio-plausible Forward-Only Learning for Training Neural Networks
3 | Implementation of the proposed predictive forward-forward (PFF) learning algorithm for training a neurobiologically-plausible recurrent neural system. This work combines elements of predictive coding with the recently proposed forward-forward algorithm to create a novel online learning process that involves dynamically adapting two neural circuits - a representation circuit and a generative circuit. Notably, the system introduces noise injection into the latent activity updates as well as learnable lateral synapses that induce competition across neural units (emulating cross-inhibitory and self-excitation
4 | effects inherent to neural computation).
5 |
6 | # Requirements
7 | Our implementation is easy to follow and, with knowledge of basic linear algebra, one can decode the inner workings of the PFF algorithm. Please look at Algorithm 1 in our paper (in the Appendix) to better understand the overall mechanics of the inference and learning processes. In this framework, we have provided simple modules; thus hopefully making it very convenient to extend our framework.
8 |
9 | To run the code, you should only need following basic packages:
10 | 1. TensorFlow (version >= 2.0)
11 | 2. Numpy
12 | 3. Matplotlib
13 | 4. Python (version >=3.5)
14 | 5. [ngc-learn](https://github.com/ago109/ngc-learn) (Some modules responsible for generating image samples are dependent on ngc-learn -- if you do not install ngc-learn, you won't be able to use `fit_gmm.py`, as this script uses the mixture model
15 | in that package to retro-fit the latent prior for the PFF model's generative circuit, which means that `sample_model.py`
16 | and `plot_tsne.py` will have no prior distribution model to access, so simply comment out the lines that
17 | involve `sample_model.py` and `plot_tsne.py` in the `analyze.sh` script if you do not install ngc-learn).
18 |
19 | # Execution
20 |
21 | To reproduce results from our paper, simply perform the following steps (running the relevant provided Bash scripts) the following provided Bash scripts:
22 | 1. `bash src/run.sh` (This will train model for `E=60` epochs.)
23 | 2. `bash src/analyze.sh` (This will evaluate a trained model and produce plots/visuals.)
24 | After running the above two scripts, you can find the simulation outputs in the example
25 | experimental results directory tree that have been pre-created for you.
26 | `exp/pff/mnist/` contains the results for the MNIST model (over 2 trials) and
27 | `exp/pff/kmnist/` contains the results for the KMNIST model (over 2 trials).
28 | In each directory, the following is stored:
29 | * `post_train_results.txt` - contains development/training cross-trial accuracy values
30 | * `test_results.txt` - contains test cross-trial accuracy values
31 | * `trial0/` - contains model data for trial 0, as well as any visuals produced by `analyze.sh`
32 | * `trial1` - contains model data for trial 1, as well as any visuals produced by `analyze.sh`
33 | (Note that you should modify the `MODEL_DIR` in `analyze.sh` to point to a particular
34 | trial's sub-directory -- the default points to trial 0, and thus only places images
35 | inside of the `trial0/` sub-directoy. )
36 |
37 | Model-specific hyper-parameter defaults can be set/adjusted in `pff_rnn.py`.
38 | Training-specific hyper-parameters are available in `sim_train.py` - note that one
39 | can create/edit an arguments dictionary much like the one depicted below (inside of `sim_train.py`):
40 |
41 |
42 |
43 | which the `PFF_RNN()` constructor takes in as input to construct the simulation of
44 | the dual-circuit system.
45 |
46 | Tips while using this algorithm/model on your own datasets:
47 | 1. Track your local losses, accordingly adjust the hyper-parameters for the model
48 | 2. Play with non-zero, small values for the weight decay coefficients `reg_lambda` (for
49 | the representation circuit) and `g_reg_lambda` (for the generative circuit) - for
50 | K-MNIST a small value (as indicated in the comments) for `reg_lambda` seemed to
51 | improve generalization performance slightly in our experience.
52 |
53 | # Citation
54 |
55 | If you use or adapt (portions of) this code/algorithm in any form in your project(s), or
56 | find the PFF algorithm helpful in your own work, please cite this code's source paper:
57 |
58 | ```bibtex
59 | @article{ororbia2023predictive,
60 | title={The Predictive Forward-Forward Algorithm},
61 | author={Ororbia, Alexander and Mali, Ankur},
62 | journal={arXiv preprint arXiv:2301.01452},
63 | year={2023}
64 | }
65 | ```
66 |
--------------------------------------------------------------------------------
/data.tar.xz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ago109/predictive-forward-forward/adeb918941afaafb11bc9f1b0953dae2d7dd1f13/data.tar.xz
--------------------------------------------------------------------------------
/exp/pff/kmnist/tmp.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ago109/predictive-forward-forward/adeb918941afaafb11bc9f1b0953dae2d7dd1f13/exp/pff/kmnist/tmp.txt
--------------------------------------------------------------------------------
/exp/pff/mnist/tmp.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ago109/predictive-forward-forward/adeb918941afaafb11bc9f1b0953dae2d7dd1f13/exp/pff/mnist/tmp.txt
--------------------------------------------------------------------------------
/fig/pff_config.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ago109/predictive-forward-forward/adeb918941afaafb11bc9f1b0953dae2d7dd1f13/fig/pff_config.png
--------------------------------------------------------------------------------
/src/analyze.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | GPU_ID=0
3 | N_TRIALS=2
4 |
5 | ## Analyze MNIST
6 | MODEL_TOPDIR="../exp/pff/mnist/"
7 | MODEL_DIR="../exp/pff/mnist/trial0/" ## <-- choose a specific trial idx to analyze
8 | DATA_DIR="../data/mnist/"
9 |
10 | echo " ---------- Evaluating MNIST Test Performance ---------- "
11 | python eval_model.py --data_dir=$DATA_DIR --split=test --model_topdir=$MODEL_TOPDIR --gpu_id=$GPU_ID --n_trials=$N_TRIALS --out_dir=$MODEL_TOPDIR
12 |
13 | # for prior distribution fitting / sampling
14 | echo " ---------- Fitting MNIST Prior ---------- "
15 | python fit_gmm.py --data_dir=$DATA_DIR --gpu_id=$GPU_ID --split=train --model_dir=$MODEL_DIR
16 | echo " ---------- Sampling MNIST Model ---------- "
17 | python sample_model.py --gpu_id=$GPU_ID --model_dir=$MODEL_DIR
18 |
19 | # for latent code visualization
20 | echo " ---------- Extracting MNIST Test Latents ---------- "
21 | python fit_gmm.py --data_dir=$DATA_DIR --model_dir=$MODEL_DIR --gpu_id=$GPU_ID --split=test --disable_prior=1
22 | echo " ---------- Visualizing MNIST Test Latents ---------- "
23 | python plot_tsne.py --data_dir=$DATA_DIR --model_dir=$MODEL_DIR --gpu_id=$GPU_ID --split=test
24 |
25 |
26 | ## Analyze K-MNIST
27 | MODEL_TOPDIR="../exp/pff/kmnist/"
28 | MODEL_DIR="../exp/pff/kmnist/trial0/" ## <-- choose a specific trial idx to analyze
29 | DATA_DIR="../data/kmnist/"
30 |
31 | echo " ---------- Evaluating K-MNIST Test Performance ---------- "
32 | python eval_model.py --data_dir=$DATA_DIR --split=test --model_topdir=$MODEL_TOPDIR --gpu_id=$GPU_ID --n_trials=$N_TRIALS --out_dir=$MODEL_TOPDIR
33 |
34 | # for prior distribution fitting / sampling
35 | echo " ---------- Fitting K-MNIST Prior ---------- "
36 | python fit_gmm.py --data_dir=$DATA_DIR --gpu_id=$GPU_ID --split=train --model_dir=$MODEL_DIR
37 | echo " ---------- Sampling K-MNIST Model ---------- "
38 | python sample_model.py --gpu_id=$GPU_ID --model_dir=$MODEL_DIR
39 |
40 | # for latent code visualization
41 | echo " ---------- Extracting K-MNIST Test Latents ---------- "
42 | python fit_gmm.py --data_dir=$DATA_DIR --model_dir=$MODEL_DIR --gpu_id=$GPU_ID --split=test --disable_prior=1
43 | echo " ---------- Visualizing K-MNIST Test Latents ---------- "
44 | python plot_tsne.py --data_dir=$DATA_DIR --model_dir=$MODEL_DIR --gpu_id=$GPU_ID --split=test
45 |
--------------------------------------------------------------------------------
/src/data_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Data functions and utilities.
3 |
4 | Code for paper "The Predictive Forward-Forward Algorithm" (Ororbia & Mali, 2022)
5 | """
6 | import random
7 | import numpy as np
8 | import io
9 | import sys
10 | import math
11 |
12 | seed = 69
13 | # set your random seed - very important to reproduce any Machine/Deep learning work
14 | np.random.seed(seed)
15 |
16 | class DataLoader(object):
17 | """
18 | A data loader object, meant to allow sampling w/o replacement of one or
19 | more named design matrices. Note that this object is iterable (and
20 | implements an __iter__() method).
21 |
22 | Args:
23 | design_matrices: list of named data design matrices - [("name", matrix), ...]
24 |
25 | batch_size: number of samples to place inside a mini-batch
26 |
27 | disable_shuffle: if True, turns off sample shuffling (thus no sampling w/o replacement)
28 |
29 | ensure_equal_batches: if True, ensures sampled batches are equal in size (Default = True).
30 | Note that this means the very last batch, if it's not the same size as the rest, will
31 | reuse random samples from previously seen batches (yielding a batch with a mix of
32 | vectors sampled with and without replacement).
33 | """
34 | def __init__(self, design_matrices, batch_size, disable_shuffle=False,
35 | ensure_equal_batches=True):
36 | self.batch_size = batch_size
37 | self.ensure_equal_batches = ensure_equal_batches
38 | self.disable_shuffle = disable_shuffle
39 | self.design_matrices = design_matrices
40 | if len(design_matrices) < 1:
41 | print(" ERROR: design_matrices must contain at least one design matrix!")
42 | sys.exit(1)
43 | self.data_len = len( self.design_matrices[0][1] )
44 | self.ptrs = np.arange(0, self.data_len, 1)
45 | if self.data_len < self.batch_size:
46 | print("ERROR: batch size {} is > total number data samples {}".format(
47 | self.batch_size, self.data_len))
48 | sys.exit(1)
49 |
50 | def __iter__(self):
51 | """
52 | Yields a mini-batch of the form: [("name", batch),("name",batch),...]
53 | """
54 | if self.disable_shuffle == False:
55 | self.ptrs = np.random.permutation(self.data_len)
56 | idx = 0
57 | while idx < len(self.ptrs): # go through each sample via the sampling pointer
58 | e_idx = idx + self.batch_size
59 | if e_idx > len(self.ptrs): # prevents reaching beyond length of dataset
60 | e_idx = len(self.ptrs)
61 | # extract sampling integer pointers
62 | indices = self.ptrs[idx:e_idx]
63 | if self.ensure_equal_batches == True:
64 | if indices.shape[0] < self.batch_size:
65 | diff = self.batch_size - indices.shape[0]
66 | indices = np.concatenate((indices, self.ptrs[0:diff]))
67 | # create the actual pattern vector batch block matrices
68 | data_batch = []
69 | for dname, dmatrix in self.design_matrices:
70 | x_batch = dmatrix[indices]
71 | data_batch.append((dname, x_batch))
72 | yield data_batch
73 | idx = e_idx
74 |
--------------------------------------------------------------------------------
/src/eval_model.py:
--------------------------------------------------------------------------------
1 | """
2 | Code for paper "The Predictive Forward-Forward Algorithm" (Ororbia & Mali, 2022)
3 |
4 | ################################################################################
5 | Simulates the training/adaptation of a recurrent neural system composed of
6 | a representation and generative circuit, trained via the preditive forward-forward
7 | process.
8 | Note that this code focuses on datasets of gray-scale images/patterns.
9 | ################################################################################
10 | """
11 |
12 | import os
13 | import sys, getopt, optparse
14 | import pickle
15 | #import dill as pickle
16 | sys.path.insert(0, '../')
17 | import tensorflow as tf
18 | import numpy as np
19 | import time
20 | import copy
21 |
22 | import matplotlib
23 | matplotlib.use('Agg')
24 | import matplotlib.pyplot as plt
25 | #cmap = plt.cm.jet
26 |
27 | # import general simulation utilities
28 | from data_utils import DataLoader
29 | from pff_rnn import PFF_RNN
30 |
31 | ################################################################################
32 |
33 | seed = 69
34 | tf.random.set_seed(seed=seed)
35 | np.random.seed(seed)
36 |
37 | model_topdir = "../exp/"
38 | out_dir = "../exp/"
39 | data_dir = "../data/mnist/"
40 | split = "test"
41 | # read in configuration file and extract necessary variables/constants
42 | options, remainder = getopt.getopt(sys.argv[1:], '', ["data_dir=","model_topdir=",
43 | "gpu_id=","n_trials=",
44 | "out_dir=","split="])
45 | # Collect arguments from argv
46 | n_trials = 1
47 | gpu_id = -1
48 | for opt, arg in options:
49 | if opt in ("--data_dir"):
50 | data_dir = arg.strip()
51 | elif opt in ("--model_topdir"):
52 | model_topdir = arg.strip()
53 | elif opt in ("--split"):
54 | split = arg.strip()
55 | elif opt in ("--out_dir"):
56 | out_dir = arg.strip()
57 | elif opt in ("--gpu_id"):
58 | gpu_id = int(arg.strip())
59 | elif opt in ("--n_trials"):
60 | n_trials = int(arg.strip())
61 | print(" Exp out dir: ",out_dir)
62 |
63 | mid = gpu_id # 0
64 | if mid >= 0:
65 | print(" > Using GPU ID {0}".format(mid))
66 | os.environ["CUDA_VISIBLE_DEVICES"]="{0}".format(mid)
67 | #gpu_tag = '/GPU:0'
68 | gpu_tag = '/GPU:0'
69 | else:
70 | os.environ["CUDA_VISIBLE_DEVICES"]="-1"
71 | gpu_tag = '/CPU:0'
72 |
73 | print(" >>> Run sim on {} w/ GPU {}".format(data_dir, mid))
74 |
75 | dev_batch_size = 500 # dev batch size
76 | print(" > Loading dev set")
77 | xfname = "{}/{}X.npy".format(data_dir, split)
78 | yfname = "{}/{}Y.npy".format(data_dir, split)
79 | Xdev = ( tf.cast(np.load(xfname, allow_pickle=True),dtype=tf.float32) )
80 |
81 | if len(Xdev.shape) > 2:
82 | Xdev = tf.reshape(Xdev, [Xdev.shape[0], Xdev.shape[1] * Xdev.shape[2]])
83 | print("Xdev.shape = ",Xdev.shape)
84 | max_val = float(tf.reduce_max(Xdev))
85 | if max_val > 1.0:
86 | Xdev = Xdev/max_val
87 |
88 | Ydev = ( tf.cast(np.load(yfname, allow_pickle=True),dtype=tf.float32) )
89 | if len(Ydev.shape) == 1:
90 | nC = Y.shape[1]
91 | Ydev = tf.one_hot(Ydev.numpy().astype(np.int32), depth=nC)
92 | elif Ydev.shape[1] == 1:
93 | nC = 10
94 | Ydev = tf.one_hot(tf.squeeze(Ydev).numpy().astype(np.int32), depth=nC)
95 | print("Ydev.shape = ",Ydev.shape)
96 |
97 | devset = DataLoader(design_matrices=[("x",Xdev.numpy()), ("y",Ydev.numpy())],
98 | batch_size=dev_batch_size, disable_shuffle=True)
99 |
100 | def classify(agent, x):
101 | K_low = int(agent.K/2) - 1 # 3
102 | K_high = int(agent.K/2) + 1 # 5
103 | x_ = x
104 | Ey = None
105 | z_lat = agent.forward(x_) # do forward init pass
106 | for i in range(agent.y_dim):
107 | z_lat_ = []
108 | for ii in range(len(z_lat)):
109 | z_lat_.append(z_lat[ii] + 0)
110 |
111 | yi = tf.ones([x.shape[0],agent.y_dim]) * tf.expand_dims(tf.one_hot(i,depth=agent.y_dim),axis=0)
112 |
113 | gi = 0.0
114 | for k in range(K_high):
115 | z_lat_, p_g = agent._step(x_, yi, z_lat_, thr=0.0)
116 | if k >= K_low and k <= K_high: # only keep goodness in middle iterations
117 | gi = ((p_g[0] + p_g[1])*0.5) + gi
118 |
119 | if i > 0:
120 | Ey = tf.concat([Ey,gi],axis=1)
121 | else:
122 | Ey = gi
123 |
124 | Ey = Ey / (3.0)
125 | y_hat = tf.nn.softmax(Ey)
126 | return y_hat, Ey
127 |
128 | def eval(agent, dataset, debug=False):
129 | '''
130 | Evaluates the current state of the agent given a dataset (data-loader).
131 | '''
132 | N = 0.0
133 | Ny = 0.0
134 | Acc = 0.0
135 | Ly = 0.0
136 | Lx = 0.0
137 | tt = 0
138 | #debug = True
139 | for batch in dataset:
140 | _, x = batch[0]
141 | #_, y_tag = batch[1]
142 | _, y = batch[1]
143 | N += x.shape[0]
144 | Ny += float(tf.reduce_sum(y))
145 |
146 | y_hat, Ey = classify(agent, x)
147 | Ly += tf.reduce_sum(-tf.reduce_sum(y * tf.math.log(y_hat), axis=1, keepdims=True))
148 |
149 | z_lat = agent.forward(x)
150 | x_hat = agent.sample(z=z_lat[len(z_lat)-2])
151 |
152 | ex = x_hat - x
153 | Lx += tf.reduce_sum(tf.reduce_sum(tf.math.square(ex),axis=1,keepdims=True))
154 |
155 | if debug == True:
156 | print("------------------------")
157 | print(Ey[0:4,:])
158 | print(y_hat[0:4,:])
159 | print("------------------------")
160 |
161 | #y_m = tf.squeeze(y_m)
162 | y_ind = tf.cast(tf.argmax(y,1),dtype=tf.int32)
163 | y_pred = tf.cast(tf.argmax(Ey,1),dtype=tf.int32)
164 | comp = tf.cast(tf.equal(y_pred,y_ind),dtype=tf.float32) #* y_m
165 | Acc += tf.reduce_sum( comp )
166 | print("\r Lx = {} Ly = {} Acc = {} ({} samples)".format(Lx/Ny,Ly/Ny,Acc/Ny,Ny),end="")
167 | print()
168 | Ly = Ly/Ny
169 | Acc = Acc/Ny
170 | Lx = Lx/Ny
171 |
172 | return Ly, Acc, Lx
173 |
174 | print("----")
175 | with tf.device(gpu_tag):
176 |
177 | best_acc_list = []
178 | acc_list = []
179 | for tr in range(n_trials):
180 | ########################################################################
181 | ## load model
182 | ########################################################################
183 | agent = PFF_RNN(model_dir="{}trial{}/".format(model_topdir, tr))
184 | K = agent.K
185 | ########################################################################
186 |
187 | ########################################################################
188 | Ly, Acc, Lx = eval(agent, devset)
189 | print("{}: L {} Acc = {} Lx {}".format(-1, Ly, Acc, Lx))
190 |
191 | acc_list.append(1.0 - Acc)
192 |
193 | ############################################################################
194 | ## calc post-trial statistics
195 | n_dec = 4
196 | mu = round(np.mean(np.asarray(acc_list)), n_dec)
197 | sd = round(np.std(np.asarray(acc_list)), n_dec)
198 | print(" Test.Error = {:.4f} \pm {:.4f}".format(mu, sd))
199 |
200 | ## store result to disk just in case...
201 | results_fname = "{}/test_results.txt".format(out_dir)
202 | log_t = open(results_fname,"a")
203 | log_t.write("Generalization Results:\n")
204 | log_t.write(" Test.Error = {:.4f} \pm {:.4f}\n".format(mu, sd))
205 | log_t.close()
206 |
--------------------------------------------------------------------------------
/src/fit_gmm.py:
--------------------------------------------------------------------------------
1 | """
2 | Code for paper "The Predictive Forward-Forward Algorithm" (Ororbia & Mali, 2022)
3 |
4 | ################################################################################
5 | Fits a multi-modal/mixture prior to the space of a trained PFF-RNN
6 | ################################################################################
7 | """
8 |
9 | import os
10 | import sys, getopt, optparse
11 | import pickle
12 | sys.path.insert(0, '../')
13 | import tensorflow as tf
14 | import numpy as np
15 | import time
16 | import copy
17 |
18 | import matplotlib
19 | matplotlib.use('Agg')
20 | import matplotlib.pyplot as plt
21 | #cmap = plt.cm.jet
22 |
23 | from data_utils import DataLoader
24 | from pff_rnn import PFF_RNN
25 |
26 | # import general simulation utilities
27 | from ngclearn.density.gmm import GMM
28 |
29 | ################################################################################
30 |
31 | seed = 69
32 | tf.random.set_seed(seed=seed)
33 | np.random.seed(seed)
34 |
35 | disable_prior = 0
36 | data_dir = "../data/"
37 | model_dir = "../exp/"
38 | split = "train"
39 | # read in configuration file and extract necessary variables/constants
40 | options, remainder = getopt.getopt(sys.argv[1:], '', ["data_dir=","model_dir=","gpu_id=","split=",
41 | "disable_prior="])
42 | # Collect arguments from argv
43 | n_trials = 1
44 | gpu_id = -1
45 | for opt, arg in options:
46 | if opt in ("--data_dir"):
47 | data_dir = arg.strip()
48 | elif opt in ("--model_dir"):
49 | model_dir = arg.strip()
50 | elif opt in ("--split"):
51 | split = arg.strip()
52 | elif opt in ("--disable_prior"):
53 | disable_prior = int(arg.strip())
54 | elif opt in ("--gpu_id"):
55 | gpu_id = int(arg.strip())
56 |
57 | gmm_fname = "{}/prior.gmm".format(model_dir)
58 | latent_fname = "{}/latents.npy".format(model_dir)
59 | batch_size = 400 #200 #100 #50 #1000 #500
60 |
61 | mid = gpu_id
62 | if mid >= 0:
63 | print(" > Using GPU ID {0}".format(mid))
64 | os.environ["CUDA_VISIBLE_DEVICES"]="{0}".format(mid)
65 | #gpu_tag = '/GPU:0'
66 | gpu_tag = '/GPU:0'
67 | else:
68 | os.environ["CUDA_VISIBLE_DEVICES"]="-1"
69 | gpu_tag = '/CPU:0'
70 |
71 |
72 | xfname = "{}{}X.npy".format(data_dir,split)
73 | X = ( tf.cast(np.load(xfname),dtype=tf.float32) )
74 | if len(X.shape) > 2:
75 | X = tf.reshape(X, [X.shape[0], X.shape[1] * X.shape[2]])
76 | x_dim = X.shape[1]
77 | max_val = float(tf.reduce_max(X))
78 | if max_val > 1.0:
79 | X = X/max_val
80 | print("X.shape = ",X.shape)
81 |
82 | yfname = "{}{}Y.npy".format(data_dir,split)
83 | Y = ( tf.cast(np.load(yfname),dtype=tf.float32) )
84 | if len(Y.shape) == 1:
85 | print("y_init.shape = ",Y.shape)
86 | nC = 10 #Y.shape[1]
87 | Y = tf.one_hot(Y.numpy().astype(np.int32), depth=nC)
88 | print("y_post.shape = ",Y.shape)
89 | elif Y.shape[1] == 1:
90 | print("y_init.shape = ",Y.shape)
91 | nC = 10 #Y.shape[1]
92 | Y = tf.one_hot(tf.squeeze(Y).numpy().astype(np.int32), depth=nC)
93 | print("y_post.shape = ",Y.shape)
94 | y_dim = Y.shape[1]
95 |
96 | dataset = DataLoader(design_matrices=[("x",X.numpy()),("y",Y.numpy())],
97 | batch_size=batch_size, disable_shuffle=True)
98 |
99 | def calc_latent_map(agent, dataset, debug=False):
100 | '''
101 | Calculates a latent "map" or matrix containing the latent encoding of
102 | a dataset/loader.
103 | '''
104 | z = None
105 | N = 0.0
106 | L_r = 0.0
107 | for batch in dataset:
108 | _, x = batch[0]
109 | _, y = batch[1]
110 |
111 | z2 = agent.get_latent(x, y, K, use_y_hat=True)
112 | e = agent.z0_hat - x
113 | Li = tf.reduce_sum(tf.math.square(e) * 0.5)
114 | L_r = Li + L_r
115 |
116 | if z is not None:
117 | z = tf.concat([z,z2],axis=0)
118 | else:
119 | z = z2
120 | N += x.shape[0]
121 | L_r = L_r / N
122 | return z, L_r
123 |
124 |
125 | print("----")
126 | with tf.device(gpu_tag):
127 | print(" >> Loading model from: ",model_dir)
128 | agent = PFF_RNN(model_dir=model_dir)
129 | agent.z1 = None
130 | agent.z0_hat = None
131 | K = agent.K
132 |
133 | z_map, Lr = calc_latent_map(agent, dataset)
134 | print("MSE: ",float(Lr))
135 |
136 | # save latent space map to disk
137 | print(" > Saving latents to disk: lat.shape = ",z_map.shape)
138 | np.save(latent_fname, z_map)
139 |
140 | max_w = -10000.0
141 | min_w = 10000.0
142 | max_w = max(max_w, float(tf.reduce_max(z_map)))
143 | min_w = min(min_w, float(tf.reduce_min(z_map)))
144 | print("max_z = ", max_w)
145 | print("min_z = ", min_w)
146 |
147 | if disable_prior == 0:
148 | print(" > Estimating latent density...")
149 | n_comp = 10 # number of compoments that will define the prior P(z)
150 | lat_density = GMM(k=n_comp)
151 | lat_density.fit(z_map)
152 |
153 | print(" > Saving density estimator to: {0}".format("gmm.pkl"))
154 | fd = open("{0}".format(gmm_fname), 'wb')
155 | pickle.dump(lat_density, fd)
156 | fd.close()
157 |
--------------------------------------------------------------------------------
/src/pff_rnn.py:
--------------------------------------------------------------------------------
1 | """
2 | Code for paper "The Predictive Forward-Forward Algorithm" (Ororbia & Mali, 2022)
3 |
4 | This file contains model constructor and its credit assignment code.
5 | """
6 |
7 | import os
8 | import sys
9 | import copy
10 | import pickle
11 | #import dill as pickle
12 | import tensorflow as tf
13 | import numpy as np
14 |
15 | ### generic routines/functions
16 |
17 | def serialize(fname, object): ## object "saving" routine
18 | fd = open(fname, 'wb')
19 | pickle.dump(object, fd)
20 | fd.close()
21 |
22 | def deserialize(fname): ## object "loading" routine
23 | fd = open(fname, 'rb')
24 | object = pickle.load( fd )
25 | fd.close()
26 | return object
27 |
28 | @tf.function
29 | def create_competiion_matrix(z_dim, n_group, beta_scale=1.0, alpha_scale=1.0):
30 | """
31 | Competition matrix initialization function, adapted from
32 | (Ororbia & Kifer 2022; Nature Communications).
33 | """
34 | diag = tf.eye(z_dim)
35 | V_l = None
36 | g_shift = 0
37 | while (z_dim - (n_group + g_shift)) >= 0:
38 | if g_shift > 0:
39 | left = tf.zeros([1,g_shift])
40 | middle = tf.ones([1,n_group])
41 | right = tf.zeros([1,z_dim - (n_group + g_shift)])
42 | slice = tf.concat([left,middle,right],axis=1)
43 | for n in range(n_group):
44 | V_l = tf.concat([V_l,slice],axis=0)
45 | else:
46 | middle = tf.ones([1,n_group])
47 | right = tf.zeros([1,z_dim - n_group])
48 | slice = tf.concat([middle,right],axis=1)
49 | for n in range(n_group):
50 | if V_l is not None:
51 | V_l = tf.concat([V_l,slice],axis=0)
52 | else:
53 | V_l = slice
54 | g_shift += n_group
55 | V_l = V_l * (1.0 - diag) * beta_scale + diag * alpha_scale
56 | return V_l
57 |
58 | @tf.function
59 | def softmax(x, tau=0.0): ## temperature-controlled softmax activation
60 | if tau > 0.0:
61 | x = x / tau
62 | max_x = tf.expand_dims( tf.reduce_max(x, axis=1), axis=1)
63 | exp_x = tf.exp(tf.subtract(x, max_x))
64 | return exp_x / tf.expand_dims( tf.reduce_sum(exp_x, axis=1), axis=1)
65 |
66 | @tf.custom_gradient
67 | def _relu(x): ## FF/PFF relu variant activation
68 | ## modified relu
69 | out = tf.nn.relu(x)
70 | def grad(upstream):
71 | #dx = tf.cast(tf.math.greater_equal(x, 0.0), dtype=tf.float32) # d_relu/d_x
72 | dx = tf.ones(x.shape) # pretend like derivatives exist for zero values
73 | return upstream * dx
74 | return out, grad
75 | @tf.function
76 | def clip_fx(x): ## hard-clip activation
77 | return tf.clip_by_value(x, 0.0, 1.0)
78 |
79 | ### begin constructor definition
80 |
81 | class PFF_RNN:
82 | """
83 | Basic implementation of the predictive forward-forward (FF) algorithm for a recurrent
84 | neural network from (Ororbia & Mali 2022).
85 |
86 | A "model" in this framework is defined as a pair containing an arguments
87 | dictionary and a theta construct, i.e., (args, theta).
88 | """
89 | def __init__(self, args=None, model_dir=None):
90 | theta_r = None
91 | theta_g = None
92 | if model_dir is not None:
93 | args = deserialize("{}config.args".format(model_dir))
94 | theta_r = deserialize("{}rep_params.theta".format(model_dir))
95 | theta_g = deserialize("{}gen_params.theta".format(model_dir))
96 | if args is None:
97 | print("ERROR: no model arguments provided...")
98 | sys.exit(1)
99 | self.args = args
100 | ## collect hyper-parameters
101 | self.seed = 69
102 | if args.get("seed") is not None:
103 | self.seed = args["seed"]
104 | self.x_dim = args["x_dim"]
105 | self.y_dim = args["y_dim"]
106 | self.n_units = 2000
107 | if args.get("n_units") is not None:
108 | self.n_units = args["n_units"]
109 |
110 | self.g_units = 20 # number of top-most latent variables for generative circuit
111 | if args.get("g_units") is not None:
112 | self.g_units = args["g_units"]
113 | self.K = 10
114 | if args.get("K") is not None:
115 | self.K = args["K"]
116 | self.beta = 0.025 #0.05 #0.1
117 | if args.get("beta") is not None:
118 | self.beta = args["beta"]
119 |
120 | self.gen_gamma = 1.0
121 | self.rec_gamma = 1.0
122 | self.thr = 3.0 # goodness threshold
123 | if args.get("thr") is not None:
124 | self.thr = args["thr"]
125 | self.alpha = 0.3 # dampening factor
126 | if args.get("alpha") is not None:
127 | self.alpha = args["alpha"]
128 |
129 | self.y_scale = 5.0 # 1.0
130 | # stats for peer normalization (if used)
131 | self.eps_r = 0.01 # noise factor for representation circuit
132 | if args.get("eps_r") is not None:
133 | self.eps_r = args["eps_r"]
134 | self.eps_g = 0.025 # noise factor for generative circuit
135 | if args.get("eps_g") is not None:
136 | self.eps_g = args["eps_g"]
137 |
138 | ## Set up parameter construct - theta_r and theta_g
139 | if theta_r is None: ## representation circuit params
140 | initializer = tf.compat.v1.keras.initializers.Orthogonal()
141 | self.b1 = tf.Variable(initializer([1, self.n_units]))
142 | self.b2 = tf.Variable(initializer([1, self.n_units]))
143 | self.W1 = tf.Variable(initializer([self.x_dim, self.n_units]))
144 | self.V2 = tf.Variable(initializer([self.n_units, self.n_units])) # inner feedback
145 | self.W2 = tf.Variable(initializer([self.n_units, self.n_units]))
146 | self.V = tf.Variable(initializer([self.y_dim, self.n_units])) # top to inner feedback
147 | self.W = tf.Variable(initializer([self.n_units, self.y_dim])) # softmax output
148 | self.b = tf.Variable(tf.zeros([1, self.y_dim]))
149 | initializer = tf.compat.v1.keras.initializers.RandomUniform(minval=0.0, maxval=0.05)
150 |
151 | self.M1 = create_competiion_matrix(self.n_units, n_group=10)
152 | self.M2 = create_competiion_matrix(self.n_units, n_group=10)
153 | self.L1 = tf.Variable(initializer([self.n_units, self.n_units])) # lateral lyr 1
154 | self.L2 = tf.Variable(initializer([self.n_units, self.n_units])) # lateral lyr 2
155 | theta_r = [self.L1,self.L2,self.b1,self.b2,self.W1,self.W2,self.V2,self.V,self.W,self.b] ## theta_r
156 | else:
157 | self.M1 = create_competiion_matrix(self.n_units, n_group=10)
158 | self.M2 = create_competiion_matrix(self.n_units, n_group=10)
159 | self.L1 = theta_r[0]
160 | self.L2 = theta_r[1]
161 | self.b1 = theta_r[2]
162 | self.b2 = theta_r[3]
163 | self.W1 = theta_r[4]
164 | self.W2 = theta_r[5]
165 | self.V2 = theta_r[6]
166 | self.V = theta_r[7]
167 | self.W = theta_r[8]
168 | self.b = theta_r[9]
169 | self.theta = theta_r
170 |
171 | if theta_g is None: ## generative circuit params
172 | self.Gy = tf.Variable(initializer([self.g_units, self.n_units]))
173 | self.G2 = tf.Variable(initializer([self.n_units, self.n_units]))
174 | self.G1 = tf.Variable(initializer([self.n_units, self.x_dim]))
175 | theta_g = [self.Gy,self.G1,self.G2] ## theta_g
176 | else:
177 | self.Gy = theta_g[0]
178 | self.G1 = theta_g[1]
179 | self.G2 = theta_g[2]
180 | self.theta_g = theta_g
181 |
182 | ## activation functions and latent states/stat variables
183 | self.z_g = None # top-most generative latent state
184 | self.fx = _relu ## internal activation function
185 | self.ofx = clip_fx ## predictive activation output function
186 | self.gfx = _relu ## generative activation output function
187 | self.z1 = None
188 | self.z0_hat = None
189 |
190 | def save_model(self, model_dir):
191 | """
192 | Save current model configuration and synaptic parameters (of both
193 | representation & generative circuits) to disk.
194 |
195 | Args:
196 | model_dir: directory to save model config
197 | """
198 | if not os.path.exists(model_dir):
199 | os.makedirs(model_dir)
200 | serialize("{}config.args".format(model_dir), self.args)
201 | serialize("{}rep_params.theta".format(model_dir), self.theta)
202 | serialize("{}gen_params.theta".format(model_dir), self.theta_g)
203 |
204 | def calc_goodness(self, z, thr):
205 | """
206 | Calculates the "goodness" of an activation vector.
207 |
208 | Args:
209 | z: activation vector/matrix
210 | thr: goodness threshold
211 |
212 | Returns:
213 | goodness scalar of z
214 | """
215 | z_sqr = tf.math.square(z)
216 | delta = tf.reduce_sum(z_sqr, axis=1, keepdims=True)
217 | #delta = delta - thr # maximize for positive samps, minimize for negative samps
218 | delta = -delta + thr # minimize for positive samps, maximize for negative samps
219 | # gets the probability P(pos)
220 | p = tf.nn.sigmoid(delta)
221 | eps = 1e-5
222 | p = tf.clip_by_value(p, eps, 1.0 - eps)
223 | return p, delta
224 |
225 | def calc_loss(self, z, lab, thr, keep_batch=False):
226 | """
227 | Calculates the local loss of an activation vector.
228 |
229 | Args:
230 | z: activation vector/matrix (vector/matrix)
231 | lab: data "type" binary label (1 for pos, 0 for neg) (vector/matrix)
232 | thr: goodness threshold
233 |
234 | Returns:
235 | goodness scalar of z
236 | """
237 | p, logit = self.calc_goodness(z, thr)
238 | ## the loss below is what the original PFF paper used & adheres to Eqn 3
239 | CE = tf.math.maximum(logit, 0) - logit * lab + tf.math.log(1. + tf.math.exp(-tf.math.abs(logit)))
240 | ## the commented-out loss below, however, also works just fine
241 | #CE = tf.nn.softplus(-logit) * lab + tf.nn.softplus(logit) * (1.0 - lab)
242 | L = tf.reduce_sum(CE, axis=1, keepdims=True)
243 | if keep_batch == True:
244 | return L
245 | L = tf.reduce_mean(L)
246 | return L
247 |
248 | def forward(self, x):
249 | """
250 | Forward propagates x thru rep circuit
251 |
252 | Args:
253 | x: sensory input (vector/matrix)
254 |
255 | Returns:
256 | list of layer-wise activities
257 | """
258 | z1 = self.fx(tf.matmul(self.normalize(x), self.W1) + self.b1)
259 | z2 = self.fx(tf.matmul(self.normalize(z1), self.W2) + self.b2)
260 | z3 = softmax(tf.matmul(self.normalize(z2), self.W) + self.b)
261 | return [z1,z2,z3] # return all latents
262 |
263 | def classify(self, x):
264 | """
265 | Categorizes sensory input x
266 |
267 | Args:
268 | x: sensory input (vector/matrix)
269 |
270 | Returns:
271 | y_hat, probability distribution over labels (vector/matrix)
272 | """
273 | z = self.forward(x)
274 | y_hat = z[len(z)-1]
275 | return y_hat
276 |
277 | def infer(self, x, y, lab, z_lat, K, opt=None, g_opt=None, reg_lambda=0.0,
278 | zero_y=False, g_reg_lambda=0.0):
279 | """
280 | Simulates the PFF intertwined inference-and-learning process for
281 | the underlying dual-circuit system.
282 |
283 | Args:
284 | x: sensory input (vector/matrix)
285 | y: sensory class label (vector/matrix)
286 | lab: data "type" binary label (1 for pos, 0 for neg) (vector/matrix)
287 | z_lat: list of initial conditions for model's representation activities
288 | (usually provided by an initial forward pass with .forward(x) )
289 | K: number of simulation steps
290 | opt: rep circuit optimizer
291 | g_opt: gen circuit optimizer
292 | reg_lambda: regularization coefficient (for representation synapses)
293 | zero_y: "zero out" the y-vector top-down context
294 | g_reg_lambda: regularization coefficient (for generative synapses)
295 |
296 | Returns:
297 | (global energy value (goodness + regularization), label distribution matrix,
298 | generative loss, x reconstruction)
299 | """
300 | if self.rec_gamma > 0.0:
301 | self.theta = [self.L1,self.L2,self.b1,self.b2,self.W1,self.W2,self.V2,self.V,self.W,self.b]
302 | else:
303 | self.theta = [self.b1,self.b2,self.W1,self.W2,self.V2,self.V,self.W,self.b]
304 |
305 | calc_grad = False
306 | if opt is not None:
307 | calc_grad = True
308 | # update generative model
309 | self.z_g = tf.Variable(tf.zeros([x.shape[0], self.g_units]))
310 | for k in range(K):
311 | # update representation model
312 | z_lat, L, delta = self.step(x,y,lab,z_lat,calc_grad=calc_grad,
313 | zero_y=zero_y, reg_lambda=reg_lambda)
314 | if opt is not None: ## update synapses
315 | bound = 1.0 #5.0 # 1.0
316 | for l in range(len(delta)):
317 | delta[l] = tf.clip_by_value(delta[l], -bound, bound) # clip update by projection
318 | opt.apply_gradients(zip(delta, self.theta))
319 | if self.gen_gamma > 0.0: ## update generative model
320 | grad_f = True #False
321 | Lg, x_hat = self.update_generator(x,y,z_lat,g_opt,reg_lambda=g_reg_lambda,grad_f=grad_f)
322 | else:
323 | Lg = 0.0
324 | x_hat = x * 0
325 |
326 | y_hat = z_lat[len(z_lat)-1]
327 | return L, y_hat, Lg, x_hat
328 |
329 | def sample(self, n_s=0, z=None, y=None): # samples generative circuit
330 | """
331 | Samples the generative circuit within this current neural system.
332 |
333 | Args:
334 | n_s: number of samples to synthesize/confabulate
335 | z: top-most externally produced input sample (from a prior); (vector/matrix)
336 | y: sensory class label (vector/matrix)
337 |
338 | Returns:
339 | samples of the bottom-most sensory layer
340 | """
341 | if z is None:
342 | eps_sigma = 0.05
343 | #eps = tf.random.normal([y.shape[0],self.n_units], 0.0, eps_sigma) #* 0
344 | z_in = self.normalize(self.gfx(y))
345 | z2 = self.gfx(tf.matmul(z_in,self.Gy))# + self.c)
346 | else:
347 | z2 = self.gfx(z)
348 | #z2 = self.gfx(z2)
349 | z1 = self.gfx(tf.matmul(self.normalize(z2),self.G2))# + self.c2)
350 | #z0 = tf.matmul(self.normalize(z1),self.G1) #
351 | z0 = self.ofx(tf.matmul(self.normalize(z1),self.G1))# + self.c1)
352 | return z0
353 |
354 | def update_generator(self, x, y, z_lat, opt, reg_lambda=0.0001, grad_f=True):
355 | '''
356 | Internal routine for adjusting synapses of generative circuit
357 | '''
358 | z0 = x
359 | z1 = z_lat[0]
360 | z2 = z_lat[1]
361 | eps_sigma = self.eps_g #0.025 #0.055 # 0.05 #0.02 #0.1
362 | eps1 = tf.random.normal([x.shape[0],self.n_units], 0.0, eps_sigma) #* 0
363 | eps2 = tf.random.normal([x.shape[0],self.n_units], 0.0, eps_sigma) #* 0
364 | with tf.GradientTape(persistent=True) as tape:
365 | z3_bar = self.gfx(self.z_g)
366 | z2_hat = tf.matmul(self.normalize(z3_bar),self.Gy) #+ self.c
367 | #z2_hat = self.fx(tf.matmul(self.z_g,self.Gy))
368 | z2_bar = self.gfx(z2_hat) #z2 #self.fx(z2 + eps2)
369 | z1_hat = tf.matmul(tf.stop_gradient(self.normalize(self.gfx(z2 + eps2))),self.G2) #+ self.c2
370 | #z1_hat = self.fx(tf.matmul(z2_bar,self.G2))
371 | z1_bar = self.gfx(z1_hat) #z1 #self.fx(z1 + eps1)
372 | #z0_hat = tf.matmul(z1_bar,self.G1) #
373 | z0_hat = self.ofx(tf.matmul(tf.stop_gradient(self.normalize(self.gfx(z1 + eps1))),self.G1)) #+ self.c1
374 |
375 | e2 = z2_bar - z2#_bar
376 | L2 = tf.reduce_mean(tf.reduce_sum(tf.math.square(e2),axis=1,keepdims=True))
377 | e1 = z1_bar - z1#_bar
378 | L1 = tf.reduce_mean(tf.reduce_sum(tf.math.square(e1),axis=1,keepdims=True))
379 | e0 = z0_hat - z0
380 | L0 = tf.reduce_mean(tf.reduce_sum(tf.math.square(e0),axis=1,keepdims=True))
381 |
382 | reg = 0.0
383 | if reg_lambda > 0.0: # weight decay
384 | reg = (tf.norm(self.G1) + tf.norm(self.G2) + tf.norm(self.Gy)) * reg_lambda
385 | L = L2 + L1 + L0 + reg
386 | Lg = L2 + L1 + L0
387 | if grad_f == True:
388 | delta = tape.gradient(L, self.theta_g)
389 | bound = 1.0 #5.0 # 1.0
390 | for l in range(len(delta)):
391 | delta[l] = tf.clip_by_value(delta[l], -bound, bound) # clip update by projection
392 | opt.apply_gradients(zip(delta, self.theta_g))
393 | # for l in range(len(self.theta_g)):
394 | # self.theta_g[l].assign(tf.clip_by_norm(self.theta_g[l], 5.0, axes=1))
395 |
396 | d_z = tape.gradient(L2, self.z_g)
397 | self.z_g.assign(self.z_g - d_z * self.beta)
398 |
399 | return Lg, z0_hat # return generative loss
400 |
401 | def step(self, x, y, lab, z_lat, calc_grad=False, zero_y=False, reg_lambda=0.0):
402 | '''
403 | Internal full simulation step routine (inference and credit assignment)
404 | '''
405 | y_ = y * self.y_scale
406 | Npos = tf.reduce_sum(lab)
407 |
408 | eps_sigma = self.eps_r # 0.01 #0.05 # kmnist
409 | #eps_sigma = 0.025 # mnist
410 | eps1 = tf.random.normal([x.shape[0],self.n_units], 0.0, 1.0) * eps_sigma
411 | eps2 = tf.random.normal([x.shape[0],self.n_units], 0.0, 1.0) * eps_sigma
412 |
413 | with tf.GradientTape() as tape:
414 | z1_tm1 = tf.stop_gradient(z_lat[0])
415 | z2_tm1 = tf.stop_gradient(z_lat[1])
416 |
417 | z1 = tf.matmul(self.normalize(x),self.W1) + tf.matmul(self.normalize(z2_tm1),self.V2) + self.b1 + eps1
418 | if self.rec_gamma > 0.0:
419 | L1 = tf.nn.relu(self.L1)
420 | L1 = L1 * self.M1 * (1. - tf.eye(self.L1.shape[0])) - L1 * tf.eye(self.L1.shape[0])
421 | z1 = z1 - tf.matmul(z1_tm1, L1) * self.rec_gamma
422 | z1 = self.fx(z1) * (1.0 - self.alpha) + z1_tm1 * self.alpha
423 |
424 | z2 = tf.matmul(self.normalize(z1_tm1),self.W2) + tf.matmul(y_,self.V) + self.b2 + eps2
425 | if self.rec_gamma > 0.0:
426 | L2 = tf.nn.relu(self.L2)
427 | L2 = L2 * self.M2 * (1. - tf.eye(self.L2.shape[0])) - L2 * tf.eye(self.L2.shape[0])
428 | z2 = z2 - tf.matmul(z2_tm1, L2) * self.rec_gamma
429 | z2 = self.fx(z2) * (1.0 - self.alpha) + z2_tm1 * self.alpha
430 |
431 | z3 = softmax(tf.matmul(self.normalize(z2_tm1),self.W) + self.b) #tf.nn.softmax(tf.matmul(self.normalize(z2),self.W) + self.b)
432 |
433 | ## calc loss
434 | L1 = self.calc_loss(z1, lab, thr=self.thr)
435 | L2 = self.calc_loss(z2, lab, thr=self.thr)
436 | if zero_y == True:
437 | L3 = tf.reduce_sum(-tf.reduce_sum(y_ * tf.math.log(z3), axis=1, keepdims=True)) * 0
438 | else:
439 | L3 = tf.reduce_sum(-tf.reduce_sum(y_ * tf.math.log(z3) * lab, axis=1, keepdims=True) * lab)
440 | L3 = L3 / Npos
441 | reg = 0.0
442 | if reg_lambda > 0.0: # weight decay
443 | reg = (tf.norm(self.W1) + tf.norm(self.W2) + tf.norm(self.V2) +
444 | tf.norm(self.V) + tf.norm(self.W)) * reg_lambda
445 | L = L3 + L2 + L1 + reg
446 | Lg = L2 + L1
447 | delta = None
448 | if calc_grad == True:
449 | delta = tape.gradient(L, self.theta)
450 | return [z1,z2,z3], Lg, delta
451 |
452 | def _step(self, x, y, z_lat, thr=None):
453 | '''
454 | Internal simulation step, without credit assignment
455 | '''
456 | thr_ = thr
457 | y_ = y * self.y_scale
458 |
459 | z1_tm1 = z_lat[0]
460 | z2_tm1 = z_lat[1]
461 |
462 | z1 = tf.matmul(self.normalize(x),self.W1) + tf.matmul(self.normalize(z2_tm1),self.V2) + self.b1
463 | if self.rec_gamma > 0.0:
464 | L1 = tf.nn.relu(self.L1)
465 | L1 = L1 * self.M1 * (1. - tf.eye(self.L1.shape[0])) - L1 * tf.eye(self.L1.shape[0])
466 | z1 = z1 - tf.matmul(z1_tm1, L1) * self.rec_gamma
467 | z1 = self.fx(z1) * (1.0 - self.alpha) + z1_tm1 * self.alpha
468 |
469 | z2 = tf.matmul(self.normalize(z1_tm1),self.W2) + tf.matmul(y_,self.V) + self.b2
470 | if self.rec_gamma > 0.0:
471 | L2 = tf.nn.relu(self.L2)
472 | L2 = L2 * self.M2 * (1. - tf.eye(self.L2.shape[0])) - L2 * tf.eye(self.L2.shape[0])
473 | z2 = z2 - tf.matmul(z2_tm1, L2) * self.rec_gamma
474 | z2 = self.fx(z2) * (1.0 - self.alpha) + z2_tm1 * self.alpha
475 | ## calc goodness values
476 | if thr_ is None:
477 | thr_ = self.thr
478 | p1, logit1 = self.calc_goodness(z1, thr_)
479 | p2, logit2 = self.calc_goodness(z2, thr_)
480 |
481 | return [z1,z2], [logit1,logit2]
482 |
483 | def get_latent(self, x, y, K, use_y_hat=False):
484 | self.z1 = None
485 | self.z0_hat = 0.0
486 | z_lat = self.forward(x)
487 | y_hat = z_lat[len(z_lat)-1]
488 | self.z_g = tf.Variable(tf.zeros([x.shape[0], self.g_units]))
489 | y_ = y
490 | if use_y_hat == True:
491 | y_ = y_hat
492 | for k in range(K):
493 | self._infer_latent(x,y_,z_lat)
494 | #self.z0_hat = self.ofx(tf.matmul(self.normalize(self.gfx(self.z1)),self.G1))
495 | #self.z0_hat = self.z0_hat/(K * 1.0)
496 | return self.z_g + 0
497 |
498 | def _infer_latent(self, x, y, z_lat):
499 | '''
500 | Internal routine for inferring the generative circuit's top-most latent state activity
501 | '''
502 | y_ = y * self.y_scale
503 |
504 | z1_tm1 = z_lat[0]
505 | z2_tm1 = z_lat[1]
506 | with tf.GradientTape() as tape:
507 | z1 = tf.matmul(self.normalize(x),self.W1) + tf.matmul(self.normalize(z2_tm1),self.V2) + self.b1
508 | if self.rec_gamma > 0.0:
509 | L1 = tf.nn.relu(self.L1)
510 | L1 = L1 * self.M1 * (1. - tf.eye(self.L1.shape[0])) - L1 * tf.eye(self.L1.shape[0])
511 | z1 = z1 - tf.matmul(z1_tm1, L1) * self.rec_gamma
512 | z1 = self.fx(z1) * (1.0 - self.alpha) + z1_tm1 * self.alpha
513 | self.z1 = z1 + 0
514 |
515 | z2 = tf.matmul(self.normalize(z1_tm1),self.W2) + tf.matmul(y_,self.V) + self.b2
516 | if self.rec_gamma > 0.0:
517 | L2 = tf.nn.relu(self.L2)
518 | L2 = L2 * self.M2 * (1. - tf.eye(self.L2.shape[0])) - L2 * tf.eye(self.L2.shape[0])
519 | z2 = z2 - tf.matmul(z2_tm1, L2) * self.rec_gamma
520 | z2 = self.fx(z2) * (1.0 - self.alpha) + z2_tm1 * self.alpha
521 |
522 | z3_bar = self.gfx(self.z_g)
523 | z2_hat = tf.matmul(self.normalize(z3_bar),self.Gy)
524 | z2_bar = self.gfx(z2_hat)
525 | e2 = z2_bar - z2#_bar
526 | L2 = tf.reduce_mean(tf.reduce_sum(tf.math.square(e2),axis=1,keepdims=True))
527 |
528 | d_z = tape.gradient(L2, self.z_g)
529 | self.z_g.assign(self.z_g - d_z * self.beta)
530 |
531 | z0_hat = self.ofx(tf.matmul(self.normalize(self.gfx(z1)),self.G1))
532 | self.z0_hat = z0_hat
533 |
534 | def normalize(self, z_state, a_scale=1.0): ## norm
535 | '''
536 | Internal routine for normalizing state vector/matrix "z_state"
537 | '''
538 | eps = 1e-8
539 | L2 = tf.norm(z_state, ord=2, axis=1, keepdims=True)
540 | z_state = z_state / (L2 + eps)
541 | return z_state * a_scale
542 |
--------------------------------------------------------------------------------
/src/plot_tsne.py:
--------------------------------------------------------------------------------
1 | """
2 | Code for paper "The Predictive Forward-Forward Algorithm" (Ororbia & Mali, 2022)
3 | """
4 |
5 | import os
6 | import sys, getopt, optparse
7 | import pickle
8 | sys.path.insert(0, '../')
9 | import tensorflow as tf
10 | import numpy as np
11 |
12 | import matplotlib #.pyplot as plt
13 | matplotlib.use('Agg')
14 | import matplotlib.pyplot as plt
15 | cmap = plt.cm.jet
16 | #cmap = plt.cm.cividis # red-green color blind friendly palette
17 |
18 | ################################################################################
19 |
20 | # GPU arguments
21 | # read in configuration file and extract necessary variables/constants
22 | options, remainder = getopt.getopt(sys.argv[1:], '', ["data_dir=","model_dir=","gpu_id=","split="])
23 |
24 | # Collect arguments from argv
25 | model_dir = "../exp/"
26 | data_dir = "../data/"
27 | use_tsne = True # (args.getArg("use_tsne").lower() == 'true')
28 | split = "test"
29 |
30 | use_gpu = False
31 | gpu_id = -1
32 | for opt, arg in options:
33 | if opt in ("--data_dir"):
34 | data_dir = arg.strip()
35 | elif opt in ("--model_dir"):
36 | model_dir = arg.strip()
37 | elif opt in ("--split"):
38 | split = arg.strip()
39 | elif opt in ("--gpu_id"):
40 | gpu_id = int(arg.strip())
41 | use_gpu = True
42 |
43 | mid = gpu_id
44 | if mid >= 0:
45 | print(" > Using GPU ID {0}".format(mid))
46 | os.environ["CUDA_VISIBLE_DEVICES"]="{0}".format(mid)
47 | gpu_tag = '/GPU:0'
48 | else:
49 | os.environ["CUDA_VISIBLE_DEVICES"]="-1"
50 | gpu_tag = '/CPU:0'
51 |
52 | plot_fname = "{}/lat_viz.jpg".format(model_dir)
53 | latents_fname = "{}/latents.npy".format(model_dir)
54 |
55 | #batch_size = int(args.getArg("batch_size")) #128
56 | delimiter = "\t"
57 | # xfname = args.getArg("xfname") #"../data/mnist/trainX.tsv"
58 | yfname = "{}/{}Y.npy".format(data_dir,split)
59 | Y = tf.cast(np.load(yfname),dtype=tf.float32).numpy()
60 |
61 | with tf.device(gpu_tag):
62 |
63 | z_lat = tf.cast(np.load(latents_fname),dtype=tf.float32)
64 | print("Lat.shape = {}".format(z_lat.shape))
65 | y_sample = Y
66 | if len(y_sample.shape) == 1:
67 | print("y_init.shape = ",y_sample.shape)
68 | nC = 10 #Y.shape[1]
69 | y_sample = tf.one_hot(tf.cast(y_sample,dtype=tf.float32).numpy().astype(np.int32), depth=nC)
70 | print("y_post.shape = ",y_sample.shape)
71 | elif y_sample.shape[1] == 1:
72 | print("y_init.shape = ",y_sample.shape)
73 | nC = 10 #Y.shape[1]
74 | y_sample = tf.one_hot(tf.squeeze(y_sample).numpy().astype(np.int32), depth=nC)
75 | print("y_post.shape = ",y_sample.shape)
76 |
77 | max_w = -10000.0
78 | min_w = 10000.0
79 | max_w = max(max_w, float(tf.reduce_max(z_lat)))
80 | min_w = min(min_w, float(tf.reduce_min(z_lat)))
81 | print("max_z = ", max_w)
82 | print("min_z = ", min_w)
83 | print("Y.shape = ",y_sample.shape)
84 |
85 | z_top_dim = z_lat.shape[1]
86 | z_2D = None
87 | if z_top_dim != 2:
88 | from sklearn.decomposition import IncrementalPCA
89 | print(" > Projecting latents via iPCA...")
90 | if use_tsne is True:
91 | n_comp = 32 #10 #16 #50
92 | if z_lat.shape[1] < n_comp:
93 | n_comp = z_lat.shape[1] - 2 #z_top.shape[1]-2
94 | n_comp = max(2, n_comp)
95 | ipca = IncrementalPCA(n_components=n_comp, batch_size=50)
96 | ipca.fit(z_lat.numpy())
97 | z_2D = ipca.transform(z_lat.numpy())
98 | print("PCA.lat.shape = ",z_2D.shape)
99 | print(" > Finishing projection via t-SNE...")
100 | from sklearn.manifold import TSNE
101 | z_2D = TSNE(n_components=2,perplexity=30).fit_transform(z_2D)
102 | #z_2D.shape
103 | else:
104 | ipca = IncrementalPCA(n_components=2, batch_size=50)
105 | ipca.fit(z_lat.numpy())
106 | z_2D = ipca.transform(z_lat.numpy())
107 | else:
108 | z_2D = z_lat
109 |
110 | print(" > Plotting 2D latent encodings...")
111 | plt.figure(figsize=(8, 6))
112 | plt.scatter(z_2D[:, 0], z_2D[:, 1], c=np.argmax(y_sample, 1), cmap=cmap)
113 | plt.colorbar()
114 | plt.grid()
115 | plt.savefig("{0}".format(plot_fname), dpi=300) # latents.jpg
116 | plt.clf()
117 |
--------------------------------------------------------------------------------
/src/run.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | GPU_ID=0
3 | N_TRIALS=2 # number of experimental trials to run
4 |
5 | echo " >>> Running MNIST simulation!"
6 | DATA_DIR="../data/mnist/"
7 | OUT_DIR="../exp/pff/mnist/"
8 | python sim_train.py --data_dir=$DATA_DIR --gpu_id=$GPU_ID --n_trials=$N_TRIALS --out_dir=$OUT_DIR
9 |
10 | echo " >>> Running K-MNIST simulation!"
11 | DATA_DIR="../data/kmnist/"
12 | OUT_DIR="../exp/pff/kmnist/"
13 | python sim_train.py --data_dir=$DATA_DIR --gpu_id=$GPU_ID --n_trials=$N_TRIALS --out_dir=$OUT_DIR
14 |
--------------------------------------------------------------------------------
/src/sample_model.py:
--------------------------------------------------------------------------------
1 | """
2 | Code for paper "The Predictive Forward-Forward Algorithm" (Ororbia & Mali, 2022)
3 |
4 | ################################################################################
5 | Samples a trained PFF-RNN system.
6 | ################################################################################
7 | """
8 |
9 | import os
10 | import sys, getopt, optparse
11 | import pickle
12 | sys.path.insert(0, '../')
13 | import tensorflow as tf
14 | import numpy as np
15 | import time
16 | import copy
17 |
18 | import matplotlib
19 | matplotlib.use('Agg')
20 | import matplotlib.pyplot as plt
21 | #cmap = plt.cm.jet
22 |
23 | from pff_rnn import PFF_RNN
24 |
25 | ################################################################################
26 |
27 | def deserialize(fname): ## object "loading" routine
28 | fd = open(fname, 'rb')
29 | object = pickle.load( fd )
30 | fd.close()
31 | return object
32 |
33 | def plot_img_grid(samples, fname, nx, ny, px, py, plt, rotNeg90=False): # rows, cols,...
34 | px_dim = px
35 | py_dim = py
36 | canvas = np.empty((px_dim*nx, py_dim*ny))
37 | ptr = 0
38 | for i in range(0,nx,1):
39 | for j in range(0,ny,1):
40 | #xs = tf.expand_dims(tf.cast(samples[ptr,:],dtype=tf.float32),axis=0)
41 | xs = np.expand_dims(samples[ptr,:],axis=0)
42 | #xs = xs.numpy() #tf.make_ndarray(x_mean)
43 | xs = xs[0].reshape(px_dim, py_dim)
44 | if rotNeg90 is True:
45 | xs = np.rot90(xs, -1)
46 | canvas[(nx-i-1)*px_dim:(nx-i)*px_dim, j*py_dim:(j+1)*py_dim] = xs
47 | ptr += 1
48 | plt.figure(figsize=(12, 14))
49 | plt.imshow(canvas, origin="upper", cmap="gray")
50 | plt.tight_layout()
51 | plt.axis('off')
52 | plt.savefig("{0}".format(fname), bbox_inches='tight', pad_inches=0)
53 | plt.clf()
54 |
55 | seed = 69
56 | tf.random.set_seed(seed=seed)
57 | np.random.seed(seed)
58 |
59 | model_dir = "../exp/"
60 | # read in configuration file and extract necessary variables/constants
61 | options, remainder = getopt.getopt(sys.argv[1:], '', ["model_dir=","gpu_id="])
62 | # Collect arguments from argv
63 | n_trials = 1
64 | gpu_id = -1
65 | for opt, arg in options:
66 | if opt in ("--model_dir"):
67 | model_dir = arg.strip()
68 | elif opt in ("--gpu_id"):
69 | gpu_id = int(arg.strip())
70 |
71 | gmm_fname = "{}/prior.gmm".format(model_dir)
72 |
73 | mid = gpu_id #0
74 | if mid >= 0:
75 | print(" > Using GPU ID {0}".format(mid))
76 | os.environ["CUDA_VISIBLE_DEVICES"]="{0}".format(mid)
77 | #gpu_tag = '/GPU:0'
78 | gpu_tag = '/GPU:0'
79 | else:
80 | os.environ["CUDA_VISIBLE_DEVICES"]="-1"
81 | gpu_tag = '/CPU:0'
82 |
83 | def get_n_comp(gmm, use_sklearn=True):
84 | if use_sklearn is True:
85 | return gmm.n_components
86 | else:
87 | return gmm.k
88 |
89 | def sample_gmm(gmm, n_samps, use_sklearn=False):
90 | if use_sklearn is True:
91 | np_samps, np_labs = gmm.sample(n_samps)
92 | z_samp = tf.cast(np_samps,dtype=tf.float32)
93 | else:
94 | z_samp, z_labs = gmm.sample(n_samps)
95 | np_labs = tf.squeeze(z_labs).numpy()
96 | y_s = tf.one_hot(np_labs, get_n_comp(gmm, use_sklearn=use_sklearn))
97 | return z_samp, y_s
98 |
99 | print("----")
100 | with tf.device(gpu_tag):
101 | print(" >> Loading prior P(z): ",gmm_fname)
102 | prior = deserialize(gmm_fname)
103 | print(" >> Loading model P(x|z): ",model_dir)
104 | agent = PFF_RNN(model_dir=model_dir)
105 | print(agent.V.shape)
106 | nrow = 10 #2
107 | ncol = 10 #4
108 | n_samp = nrow * ncol # per class
109 | print(" >> Generating confabulations from P(x|z)P(z)...")
110 | for _ in range(15): ## jitter the prior
111 | z2s, _ = sample_gmm(prior, n_samp)
112 | xs = agent.sample(y=z2s)
113 | fname = "{}samples.jpg".format(model_dir)
114 | plot_img_grid(xs.numpy(), fname, nx=nrow, ny=ncol, px=28, py=28, plt=plt)
115 | plt.close()
116 |
--------------------------------------------------------------------------------
/src/sim_train.py:
--------------------------------------------------------------------------------
1 | """
2 | Code for paper "The Predictive Forward-Forward Algorithm" (Ororbia & Mali, 2022)
3 |
4 | ################################################################################
5 | Simulates the training/adaptation of a recurrent neural system composed of
6 | a representation and generative circuit, trained via the preditive forward-forward
7 | process.
8 | Note that this code focuses on datasets of gray-scale images/patterns.
9 | ################################################################################
10 | """
11 |
12 | import os
13 | import sys, getopt, optparse
14 | import pickle
15 | #import dill as pickle
16 | sys.path.insert(0, '../')
17 | import tensorflow as tf
18 | import numpy as np
19 | import time
20 | import copy
21 |
22 | import matplotlib
23 | matplotlib.use('Agg')
24 | import matplotlib.pyplot as plt
25 | #cmap = plt.cm.jet
26 |
27 | # import general simulation utilities
28 | from data_utils import DataLoader
29 |
30 | def plot_img_grid(samples, fname, nx, ny, px, py, plt, rotNeg90=False): # rows, cols,...
31 | '''
32 | Visualizes a matrix of vector patterns in the form of an image grid plot.
33 | '''
34 | px_dim = px
35 | py_dim = py
36 | canvas = np.empty((px_dim*nx, py_dim*ny))
37 | ptr = 0
38 | for i in range(0,nx,1):
39 | for j in range(0,ny,1):
40 | #xs = tf.expand_dims(tf.cast(samples[ptr,:],dtype=tf.float32),axis=0)
41 | xs = np.expand_dims(samples[ptr,:],axis=0)
42 | #xs = xs.numpy() #tf.make_ndarray(x_mean)
43 | xs = xs[0].reshape(px_dim, py_dim)
44 | if rotNeg90 is True:
45 | xs = np.rot90(xs, -1)
46 | canvas[(nx-i-1)*px_dim:(nx-i)*px_dim, j*py_dim:(j+1)*py_dim] = xs
47 | ptr += 1
48 | plt.figure(figsize=(12, 14))
49 | plt.imshow(canvas, origin="upper", cmap="gray")
50 | plt.tight_layout()
51 | plt.axis('off')
52 | #print(" SAVE: {0}{1}".format(out_dir,"gmm_decoded_samples.jpg"))
53 | plt.savefig("{0}".format(fname), bbox_inches='tight', pad_inches=0)
54 | plt.clf()
55 |
56 | from pff_rnn import PFF_RNN
57 |
58 | ################################################################################
59 |
60 | seed = 69
61 | tf.random.set_seed(seed=seed)
62 | np.random.seed(seed)
63 |
64 | out_dir = "../exp/"
65 | data_dir = "../data/mnist/"
66 | # read in configuration file and extract necessary variables/constants
67 | options, remainder = getopt.getopt(sys.argv[1:], '', ["data_dir=","gpu_id=","n_trials=","out_dir="])
68 | # Collect arguments from argv
69 | n_trials = 1
70 | gpu_id = -1
71 | for opt, arg in options:
72 | if opt in ("--data_dir"):
73 | data_dir = arg.strip()
74 | elif opt in ("--out_dir"):
75 | out_dir = arg.strip()
76 | elif opt in ("--gpu_id"):
77 | gpu_id = int(arg.strip())
78 | elif opt in ("--n_trials"):
79 | n_trials = int(arg.strip())
80 | print(" Exp out dir: ",out_dir)
81 |
82 | mid = gpu_id # 0
83 | if mid >= 0:
84 | print(" > Using GPU ID {0}".format(mid))
85 | os.environ["CUDA_VISIBLE_DEVICES"]="{0}".format(mid)
86 | #gpu_tag = '/GPU:0'
87 | gpu_tag = '/GPU:0'
88 | else:
89 | os.environ["CUDA_VISIBLE_DEVICES"]="-1"
90 | gpu_tag = '/CPU:0'
91 |
92 | print(" >>> Run sim on {} w/ GPU {}".format(data_dir, mid))
93 |
94 | xfname = "{}/trainX.npy".format(data_dir)
95 | yfname = "{}/trainY.npy".format(data_dir)
96 |
97 | X = ( tf.cast(np.load(xfname, allow_pickle=True),dtype=tf.float32) )
98 | if len(X.shape) > 2:
99 | X = tf.reshape(X, [X.shape[0], X.shape[1] * X.shape[2]])
100 | print(X.shape)
101 | x_dim = X.shape[1]
102 | max_val = float(tf.reduce_max(X))
103 | if max_val > 1.0:
104 | X = X/max_val
105 |
106 | Y = ( tf.cast(np.load(yfname, allow_pickle=True),dtype=tf.float32) )
107 | y_dim = Y.shape[1]
108 | print("Y.shape = ",Y.shape)
109 |
110 | n_iter = 60 #100 #60 # number of training iterations
111 | batch_size = 500 # batch size
112 | dev_batch_size = 500 # dev batch size
113 | dataset = DataLoader(design_matrices=[("x",X.numpy()),("y",Y.numpy())], batch_size=batch_size)
114 |
115 | print(" > Loading dev set")
116 | xfname = "{}/validX.npy".format(data_dir)
117 | yfname = "{}/validY.npy".format(data_dir)
118 | Xdev = ( tf.cast(np.load(xfname, allow_pickle=True),dtype=tf.float32) )
119 |
120 | if len(Xdev.shape) > 2:
121 | Xdev = tf.reshape(Xdev, [Xdev.shape[0], Xdev.shape[1] * Xdev.shape[2]])
122 | print("Xdev.shape = ",Xdev.shape)
123 | max_val = float(tf.reduce_max(Xdev))
124 | if max_val > 1.0:
125 | Xdev = Xdev/max_val
126 |
127 | Ydev = ( tf.cast(np.load(yfname, allow_pickle=True),dtype=tf.float32) )
128 | if len(Ydev.shape) == 1:
129 | nC = Y.shape[1]
130 | Ydev = tf.one_hot(Ydev.numpy().astype(np.int32), depth=nC)
131 | elif Ydev.shape[1] == 1:
132 | nC = 10
133 | Ydev = tf.one_hot(tf.squeeze(Ydev).numpy().astype(np.int32), depth=nC)
134 | print("Ydev.shape = ",Ydev.shape)
135 |
136 | devset = DataLoader(design_matrices=[("x",Xdev.numpy()), ("y",Ydev.numpy())],
137 | batch_size=dev_batch_size, disable_shuffle=True)
138 |
139 | def classify(agent, x):
140 | K_low = int(agent.K/2) - 1 # 3
141 | K_high = int(agent.K/2) + 1 # 5
142 | x_ = x
143 | Ey = None
144 | z_lat = agent.forward(x_) # do forward init pass
145 | for i in range(agent.y_dim):
146 | z_lat_ = []
147 | for ii in range(len(z_lat)):
148 | z_lat_.append(z_lat[ii] + 0)
149 |
150 | yi = tf.ones([x.shape[0],agent.y_dim]) * tf.expand_dims(tf.one_hot(i,depth=agent.y_dim),axis=0)
151 |
152 | gi = 0.0
153 | for k in range(K_high):
154 | z_lat_, p_g = agent._step(x_, yi, z_lat_, thr=0.0)
155 | if k >= K_low and k <= K_high: # only keep goodness in middle iterations
156 | gi = ((p_g[0] + p_g[1])*0.5) + gi
157 |
158 | if i > 0:
159 | Ey = tf.concat([Ey,gi],axis=1)
160 | else:
161 | Ey = gi
162 |
163 | Ey = Ey / (3.0)
164 | y_hat = tf.nn.softmax(Ey)
165 | return y_hat, Ey
166 |
167 | def eval(agent, dataset, debug=False, save_img=True, out_dir=""):
168 | '''
169 | Evaluates the current state of the agent given a dataset (data-loader).
170 | '''
171 | N = 0.0
172 | Ny = 0.0
173 | Acc = 0.0
174 | Ly = 0.0
175 | Lx = 0.0
176 | tt = 0
177 | #debug = True
178 | for batch in dataset:
179 | _, x = batch[0]
180 | _, y = batch[1]
181 | N += x.shape[0]
182 | Ny += float(tf.reduce_sum(y))
183 |
184 | y_hat, Ey = classify(agent, x)
185 | Ly += tf.reduce_sum(-tf.reduce_sum(y * tf.math.log(y_hat), axis=1, keepdims=True))
186 |
187 | z_lat = agent.forward(x)
188 | x_hat = agent.sample(z=z_lat[len(z_lat)-2])
189 |
190 | ex = x_hat - x
191 | Lx += tf.reduce_sum(tf.reduce_sum(tf.math.square(ex),axis=1,keepdims=True))
192 |
193 | if debug == True:
194 | print("------------------------")
195 | print(Ey[0:4,:])
196 | print(y_hat[0:4,:])
197 | print("------------------------")
198 |
199 | y_ind = tf.cast(tf.argmax(y,1),dtype=tf.int32)
200 | y_pred = tf.cast(tf.argmax(Ey,1),dtype=tf.int32)
201 | comp = tf.cast(tf.equal(y_pred,y_ind),dtype=tf.float32) #* y_m
202 | Acc += tf.reduce_sum( comp )
203 | Ly = Ly/Ny
204 | Acc = Acc/Ny
205 | Lx = Lx/Ny
206 |
207 | if save_img == True:
208 | fname = "{}/x_samples.png".format(out_dir)
209 | plot_img_grid(x_hat.numpy(), fname, nx=10, ny=10, px=28, py=28, plt=plt)
210 | plt.close()
211 |
212 | fname = "{}/x_data.png".format(out_dir)
213 | plot_img_grid(x, fname, nx=10, ny=10, px=28, py=28, plt=plt)
214 | plt.close()
215 |
216 | return Ly, Acc, Lx
217 |
218 | print("----")
219 | with tf.device(gpu_tag):
220 |
221 | best_acc_list = []
222 | acc_list = []
223 | for tr in range(n_trials):
224 | acc_scores = [] # tracks acc during training w/in a trial
225 | ########################################################################
226 | ## create model
227 | ########################################################################
228 | model_dir = "{}/trial{}/".format(out_dir, tr)
229 | if not os.path.exists(model_dir):
230 | os.makedirs(model_dir)
231 | args = {"x_dim": x_dim,
232 | "y_dim": y_dim,
233 | "n_units": 2000,
234 | "K":12,
235 | "thr": 10.0,
236 | "eps_r": 0.01,
237 | "eps_g": 0.025}
238 | agent = PFF_RNN(args=args)
239 | ## set up optimization
240 | eta = 0.00025 # 0.0005 # for grnn
241 | reg_lambda = 0.0
242 | #reg_lambda = 0.0001 # works nice for kmnist
243 | opt = tf.keras.optimizers.Adam(eta)
244 |
245 | g_eta = 0.00025 # 0.0005 #0.001 #0.0005 #0.001
246 | g_reg_lambda = 0
247 | g_opt = tf.keras.optimizers.Adam(g_eta)
248 | ########################################################################
249 |
250 | ########################################################################
251 | ## begin simulation
252 | ########################################################################
253 | Ly, Acc, Lx = eval(agent, devset, out_dir=model_dir)
254 | acc_scores.append(Acc)
255 | print("{}: L {} Acc = {} Lx {}".format(-1, Ly, Acc, Lx))
256 |
257 | best_Acc = Acc
258 | best_Ly = Ly
259 | for t in range(n_iter):
260 | N = 0.0
261 | Ng = 0.0
262 | Lg = 0.0
263 | Ly = 0.0
264 | for batch in dataset:
265 | _, x = batch[0]
266 | _, y = batch[1]
267 | N += x.shape[0]
268 |
269 | # create negative data (x, y_neg)
270 | x_neg = x
271 | y_neg = tf.random.uniform(y.shape, 0.0, 1.0) * (1.0 - y)
272 | y_neg = tf.one_hot(tf.argmax(y_neg,axis=1), depth=agent.y_dim)
273 |
274 | ## create full batch
275 | x_ = tf.concat([x,x_neg],axis=0)
276 | y_ = tf.concat([y,y_neg],axis=0)
277 | lab = tf.concat([tf.ones([x.shape[0],1]),tf.zeros([x_neg.shape[0],1])],axis=0)
278 | ## update model given full batch
279 | z_lat = agent.forward(x_)
280 | Lg_t, _, Lgen_t, x_hat = agent.infer(x_, y_, lab, z_lat, agent.K, opt, g_opt, reg_lambda=reg_lambda, g_reg_lambda=g_reg_lambda) # total goodness
281 | Lg_t = Lg_t * (x.shape[0] + x_neg.shape[0])
282 |
283 | #y_hat = y_hat[0:x.shape[0],:]
284 | y_hat = agent.classify(x)
285 | Ly_t = tf.reduce_sum(-tf.reduce_sum(y * tf.math.log(y_hat), axis=1, keepdims=True))
286 |
287 | ## track losses
288 | Ly = Ly_t + Ly
289 | Lg = Lg_t + Lg
290 | Ng += (x.shape[0] + x_neg.shape[0])
291 |
292 | print("\r {}: Ly = {} L = {} w/ {} samples".format(t, Ly/N, Lg/Ng, N), end="")
293 | print()
294 | print("--------------------------------------")
295 |
296 | Ly, Acc, Lx = eval(agent, devset, out_dir=model_dir)
297 | acc_scores.append(Acc)
298 | np.save("{}/dev_acc.npy".format(model_dir), np.asarray(acc_scores))
299 | print("{}: L {} Acc = {} Lx {}".format(t, Ly, Acc, Lx))
300 |
301 | if Acc > best_Acc:
302 | best_Acc = Acc
303 | best_Ly = Ly
304 |
305 | print(" >> Saving model to: ",model_dir)
306 | agent.save_model(model_dir)
307 |
308 | print("************")
309 | Ly, Acc, _ = eval(agent, dataset, out_dir=model_dir, save_img=False)
310 | print(" Train: Ly {} Acc = {}".format(Ly, Acc))
311 | print("Best.Dev: Ly {} Acc = {}".format(best_Ly, best_Acc))
312 |
313 | acc_list.append(1.0 - Acc)
314 | best_acc_list.append(1.0 - best_Acc)
315 |
316 | ############################################################################
317 | ## calc post-trial statistics
318 | n_dec = 4
319 | mu = round(np.mean(np.asarray(best_acc_list)), n_dec)
320 | sd = round(np.std(np.asarray(best_acc_list)), n_dec)
321 | print(" Dev.Error = {:.4f} \pm {:.4f}".format(mu, sd))
322 |
323 | ## store result to disk just in case...
324 | results_fname = "{}/post_train_results.txt".format(out_dir)
325 | log_t = open(results_fname,"a")
326 | log_t.write("Generalization Results:\n")
327 | log_t.write(" Dev.Error = {:.4f} \pm {:.4f}\n".format(mu, sd))
328 |
329 | n_dec = 4
330 | mu = round(np.mean(np.asarray(acc_list)), n_dec)
331 | sd = round(np.std(np.asarray(acc_list)), n_dec)
332 | print(" Train.Error = {:.4f} \pm {:.4f}".format(mu, sd))
333 | log_t.write("Training-Set/Optimization Results:\n")
334 | log_t.write(" Train.Error = {:.4f} \pm {:.4f}\n".format(mu, sd))
335 |
336 | log_t.close() # close the log file
337 |
--------------------------------------------------------------------------------