├── .gitignore ├── LICENSE ├── README.md ├── mnist_bp.py ├── mnist_ff.py ├── network.py └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | runs 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 97 | __pypackages__/ 98 | 99 | # Celery stuff 100 | celerybeat-schedule 101 | celerybeat.pid 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | env/ 110 | venv/ 111 | ENV/ 112 | env.bak/ 113 | venv.bak/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | .dmypy.json 128 | dmypy.json 129 | 130 | # Pyre type checker 131 | .pyre/ 132 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Carlo Alberto Barbano 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # forward-forward-pytorch 2 | 3 | PyTorch implementation of Hinton's FF Algorithm with hard negatives sampling. 4 | -------------------------------------------------------------------------------- /mnist_bp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchvision 5 | import argparse 6 | import network 7 | 8 | from torch.utils.data import DataLoader 9 | from torchvision import transforms 10 | from torchvision.datasets import MNIST 11 | from tqdm import tqdm 12 | from util import accuracy, set_seed 13 | 14 | class Opts: 15 | layer_size = 2000 16 | batch_size = 1000 17 | 18 | lr = 0.1 19 | weight_decay = 0 20 | epochs = 60 21 | 22 | seed = 0 23 | device = 'cuda' 24 | 25 | 26 | @torch.no_grad() 27 | def test(network_bp, test_loader, opts): 28 | all_outputs = [] 29 | all_labels = [] 30 | 31 | for (x_test, y_test) in test_loader: 32 | x_test, y_test = x_test.to(opts.device), y_test.to(opts.device) 33 | x_test = x_test.view(x_test.shape[0], -1) 34 | acts = network_bp(x_test) 35 | all_outputs.append(acts) 36 | all_labels.append(y_test) 37 | 38 | all_outputs = torch.cat(all_outputs) 39 | all_labels = torch.cat(all_labels) 40 | top1 = accuracy(all_outputs, all_labels, topk=(1,))[0] 41 | return top1 42 | 43 | def train(network_bp, optimizer, train_loader, opts): 44 | running_loss = 0. 45 | 46 | for (x, y_ground) in train_loader: 47 | x, y_ground = x.to(opts.device), y_ground.to(opts.device) 48 | x = x.view(opts.batch_size, -1) 49 | 50 | with torch.enable_grad(): 51 | ys = network_bp(x) 52 | loss = F.cross_entropy(ys, y_ground) 53 | loss.backward() 54 | 55 | running_loss += loss.detach() 56 | 57 | optimizer.step() 58 | optimizer.zero_grad() 59 | 60 | running_loss /= len(train_loader) 61 | return running_loss 62 | 63 | def main(opts): 64 | set_seed(opts.seed) 65 | 66 | T_train = transforms.Compose([ 67 | transforms.ToTensor(), 68 | transforms.Normalize((0.1307,), (0.3081,)) 69 | ]) 70 | 71 | T_test = transforms.Compose([ 72 | transforms.ToTensor(), 73 | transforms.Normalize((0.1307,), (0.3081,)) 74 | ]) 75 | 76 | train_loader = DataLoader(MNIST("~/data", train=True, download=True, transform=T_train), 77 | batch_size=opts.batch_size, shuffle=True, drop_last=True) 78 | 79 | test_loader = DataLoader(MNIST("~/data", train=False, download=True, transform=T_test), 80 | batch_size=opts.batch_size, shuffle=True) 81 | 82 | size = opts.layer_size 83 | network_bp = network.Network(dims=[28*28, size, size, size, 10], ff=False).to(opts.device) 84 | print(network_bp) 85 | 86 | optimizer = torch.optim.SGD(network_bp.parameters(), 87 | lr=opts.lr, 88 | weight_decay=opts.weight_decay) 89 | 90 | best_acc = 0. 91 | for step in range(1, opts.epochs+1): 92 | running_ce = train(network_bp, optimizer, train_loader, opts) 93 | 94 | top1 = test(network_bp, test_loader, opts) 95 | if top1 > best_acc: 96 | best_acc = top1 97 | print(f"Step {step:04d} CE: {running_ce:.4f} acc@1: {top1:.2f}") 98 | print('Best acc:', best_acc) 99 | 100 | if __name__ == '__main__': 101 | opts = Opts() 102 | main(opts) -------------------------------------------------------------------------------- /mnist_ff.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchvision 5 | import argparse 6 | import network 7 | import torch.utils.tensorboard 8 | 9 | from collections import defaultdict 10 | from torch.utils.data import DataLoader 11 | from torchvision import transforms 12 | from torchvision.datasets import MNIST 13 | from tqdm import tqdm 14 | from torch.utils.tensorboard.writer import SummaryWriter 15 | from util import set_seed, accuracy 16 | 17 | class Opts: 18 | hard_negatives = True 19 | layer_size = 2000 20 | 21 | batch_size = 200 22 | lr = 0.0001 23 | weight_decay = 0 24 | epochs = 60 25 | steps_per_block = 60 26 | theta = 10. 27 | 28 | seed = 0 29 | device = 'cuda' 30 | 31 | 32 | def norm_y(y_one_hot: torch.Tensor): 33 | return y_one_hot.sub(0.1307).div(0.3081) 34 | 35 | @torch.no_grad() 36 | def test(network_ff, linear_cf, test_loader, opts): 37 | all_outputs = [] 38 | all_labels = [] 39 | all_logits = [] 40 | 41 | for (x_test, y_test) in test_loader: 42 | x_test, y_test = x_test.to(opts.device), y_test.to(opts.device) 43 | x_test = x_test.view(x_test.shape[0], -1) 44 | 45 | acts_for_labels = [] 46 | 47 | # slow method 48 | for label in range(10): 49 | test_label = torch.ones_like(y_test).fill_(label) 50 | test_label = norm_y(F.one_hot(test_label, num_classes=10)) 51 | x_with_labels = torch.cat((x_test, test_label), dim=1) 52 | 53 | acts = network_ff(x_with_labels) 54 | acts = acts.norm(dim=-1) 55 | acts_for_labels.append(acts) 56 | 57 | # these are logits 58 | acts_for_labels = torch.stack(acts_for_labels, dim=1) #should be BSZxLABELSxLAYERS (10) 59 | all_outputs.append(acts_for_labels) 60 | all_labels.append(y_test) 61 | 62 | # quick method 63 | neutral_label = norm_y(torch.full((x_test.shape[0], 10), 0.1, device=opts.device)) 64 | acts = network_ff(torch.cat((x_test, neutral_label), dim=1)) 65 | logits = linear_cf(acts.view(acts.shape[0], -1)) 66 | all_logits.append(logits) 67 | 68 | all_outputs = torch.cat(all_outputs) 69 | all_labels = torch.cat(all_labels) 70 | all_logits = torch.cat(all_logits) 71 | 72 | slow_acc = accuracy(all_outputs.mean(dim=-1), all_labels, topk=(1,))[0] 73 | fast_acc = accuracy(all_logits, all_labels, topk=(1,))[0] 74 | return slow_acc, fast_acc 75 | 76 | def train(network_ff, optimizer, linear_cf, optimizer_cf, train_loader, start_block, opts): 77 | running_loss = 0. 78 | running_ce = 0. 79 | 80 | for (x, y_pos) in train_loader: 81 | x, y_pos = x.to(opts.device), y_pos.to(opts.device) 82 | x = x.view(opts.batch_size, -1) 83 | 84 | # positive pairs 85 | y_pos_one_hot = norm_y(F.one_hot(y_pos, num_classes=10)) 86 | x_pos = torch.cat((x, y_pos_one_hot), dim=1) 87 | 88 | # sample negatives (and train linear cf) 89 | with torch.no_grad(): 90 | ys = network_ff(torch.cat((x, torch.ones_like(y_pos_one_hot).fill_(0.1)), dim=1)) 91 | 92 | with torch.enable_grad(): 93 | logits = linear_cf(ys.view(ys.shape[0], -1).detach()) 94 | ce = F.cross_entropy(logits, y_pos) 95 | ce.backward() 96 | running_ce += ce.detach() 97 | 98 | optimizer_cf.step() 99 | optimizer_cf.zero_grad() 100 | 101 | # negative pairs from softmax layer 102 | probs = torch.softmax(logits, dim=1) 103 | preds = torch.argmax(probs, dim=1) 104 | idx = torch.where(preds != y_pos) 105 | y_hard_one_hot = norm_y(F.one_hot(preds, num_classes=10)) 106 | x_hard = torch.cat((x, y_hard_one_hot), dim=1)[idx] 107 | 108 | # negative pairs from random labels 109 | y_rand = torch.randint(0, 10, (opts.batch_size,), device=opts.device) 110 | idx = torch.where(y_rand != y_pos) # correct labels 111 | y_rand_one_hot = norm_y(F.one_hot(y_rand, num_classes=10)) 112 | x_rand = torch.cat((x, y_rand_one_hot), dim=1) #[idx] # keeping positives seems to work better 113 | 114 | x_neg = x_rand 115 | if opts.hard_negatives: 116 | x_neg = torch.cat((x_neg, x_hard), dim=0) 117 | 118 | with torch.enable_grad(): 119 | z_pos = network_ff(x_pos, cat=False) 120 | z_neg = network_ff(x_neg, cat=False) 121 | 122 | for idx, (zp, zn) in enumerate(zip(z_pos, z_neg)): 123 | if idx < start_block: 124 | continue 125 | 126 | positive_loss = torch.log(1 + torch.exp((-zp.norm(dim=-1) + opts.theta))).mean() 127 | negative_loss = torch.log(1 + torch.exp((zn.norm(dim=-1) - opts.theta))).mean() 128 | loss = positive_loss + negative_loss 129 | loss.backward() 130 | 131 | running_loss += loss.detach() 132 | optimizer[idx].step() 133 | optimizer[idx].zero_grad() 134 | 135 | running_loss /= len(train_loader) 136 | running_ce /= len(train_loader) 137 | 138 | return running_loss, running_ce 139 | 140 | def main(opts): 141 | set_seed(opts.seed) 142 | 143 | T_train = transforms.Compose([ 144 | transforms.ToTensor(), 145 | transforms.Normalize((0.1307,), (0.3081,)) 146 | ]) 147 | 148 | T_test = transforms.Compose([ 149 | transforms.ToTensor(), 150 | transforms.Normalize((0.1307,), (0.3081,)) 151 | ]) 152 | 153 | train_loader = DataLoader(MNIST("~/data", train=True, download=True, transform=T_train), 154 | batch_size=opts.batch_size, shuffle=True, drop_last=True, num_workers=8, 155 | persistent_workers=True) 156 | 157 | test_loader = DataLoader(MNIST("~/data", train=False, download=True, transform=T_test), 158 | batch_size=opts.batch_size, shuffle=True, num_workers=8, 159 | persistent_workers=True) 160 | 161 | size = opts.layer_size 162 | network_ff = network.Network(dims=[28*28 + 10, size, size, size, size]).to(opts.device) 163 | print(network_ff) 164 | 165 | # Create one optimizer for evey relu layer (block) 166 | optimizers = [ 167 | torch.optim.Adam(block.parameters(), lr=opts.lr, weight_decay=opts.weight_decay) 168 | for block in network_ff.blocks.children() 169 | ] 170 | 171 | # Softmax layer for predicting classes from embeddings (fast method) 172 | linear_cf = nn.Linear(size*network_ff.n_blocks, 10).to(opts.device) 173 | optimizer_cf = torch.optim.Adam(linear_cf.parameters(), lr=0.0001) 174 | 175 | writer = SummaryWriter() 176 | 177 | start_block = 0 178 | for step in range(1, opts.epochs+1): 179 | running_loss, running_ce = train(network_ff, optimizers, linear_cf, optimizer_cf, 180 | train_loader, start_block, opts) 181 | if step % opts.steps_per_block == 0: 182 | if start_block+1 < network_ff.n_blocks: 183 | start_block += 1 184 | print("Freezing block", start_block-1) 185 | 186 | writer.add_scalar("train/loss", running_loss, step) 187 | writer.add_scalar("train/ce", running_ce, step) 188 | 189 | train_slow_acc, train_fast_acc = test(network_ff, linear_cf, train_loader, opts) 190 | test_slow_acc, test_fast_acc = test(network_ff, linear_cf, test_loader, opts) 191 | 192 | writer.add_scalar("acc_fast/train", train_fast_acc, step) 193 | writer.add_scalar("acc_fast/test", test_fast_acc, step) 194 | writer.add_scalar("acc_slow/train", train_slow_acc, step) 195 | writer.add_scalar("acc_slow/test", test_slow_acc, step) 196 | 197 | print(f"Step {step:03d} Loss: {running_loss:.4f} CE: {running_ce:.4f}", 198 | f"-- TRAIN: fast {train_fast_acc:.2f} (err {(100. - train_fast_acc):.2f}) slow {train_slow_acc:.2f} (err {(100. - train_slow_acc):.2f})", 199 | f"-- TEST: fast {test_fast_acc:.2f} (err {(100. - test_fast_acc):.2f}) - slow {test_slow_acc:.2f} (err {(100. - test_slow_acc):.2f})") 200 | 201 | if __name__ == '__main__': 202 | opts = Opts() 203 | main(opts) -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class Block(nn.Module): 7 | def __init__(self, in_dim, out_dim, normalize_input, ff, bias=False): 8 | super().__init__() 9 | self.fc = nn.Linear(in_dim, out_dim, bias=bias) 10 | self.relu = nn.ReLU(True) 11 | self.normalize_input = normalize_input 12 | self.ff = ff 13 | 14 | def forward(self, x): 15 | if self.normalize_input: 16 | x = F.normalize(x, dim=1) 17 | 18 | x = self.fc(x) 19 | self.x = x 20 | 21 | if self.ff: 22 | return self.relu(x).detach() 23 | return self.relu(x) 24 | 25 | class Network(nn.Module): 26 | def __init__(self, dims, ff=True): 27 | super().__init__() 28 | 29 | blocks = [] 30 | blocks.append(Block(dims[0], dims[1], False, ff)) 31 | for i in range(len(dims[1:-1])): 32 | blocks.append(Block(dims[i+1], dims[i+2], True, ff)) 33 | 34 | # just for print 35 | self.blocks = nn.Sequential(*blocks) 36 | self.n_blocks = len(blocks) 37 | self.ff = ff 38 | 39 | def forward(self, x, cat=True): 40 | x = self.blocks(x) 41 | 42 | if not self.ff: 43 | return x 44 | 45 | xs = [b.x for b in self.blocks.children()] 46 | 47 | if not cat: 48 | return xs 49 | return torch.stack(xs, dim=1) 50 | 51 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy as np 4 | import os 5 | 6 | 7 | def set_seed(seed): 8 | random.seed(seed) 9 | os.environ["PYTHONHASHSEED"] = str(seed) 10 | np.random.seed(seed) 11 | torch.cuda.manual_seed(seed) 12 | torch.cuda.manual_seed_all(seed) 13 | torch.backends.cudnn.deterministic = False 14 | torch.backends.cudnn.benchmark = True 15 | torch.manual_seed(seed) 16 | 17 | def accuracy(output, target, topk=(1,)): 18 | """Computes the accuracy over the k top predictions for the specified values of k""" 19 | with torch.no_grad(): 20 | maxk = max(topk) 21 | batch_size = target.size(0) 22 | 23 | _, pred = output.topk(maxk, 1, True, True) 24 | pred = pred.t() 25 | correct = pred.eq(target.reshape(1, -1).expand_as(pred)) 26 | 27 | res = [] 28 | for k in topk: 29 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 30 | res.append(correct_k.mul_(100.0 / batch_size).item()) 31 | return res 32 | --------------------------------------------------------------------------------