├── .gitignore ├── LICENSE ├── README.md ├── imgs ├── BP_vs_FF.png ├── layer.png └── pos_neg.png └── main.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Mohammad Pezeshki 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch_forward_forward 2 | Implementation of forward-forward (FF) training algorithm - an alternative to back-propagation 3 | --- 4 | 5 | Below is my understanding of the FF algorithm presented at [Geoffrey Hinton's talk at NeurIPS 2022](https://www.cs.toronto.edu/~hinton/FFA13.pdf).\ 6 | The conventional backprop computes the gradients by successive applications of the chain rule, from the objective function to the parameters. FF, however, computes the gradients locally with a local objective function, so there is no need to backpropagate the errors. 7 | 8 | ![](./imgs/BP_vs_FF.png) 9 | 10 | The local objective function is designed to push a layer's output to values larger than a threshold for positive samples and to values smaller than a threshold for negative samples. 11 | 12 | A positive sample $s$ is a real datapoint with a large $P(s)$ under the training distribution.\ 13 | A negative sample $s'$ is a fake datapoint with a small $P(s')$ under the training distribution. 14 | 15 | ![](./imgs/layer.png) 16 | 17 | Among the many ways of generating the positive/negative samples, for MNIST, we have:\ 18 | Positive sample $s = merge(x, y)$, the image and its label\ 19 | Negative sample $s' = merge(x, y_{random})$, the image and a random label 20 | 21 | ![](./imgs/pos_neg.png) 22 | 23 | After training all the layers, to make a prediction for a test image $x$, we find the pair $s = (x, y)$ for all $0 \leq y < 10$ that maximizes the network's overall activation. 24 | 25 | With this implementation, the training and test errors on MNIST are: 26 | ```python 27 | > python main.py 28 | train error: 0.06754004955291748 29 | test error: 0.06840002536773682 30 | ``` 31 | -------------------------------------------------------------------------------- /imgs/BP_vs_FF.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mpezeshki/pytorch_forward_forward/1c7a2114bbf85a055dfb50bb13f1516534e9f35d/imgs/BP_vs_FF.png -------------------------------------------------------------------------------- /imgs/layer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mpezeshki/pytorch_forward_forward/1c7a2114bbf85a055dfb50bb13f1516534e9f35d/imgs/layer.png -------------------------------------------------------------------------------- /imgs/pos_neg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mpezeshki/pytorch_forward_forward/1c7a2114bbf85a055dfb50bb13f1516534e9f35d/imgs/pos_neg.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import torch 3 | import torch.nn as nn 4 | from tqdm import tqdm 5 | from torch.optim import Adam 6 | from torchvision.datasets import MNIST 7 | from torchvision.transforms import Compose, ToTensor, Normalize, Lambda 8 | from torch.utils.data import DataLoader 9 | 10 | 11 | def MNIST_loaders(train_batch_size=50000, test_batch_size=10000): 12 | 13 | transform = Compose([ 14 | ToTensor(), 15 | Normalize((0.1307,), (0.3081,)), 16 | Lambda(lambda x: torch.flatten(x))]) 17 | 18 | train_loader = DataLoader( 19 | MNIST('./data/', train=True, 20 | download=True, 21 | transform=transform), 22 | batch_size=train_batch_size, shuffle=True) 23 | 24 | test_loader = DataLoader( 25 | MNIST('./data/', train=False, 26 | download=True, 27 | transform=transform), 28 | batch_size=test_batch_size, shuffle=False) 29 | 30 | return train_loader, test_loader 31 | 32 | 33 | def overlay_y_on_x(x, y): 34 | """Replace the first 10 pixels of data [x] with one-hot-encoded label [y] 35 | """ 36 | x_ = x.clone() 37 | x_[:, :10] *= 0.0 38 | x_[range(x.shape[0]), y] = x.max() 39 | return x_ 40 | 41 | 42 | class Net(torch.nn.Module): 43 | 44 | def __init__(self, dims): 45 | super().__init__() 46 | self.layers = [] 47 | for d in range(len(dims) - 1): 48 | self.layers += [Layer(dims[d], dims[d + 1]).cuda()] 49 | 50 | def predict(self, x): 51 | goodness_per_label = [] 52 | for label in range(10): 53 | h = overlay_y_on_x(x, label) 54 | goodness = [] 55 | for layer in self.layers: 56 | h = layer(h) 57 | goodness += [h.pow(2).mean(1)] 58 | goodness_per_label += [sum(goodness).unsqueeze(1)] 59 | goodness_per_label = torch.cat(goodness_per_label, 1) 60 | return goodness_per_label.argmax(1) 61 | 62 | def train(self, x_pos, x_neg): 63 | h_pos, h_neg = x_pos, x_neg 64 | for i, layer in enumerate(self.layers): 65 | print('training layer', i, '...') 66 | h_pos, h_neg = layer.train(h_pos, h_neg) 67 | 68 | 69 | class Layer(nn.Linear): 70 | def __init__(self, in_features, out_features, 71 | bias=True, device=None, dtype=None): 72 | super().__init__(in_features, out_features, bias, device, dtype) 73 | self.relu = torch.nn.ReLU() 74 | self.opt = Adam(self.parameters(), lr=0.03) 75 | self.threshold = 2.0 76 | self.num_epochs = 1000 77 | 78 | def forward(self, x): 79 | x_direction = x / (x.norm(2, 1, keepdim=True) + 1e-4) 80 | return self.relu( 81 | torch.mm(x_direction, self.weight.T) + 82 | self.bias.unsqueeze(0)) 83 | 84 | def train(self, x_pos, x_neg): 85 | for i in tqdm(range(self.num_epochs)): 86 | g_pos = self.forward(x_pos).pow(2).mean(1) 87 | g_neg = self.forward(x_neg).pow(2).mean(1) 88 | # The following loss pushes pos (neg) samples to 89 | # values larger (smaller) than the self.threshold. 90 | loss = torch.log(1 + torch.exp(torch.cat([ 91 | -g_pos + self.threshold, 92 | g_neg - self.threshold]))).mean() 93 | self.opt.zero_grad() 94 | # this backward just compute the derivative and hence 95 | # is not considered backpropagation. 96 | loss.backward() 97 | self.opt.step() 98 | return self.forward(x_pos).detach(), self.forward(x_neg).detach() 99 | 100 | 101 | def visualize_sample(data, name='', idx=0): 102 | reshaped = data[idx].cpu().reshape(28, 28) 103 | plt.figure(figsize = (4, 4)) 104 | plt.title(name) 105 | plt.imshow(reshaped, cmap="gray") 106 | plt.show() 107 | 108 | 109 | if __name__ == "__main__": 110 | torch.manual_seed(1234) 111 | train_loader, test_loader = MNIST_loaders() 112 | 113 | net = Net([784, 500, 500]) 114 | x, y = next(iter(train_loader)) 115 | x, y = x.cuda(), y.cuda() 116 | x_pos = overlay_y_on_x(x, y) 117 | rnd = torch.randperm(x.size(0)) 118 | x_neg = overlay_y_on_x(x, y[rnd]) 119 | 120 | for data, name in zip([x, x_pos, x_neg], ['orig', 'pos', 'neg']): 121 | visualize_sample(data, name) 122 | 123 | net.train(x_pos, x_neg) 124 | 125 | print('train error:', 1.0 - net.predict(x).eq(y).float().mean().item()) 126 | 127 | x_te, y_te = next(iter(test_loader)) 128 | x_te, y_te = x_te.cuda(), y_te.cuda() 129 | 130 | print('test error:', 1.0 - net.predict(x_te).eq(y_te).float().mean().item()) 131 | --------------------------------------------------------------------------------