├── LICENSE ├── README.md ├── function.py ├── main.py ├── models.py ├── module.py ├── setup.py ├── synth_dataset_gen.py ├── talks └── LightOnAIMeetUp.pdf ├── train.py └── training_algorithms_topologies.png /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Direct Random Target Projection (DRTP) - PyTorch-based implementation 2 | 3 | > *Copyright (C) 2019, Université catholique de Louvain (UCLouvain), Belgium.* 4 | 5 | > *The source code is free: you can redistribute it and/or modify it under the terms of the Apache v2.0 license.* 6 | 7 | > *The software and materials distributed under this license are provided in the hope that it will be useful on an **'as is' basis, without warranties or conditions of any kind, either expressed or implied; without even the implied warranty of merchantability or fitness for a particular purpose**. See the Apache v2.0 license for more details.* 8 | 9 | > *You should have received a copy of the Apache v2.0 license along with the source code files (see [LICENSE](LICENSE) file). If not, see .* 10 | 11 | | ![Topologies](training_algorithms_topologies.png) | 12 | |:--:| 13 | | *Fig. 1 - (a) Backpropagation of error algorithm (BP). (b) Feedback alignment (FA) [Lillicrap et al., Nat. Comms., 2016]. (c) Direct feedback alignment (DFA) [Nokland, NIPS, 2016]. (d) Proposed direct random target projection (DRTP) algorithm.* | 14 | 15 | The provided source files contain the PyTorch-based code for training fully-connected and convolutional networks using the following algorithms, as summarized in Fig. 1: 16 | * the backpropagation of error algorithm (BP), 17 | * feedback alignment (FA) [Lillicrap et al., Nat. Comms., 2016], 18 | * direct feedback alignment (DFA) [Nokland, NIPS, 2016], 19 | * the **proposed direct random target projection (DRTP)** algorithm, which solves both the weight transport and the update locking problems. 20 | 21 | In order to reproduce our experiments in the associated paper (see below), the error-sign-based variant of DFA (sDFA) and shallow learning are also available. 22 | 23 | In case you decide to use the source code for academic or commercial use, we would appreciate if you let us know; **feedback is welcome**. Upon usage of the source code, **please cite the associated paper** (available [here](https://www.frontiersin.org/articles/10.3389/fnins.2021.629892/full)): 24 | 25 | > C. Frenkel, M. Lefebvre and D. Bol, "Learning without feedback: Fixed Random Learning Signals Allow for Feedforward Training of Deep Neural Networks," *Frontiers in Neuroscience*, vol. 15, no. 629892, 2021. doi: 10.3389/fnins.2021.629892 26 | 27 | Instructions on how to use the code are available in the [main.py](main.py) source file. 28 | -------------------------------------------------------------------------------- /function.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | ------------------------------------------------------------------------------ 5 | 6 | Copyright (C) 2019 Université catholique de Louvain (UCLouvain), Belgium. 7 | 8 | Licensed under the Apache License, Version 2.0 (the "License"); 9 | you may not use this file except in compliance with the License. 10 | You may obtain a copy of the License at 11 | 12 | http://www.apache.org/licenses/LICENSE-2.0 13 | 14 | Unless required by applicable law or agreed to in writing, software 15 | distributed under the License is distributed on an "AS IS" BASIS, 16 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | See the License for the specific language governing permissions and 18 | limitations under the License. 19 | 20 | ------------------------------------------------------------------------------ 21 | 22 | "function.py" - Functional definition of the TrainingHook class (module.py). 23 | 24 | Project: DRTP - Direct Random Target Projection 25 | 26 | Authors: C. Frenkel and M. Lefebvre, Université catholique de Louvain (UCLouvain), 09/2019 27 | 28 | Cite/paper: C. Frenkel, M. Lefebvre and D. Bol, "Learning without feedback: 29 | Fixed random learning signals allow for feedforward training of deep neural networks," 30 | Frontiers in Neuroscience, vol. 15, no. 629892, 2021. doi: 10.3389/fnins.2021.629892 31 | 32 | ------------------------------------------------------------------------------ 33 | """ 34 | 35 | 36 | import torch 37 | from torch.autograd import Function 38 | from numpy import prod 39 | 40 | class HookFunction(Function): 41 | @staticmethod 42 | def forward(ctx, input, labels, y, fixed_fb_weights, train_mode): 43 | if train_mode in ["DFA", "sDFA", "DRTP"]: 44 | ctx.save_for_backward(input, labels, y, fixed_fb_weights) 45 | ctx.in1 = train_mode 46 | return input 47 | 48 | @staticmethod 49 | def backward(ctx, grad_output): 50 | train_mode = ctx.in1 51 | if train_mode == "BP": 52 | return grad_output, None, None, None, None 53 | elif train_mode == "shallow": 54 | grad_output.data.zero_() 55 | return grad_output, None, None, None, None 56 | 57 | input, labels, y, fixed_fb_weights = ctx.saved_variables 58 | if train_mode == "DFA": 59 | grad_output_est = (y-labels).mm(fixed_fb_weights.view(-1,prod(fixed_fb_weights.shape[1:]))).view(grad_output.shape) 60 | elif train_mode == "sDFA": 61 | grad_output_est = torch.sign(y-labels).mm(fixed_fb_weights.view(-1,prod(fixed_fb_weights.shape[1:]))).view(grad_output.shape) 62 | elif train_mode == "DRTP": 63 | grad_output_est = labels.mm(fixed_fb_weights.view(-1,prod(fixed_fb_weights.shape[1:]))).view(grad_output.shape) 64 | else: 65 | raise NameError("=== ERROR: training mode " + str(train_mode) + " not supported") 66 | 67 | return grad_output_est, None, None, None, None 68 | 69 | trainingHook = HookFunction.apply 70 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | ------------------------------------------------------------------------------ 5 | 6 | Copyright (C) 2019 Université catholique de Louvain (UCLouvain), Belgium. 7 | 8 | Licensed under the Apache License, Version 2.0 (the "License"); 9 | you may not use this file except in compliance with the License. 10 | You may obtain a copy of the License at 11 | 12 | http://www.apache.org/licenses/LICENSE-2.0 13 | 14 | Unless required by applicable law or agreed to in writing, software 15 | distributed under the License is distributed on an "AS IS" BASIS, 16 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | See the License for the specific language governing permissions and 18 | limitations under the License. 19 | 20 | ------------------------------------------------------------------------------ 21 | 22 | "main.py" - Main file for training fully-connected and convolutional networks using backpropagation (BP), 23 | feedback alignment (FA) [Lillicrap, Nat. Comms, 2016], direct feedback alignment (DFA) [Nokland, NIPS, 2016], 24 | and the proposed direct random target projection (DRTP). 25 | Example: use the following command to reach ~70% accuracy on the test set of CIFAR-10 using DRTP: 26 | python main.py --dataset CIFAR10aug --train-mode DRTP --epochs 200 --freeze-conv-layers 27 | --dropout 0.05 --topology CONV_64_3_1_1_CONV_256_3_1_1_FC_2000_FC_10 28 | --loss CE --output-act none --lr 5e-4 29 | 30 | Project: DRTP - Direct Random Target Projection 31 | 32 | Authors: C. Frenkel and M. Lefebvre, Université catholique de Louvain (UCLouvain), 09/2019 33 | 34 | Cite/paper: C. Frenkel, M. Lefebvre and D. Bol, "Learning without feedback: 35 | Fixed random learning signals allow for feedforward training of deep neural networks," 36 | Frontiers in Neuroscience, vol. 15, no. 629892, 2021. doi: 10.3389/fnins.2021.629892 37 | 38 | ------------------------------------------------------------------------------ 39 | """ 40 | 41 | 42 | import argparse 43 | import train 44 | import setup 45 | 46 | def main(): 47 | parser = argparse.ArgumentParser(description='Training fully-connected and convolutional networks using backpropagation (BP), feedback alignment (FA), direct feedback alignment (DFA), and direct random target projection (DRTP)') 48 | # General 49 | parser.add_argument('--cpu', action='store_true', default=False, help='Disable CUDA and run on CPU.') 50 | # Dataset 51 | parser.add_argument('--dataset', type=str, choices = ['regression_synth', 'classification_synth', 'MNIST', 'CIFAR10', 'CIFAR10aug'], default='MNIST', help='Choice of the dataset: synthetic regression (regression_synth), synthetic classification (classification_synth), MNIST (MNIST), CIFAR-10 (CIFAR10), CIFAR-10 with data augmentation (CIFAR10aug). Synthetic datasets must have been generated previously with synth_dataset_gen.py. Default: MNIST.') 52 | # Training 53 | parser.add_argument('--train-mode', choices = ['BP','FA','DFA','DRTP','sDFA','shallow'], default='DRTP', help='Choice of the training algorithm - backpropagation (BP), feedback alignment (FA), direct feedback alignment (DFA), direct random target propagation (DRTP), error-sign-based DFA (sDFA), shallow learning with all layers freezed but the last one that is BP-trained (shallow). Default: DRTP.') 54 | parser.add_argument('--optimizer', choices = ['SGD', 'NAG', 'Adam', 'RMSprop'], default='NAG', help='Choice of the optimizer - stochastic gradient descent with 0.9 momentum (SGD), SGD with 0.9 momentum and Nesterov-accelerated gradients (NAG), Adam (Adam), and RMSprop (RMSprop). Default: NAG.') 55 | parser.add_argument('--loss', choices = ['MSE', 'BCE', 'CE'], default='BCE', help='Choice of loss function - mean squared error (MSE), binary cross entropy (BCE), cross entropy (CE, which already contains a logsoftmax activation function). Default: BCE.') 56 | parser.add_argument('--freeze-conv-layers', action='store_true', default=False, help='Disable training of convolutional layers and keeps the weights at their initialized values.') 57 | parser.add_argument('--fc-zero-init', action='store_true', default=False, help='Initializes fully-connected weights to zero instead of the default He uniform initialization.') 58 | parser.add_argument('--dropout', type=float, default=0, help='Dropout probability (applied only to fully-connected layers). Default: 0.') 59 | parser.add_argument('--trials', type=int, default=1, help='Number of training trials Default: 1.') 60 | parser.add_argument('--epochs', type=int, default=100, help='Number of training epochs Default: 100.') 61 | parser.add_argument('--batch-size', type=int, default=100, help='Input batch size for training. Default: 100.') 62 | parser.add_argument('--test-batch-size', type=int, default=1000, help='Input batch size for testing Default: 1000.') 63 | parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate. Default: 1e-4.') 64 | # Network 65 | parser.add_argument('--topology', type=str, default='CONV_32_5_1_2_FC_1000_FC_10', help='Choice of network topology. Format for convolutional layers: CONV_{output channels}_{kernel size}_{stride}_{padding}. Format for fully-connected layers: FC_{output units}.') 66 | parser.add_argument('--conv-act', type=str, choices = {'tanh', 'sigmoid', 'relu'}, default='tanh', help='Type of activation for the convolutional layers - Tanh (tanh), Sigmoid (sigmoid), ReLU (relu). Default: tanh.') 67 | parser.add_argument('--hidden-act', type=str, choices = {'tanh', 'sigmoid', 'relu'}, default='tanh', help='Type of activation for the fully-connected hidden layers - Tanh (tanh), Sigmoid (sigmoid), ReLU (relu). Default: tanh.') 68 | parser.add_argument('--output-act', type=str, choices = {'sigmoid', 'tanh', 'none'}, default='sigmoid', help='Type of activation for the network output layer - Sigmoid (sigmoid), Tanh (tanh), none (none). Default: sigmoid.') 69 | 70 | args = parser.parse_args() 71 | 72 | (device, train_loader, traintest_loader, test_loader) = setup.setup(args) 73 | train.train(args, device, train_loader, traintest_loader, test_loader) 74 | 75 | if __name__ == '__main__': 76 | main() -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | ------------------------------------------------------------------------------ 5 | 6 | Copyright (C) 2019 Université catholique de Louvain (UCLouvain), Belgium. 7 | 8 | Licensed under the Apache License, Version 2.0 (the "License"); 9 | you may not use this file except in compliance with the License. 10 | You may obtain a copy of the License at 11 | 12 | http://www.apache.org/licenses/LICENSE-2.0 13 | 14 | Unless required by applicable law or agreed to in writing, software 15 | distributed under the License is distributed on an "AS IS" BASIS, 16 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | See the License for the specific language governing permissions and 18 | limitations under the License. 19 | 20 | ------------------------------------------------------------------------------ 21 | 22 | "models.py" - Construction of arbitrary network topologies. 23 | 24 | Project: DRTP - Direct Random Target Projection 25 | 26 | Authors: C. Frenkel and M. Lefebvre, Université catholique de Louvain (UCLouvain), 09/2019 27 | 28 | Cite/paper: C. Frenkel, M. Lefebvre and D. Bol, "Learning without feedback: 29 | Fixed random learning signals allow for feedforward training of deep neural networks," 30 | Frontiers in Neuroscience, vol. 15, no. 629892, 2021. doi: 10.3389/fnins.2021.629892 31 | 32 | ------------------------------------------------------------------------------ 33 | """ 34 | 35 | 36 | 37 | import torch 38 | import torch.nn as nn 39 | import torch.nn.functional as F 40 | import function 41 | from module import FA_wrapper, TrainingHook 42 | 43 | class NetworkBuilder(nn.Module): 44 | """ 45 | This version of the network builder assumes stride-2 pooling operations. 46 | """ 47 | def __init__(self, topology, input_size, input_channels, label_features, train_batch_size, train_mode, dropout, conv_act, hidden_act, output_act, fc_zero_init, loss, device): 48 | super(NetworkBuilder, self).__init__() 49 | self.apply_softmax = (output_act == "none") and (loss == "CE") 50 | 51 | self.layers = nn.ModuleList() 52 | if (train_mode == "DFA") or (train_mode == "sDFA"): 53 | self.y = torch.zeros(train_batch_size, label_features, device=device) 54 | self.y.requires_grad = False 55 | else: 56 | self.y = None 57 | 58 | topology = topology.split('_') 59 | topology_layers = [] 60 | num_layers = 0 61 | for elem in topology: 62 | if not any(i.isdigit() for i in elem): 63 | num_layers += 1 64 | topology_layers.append([]) 65 | topology_layers[num_layers-1].append(elem) 66 | for i in range(num_layers): 67 | layer = topology_layers[i] 68 | try: 69 | if layer[0] == "CONV": 70 | in_channels = input_channels if (i==0) else out_channels 71 | out_channels = int(layer[1]) 72 | input_dim = input_size if (i==0) else int(output_dim/2) #/2 accounts for pooling operation of the previous convolutional layer 73 | output_dim = int((input_dim - int(layer[2]) + 2*int(layer[4]))/int(layer[3]))+1 74 | self.layers.append(CNN_block( 75 | in_channels=in_channels, 76 | out_channels=int(layer[1]), 77 | kernel_size=int(layer[2]), 78 | stride=int(layer[3]), 79 | padding=int(layer[4]), 80 | bias=True, 81 | activation=conv_act, 82 | dim_hook=[label_features,out_channels,output_dim,output_dim], 83 | label_features=label_features, 84 | train_mode=train_mode 85 | )) 86 | elif layer[0] == "FC": 87 | if (i==0): 88 | input_dim = pow(input_size,2)*input_channels 89 | self.conv_to_fc = 0 90 | elif topology_layers[i-1][0]=="CONV": 91 | input_dim = pow(int(output_dim/2),2)*int(topology_layers[i-1][1]) #/2 accounts for pooling operation of the previous convolutional layer 92 | self.conv_to_fc = i 93 | else: 94 | input_dim = output_dim 95 | output_dim = int(layer[1]) 96 | output_layer = (i == (num_layers-1)) 97 | self.layers.append(FC_block( 98 | in_features=input_dim, 99 | out_features=output_dim, 100 | bias=True, 101 | activation=output_act if output_layer else hidden_act, 102 | dropout=dropout, 103 | dim_hook=None if output_layer else [label_features,output_dim], 104 | label_features=label_features, 105 | fc_zero_init=fc_zero_init, 106 | train_mode=("BP" if (train_mode != "FA") else "FA") if output_layer else train_mode 107 | )) 108 | else: 109 | raise NameError("=== ERROR: layer construct " + str(elem) + " not supported") 110 | except ValueError as e: 111 | raise ValueError("=== ERROR: unsupported layer parameter format: " + str(e)) 112 | 113 | def forward(self, x, labels): 114 | for i in range(len(self.layers)): 115 | if i == self.conv_to_fc: 116 | x = x.reshape(x.size(0), -1) 117 | x = self.layers[i](x, labels, self.y) 118 | 119 | if x.requires_grad and (self.y is not None): 120 | if self.apply_softmax: 121 | self.y.data.copy_(F.softmax(input=x.data, dim=1)) # in-place update, only happens with (s)DFA 122 | else: 123 | self.y.data.copy_(x.data) # in-place update, only happens with (s)DFA 124 | 125 | return x 126 | 127 | 128 | class FC_block(nn.Module): 129 | def __init__(self, in_features, out_features, bias, activation, dropout, dim_hook, label_features, fc_zero_init, train_mode): 130 | super(FC_block, self).__init__() 131 | 132 | self.dropout = dropout 133 | self.fc = nn.Linear(in_features=in_features, out_features=out_features, bias=bias) 134 | if fc_zero_init: 135 | torch.zero_(self.fc.weight.data) 136 | if train_mode == 'FA': 137 | self.fc = FA_wrapper(module=self.fc, layer_type='fc', dim=self.fc.weight.shape) 138 | self.act = Activation(activation) 139 | if dropout != 0: 140 | self.drop = nn.Dropout(p=dropout) 141 | self.hook = TrainingHook(label_features=label_features, dim_hook=dim_hook, train_mode=train_mode) 142 | 143 | def forward(self, x, labels, y): 144 | if self.dropout != 0: 145 | x = self.drop(x) 146 | x = self.fc(x) 147 | x = self.act(x) 148 | x = self.hook(x, labels, y) 149 | return x 150 | 151 | 152 | class CNN_block(nn.Module): 153 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bias, activation, dim_hook, label_features, train_mode): 154 | super(CNN_block, self).__init__() 155 | 156 | self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias) 157 | if train_mode == 'FA': 158 | self.conv = FA_wrapper(module=self.conv, layer_type='conv', dim=self.conv.weight.shape, stride=stride, padding=padding) 159 | self.act = Activation(activation) 160 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2) 161 | self.hook = TrainingHook(label_features=label_features, dim_hook=dim_hook, train_mode=train_mode) 162 | 163 | def forward(self, x, labels, y): 164 | x = self.conv(x) 165 | x = self.act(x) 166 | x = self.hook(x, labels, y) 167 | x = self.pool(x) 168 | return x 169 | 170 | 171 | class Activation(nn.Module): 172 | def __init__(self, activation): 173 | super(Activation, self).__init__() 174 | 175 | if activation == "tanh": 176 | self.act = nn.Tanh() 177 | elif activation == "sigmoid": 178 | self.act = nn.Sigmoid() 179 | elif activation == "relu": 180 | self.act = nn.ReLU() 181 | elif activation == "none": 182 | self.act = None 183 | else: 184 | raise NameError("=== ERROR: activation " + str(activation) + " not supported") 185 | 186 | def forward(self, x): 187 | if self.act == None: 188 | return x 189 | else: 190 | return self.act(x) -------------------------------------------------------------------------------- /module.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | ------------------------------------------------------------------------------ 5 | 6 | Copyright (C) 2019 Université catholique de Louvain (UCLouvain), Belgium. 7 | 8 | Licensed under the Apache License, Version 2.0 (the "License"); 9 | you may not use this file except in compliance with the License. 10 | You may obtain a copy of the License at 11 | 12 | http://www.apache.org/licenses/LICENSE-2.0 13 | 14 | Unless required by applicable law or agreed to in writing, software 15 | distributed under the License is distributed on an "AS IS" BASIS, 16 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | See the License for the specific language governing permissions and 18 | limitations under the License. 19 | 20 | ------------------------------------------------------------------------------ 21 | 22 | "module.py" - Definition of hooks that allow performing FA, DFA, and DRTP training. 23 | 24 | Project: DRTP - Direct Random Target Projection 25 | 26 | Authors: C. Frenkel and M. Lefebvre, Université catholique de Louvain (UCLouvain), 09/2019 27 | 28 | Cite/paper: C. Frenkel, M. Lefebvre and D. Bol, "Learning without feedback: 29 | Fixed random learning signals allow for feedforward training of deep neural networks," 30 | Frontiers in Neuroscience, vol. 15, no. 629892, 2021. doi: 10.3389/fnins.2021.629892 31 | 32 | ------------------------------------------------------------------------------ 33 | """ 34 | 35 | import torch 36 | import torch.nn as nn 37 | from function import trainingHook 38 | 39 | 40 | class FA_wrapper(nn.Module): 41 | def __init__(self, module, layer_type, dim, stride=None, padding=None): 42 | super(FA_wrapper, self).__init__() 43 | self.module = module 44 | self.layer_type = layer_type 45 | self.stride = stride 46 | self.padding = padding 47 | self.output_grad = None 48 | self.x_shape = None 49 | 50 | # FA feedback weights definition 51 | self.fixed_fb_weights = nn.Parameter(torch.Tensor(torch.Size(dim))) 52 | self.reset_weights() 53 | 54 | def forward(self, x): 55 | if x.requires_grad: 56 | x.register_hook(self.FA_hook_pre) 57 | self.x_shape = x.shape 58 | x = self.module(x) 59 | x.register_hook(self.FA_hook_post) 60 | return x 61 | else: 62 | return self.module(x) 63 | 64 | def reset_weights(self): 65 | torch.nn.init.kaiming_uniform_(self.fixed_fb_weights) 66 | self.fixed_fb_weights.requires_grad = False 67 | 68 | def FA_hook_pre(self, grad): 69 | if self.output_grad is not None: 70 | if (self.layer_type == "fc"): 71 | return self.output_grad.mm(self.fixed_fb_weights) 72 | elif (self.layer_type == "conv"): 73 | return torch.nn.grad.conv2d_input(self.x_shape, self.fixed_fb_weights, self.output_grad, self.stride, self.padding) 74 | else: 75 | raise NameError("=== ERROR: layer type " + str(self.layer_type) + " is not supported in FA wrapper") 76 | else: 77 | return grad 78 | 79 | def FA_hook_post(self, grad): 80 | self.output_grad = grad 81 | return grad 82 | 83 | 84 | class TrainingHook(nn.Module): 85 | def __init__(self, label_features, dim_hook, train_mode): 86 | super(TrainingHook, self).__init__() 87 | self.train_mode = train_mode 88 | assert train_mode in ["BP", "FA", "DFA", "DRTP", "sDFA", "shallow"], "=== ERROR: Unsupported hook training mode " + train_mode + "." 89 | 90 | # Feedback weights definition (FA feedback weights are handled in the FA_wrapper class) 91 | if self.train_mode in ["DFA", "DRTP", "sDFA"]: 92 | self.fixed_fb_weights = nn.Parameter(torch.Tensor(torch.Size(dim_hook))) 93 | self.reset_weights() 94 | else: 95 | self.fixed_fb_weights = None 96 | 97 | def reset_weights(self): 98 | torch.nn.init.kaiming_uniform_(self.fixed_fb_weights) 99 | self.fixed_fb_weights.requires_grad = False 100 | 101 | def forward(self, input, labels, y): 102 | return trainingHook(input, labels, y, self.fixed_fb_weights, self.train_mode if (self.train_mode != "FA") else "BP") #FA is handled in FA_wrapper, not in TrainingHook 103 | 104 | def __repr__(self): 105 | return self.__class__.__name__ + ' (' + self.train_mode + ')' 106 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | ------------------------------------------------------------------------------ 5 | 6 | Copyright (C) 2019 Université catholique de Louvain (UCLouvain), Belgium. 7 | 8 | Licensed under the Apache License, Version 2.0 (the "License"); 9 | you may not use this file except in compliance with the License. 10 | You may obtain a copy of the License at 11 | 12 | http://www.apache.org/licenses/LICENSE-2.0 13 | 14 | Unless required by applicable law or agreed to in writing, software 15 | distributed under the License is distributed on an "AS IS" BASIS, 16 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | See the License for the specific language governing permissions and 18 | limitations under the License. 19 | 20 | ------------------------------------------------------------------------------ 21 | 22 | "setup.py" - Setup configuration and dataset loading. 23 | 24 | Project: DRTP - Direct Random Target Projection 25 | 26 | Authors: C. Frenkel and M. Lefebvre, Université catholique de Louvain (UCLouvain), 09/2019 27 | 28 | Cite/paper: C. Frenkel, M. Lefebvre and D. Bol, "Learning without feedback: 29 | Fixed random learning signals allow for feedforward training of deep neural networks," 30 | Frontiers in Neuroscience, vol. 15, no. 629892, 2021. doi: 10.3389/fnins.2021.629892 31 | 32 | ------------------------------------------------------------------------------ 33 | """ 34 | 35 | 36 | import torch 37 | import torchvision 38 | from torchvision import transforms,datasets 39 | import numpy as np 40 | import os 41 | import sys 42 | import subprocess 43 | 44 | 45 | class SynthDataset(torch.utils.data.Dataset): 46 | 47 | def __init__(self, select, type): 48 | self.dataset, self.input_size, self.input_channels, self.label_features = torch.load( './DATASETS/'+select+'/'+type+'.pt') 49 | 50 | def __len__(self): 51 | return len(self.dataset[1]) 52 | 53 | def __getitem__(self, index): 54 | return self.dataset[0][index], self.dataset[1][index] 55 | 56 | def setup(args): 57 | args.cuda = not args.cpu and torch.cuda.is_available() 58 | if args.cuda: 59 | print("=== The available CUDA GPU will be used for computations.") 60 | memory_load = get_gpu_memory_usage() 61 | cuda_device = np.argmin(memory_load).item() 62 | torch.cuda.set_device(cuda_device) 63 | device = torch.cuda.current_device() 64 | else: 65 | device = torch.device('cpu') 66 | 67 | kwargs = {'num_workers': 2, 'pin_memory': True} if args.cuda else {} 68 | if args.dataset == "regression_synth": 69 | print("=== Loading the synthetic regression dataset...") 70 | (train_loader, traintest_loader, test_loader) = load_dataset_regression_synth(args, kwargs) 71 | elif args.dataset == "classification_synth": 72 | print("=== Loading the synthetic classification dataset...") 73 | (train_loader, traintest_loader, test_loader) = load_dataset_classification_synth(args, kwargs) 74 | elif args.dataset == "MNIST": 75 | print("=== Loading the MNIST dataset...") 76 | (train_loader, traintest_loader, test_loader) = load_dataset_mnist(args, kwargs) 77 | elif args.dataset == "CIFAR10": 78 | print("=== Loading the CIFAR-10 dataset...") 79 | (train_loader, traintest_loader, test_loader) = load_dataset_cifar10(args, kwargs) 80 | elif args.dataset == "CIFAR10aug": 81 | print("=== Loading and augmenting the CIFAR-10 dataset...") 82 | (train_loader, traintest_loader, test_loader) = load_dataset_cifar10_augmented(args, kwargs) 83 | else: 84 | print("=== ERROR - Unsupported dataset ===") 85 | sys.exit(1) 86 | args.regression = (args.dataset == "regression_synth") 87 | 88 | return (device, train_loader, traintest_loader, test_loader) 89 | 90 | def get_gpu_memory_usage(): 91 | if sys.platform == "win32": 92 | curr_dir = os.getcwd() 93 | nvsmi_dir = r"C:\Program Files\NVIDIA Corporation\NVSMI" 94 | os.chdir(nvsmi_dir) 95 | result = subprocess.check_output(['nvidia-smi', '--query-gpu=memory.used','--format=csv,nounits,noheader']) 96 | os.chdir(curr_dir) 97 | else: 98 | result = subprocess.check_output(['nvidia-smi', '--query-gpu=memory.used','--format=csv,nounits,noheader']) 99 | gpu_memory = [int(x) for x in result.decode('utf-8').strip().split('\n')] 100 | return gpu_memory 101 | 102 | def load_dataset_regression_synth(args, kwargs): 103 | 104 | trainset = SynthDataset("regression","train") 105 | testset = SynthDataset("regression", "test") 106 | 107 | train_loader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True , **kwargs) 108 | traintest_loader = torch.utils.data.DataLoader(trainset, batch_size=args.test_batch_size, shuffle=False, **kwargs) 109 | test_loader = torch.utils.data.DataLoader(testset , batch_size=args.test_batch_size, shuffle=False, **kwargs) 110 | 111 | args.input_size = trainset.input_size 112 | args.input_channels = trainset.input_channels 113 | args.label_features = trainset.label_features 114 | 115 | return (train_loader, traintest_loader, test_loader) 116 | 117 | def load_dataset_classification_synth(args, kwargs): 118 | 119 | trainset = SynthDataset("classification","train") 120 | testset = SynthDataset("classification", "test") 121 | 122 | train_loader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True , **kwargs) 123 | traintest_loader = torch.utils.data.DataLoader(trainset, batch_size=args.test_batch_size, shuffle=False, **kwargs) 124 | test_loader = torch.utils.data.DataLoader(testset , batch_size=args.test_batch_size, shuffle=False, **kwargs) 125 | 126 | args.input_size = trainset.input_size 127 | args.input_channels = trainset.input_channels 128 | args.label_features = trainset.label_features 129 | 130 | return (train_loader, traintest_loader, test_loader) 131 | 132 | def load_dataset_mnist(args, kwargs): 133 | train_loader = torch.utils.data.DataLoader(datasets.MNIST('./DATASETS', train=True, download=True, transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.0,),(1.0,))])), batch_size=args.batch_size, shuffle=True , **kwargs) 134 | traintest_loader = torch.utils.data.DataLoader(datasets.MNIST('./DATASETS', train=True, download=True, transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.0,),(1.0,))])), batch_size=args.test_batch_size, shuffle=False, **kwargs) 135 | test_loader = torch.utils.data.DataLoader(datasets.MNIST('./DATASETS', train=False, download=True, transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.0,),(1.0,))])), batch_size=args.test_batch_size, shuffle=False, **kwargs) 136 | 137 | args.input_size = 28 138 | args.input_channels = 1 139 | args.label_features = 10 140 | 141 | return (train_loader, traintest_loader, test_loader) 142 | 143 | def load_dataset_cifar10(args, kwargs): 144 | normalize = transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]], std=[x/255.0 for x in [63.0, 62.1, 66.7]]) 145 | transform_cifar10 = transforms.Compose([transforms.ToTensor(),normalize,]) 146 | 147 | train_loader = torch.utils.data.DataLoader(datasets.CIFAR10('./DATASETS', train=True, download=True, transform=transform_cifar10), batch_size=args.batch_size, shuffle=True , **kwargs) 148 | traintest_loader = torch.utils.data.DataLoader(datasets.CIFAR10('./DATASETS', train=True, download=True, transform=transform_cifar10), batch_size=args.test_batch_size, shuffle=False, **kwargs) 149 | test_loader = torch.utils.data.DataLoader(datasets.CIFAR10('./DATASETS', train=False, download=True, transform=transform_cifar10), batch_size=args.test_batch_size, shuffle=False, **kwargs) 150 | 151 | args.input_size = 32 152 | args.input_channels = 3 153 | args.label_features = 10 154 | 155 | return (train_loader, traintest_loader, test_loader) 156 | 157 | def load_dataset_cifar10_augmented(args, kwargs): 158 | #Source: https://zhenye-na.github.io/2018/09/28/pytorch-cnn-cifar10.html 159 | 160 | transform_train = transforms.Compose([ 161 | transforms.RandomHorizontalFlip(), 162 | transforms.ToTensor(), 163 | transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]], 164 | std=[x/255.0 for x in [63.0, 62.1, 66.7]]), 165 | ]) 166 | 167 | # Normalize the test set same as training set without augmentation 168 | transform_test = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]], std=[x/255.0 for x in [63.0, 62.1, 66.7]]),]) 169 | 170 | trainset = torchvision.datasets.CIFAR10('./DATASETS', train=True, download=True, transform=transform_train) 171 | train_loader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True) 172 | 173 | traintestset = torchvision.datasets.CIFAR10('./DATASETS', train=True, download=True, transform=transform_test) 174 | traintest_loader = torch.utils.data.DataLoader(traintestset, batch_size=args.test_batch_size, shuffle=False) 175 | 176 | testset = torchvision.datasets.CIFAR10('./DATASETS', train=False, download=True, transform=transform_test) 177 | test_loader = torch.utils.data.DataLoader(testset, batch_size=args.test_batch_size, shuffle=False) 178 | 179 | args.input_size = 32 180 | args.input_channels = 3 181 | args.label_features = 10 182 | 183 | return (train_loader, traintest_loader, test_loader) -------------------------------------------------------------------------------- /synth_dataset_gen.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | ------------------------------------------------------------------------------ 5 | 6 | Copyright (C) 2019 Université catholique de Louvain (UCLouvain), Belgium. 7 | 8 | Licensed under the Apache License, Version 2.0 (the "License"); 9 | you may not use this file except in compliance with the License. 10 | You may obtain a copy of the License at 11 | 12 | http://www.apache.org/licenses/LICENSE-2.0 13 | 14 | Unless required by applicable law or agreed to in writing, software 15 | distributed under the License is distributed on an "AS IS" BASIS, 16 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | See the License for the specific language governing permissions and 18 | limitations under the License. 19 | 20 | ------------------------------------------------------------------------------ 21 | 22 | "synth_dataset_gen.py" - Generation of synthetic regression and classification datasets. 23 | Launch with command 'python synth_dataset_gen.py' 24 | 25 | Project: DRTP - Direct Random Target Projection 26 | 27 | Authors: C. Frenkel and M. Lefebvre, Université catholique de Louvain (UCLouvain), 09/2019 28 | 29 | Cite/paper: C. Frenkel, M. Lefebvre and D. Bol, "Learning without feedback: 30 | Fixed random learning signals allow for feedforward training of deep neural networks," 31 | Frontiers in Neuroscience, vol. 15, no. 629892, 2021. doi: 10.3389/fnins.2021.629892 32 | 33 | ------------------------------------------------------------------------------ 34 | """ 35 | 36 | 37 | import torch 38 | import numpy as np 39 | from sklearn.datasets import make_classification 40 | import os 41 | import math 42 | 43 | def gen_regression(n_train, n_test, n_classes, input_size_sqrt): 44 | 45 | input_dim = pow(input_size_sqrt,2) 46 | for dataset in ["train","test"]: 47 | n = n_train if (dataset=="train") else n_test 48 | norms = math.pi*(torch.rand(n)*2-1) 49 | X = torch.normal(mean=norms.unsqueeze(1).repeat(1,input_dim),std=1) 50 | t = torch.zeros(n, n_classes) 51 | for i in range(n_classes): 52 | t[:,i] = torch.cos(torch.mean(X,dim=1)+math.pi*(i-4.5)/9) 53 | if dataset=="train": 54 | X_train, t_train = X, t 55 | else: 56 | X_test, t_test = X, t 57 | 58 | if not os.path.exists('./DATASETS/regression'): 59 | os.makedirs('./DATASETS/regression') 60 | torch.save(((X_train,t_train), input_size_sqrt, 1, n_classes), "./DATASETS/regression/train.pt") 61 | torch.save(((X_test ,t_test ), input_size_sqrt, 1, n_classes), "./DATASETS/regression/test.pt") 62 | 63 | 64 | def gen_classification(n_train, n_test, n_classes, n_pix_sqrt, n_inf, n_clusters_per_class=5, class_sep=4.5, random_state=0): 65 | 66 | n_samples = n_train + n_test 67 | input_dim = pow(n_pix_sqrt,2) 68 | X, y = make_classification(n_samples=n_samples, n_features=input_dim, n_informative=n_inf, n_redundant=0, n_repeated=0, n_classes=n_classes, n_clusters_per_class=n_clusters_per_class, scale=np.ones(shape=(input_dim,)), class_sep=class_sep, shuffle=True, random_state=random_state) 69 | X = torch.Tensor(X) 70 | y = torch.Tensor(y).long() 71 | X_train, y_train = X[0:n_train, :], y[0:n_train] 72 | X_test, y_test = X[n_train:, :], y[n_train:] 73 | 74 | if not os.path.exists('./DATASETS/classification'): 75 | os.makedirs('./DATASETS/classification') 76 | torch.save(((X_train,y_train), n_pix_sqrt, 1, n_classes), "./DATASETS/classification/train.pt") 77 | torch.save(((X_test ,y_test ), n_pix_sqrt, 1, n_classes), "./DATASETS/classification/test.pt") 78 | 79 | 80 | if __name__ == '__main__': 81 | gen_regression(n_train=5000, n_test=1000, n_classes=10, input_size_sqrt=16) 82 | gen_classification(n_train=25000, n_test=5000, n_classes=10, n_pix_sqrt=16, n_inf=128) -------------------------------------------------------------------------------- /talks/LightOnAIMeetUp.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChFrenkel/DirectRandomTargetProjection/dfe02b81b33864f38cea9b6ecfdeef2c984e71d9/talks/LightOnAIMeetUp.pdf -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | ------------------------------------------------------------------------------ 5 | 6 | Copyright (C) 2019 Université catholique de Louvain (UCLouvain), Belgium. 7 | 8 | Licensed under the Apache License, Version 2.0 (the "License"); 9 | you may not use this file except in compliance with the License. 10 | You may obtain a copy of the License at 11 | 12 | http://www.apache.org/licenses/LICENSE-2.0 13 | 14 | Unless required by applicable law or agreed to in writing, software 15 | distributed under the License is distributed on an "AS IS" BASIS, 16 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | See the License for the specific language governing permissions and 18 | limitations under the License. 19 | 20 | ------------------------------------------------------------------------------ 21 | 22 | "train.py" - Initializing the network, optimizer and loss for training and testing. 23 | 24 | Project: DRTP - Direct Random Target Projection 25 | 26 | Authors: C. Frenkel and M. Lefebvre, Université catholique de Louvain (UCLouvain), 09/2019 27 | 28 | Cite/paper: C. Frenkel, M. Lefebvre and D. Bol, "Learning without feedback: 29 | Fixed random learning signals allow for feedforward training of deep neural networks," 30 | Frontiers in Neuroscience, vol. 15, no. 629892, 2021. doi: 10.3389/fnins.2021.629892 31 | 32 | ------------------------------------------------------------------------------ 33 | """ 34 | 35 | import torch.nn as nn 36 | import torch.nn.functional as F 37 | import torch.optim as optim 38 | import torch.utils.data 39 | import models 40 | from tqdm import tqdm 41 | 42 | def train(args, device, train_loader, traintest_loader, test_loader): 43 | torch.manual_seed(42) 44 | 45 | for trial in range(1,args.trials+1): 46 | # Network topology 47 | model = models.NetworkBuilder(args.topology, input_size=args.input_size, input_channels=args.input_channels, label_features=args.label_features, train_batch_size=args.batch_size, train_mode=args.train_mode, dropout=args.dropout, conv_act=args.conv_act, hidden_act=args.hidden_act, output_act=args.output_act, fc_zero_init=args.fc_zero_init, loss=args.loss, device=device) 48 | 49 | if args.cuda: 50 | model.cuda() 51 | 52 | if (args.trials > 1): 53 | print('\nIn trial {} of {}'.format(trial,args.trials)) 54 | if (trial == 1): 55 | print("=== Model ===" ) 56 | print(model) 57 | 58 | # Optimizer 59 | if args.optimizer == 'SGD': 60 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, nesterov=False) 61 | elif args.optimizer == 'NAG': 62 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, nesterov=True) 63 | elif args.optimizer == 'Adam': 64 | optimizer = optim.Adam(model.parameters(), lr=args.lr) 65 | elif args.optimizer == 'RMSprop': 66 | optimizer = optim.RMSprop(model.parameters(), lr=args.lr) 67 | else: 68 | raise NameError("=== ERROR: optimizer " + str(args.optimizer) + " not supported") 69 | 70 | # Loss function 71 | if args.loss == 'MSE': 72 | loss = (F.mse_loss, (lambda l : l)) 73 | elif args.loss == 'BCE': 74 | loss = (F.binary_cross_entropy, (lambda l : l)) 75 | elif args.loss == 'CE': 76 | loss = (F.cross_entropy, (lambda l : torch.max(l, 1)[1])) 77 | else: 78 | raise NameError("=== ERROR: loss " + str(args.loss) + " not supported") 79 | 80 | print("\n\n=== Starting model training with %d epochs:\n" % (args.epochs,)) 81 | for epoch in range(1, args.epochs + 1): 82 | # Training 83 | train_epoch(args, model, device, train_loader, optimizer, loss) 84 | 85 | # Compute accuracy on training and testing set 86 | print("\nSummary of epoch %d:" % (epoch)) 87 | test_epoch(args, model, device, traintest_loader, loss, 'Train') 88 | test_epoch(args, model, device, test_loader, loss, 'Test') 89 | 90 | 91 | def train_epoch(args, model, device, train_loader, optimizer, loss): 92 | model.train() 93 | 94 | if args.freeze_conv_layers: 95 | for i in range(model.conv_to_fc): 96 | for param in model.layers[i].conv.parameters(): 97 | param.requires_grad = False 98 | 99 | for batch_idx, (data, label) in enumerate(tqdm(train_loader)): 100 | data, label = data.to(device), label.to(device) 101 | if args.regression: 102 | targets = label 103 | else: 104 | targets = torch.zeros(label.shape[0], args.label_features, device=device).scatter_(1, label.unsqueeze(1), 1.0) 105 | optimizer.zero_grad() 106 | output = model(data, targets) 107 | loss_val = loss[0](output, loss[1](targets)) 108 | loss_val.backward() 109 | optimizer.step() 110 | 111 | 112 | def test_epoch(args, model, device, test_loader, loss, phase): 113 | model.eval() 114 | 115 | test_loss, correct = 0, 0 116 | len_dataset = len(test_loader.dataset) 117 | 118 | with torch.no_grad(): 119 | for data, label in test_loader: 120 | data, label = data.to(device), label.to(device) 121 | if args.regression: 122 | targets = label 123 | else: 124 | targets = torch.zeros(label.shape[0], args.label_features, device=device).scatter_(1, label.unsqueeze(1), 1.0) 125 | output = model(data, None) 126 | test_loss += loss[0](output, loss[1](targets), reduction='sum').item() 127 | pred = output.max(1, keepdim=True)[1] 128 | if not args.regression: 129 | correct += pred.eq(label.view_as(pred)).sum().item() 130 | 131 | loss = test_loss / len_dataset 132 | if not args.regression: 133 | acc = 100. * correct / len_dataset 134 | print("\t[%5sing set] Loss: %6f, Accuracy: %6.2f%%" % (phase, loss, acc)) 135 | else: 136 | print("\t[%5sing set] Loss: %6f" % (phase, loss)) -------------------------------------------------------------------------------- /training_algorithms_topologies.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChFrenkel/DirectRandomTargetProjection/dfe02b81b33864f38cea9b6ecfdeef2c984e71d9/training_algorithms_topologies.png --------------------------------------------------------------------------------