├── LICENSE ├── README.md ├── combine_nets.py ├── data ├── .DS_Store ├── cifar10 │ └── .DS_Store └── mnist │ └── .DS_Store ├── datasets.py ├── experiment.py ├── logs └── .DS_Store ├── matching ├── .DS_Store ├── __init__.py ├── pfnm.py └── pfnm_communication.py └── model.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Probabilistic Federated Neural Matching 2 | 3 | 4 | This is the code accompanying the ICML 2019 paper "Bayesian Nonparametric Federated Learning of Neural Networks" 5 | Paper link: [http://proceedings.mlr.press/v97/yurochkin19a.html] 6 | 7 | #### Requirements to run the code: 8 | --- 9 | 10 | 1. Python 3.6 11 | 2. PyTorch 0.4 12 | 3. Scikit-learn 13 | 4. Matplotlib 14 | 5. Numpy 15 | 16 | 17 | #### Important source files: 18 | --- 19 | 20 | 1. `experiment.py`: Main entryway to the code. Used to run all experiments 21 | 2. `matching/pfnm.py`: Contains the PFNM matching code for single communication federated learning 22 | 3. `matching/pfnm_communication.py`: Contains the PFNM matching code for multiple communication federated learning 23 | 24 | 25 | #### Sample Commands: 26 | --- 27 | 28 | 1. MNIST Heterogenous 10 batches 29 | 30 | `python experiment.py --logdir "logs/mnist_test" --dataset "mnist" --datadir "data/mnist/" --net_config "784, 100, 10" --n_nets 10 --partition "hetero-dir" --experiment "u-ensemble,pdm,pdm_iterative" --lr 0.01 --epochs 10 --reg 1e-6 --communication_rounds 5 --lr_decay 0.99 --iter_epochs 5` 31 | 32 | 2. CIFAR-10 Heterogenous 10 batches 33 | 34 | `python experiment.py --logdir "logs/cifar10_test" --dataset "cifar10" --datadir "data/cifar10/" --net_config "3072, 100, 10" --n_nets 10 --partition "hetero-dir" --experiment "u-ensemble,pdm,pdm_iterative" --lr 0.001 --epochs 10 --reg 1e-5 --communication_rounds 5 --lr_decay 0.99 --iter_epochs 5` 35 | 36 | 37 | #### Important arguments: 38 | --- 39 | 40 | 41 | The following arguments to the PFNM file control the important parameters of the experiment 42 | 43 | 1. `net_config`: Defines the local network architecture. CSV of sizes. Ex: "784, 100, 100, 10" defines a 2-layer network with 100 neurons in each layer. 44 | 2. `n_nets`: Number of local networks. This is denoted by "J" in the paper 45 | 3. `partition`: Kind of data partition. Values: homo, hetero-dir 46 | 4. `experiments`: Defines which experiments will be executed. Values: u-ensemble (Uniform ensemble), pdm (PFNM matching), pdm_iterative (PFNM with extra communications) 47 | 5. `communication_rounds`: How many rounds of communication between the local learner and the master network in the case of PFNM with multiple communications. 48 | 49 | 50 | #### Output: 51 | --- 52 | 53 | Some of the output is printed on the terminal. However, majority of the information is logged to a log file in the specified log folder. 54 | 55 | 56 | ### Citing PFNM 57 | --- 58 | 59 | ``` 60 | @InProceedings{pmlr-v97-yurochkin19a, 61 | title = {{B}ayesian Nonparametric Federated Learning of Neural Networks}, 62 | author = {Yurochkin, Mikhail and Agarwal, Mayank and Ghosh, Soumya and Greenewald, Kristjan and Hoang, Nghia and Khazaeni, Yasaman}, 63 | booktitle = {Proceedings of the 36th International Conference on Machine Learning}, 64 | pages = {7252--7261}, 65 | year = {2019}, 66 | editor = {Chaudhuri, Kamalika and Salakhutdinov, Ruslan}, 67 | volume = {97}, 68 | series = {Proceedings of Machine Learning Research}, 69 | address = {Long Beach, California, USA}, 70 | month = {09--15 Jun}, 71 | publisher = {PMLR}, 72 | pdf = {http://proceedings.mlr.press/v97/yurochkin19a/yurochkin19a.pdf}, 73 | url = {http://proceedings.mlr.press/v97/yurochkin19a.html}, 74 | abstract = {In federated learning problems, data is scattered across different servers and exchanging or pooling it is often impractical or prohibited. We develop a Bayesian nonparametric framework for federated learning with neural networks. Each data server is assumed to provide local neural network weights, which are modeled through our framework. We then develop an inference approach that allows us to synthesize a more expressive global network without additional supervision, data pooling and with as few as a single communication round. We then demonstrate the efficacy of our approach on federated learning problems simulated from two popular image classification datasets.} 75 | } 76 | ``` 77 | -------------------------------------------------------------------------------- /combine_nets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | from model import FcNet 5 | 6 | from matching.pfnm import layer_group_descent as pdm_multilayer_group_descent 7 | from matching.pfnm_communication import layer_group_descent as pdm_iterative_layer_group_descent 8 | from matching.pfnm_communication import build_init as pdm_build_init 9 | 10 | from itertools import product 11 | from sklearn.metrics import confusion_matrix 12 | 13 | def prepare_weight_matrix(n_classes, weights: dict): 14 | weights_list = {} 15 | 16 | for net_i, cls_cnts in weights.items(): 17 | cls = np.array(list(cls_cnts.keys())) 18 | cnts = np.array(list(cls_cnts.values())) 19 | weights_list[net_i] = np.array([0] * n_classes, dtype=np.float32) 20 | weights_list[net_i][cls] = cnts 21 | weights_list[net_i] = torch.from_numpy(weights_list[net_i]).view(1, -1) 22 | 23 | return weights_list 24 | 25 | 26 | def prepare_uniform_weights(n_classes, net_cnt, fill_val=1): 27 | weights_list = {} 28 | 29 | for net_i in range(net_cnt): 30 | temp = np.array([fill_val] * n_classes, dtype=np.float32) 31 | weights_list[net_i] = torch.from_numpy(temp).view(1, -1) 32 | 33 | return weights_list 34 | 35 | 36 | def prepare_sanity_weights(n_classes, net_cnt): 37 | return prepare_uniform_weights(n_classes, net_cnt, fill_val=0) 38 | 39 | 40 | def normalize_weights(weights): 41 | Z = np.array([]) 42 | eps = 1e-6 43 | weights_norm = {} 44 | 45 | for _, weight in weights.items(): 46 | if len(Z) == 0: 47 | Z = weight.data.numpy() 48 | else: 49 | Z = Z + weight.data.numpy() 50 | 51 | for mi, weight in weights.items(): 52 | weights_norm[mi] = weight / torch.from_numpy(Z + eps) 53 | 54 | return weights_norm 55 | 56 | 57 | def get_weighted_average_pred(models: list, weights: dict, x): 58 | out_weighted = None 59 | 60 | # Compute the predictions 61 | for model_i, model in enumerate(models): 62 | out = F.softmax(model(x), dim=-1) # (N, C) 63 | 64 | if out_weighted is None: 65 | out_weighted = (out * weights[model_i]) 66 | else: 67 | out_weighted += (out * weights[model_i]) 68 | 69 | return out_weighted 70 | 71 | 72 | def compute_ensemble_accuracy(models: list, dataloader, n_classes, train_cls_counts=None, uniform_weights=False, sanity_weights=False): 73 | 74 | correct, total = 0, 0 75 | true_labels_list, pred_labels_list = np.array([]), np.array([]) 76 | 77 | was_training = [False]*len(models) 78 | for i, model in enumerate(models): 79 | if model.training: 80 | was_training[i] = True 81 | model.eval() 82 | 83 | if uniform_weights is True: 84 | weights_list = prepare_uniform_weights(n_classes, len(models)) 85 | elif sanity_weights is True: 86 | weights_list = prepare_sanity_weights(n_classes, len(models)) 87 | else: 88 | weights_list = prepare_weight_matrix(n_classes, train_cls_counts) 89 | 90 | weights_norm = normalize_weights(weights_list) 91 | 92 | with torch.no_grad(): 93 | for batch_idx, (x, target) in enumerate(dataloader): 94 | target = target.long() 95 | out = get_weighted_average_pred(models, weights_norm, x) 96 | 97 | _, pred_label = torch.max(out, 1) 98 | 99 | total += x.data.size()[0] 100 | correct += (pred_label == target.data).sum().item() 101 | 102 | pred_labels_list = np.append(pred_labels_list, pred_label.numpy()) 103 | true_labels_list = np.append(true_labels_list, target.data.numpy()) 104 | 105 | print(correct, total) 106 | 107 | conf_matrix = confusion_matrix(true_labels_list, pred_labels_list) 108 | 109 | for i, model in enumerate(models): 110 | if was_training[i]: 111 | model.train() 112 | 113 | return correct / float(total), conf_matrix 114 | 115 | 116 | def pdm_prepare_weights(nets): 117 | weights = [] 118 | 119 | for net_i, net in enumerate(nets): 120 | layer_i = 0 121 | statedict = net.state_dict() 122 | net_weights = [] 123 | while True: 124 | 125 | if ('layers.%d.weight' % layer_i) not in statedict.keys(): 126 | break 127 | 128 | layer_weight = statedict['layers.%d.weight' % layer_i].numpy().T 129 | layer_bias = statedict['layers.%d.bias' % layer_i].numpy() 130 | 131 | net_weights.extend([layer_weight, layer_bias]) 132 | layer_i += 1 133 | 134 | weights.append(net_weights) 135 | 136 | return weights 137 | 138 | 139 | def pdm_prepare_freq(cls_freqs, n_classes): 140 | freqs = [] 141 | 142 | for net_i in sorted(cls_freqs.keys()): 143 | net_freqs = [0] * n_classes 144 | 145 | for cls_i in cls_freqs[net_i]: 146 | net_freqs[cls_i] = cls_freqs[net_i][cls_i] 147 | 148 | freqs.append(np.array(net_freqs)) 149 | 150 | return freqs 151 | 152 | def compute_pdm_net_accuracy(weights, train_dl, test_dl, n_classes): 153 | 154 | dims = [] 155 | dims.append(weights[0].shape[0]) 156 | 157 | for i in range(0, len(weights), 2): 158 | dims.append(weights[i].shape[1]) 159 | 160 | ip_dim = dims[0] 161 | op_dim = dims[-1] 162 | hidden_dims = dims[1:-1] 163 | 164 | pdm_net = FcNet(ip_dim, hidden_dims, op_dim) 165 | statedict = pdm_net.state_dict() 166 | 167 | # print(pdm_net) 168 | 169 | i = 0 170 | layer_i = 0 171 | while i < len(weights): 172 | weight = weights[i] 173 | i += 1 174 | bias = weights[i] 175 | i += 1 176 | 177 | statedict['layers.%d.weight' % layer_i] = torch.from_numpy(weight.T) 178 | statedict['layers.%d.bias' % layer_i] = torch.from_numpy(bias) 179 | layer_i += 1 180 | 181 | pdm_net.load_state_dict(statedict) 182 | 183 | train_acc, conf_matrix_train = compute_ensemble_accuracy([pdm_net], train_dl, n_classes, uniform_weights=True) 184 | test_acc, conf_matrix_test = compute_ensemble_accuracy([pdm_net], test_dl, n_classes, uniform_weights=True) 185 | 186 | return train_acc, test_acc, conf_matrix_train, conf_matrix_test 187 | 188 | 189 | def compute_pdm_matching_multilayer(models, train_dl, test_dl, cls_freqs, n_classes, sigma0=None, it=0, sigma=None, gamma=None): 190 | batch_weights = pdm_prepare_weights(models) 191 | batch_freqs = pdm_prepare_freq(cls_freqs, n_classes) 192 | res = {} 193 | best_test_acc, best_train_acc, best_weights, best_sigma, best_gamma, best_sigma0 = -1, -1, None, -1, -1, -1 194 | 195 | gammas = [1.0, 1e-3, 50.0] if gamma is None else [gamma] 196 | sigmas = [1.0, 0.1, 0.5] if sigma is None else [sigma] 197 | sigma0s = [1.0, 10.0] if sigma0 is None else [sigma0] 198 | 199 | for gamma, sigma, sigma0 in product(gammas, sigmas, sigma0s): 200 | print("Gamma: ", gamma, "Sigma: ", sigma, "Sigma0: ", sigma0) 201 | 202 | hungarian_weights = pdm_multilayer_group_descent( 203 | batch_weights, sigma0_layers=sigma0, sigma_layers=sigma, batch_frequencies=batch_freqs, it=it, gamma_layers=gamma 204 | ) 205 | 206 | train_acc, test_acc, _, _ = compute_pdm_net_accuracy(hungarian_weights, train_dl, test_dl, n_classes) 207 | 208 | key = (sigma0, sigma, gamma) 209 | res[key] = {} 210 | res[key]['shapes'] = list(map(lambda x: x.shape, hungarian_weights)) 211 | res[key]['train_accuracy'] = train_acc 212 | res[key]['test_accuracy'] = test_acc 213 | 214 | print('Sigma0: %s. Sigma: %s. Shapes: %s, Accuracy: %f' % ( 215 | str(sigma0), str(sigma), str(res[key]['shapes']), test_acc)) 216 | 217 | if train_acc > best_train_acc: 218 | best_test_acc = test_acc 219 | best_train_acc = train_acc 220 | best_weights = hungarian_weights 221 | best_sigma = sigma 222 | best_gamma = gamma 223 | best_sigma0 = sigma0 224 | 225 | print('Best sigma0: %f, Best sigma: %f, Best Gamma: %f, Best accuracy (Test): %f. Training acc: %f' % ( 226 | best_sigma0, best_sigma, best_gamma, best_test_acc, best_train_acc)) 227 | 228 | return (best_sigma0, best_sigma, best_gamma, best_test_acc, best_train_acc, best_weights, res) 229 | 230 | 231 | def compute_iterative_pdm_matching(models, train_dl, test_dl, cls_freqs, n_classes, sigma, sigma0, gamma, it, old_assignment=None): 232 | 233 | batch_weights = pdm_prepare_weights(models) 234 | batch_freqs = pdm_prepare_freq(cls_freqs, n_classes) 235 | 236 | hungarian_weights, assignments = pdm_iterative_layer_group_descent( 237 | batch_weights, batch_freqs, sigma_layers=sigma, sigma0_layers=sigma0, gamma_layers=gamma, it=it, assignments_old=old_assignment 238 | ) 239 | 240 | train_acc, test_acc, conf_matrix_train, conf_matrix_test = compute_pdm_net_accuracy(hungarian_weights, train_dl, test_dl, n_classes) 241 | 242 | batch_weights_new = [pdm_build_init(hungarian_weights, assignments, j) for j in range(len(models))] 243 | matched_net_shapes = list(map(lambda x: x.shape, hungarian_weights)) 244 | 245 | return batch_weights_new, train_acc, test_acc, matched_net_shapes, assignments, hungarian_weights, conf_matrix_train, conf_matrix_test 246 | 247 | 248 | def flatten_weights(weights_j): 249 | flat_weights = np.hstack((weights_j[0].T, weights_j[1].reshape(-1,1), weights_j[2])) 250 | return flat_weights 251 | 252 | 253 | def build_network(clusters, batch_weights, D): 254 | cluster_network = [clusters[:,:D].T, clusters[:,D].T, clusters[:,(D+1):]] 255 | bias = np.mean(batch_weights, axis=0)[-1] 256 | cluster_network += [bias] 257 | return cluster_network -------------------------------------------------------------------------------- /data/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/probabilistic-federated-neural-matching/f44cf4281944fae46cdce1b8bc7cde3e7c44bd70/data/.DS_Store -------------------------------------------------------------------------------- /data/cifar10/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/probabilistic-federated-neural-matching/f44cf4281944fae46cdce1b8bc7cde3e7c44bd70/data/cifar10/.DS_Store -------------------------------------------------------------------------------- /data/mnist/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/probabilistic-federated-neural-matching/f44cf4281944fae46cdce1b8bc7cde3e7c44bd70/data/mnist/.DS_Store -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | from PIL import Image 3 | import numpy as np 4 | from torchvision.datasets import MNIST, CIFAR10 5 | 6 | class MNIST_truncated(data.Dataset): 7 | 8 | def __init__(self, root, dataidxs=None, train=True, transform=None, target_transform=None, download=False): 9 | 10 | self.root = root 11 | self.dataidxs = dataidxs 12 | self.train = train 13 | self.transform = transform 14 | self.target_transform = target_transform 15 | self.download = download 16 | 17 | self.data, self.target = self.__build_truncated_dataset__() 18 | 19 | def __build_truncated_dataset__(self): 20 | 21 | mnist_dataobj = MNIST(self.root, self.train, self.transform, self.target_transform, self.download) 22 | 23 | data = mnist_dataobj.data 24 | target = mnist_dataobj.targets 25 | 26 | if self.dataidxs is not None: 27 | data = data[self.dataidxs] 28 | target = target[self.dataidxs] 29 | 30 | return data, target 31 | 32 | def __getitem__(self, index): 33 | """ 34 | Args: 35 | index (int): Index 36 | 37 | Returns: 38 | tuple: (image, target) where target is index of the target class. 39 | """ 40 | img, target = self.data[index], self.target[index] 41 | 42 | # doing this so that it is consistent with all other datasets 43 | # to return a PIL Image 44 | img = Image.fromarray(img.numpy(), mode='L') 45 | 46 | if self.transform is not None: 47 | img = self.transform(img) 48 | 49 | if self.target_transform is not None: 50 | target = self.target_transform(target) 51 | 52 | return img, target 53 | 54 | def __len__(self): 55 | return len(self.data) 56 | 57 | 58 | class CIFAR10_truncated(data.Dataset): 59 | 60 | def __init__(self, root, dataidxs=None, train=True, transform=None, target_transform=None, download=False): 61 | 62 | self.root = root 63 | self.dataidxs = dataidxs 64 | self.train = train 65 | self.transform = transform 66 | self.target_transform = target_transform 67 | self.download = download 68 | 69 | self.data, self.target = self.__build_truncated_dataset__() 70 | 71 | def __build_truncated_dataset__(self): 72 | 73 | cifar_dataobj = CIFAR10(self.root, self.train, self.transform, self.target_transform, self.download) 74 | 75 | data = np.array(cifar_dataobj.data) 76 | target = np.array(cifar_dataobj.targets) 77 | 78 | if self.dataidxs is not None: 79 | data = data[self.dataidxs] 80 | target = target[self.dataidxs] 81 | 82 | return data, target 83 | 84 | def __getitem__(self, index): 85 | """ 86 | Args: 87 | index (int): Index 88 | 89 | Returns: 90 | tuple: (image, target) where target is index of the target class. 91 | """ 92 | img, target = self.data[index], self.target[index] 93 | 94 | if self.transform is not None: 95 | img = self.transform(img) 96 | 97 | if self.target_transform is not None: 98 | target = self.target_transform(target) 99 | 100 | return img, target 101 | 102 | def __len__(self): 103 | return len(self.data) -------------------------------------------------------------------------------- /experiment.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import json 4 | import numpy as np 5 | import torch 6 | import torch.optim as optim 7 | import torch.nn as nn 8 | import logging 9 | import torchvision.transforms as transforms 10 | import torch.utils.data as data 11 | from itertools import product 12 | import copy 13 | from sklearn.metrics import confusion_matrix 14 | 15 | from model import FcNet 16 | from datasets import MNIST_truncated, CIFAR10_truncated 17 | 18 | from combine_nets import compute_ensemble_accuracy, compute_pdm_matching_multilayer, compute_iterative_pdm_matching 19 | 20 | def mkdirs(dirpath): 21 | try: 22 | os.makedirs(dirpath) 23 | except Exception as _: 24 | pass 25 | 26 | def get_parser(): 27 | 28 | parser = argparse.ArgumentParser() 29 | 30 | parser.add_argument('--logdir', type=str, required=True, help='Log directory path') 31 | parser.add_argument('--dropout_p', type=float, required=False, default=0.0, help="Dropout probability. Default=0.0") 32 | parser.add_argument('--dataset', type=str, required=True, help="Dataset [mnist/cifar10]") 33 | parser.add_argument('--datadir', type=str, required=False, default="./data/mnist", help="Data directory") 34 | parser.add_argument('--init_seed', type=int, required=False, default=0, help="Random seed") 35 | 36 | parser.add_argument('--net_config', type=lambda x: list(map(int, x.split(', ')))) 37 | 38 | parser.add_argument('--n_nets', type=int , required=True, help="Number of nets to initialize") 39 | parser.add_argument('--partition', type=str, required=True, help="Partition = homo/hetero/hetero-dir") 40 | parser.add_argument('--experiment', required=True, type=lambda s: s.split(','), help="Type of experiment to run. [none/w-ensemble/u-ensemble/pdm/all]") 41 | parser.add_argument('--trials', type=int, required=False, default=1, help="Number of trials for each run") 42 | 43 | parser.add_argument('--lr', type=float, required=True, help="Learning rate") 44 | parser.add_argument('--epochs', type=int, required=True, help="Epochs") 45 | parser.add_argument('--reg', type=float, required=True, help="L2 regularization strength") 46 | 47 | parser.add_argument('--alpha', type=float, required=False, default=0.5, help="Dirichlet distribution constant used for data partitioning") 48 | 49 | parser.add_argument('--communication_rounds', type=int, required=False, default=None, help="How many iterations of PDM matching should be done") 50 | parser.add_argument('--lr_decay', type=float, required=False, default=1.0, help="Decay LR after every PDM iterative communication") 51 | parser.add_argument('--iter_epochs', type=int, required=False, default=None, help="Epochs for PDM-iterative method") 52 | parser.add_argument('--reg_fac', type=float, required=False, default=0.0, help="Regularization factor for PDM Iter") 53 | 54 | parser.add_argument('--pdm_sig', type=float, required=False, default=1.0, help="PDM sigma param") 55 | parser.add_argument('--pdm_sig0', type=float, required=False, default=1.0, help="PDM sigma0 param") 56 | parser.add_argument('--pdm_gamma', type=float, required=False, default=1.0, help="PDM gamma param") 57 | 58 | return parser 59 | 60 | def load_mnist_data(datadir): 61 | 62 | transform = transforms.Compose([transforms.ToTensor()]) 63 | 64 | mnist_train_ds = MNIST_truncated(datadir, train=True, download=True, transform=transform) 65 | mnist_test_ds = MNIST_truncated(datadir, train=False, download=True, transform=transform) 66 | 67 | X_train, y_train = mnist_train_ds.data, mnist_train_ds.target 68 | X_test, y_test = mnist_test_ds.data, mnist_test_ds.target 69 | 70 | X_train = X_train.data.numpy() 71 | y_train = y_train.data.numpy() 72 | X_test = X_test.data.numpy() 73 | y_test = y_test.data.numpy() 74 | 75 | return (X_train, y_train, X_test, y_test) 76 | 77 | def load_cifar10_data(datadir): 78 | 79 | transform = transforms.Compose([transforms.ToTensor()]) 80 | 81 | cifar10_train_ds = CIFAR10_truncated(datadir, train=True, download=True, transform=transform) 82 | cifar10_test_ds = CIFAR10_truncated(datadir, train=False, download=True, transform=transform) 83 | 84 | X_train, y_train = cifar10_train_ds.data, cifar10_train_ds.target 85 | X_test, y_test = cifar10_test_ds.data, cifar10_test_ds.target 86 | 87 | return (X_train, y_train, X_test, y_test) 88 | 89 | 90 | def parse_class_dist(net_class_config): 91 | 92 | cls_net_map = {} 93 | 94 | for net_idx, net_classes in enumerate(net_class_config): 95 | for net_cls in net_classes: 96 | if net_cls not in cls_net_map: 97 | cls_net_map[net_cls] = [] 98 | cls_net_map[net_cls].append(net_idx) 99 | 100 | return cls_net_map 101 | 102 | def record_net_data_stats(y_train, net_dataidx_map, logdir): 103 | 104 | net_cls_counts = {} 105 | 106 | for net_i, dataidx in net_dataidx_map.items(): 107 | unq, unq_cnt = np.unique(y_train[dataidx], return_counts=True) 108 | tmp = {unq[i]: unq_cnt[i] for i in range(len(unq))} 109 | net_cls_counts[net_i] = tmp 110 | 111 | logging.debug('Data statistics: %s' % str(net_cls_counts)) 112 | 113 | return net_cls_counts 114 | 115 | 116 | def partition_data(dataset, datadir, logdir, partition, n_nets, alpha=0.5): 117 | 118 | if dataset == 'mnist': 119 | X_train, y_train, X_test, y_test = load_mnist_data(datadir) 120 | elif dataset == 'cifar10': 121 | X_train, y_train, X_test, y_test = load_cifar10_data(datadir) 122 | 123 | n_train = X_train.shape[0] 124 | 125 | if partition == "homo": 126 | idxs = np.random.permutation(n_train) 127 | batch_idxs = np.array_split(idxs, n_nets) 128 | net_dataidx_map = {i: batch_idxs[i] for i in range(n_nets)} 129 | 130 | elif partition == "hetero-dir": 131 | min_size = 0 132 | K = 10 133 | N = y_train.shape[0] 134 | net_dataidx_map = {} 135 | 136 | while min_size < 10: 137 | idx_batch = [[] for _ in range(n_nets)] 138 | for k in range(K): 139 | idx_k = np.where(y_train == k)[0] 140 | np.random.shuffle(idx_k) 141 | proportions = np.random.dirichlet(np.repeat(alpha, n_nets)) 142 | ## Balance 143 | proportions = np.array([p*(len(idx_j)> Pre-Training Training accuracy: %f' % train_acc) 232 | logging.debug('>> Pre-Training Test accuracy: %f' % test_acc) 233 | 234 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=lr, weight_decay=0.0, amsgrad=True) # L2_reg=0 because it's manually added later 235 | 236 | criterion = nn.CrossEntropyLoss() 237 | 238 | cnt = 0 239 | losses, running_losses = [], [] 240 | 241 | for epoch in range(epochs): 242 | for batch_idx, (x, target) in enumerate(train_dataloader): 243 | 244 | l2_reg = torch.zeros(1) 245 | l2_reg.requires_grad = True 246 | 247 | optimizer.zero_grad() 248 | x.requires_grad = True 249 | target.requires_grad = False 250 | target = target.long() 251 | 252 | out = net(x) 253 | loss = criterion(out, target) 254 | 255 | if reg_base_weights is None: 256 | # Apply standard L2-regularization 257 | for param in net.parameters(): 258 | l2_reg = l2_reg + 0.5 * torch.pow(param, 2).sum() 259 | else: 260 | # Apply Iterative PDM regularization 261 | for pname, param in net.named_parameters(): 262 | if "bias" in pname: 263 | continue 264 | 265 | layer_i = int(pname.split('.')[1]) 266 | 267 | if pname.split('.')[2] == "weight": 268 | weight_i = layer_i * 2 269 | transpose = True 270 | 271 | ref_param = reg_base_weights[weight_i] 272 | ref_param = ref_param.T if transpose else ref_param 273 | 274 | l2_reg = l2_reg + 0.5 * torch.pow((param - torch.from_numpy(ref_param).float()), 2).sum() 275 | 276 | loss = loss + reg * l2_reg 277 | 278 | loss.backward() 279 | optimizer.step() 280 | 281 | cnt += 1 282 | losses.append(loss.item()) 283 | 284 | logging.debug('Epoch: %d Loss: %f L2 loss: %f' % (epoch, loss.item(), reg*l2_reg)) 285 | 286 | train_acc = compute_accuracy(net, train_dataloader) 287 | test_acc, conf_matrix = compute_accuracy(net, test_dataloader, get_confusion_matrix=True) 288 | 289 | logging.debug('>> Training accuracy: %f' % train_acc) 290 | logging.debug('>> Test accuracy: %f' % test_acc) 291 | 292 | logging.debug(' ** Training complete **') 293 | 294 | return train_acc, test_acc 295 | 296 | 297 | def load_new_state(nets, new_weights): 298 | 299 | for netid, net in nets.items(): 300 | 301 | statedict = net.state_dict() 302 | weights = new_weights[netid] 303 | 304 | # Load weight into the network 305 | i = 0 306 | layer_i = 0 307 | 308 | while i < len(weights): 309 | weight = weights[i] 310 | i += 1 311 | bias = weights[i] 312 | i += 1 313 | 314 | statedict['layers.%d.weight' % layer_i] = torch.from_numpy(weight.T) 315 | statedict['layers.%d.bias' % layer_i] = torch.from_numpy(bias) 316 | layer_i += 1 317 | 318 | net.load_state_dict(statedict) 319 | 320 | return nets 321 | 322 | def run_exp(): 323 | 324 | parser = get_parser() 325 | args = parser.parse_args() 326 | 327 | mkdirs(args.logdir) 328 | with open(os.path.join(args.logdir, 'experiment_arguments.json'), 'w') as f: 329 | json.dump(str(args), f) 330 | 331 | logging.basicConfig( 332 | filename=os.path.join(args.logdir, 'experiment_log-%d-%d.log' % (args.init_seed, args.trials)), 333 | format='%(asctime)s %(levelname)-8s %(message)s', 334 | datefmt='%m-%d %H:%M', level=logging.DEBUG, filemode='w') 335 | 336 | logging.debug("Experiment arguments: %s" % str(args)) 337 | 338 | for trial in range(args.trials): 339 | 340 | seed = trial + args.init_seed 341 | 342 | print("Executing Trial %d " % trial) 343 | logging.debug("#" * 100) 344 | logging.debug("Executing Trial %d with seed %d" % (trial, seed)) 345 | 346 | np.random.seed(seed) 347 | torch.manual_seed(seed) 348 | 349 | print("Partitioning data") 350 | X_train, y_train, X_test, y_test, net_dataidx_map, traindata_cls_counts = partition_data( 351 | args.dataset, args.datadir, args.logdir, args.partition, args.n_nets, args.alpha) 352 | 353 | n_classes = len(np.unique(y_train)) 354 | 355 | print("Initializing nets") 356 | nets = init_nets(args.net_config, args.dropout_p, args.n_nets) 357 | 358 | local_train_accs = [] 359 | local_test_accs = [] 360 | for net_id, net in nets.items(): 361 | dataidxs = net_dataidx_map[net_id] 362 | print("Training network %s. n_training: %d" % (str(net_id), len(dataidxs))) 363 | 364 | train_dl, test_dl = get_dataloader(args.dataset, args.datadir, 32, 32, dataidxs) 365 | trainacc, testacc = train_net(net_id, net, train_dl, test_dl, args.epochs, args.lr, args.reg) 366 | 367 | local_train_accs.append(trainacc) 368 | local_test_accs.append(testacc) 369 | 370 | train_dl, test_dl = get_dataloader(args.dataset, args.datadir, 32, 32) 371 | 372 | logging.debug("*"*50) 373 | logging.debug("Running experiments \n") 374 | 375 | nets_list = list(nets.values()) 376 | 377 | if ("u-ensemble" in args.experiment) or ("all" in args.experiment): 378 | print("Computing Uniform ensemble accuracy") 379 | uens_train_acc, _ = compute_ensemble_accuracy(nets_list, train_dl, n_classes, uniform_weights=True) 380 | uens_test_acc, _ = compute_ensemble_accuracy(nets_list, test_dl, n_classes, uniform_weights=True) 381 | 382 | logging.debug("Uniform ensemble (Train acc): %f" % uens_train_acc) 383 | logging.debug("Uniform ensemble (Test acc): %f" % uens_test_acc) 384 | 385 | if ("pdm" in args.experiment) or ("all" in args.experiment): 386 | print("Computing hungarian matching") 387 | best_sigma0, best_sigma, best_gamma, best_test_acc, best_train_acc, best_weights, res = compute_pdm_matching_multilayer( 388 | nets_list, train_dl, test_dl, traindata_cls_counts, args.net_config[-1], it=5, sigma=args.pdm_sig, sigma0=args.pdm_sig0, gamma=args.pdm_gamma 389 | ) 390 | 391 | logging.debug("****** PDM matching ******** ") 392 | logging.debug("Best Sigma0: %s. Best sigma: %s Best gamma: %s. Best Test accuracy: %s. Train acc: %s. \n" 393 | % (str(best_sigma0), str(best_sigma), str(best_gamma), str(best_test_acc), str(best_train_acc))) 394 | 395 | logging.debug("PDM log: %s " % str(res)) 396 | 397 | if ("pdm_iterative" in args.experiment) or ("all" in args.experiment): 398 | print("Running Iterative PDM matching procedure") 399 | logging.debug("Running Iterative PDM matching procedure") 400 | 401 | sigma0s = [1.0] 402 | sigmas = [1.0] 403 | gammas = [1.0] 404 | 405 | for (sigma0, sigma, gamma) in product(sigma0s, sigmas, gammas): 406 | logging.debug("Parameter setting: sigma0 = %f, sigma = %f, gamma = %f" % (sigma0, sigma, gamma)) 407 | 408 | iter_nets = copy.deepcopy(nets) 409 | assignment = None 410 | lr_iter = args.lr 411 | reg_iter = args.reg 412 | 413 | # Run for communication rounds iterations 414 | for i, comm_round in enumerate(range(args.communication_rounds)): 415 | 416 | it = 3 417 | 418 | iter_nets_list = list(iter_nets.values()) 419 | 420 | net_weights_new, train_acc, test_acc, new_shape, assignment, hungarian_weights, \ 421 | conf_matrix_train, conf_matrix_test = compute_iterative_pdm_matching( 422 | iter_nets_list, train_dl, test_dl, traindata_cls_counts, args.net_config[-1], 423 | sigma, sigma0, gamma, it, old_assignment=assignment 424 | ) 425 | 426 | logging.debug("Communication: %d, Train acc: %f, Test acc: %f, Shapes: %s" % (comm_round, train_acc, test_acc, str(new_shape))) 427 | logging.debug('CENTRAL MODEL CONFUSION MATRIX') 428 | logging.debug('Train data confusion matrix: \n %s' % str(conf_matrix_train)) 429 | logging.debug('Test data confusion matrix: \n %s' % str(conf_matrix_test)) 430 | 431 | iter_nets = load_new_state(iter_nets, net_weights_new) 432 | 433 | expepochs = args.iter_epochs if args.iter_epochs is not None else args.epochs 434 | 435 | # Train these networks again 436 | for net_id, net in iter_nets.items(): 437 | dataidxs = net_dataidx_map[net_id] 438 | print("Training network %s. n_training: %d" % (str(net_id), len(dataidxs))) 439 | 440 | net_train_dl, net_test_dl = get_dataloader(args.dataset, args.datadir, 32, 32, dataidxs) 441 | train_net(net_id, net, net_train_dl, net_test_dl, expepochs, lr_iter, reg_iter, net_weights_new[net_id]) 442 | 443 | lr_iter *= args.lr_decay 444 | reg_iter *= args.reg_fac 445 | 446 | logging.debug("Trial %d completed" % trial) 447 | logging.debug("#"*100) 448 | 449 | if __name__ == "__main__": 450 | run_exp() 451 | -------------------------------------------------------------------------------- /logs/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/probabilistic-federated-neural-matching/f44cf4281944fae46cdce1b8bc7cde3e7c44bd70/logs/.DS_Store -------------------------------------------------------------------------------- /matching/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/probabilistic-federated-neural-matching/f44cf4281944fae46cdce1b8bc7cde3e7c44bd70/matching/.DS_Store -------------------------------------------------------------------------------- /matching/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/probabilistic-federated-neural-matching/f44cf4281944fae46cdce1b8bc7cde3e7c44bd70/matching/__init__.py -------------------------------------------------------------------------------- /matching/pfnm.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.optimize import linear_sum_assignment 3 | 4 | 5 | def row_param_cost(global_weights, weights_j_l, global_sigmas, sigma_inv_j): 6 | 7 | match_norms = ((weights_j_l + global_weights) ** 2 / (sigma_inv_j + global_sigmas)).sum(axis=1) - ( 8 | global_weights ** 2 / global_sigmas).sum(axis=1) 9 | 10 | return match_norms 11 | 12 | 13 | def compute_cost(global_weights, weights_j, global_sigmas, sigma_inv_j, prior_mean_norm, prior_inv_sigma, 14 | popularity_counts, gamma, J): 15 | 16 | Lj = weights_j.shape[0] 17 | counts = np.minimum(np.array(popularity_counts), 10) 18 | param_cost = np.array([row_param_cost(global_weights, weights_j[l], global_sigmas, sigma_inv_j) for l in range(Lj)]) 19 | param_cost += np.log(counts / (J - counts)) 20 | 21 | ## Nonparametric cost 22 | L = global_weights.shape[0] 23 | max_added = min(Lj, max(700 - L, 1)) 24 | nonparam_cost = np.outer((((weights_j + prior_mean_norm) ** 2 / (prior_inv_sigma + sigma_inv_j)).sum(axis=1) - ( 25 | prior_mean_norm ** 2 / prior_inv_sigma).sum()), np.ones(max_added)) 26 | cost_pois = 2 * np.log(np.arange(1, max_added + 1)) 27 | nonparam_cost -= cost_pois 28 | nonparam_cost += 2 * np.log(gamma / J) 29 | 30 | full_cost = np.hstack((param_cost, nonparam_cost)) 31 | return full_cost 32 | 33 | 34 | def matching_upd_j(weights_j, global_weights, sigma_inv_j, global_sigmas, prior_mean_norm, prior_inv_sigma, 35 | popularity_counts, gamma, J): 36 | 37 | L = global_weights.shape[0] 38 | 39 | full_cost = compute_cost(global_weights, weights_j, global_sigmas, sigma_inv_j, prior_mean_norm, prior_inv_sigma, 40 | popularity_counts, gamma, J) 41 | 42 | row_ind, col_ind = linear_sum_assignment(-full_cost) 43 | 44 | assignment_j = [] 45 | 46 | new_L = L 47 | 48 | for l, i in zip(row_ind, col_ind): 49 | if i < L: 50 | popularity_counts[i] += 1 51 | assignment_j.append(i) 52 | global_weights[i] += weights_j[l] 53 | global_sigmas[i] += sigma_inv_j 54 | else: # new neuron 55 | popularity_counts += [1] 56 | assignment_j.append(new_L) 57 | new_L += 1 58 | global_weights = np.vstack((global_weights, prior_mean_norm + weights_j[l])) 59 | global_sigmas = np.vstack((global_sigmas, prior_inv_sigma + sigma_inv_j)) 60 | 61 | return global_weights, global_sigmas, popularity_counts, assignment_j 62 | 63 | 64 | def objective(global_weights, global_sigmas): 65 | obj = ((global_weights) ** 2 / global_sigmas).sum() 66 | return obj 67 | 68 | 69 | def patch_weights(w_j, L_next, assignment_j_c): 70 | if assignment_j_c is None: 71 | return w_j 72 | new_w_j = np.zeros((w_j.shape[0], L_next)) 73 | new_w_j[:, assignment_j_c] = w_j 74 | return new_w_j 75 | 76 | 77 | def process_softmax_bias(batch_weights, last_layer_const, sigma, sigma0): 78 | J = len(batch_weights) 79 | sigma_bias = sigma 80 | sigma0_bias = sigma0 81 | mu0_bias = 0.1 82 | softmax_bias = [batch_weights[j][-1] for j in range(J)] 83 | softmax_inv_sigma = [s / sigma_bias for s in last_layer_const] 84 | softmax_bias = sum([b * s for b, s in zip(softmax_bias, softmax_inv_sigma)]) + mu0_bias / sigma0_bias 85 | softmax_inv_sigma = 1 / sigma0_bias + sum(softmax_inv_sigma) 86 | return softmax_bias, softmax_inv_sigma 87 | 88 | 89 | def match_layer(weights_bias, sigma_inv_layer, mean_prior, sigma_inv_prior, gamma, it): 90 | J = len(weights_bias) 91 | 92 | group_order = sorted(range(J), key=lambda x: -weights_bias[x].shape[0]) 93 | 94 | batch_weights_norm = [w * s for w, s in zip(weights_bias, sigma_inv_layer)] 95 | prior_mean_norm = mean_prior * sigma_inv_prior 96 | 97 | global_weights = prior_mean_norm + batch_weights_norm[group_order[0]] 98 | global_sigmas = np.outer(np.ones(global_weights.shape[0]), sigma_inv_prior + sigma_inv_layer[group_order[0]]) 99 | 100 | popularity_counts = [1] * global_weights.shape[0] 101 | 102 | assignment = [[] for _ in range(J)] 103 | 104 | assignment[group_order[0]] = list(range(global_weights.shape[0])) 105 | 106 | ## Initialize 107 | for j in group_order[1:]: 108 | global_weights, global_sigmas, popularity_counts, assignment_j = matching_upd_j(batch_weights_norm[j], 109 | global_weights, 110 | sigma_inv_layer[j], 111 | global_sigmas, prior_mean_norm, 112 | sigma_inv_prior, 113 | popularity_counts, gamma, J) 114 | assignment[j] = assignment_j 115 | 116 | ## Iterate over groups 117 | for iteration in range(it): 118 | random_order = np.random.permutation(J) 119 | for j in random_order: # random_order: 120 | to_delete = [] 121 | ## Remove j 122 | Lj = len(assignment[j]) 123 | for l, i in sorted(zip(range(Lj), assignment[j]), key=lambda x: -x[1]): 124 | popularity_counts[i] -= 1 125 | if popularity_counts[i] == 0: 126 | del popularity_counts[i] 127 | to_delete.append(i) 128 | for j_clean in range(J): 129 | for idx, l_ind in enumerate(assignment[j_clean]): 130 | if i < l_ind and j_clean != j: 131 | assignment[j_clean][idx] -= 1 132 | elif i == l_ind and j_clean != j: 133 | print('Warning - weird unmatching') 134 | else: 135 | global_weights[i] = global_weights[i] - batch_weights_norm[j][l] 136 | global_sigmas[i] -= sigma_inv_layer[j] 137 | 138 | global_weights = np.delete(global_weights, to_delete, axis=0) 139 | global_sigmas = np.delete(global_sigmas, to_delete, axis=0) 140 | 141 | ## Match j 142 | global_weights, global_sigmas, popularity_counts, assignment_j = matching_upd_j(batch_weights_norm[j], 143 | global_weights, 144 | sigma_inv_layer[j], 145 | global_sigmas, 146 | prior_mean_norm, 147 | sigma_inv_prior, 148 | popularity_counts, gamma, J) 149 | assignment[j] = assignment_j 150 | 151 | print('Number of global neurons is %d, gamma %f' % (global_weights.shape[0], gamma)) 152 | 153 | return assignment, global_weights, global_sigmas 154 | 155 | 156 | def layer_group_descent(batch_weights, batch_frequencies, sigma_layers, sigma0_layers, gamma_layers, it): 157 | 158 | n_layers = int(len(batch_weights[0]) / 2) 159 | 160 | if type(sigma_layers) is not list: 161 | sigma_layers = (n_layers - 1) * [sigma_layers] 162 | if type(sigma0_layers) is not list: 163 | sigma0_layers = (n_layers - 1) * [sigma0_layers] 164 | if type(gamma_layers) is not list: 165 | gamma_layers = (n_layers - 1) * [gamma_layers] 166 | 167 | last_layer_const = [] 168 | total_freq = sum(batch_frequencies) 169 | for f in batch_frequencies: 170 | last_layer_const.append(f / total_freq) 171 | 172 | J = len(batch_weights) 173 | D = batch_weights[0][0].shape[0] 174 | sigma_bias_layers = sigma_layers 175 | sigma0_bias_layers = sigma0_layers 176 | mu0 = 0. 177 | mu0_bias = 0.1 178 | assignment_c = [None for j in range(J)] 179 | L_next = None 180 | 181 | ## Group descent for layer 182 | for c in range(1, n_layers)[::-1]: 183 | sigma = sigma_layers[c - 1] 184 | sigma_bias = sigma_bias_layers[c - 1] 185 | gamma = gamma_layers[c - 1] 186 | sigma0 = sigma0_layers[c - 1] 187 | sigma0_bias = sigma0_bias_layers[c - 1] 188 | if c == (n_layers - 1) and n_layers > 2: 189 | weights_bias = [np.hstack((batch_weights[j][c * 2 - 1].reshape(-1, 1), batch_weights[j][c * 2])) for j in 190 | range(J)] 191 | sigma_inv_prior = np.array([1 / sigma0_bias] + (weights_bias[0].shape[1] - 1) * [1 / sigma0]) 192 | mean_prior = np.array([mu0_bias] + (weights_bias[0].shape[1] - 1) * [mu0]) 193 | sigma_inv_layer = [np.array([1 / sigma_bias] + [y / sigma for y in last_layer_const[j]]) for j in range(J)] 194 | elif c > 1: 195 | weights_bias = [np.hstack((batch_weights[j][c * 2 - 1].reshape(-1, 1), 196 | patch_weights(batch_weights[j][c * 2], L_next, assignment_c[j]))) for j in 197 | range(J)] 198 | sigma_inv_prior = np.array([1 / sigma0_bias] + (weights_bias[0].shape[1] - 1) * [1 / sigma0]) 199 | mean_prior = np.array([mu0_bias] + (weights_bias[0].shape[1] - 1) * [mu0]) 200 | sigma_inv_layer = [np.array([1 / sigma_bias] + (weights_bias[j].shape[1] - 1) * [1 / sigma]) for j in 201 | range(J)] 202 | else: 203 | weights_bias = [np.hstack((batch_weights[j][0].T, batch_weights[j][c * 2 - 1].reshape(-1, 1), 204 | patch_weights(batch_weights[j][c * 2], L_next, assignment_c[j]))) for j in 205 | range(J)] 206 | sigma_inv_prior = np.array( 207 | D * [1 / sigma0] + [1 / sigma0_bias] + (weights_bias[0].shape[1] - 1 - D) * [1 / sigma0]) 208 | mean_prior = np.array(D * [mu0] + [mu0_bias] + (weights_bias[0].shape[1] - 1 - D) * [mu0]) 209 | if n_layers == 2: 210 | sigma_inv_layer = [ 211 | np.array(D * [1 / sigma] + [1 / sigma_bias] + [y / sigma for y in last_layer_const[j]]) for j in 212 | range(J)] 213 | else: 214 | sigma_inv_layer = [ 215 | np.array(D * [1 / sigma] + [1 / sigma_bias] + (weights_bias[j].shape[1] - 1 - D) * [1 / sigma]) for 216 | j in range(J)] 217 | 218 | assignment_c, global_weights_c, global_sigmas_c = match_layer(weights_bias, sigma_inv_layer, mean_prior, 219 | sigma_inv_prior, gamma, it) 220 | L_next = global_weights_c.shape[0] 221 | 222 | if c == (n_layers - 1) and n_layers > 2: 223 | softmax_bias, softmax_inv_sigma = process_softmax_bias(batch_weights, last_layer_const, sigma, sigma0) 224 | global_weights_out = [global_weights_c[:, 0], global_weights_c[:, 1:], softmax_bias] 225 | global_inv_sigmas_out = [global_sigmas_c[:, 0], global_sigmas_c[:, 1:], softmax_inv_sigma] 226 | elif c > 1: 227 | global_weights_out = [global_weights_c[:, 0], global_weights_c[:, 1:]] + global_weights_out 228 | global_inv_sigmas_out = [global_sigmas_c[:, 0], global_sigmas_c[:, 1:]] + global_inv_sigmas_out 229 | else: 230 | if n_layers == 2: 231 | softmax_bias, softmax_inv_sigma = process_softmax_bias(batch_weights, last_layer_const, sigma, sigma0) 232 | global_weights_out = [softmax_bias] 233 | global_inv_sigmas_out = [softmax_inv_sigma] 234 | global_weights_out = [global_weights_c[:, :D].T, global_weights_c[:, D], 235 | global_weights_c[:, (D + 1):]] + global_weights_out 236 | global_inv_sigmas_out = [global_sigmas_c[:, :D].T, global_sigmas_c[:, D], 237 | global_sigmas_c[:, (D + 1):]] + global_inv_sigmas_out 238 | 239 | map_out = [g_w / g_s for g_w, g_s in zip(global_weights_out, global_inv_sigmas_out)] 240 | 241 | return map_out -------------------------------------------------------------------------------- /matching/pfnm_communication.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.optimize import linear_sum_assignment 3 | 4 | 5 | def row_param_cost(global_weights, weights_j_l, global_sigmas, sigma_inv_j): 6 | 7 | match_norms = ((weights_j_l + global_weights) ** 2 / (sigma_inv_j + global_sigmas)).sum(axis=1) - ( 8 | global_weights ** 2 / global_sigmas).sum(axis=1) 9 | 10 | return match_norms 11 | 12 | 13 | def compute_cost(global_weights, weights_j, global_sigmas, sigma_inv_j, prior_mean_norm, prior_inv_sigma, 14 | popularity_counts, gamma, J): 15 | 16 | Lj = weights_j.shape[0] 17 | counts = np.minimum(np.array(popularity_counts), 10) 18 | param_cost = np.array([row_param_cost(global_weights, weights_j[l], global_sigmas, sigma_inv_j) for l in range(Lj)]) 19 | param_cost += np.log(counts / (J - counts)) 20 | 21 | ## Nonparametric cost 22 | L = global_weights.shape[0] 23 | max_added = min(Lj, max(700 - L, 1)) 24 | 25 | nonparam_cost = np.outer((((weights_j + prior_mean_norm) ** 2 / (prior_inv_sigma + sigma_inv_j)).sum(axis=1) - ( 26 | prior_mean_norm ** 2 / prior_inv_sigma).sum()), np.ones(max_added)) 27 | cost_pois = 2 * np.log(np.arange(1, max_added + 1)) 28 | nonparam_cost -= cost_pois 29 | nonparam_cost += 2 * np.log(gamma / J) 30 | 31 | full_cost = np.hstack((param_cost, nonparam_cost)) 32 | 33 | return full_cost 34 | 35 | 36 | def matching_upd_j(weights_j, global_weights, sigma_inv_j, global_sigmas, prior_mean_norm, prior_inv_sigma, 37 | popularity_counts, gamma, J): 38 | 39 | L = global_weights.shape[0] 40 | 41 | full_cost = compute_cost(global_weights, weights_j, global_sigmas, sigma_inv_j, prior_mean_norm, prior_inv_sigma, 42 | popularity_counts, gamma, J) 43 | 44 | row_ind, col_ind = linear_sum_assignment(-full_cost) 45 | 46 | assignment_j = [] 47 | 48 | new_L = L 49 | 50 | for l, i in zip(row_ind, col_ind): 51 | if i < L: 52 | popularity_counts[i] += 1 53 | assignment_j.append(i) 54 | global_weights[i] += weights_j[l] 55 | global_sigmas[i] += sigma_inv_j 56 | else: # new neuron 57 | popularity_counts += [1] 58 | assignment_j.append(new_L) 59 | new_L += 1 60 | global_weights = np.vstack((global_weights, prior_mean_norm + weights_j[l])) 61 | global_sigmas = np.vstack((global_sigmas, prior_inv_sigma + sigma_inv_j)) 62 | 63 | return global_weights, global_sigmas, popularity_counts, assignment_j 64 | 65 | 66 | def objective(global_weights, global_sigmas): 67 | obj = ((global_weights) ** 2 / global_sigmas).sum() 68 | return obj 69 | 70 | 71 | def patch_weights(w_j, L_next, assignment_j_c): 72 | if assignment_j_c is None: 73 | return w_j 74 | new_w_j = np.zeros((w_j.shape[0], L_next)) 75 | new_w_j[:, assignment_j_c] = w_j 76 | return new_w_j 77 | 78 | 79 | def process_softmax_bias(batch_weights, last_layer_const, sigma, sigma0): 80 | J = len(batch_weights) 81 | sigma_bias = sigma 82 | sigma0_bias = sigma0 83 | mu0_bias = 0.1 84 | softmax_bias = [batch_weights[j][-1] for j in range(J)] 85 | softmax_inv_sigma = [s / sigma_bias for s in last_layer_const] 86 | softmax_bias = sum([b * s for b, s in zip(softmax_bias, softmax_inv_sigma)]) + mu0_bias / sigma0_bias 87 | softmax_inv_sigma = 1 / sigma0_bias + sum(softmax_inv_sigma) 88 | return softmax_bias, softmax_inv_sigma 89 | 90 | 91 | def init_from_assignments(batch_weights_norm, sigma_inv_layer, prior_mean_norm, sigma_inv_prior, assignment): 92 | L = int(max([max(a_j) for a_j in assignment])) + 1 93 | popularity_counts = [0] * L 94 | 95 | global_weights = np.outer(np.ones(L), prior_mean_norm) 96 | global_sigmas = np.outer(np.ones(L), sigma_inv_prior) 97 | 98 | for j, a_j in enumerate(assignment): 99 | for l, i in enumerate(a_j): 100 | popularity_counts[i] += 1 101 | global_weights[i] += batch_weights_norm[j][l] 102 | global_sigmas[i] += sigma_inv_layer[j] 103 | 104 | return popularity_counts, global_weights, global_sigmas 105 | 106 | 107 | def match_layer(weights_bias, sigma_inv_layer, mean_prior, sigma_inv_prior, gamma, it, assignment=None): 108 | J = len(weights_bias) 109 | 110 | group_order = sorted(range(J), key=lambda x: -weights_bias[x].shape[0]) 111 | 112 | batch_weights_norm = [w * s for w, s in zip(weights_bias, sigma_inv_layer)] 113 | prior_mean_norm = mean_prior * sigma_inv_prior 114 | 115 | if assignment is None: 116 | global_weights = prior_mean_norm + batch_weights_norm[group_order[0]] 117 | global_sigmas = np.outer(np.ones(global_weights.shape[0]), sigma_inv_prior + sigma_inv_layer[group_order[0]]) 118 | 119 | popularity_counts = [1] * global_weights.shape[0] 120 | 121 | assignment = [[] for _ in range(J)] 122 | 123 | assignment[group_order[0]] = list(range(global_weights.shape[0])) 124 | 125 | ## Initialize 126 | for j in group_order[1:]: 127 | global_weights, global_sigmas, popularity_counts, assignment_j = matching_upd_j(batch_weights_norm[j], 128 | global_weights, 129 | sigma_inv_layer[j], 130 | global_sigmas, 131 | prior_mean_norm, 132 | sigma_inv_prior, 133 | popularity_counts, gamma, J) 134 | assignment[j] = assignment_j 135 | else: 136 | popularity_counts, global_weights, global_sigmas = init_from_assignments(batch_weights_norm, sigma_inv_layer, 137 | mean_prior, sigma_inv_prior, 138 | assignment) 139 | 140 | ## Iterate over groups 141 | for iteration in range(it): 142 | random_order = np.random.permutation(J) 143 | for j in random_order: # random_order: 144 | to_delete = [] 145 | ## Remove j 146 | Lj = len(assignment[j]) 147 | for l, i in sorted(zip(range(Lj), assignment[j]), key=lambda x: -x[1]): 148 | popularity_counts[i] -= 1 149 | if popularity_counts[i] == 0: 150 | del popularity_counts[i] 151 | to_delete.append(i) 152 | for j_clean in range(J): 153 | for idx, l_ind in enumerate(assignment[j_clean]): 154 | if i < l_ind and j_clean != j: 155 | assignment[j_clean][idx] -= 1 156 | elif i == l_ind and j_clean != j: 157 | print('Warning - weird unmatching') 158 | else: 159 | global_weights[i] = global_weights[i] - batch_weights_norm[j][l] 160 | global_sigmas[i] -= sigma_inv_layer[j] 161 | 162 | global_weights = np.delete(global_weights, to_delete, axis=0) 163 | global_sigmas = np.delete(global_sigmas, to_delete, axis=0) 164 | 165 | ## Match j 166 | global_weights, global_sigmas, popularity_counts, assignment_j = matching_upd_j(batch_weights_norm[j], 167 | global_weights, 168 | sigma_inv_layer[j], 169 | global_sigmas, 170 | prior_mean_norm, 171 | sigma_inv_prior, 172 | popularity_counts, gamma, J) 173 | assignment[j] = assignment_j 174 | 175 | print('Number of global neurons is %d, gamma %f' % (global_weights.shape[0], gamma)) 176 | 177 | return assignment, global_weights, global_sigmas 178 | 179 | 180 | def layer_group_descent(batch_weights, batch_frequencies, sigma_layers, sigma0_layers, gamma_layers, it, 181 | assignments_old=None): 182 | 183 | n_layers = int(len(batch_weights[0]) / 2) 184 | J = len(batch_weights) 185 | D = batch_weights[0][0].shape[0] 186 | K = batch_weights[0][-1].shape[0] 187 | 188 | if assignments_old is None: 189 | assignments_old = (n_layers - 1) * [None] 190 | if type(sigma_layers) is not list: 191 | sigma_layers = (n_layers - 1) * [sigma_layers] 192 | if type(sigma0_layers) is not list: 193 | sigma0_layers = (n_layers - 1) * [sigma0_layers] 194 | if type(gamma_layers) is not list: 195 | gamma_layers = (n_layers - 1) * [gamma_layers] 196 | 197 | if batch_frequencies is None: 198 | last_layer_const = [np.ones(K) for _ in range(J)] 199 | else: 200 | last_layer_const = [] 201 | total_freq = sum(batch_frequencies) 202 | for f in batch_frequencies: 203 | last_layer_const.append(f / total_freq) 204 | 205 | sigma_bias_layers = sigma_layers 206 | sigma0_bias_layers = sigma0_layers 207 | mu0 = 0. 208 | mu0_bias = 0.1 209 | assignment_c = [None for j in range(J)] 210 | L_next = None 211 | assignment_all = [] 212 | 213 | ## Group descent for layer 214 | for c in range(1, n_layers)[::-1]: 215 | sigma = sigma_layers[c - 1] 216 | sigma_bias = sigma_bias_layers[c - 1] 217 | gamma = gamma_layers[c - 1] 218 | sigma0 = sigma0_layers[c - 1] 219 | sigma0_bias = sigma0_bias_layers[c - 1] 220 | if c == (n_layers - 1) and n_layers > 2: 221 | weights_bias = [np.hstack((batch_weights[j][c * 2 - 1].reshape(-1, 1), batch_weights[j][c * 2])) for j in 222 | range(J)] 223 | sigma_inv_prior = np.array([1 / sigma0_bias] + (weights_bias[0].shape[1] - 1) * [1 / sigma0]) 224 | mean_prior = np.array([mu0_bias] + (weights_bias[0].shape[1] - 1) * [mu0]) 225 | sigma_inv_layer = [np.array([1 / sigma_bias] + [y / sigma for y in last_layer_const[j]]) for j in range(J)] 226 | elif c > 1: 227 | weights_bias = [np.hstack((batch_weights[j][c * 2 - 1].reshape(-1, 1), 228 | patch_weights(batch_weights[j][c * 2], L_next, assignment_c[j]))) for j in 229 | range(J)] 230 | sigma_inv_prior = np.array([1 / sigma0_bias] + (weights_bias[0].shape[1] - 1) * [1 / sigma0]) 231 | mean_prior = np.array([mu0_bias] + (weights_bias[0].shape[1] - 1) * [mu0]) 232 | sigma_inv_layer = [np.array([1 / sigma_bias] + (weights_bias[j].shape[1] - 1) * [1 / sigma]) for j in 233 | range(J)] 234 | else: 235 | weights_bias = [np.hstack((batch_weights[j][0].T, batch_weights[j][c * 2 - 1].reshape(-1, 1), 236 | patch_weights(batch_weights[j][c * 2], L_next, assignment_c[j]))) for j in 237 | range(J)] 238 | sigma_inv_prior = np.array( 239 | D * [1 / sigma0] + [1 / sigma0_bias] + (weights_bias[0].shape[1] - 1 - D) * [1 / sigma0]) 240 | mean_prior = np.array(D * [mu0] + [mu0_bias] + (weights_bias[0].shape[1] - 1 - D) * [mu0]) 241 | if n_layers == 2: 242 | sigma_inv_layer = [ 243 | np.array(D * [1 / sigma] + [1 / sigma_bias] + [y / sigma for y in last_layer_const[j]]) for j in 244 | range(J)] 245 | else: 246 | sigma_inv_layer = [ 247 | np.array(D * [1 / sigma] + [1 / sigma_bias] + (weights_bias[j].shape[1] - 1 - D) * [1 / sigma]) for 248 | j in range(J)] 249 | 250 | assignment_c, global_weights_c, global_sigmas_c = match_layer(weights_bias, sigma_inv_layer, mean_prior, 251 | sigma_inv_prior, gamma, it, 252 | assignment=assignments_old[c - 1]) 253 | L_next = global_weights_c.shape[0] 254 | assignment_all = [assignment_c] + assignment_all 255 | 256 | if c == (n_layers - 1) and n_layers > 2: 257 | softmax_bias, softmax_inv_sigma = process_softmax_bias(batch_weights, last_layer_const, sigma, sigma0) 258 | global_weights_out = [global_weights_c[:, 0], global_weights_c[:, 1:], softmax_bias] 259 | global_inv_sigmas_out = [global_sigmas_c[:, 0], global_sigmas_c[:, 1:], softmax_inv_sigma] 260 | elif c > 1: 261 | global_weights_out = [global_weights_c[:, 0], global_weights_c[:, 1:]] + global_weights_out 262 | global_inv_sigmas_out = [global_sigmas_c[:, 0], global_sigmas_c[:, 1:]] + global_inv_sigmas_out 263 | else: 264 | if n_layers == 2: 265 | softmax_bias, softmax_inv_sigma = process_softmax_bias(batch_weights, last_layer_const, sigma, sigma0) 266 | global_weights_out = [softmax_bias] 267 | global_inv_sigmas_out = [softmax_inv_sigma] 268 | global_weights_out = [global_weights_c[:, :D].T, global_weights_c[:, D], 269 | global_weights_c[:, (D + 1):]] + global_weights_out 270 | global_inv_sigmas_out = [global_sigmas_c[:, :D].T, global_sigmas_c[:, D], 271 | global_sigmas_c[:, (D + 1):]] + global_inv_sigmas_out 272 | 273 | map_out = [g_w / g_s for g_w, g_s in zip(global_weights_out, global_inv_sigmas_out)] 274 | 275 | return map_out, assignment_all 276 | 277 | 278 | def build_init(hungarian_weights, assignments, j): 279 | batch_init = [] 280 | C = len(assignments) 281 | if len(hungarian_weights) == 4: 282 | batch_init.append(hungarian_weights[0][:, assignments[0][j]]) 283 | batch_init.append(hungarian_weights[1][assignments[0][j]]) 284 | batch_init.append(hungarian_weights[2][assignments[0][j]]) 285 | batch_init.append(hungarian_weights[3]) 286 | return batch_init 287 | for c in range(C): 288 | if c == 0: 289 | batch_init.append(hungarian_weights[c][:, assignments[c][j]]) 290 | batch_init.append(hungarian_weights[c + 1][assignments[c][j]]) 291 | else: 292 | batch_init.append(hungarian_weights[2 * c][assignments[c - 1][j]][:, assignments[c][j]]) 293 | batch_init.append(hungarian_weights[2 * c + 1][assignments[c][j]]) 294 | if c == C - 1: 295 | batch_init.append(hungarian_weights[2 * c + 2][assignments[c][j]]) 296 | batch_init.append(hungarian_weights[-1]) 297 | return batch_init 298 | 299 | 300 | def gaus_init(n_units, D, K, seed=None, mu=0, sd=0.1, bias=0.1): 301 | if seed is not None: 302 | np.random.seed(seed) 303 | 304 | batch_init = [] 305 | C = len(n_units) 306 | for c in range(C): 307 | if c == 0: 308 | batch_init.append(np.random.normal(mu, sd, (D, n_units[c]))) 309 | batch_init.append(np.repeat(bias, n_units[c])) 310 | else: 311 | if C != 1: 312 | batch_init.append(np.random.normal(mu, sd, (n_units[c - 1], n_units[c]))) 313 | batch_init.append(np.repeat(bias, n_units[c])) 314 | if c == C - 1: 315 | batch_init.append(np.random.normal(mu, sd, (n_units[c], K))) 316 | batch_init.append(np.repeat(bias, K)) 317 | return batch_init -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class FcNet(nn.Module): 6 | """ 7 | Fully connected network for MNIST classification 8 | """ 9 | 10 | def __init__(self, input_dim, hidden_dims, output_dim, dropout_p=0.0): 11 | 12 | super().__init__() 13 | 14 | self.input_dim = input_dim 15 | self.hidden_dims = hidden_dims 16 | self.output_dim = output_dim 17 | self.dropout_p = dropout_p 18 | 19 | self.dims = [self.input_dim] 20 | self.dims.extend(hidden_dims) 21 | self.dims.append(self.output_dim) 22 | 23 | self.layers = nn.ModuleList([]) 24 | 25 | for i in range(len(self.dims)-1): 26 | ip_dim = self.dims[i] 27 | op_dim = self.dims[i+1] 28 | self.layers.append( 29 | nn.Linear(ip_dim, op_dim, bias=True) 30 | ) 31 | 32 | self.__init_net_weights__() 33 | 34 | def __init_net_weights__(self): 35 | 36 | for m in self.layers: 37 | m.weight.data.normal_(0.0, 0.1) 38 | m.bias.data.fill_(0.1) 39 | 40 | def forward(self, x): 41 | 42 | x = x.view(-1, self.input_dim) 43 | 44 | for i, layer in enumerate(self.layers): 45 | x = layer(x) 46 | 47 | # Do not apply ReLU on the final layer 48 | if i < (len(self.layers) - 1): 49 | x = F.relu(x) 50 | 51 | if i < (len(self.layers) - 1): # No dropout on output layer 52 | x = F.dropout(x, p=self.dropout_p, training=self.training) 53 | 54 | return x --------------------------------------------------------------------------------