├── .gitignore ├── KE_model.py ├── LICENSE ├── README.md ├── configs └── base_config.py ├── constants.py ├── data ├── __init__.py ├── aircraft.py ├── custom_dataset.py ├── flower.py └── py_transform.py ├── imgs ├── cvpr_2021.gif └── dense_slim.jpg ├── layers ├── CS_KD.py ├── __init__.py ├── bn_type.py ├── conv_type.py ├── linear_type.py └── normalize_layer.py ├── models ├── __init__.py ├── builder.py ├── common.py ├── split_densenet.py ├── split_googlenet.py └── split_resnet.py ├── sample_runs └── aircraft │ ├── SPLT_CLS_Aircraft100Pytorch_Split_ResNet18_cskdFalse_smthTrue_k0.5_G5_e200_evrand_hResetFalse_smkels_single_gpu_test3.csv │ └── train_log.txt ├── train_KE_cls.py ├── trainers ├── __init__.py └── default_cls.py └── utils ├── __init__.py ├── csv_utils.py ├── eval_utils.py ├── gpu_utils.py ├── log_utils.py ├── logging.py ├── model_profile.py ├── net_utils.py ├── os_utils.py ├── path_utils.py └── schedulers.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | .idea/ 3 | 4 | *.pyc 5 | 6 | honda_test.py 7 | -------------------------------------------------------------------------------- /KE_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import data 4 | import torch 5 | import random 6 | import importlib 7 | import torch.optim 8 | import numpy as np 9 | import torch.nn as nn 10 | import torch.utils.data 11 | import torch.nn.parallel 12 | from utils import net_utils 13 | from utils import csv_utils 14 | from utils import gpu_utils 15 | from utils import path_utils 16 | # from layers import ml_losses 17 | from datetime import timedelta 18 | import torch.utils.data.distributed 19 | from utils.schedulers import get_policy 20 | from torch.utils.tensorboard import SummaryWriter 21 | from utils.logging import AverageMeter, ProgressMeter 22 | 23 | 24 | 25 | def ke_cls_train(cfg, model,generation): 26 | cfg.logger.info(cfg) 27 | if cfg.seed is not None: 28 | random.seed(cfg.seed) 29 | torch.manual_seed(cfg.seed) 30 | torch.cuda.manual_seed(cfg.seed) 31 | torch.cuda.manual_seed_all(cfg.seed) 32 | 33 | # train, validate, modifier = get_trainer(cfg) 34 | train, validate = get_trainer(cfg) 35 | 36 | if cfg.gpu is not None: 37 | cfg.logger.info("Use GPU: {} for training".format(cfg.gpu)) 38 | 39 | # if cfg.pretrained: 40 | # net_utils.load_pretrained(cfg.pretrained,cfg.multigpu[0], model) 41 | 42 | optimizer = get_optimizer(cfg, model) 43 | cfg.logger.info(f"=> Getting {cfg.set} dataset") 44 | dataset = getattr(data, cfg.set)(cfg) 45 | 46 | lr_policy = get_policy(cfg.lr_policy)(optimizer, cfg) 47 | 48 | if cfg.label_smoothing is None: 49 | softmax_criterion = nn.CrossEntropyLoss().cuda() 50 | else: 51 | softmax_criterion = net_utils.LabelSmoothing(smoothing=cfg.label_smoothing).cuda() 52 | 53 | 54 | criterion = lambda output,target: softmax_criterion(output, target) 55 | 56 | 57 | # optionally resume from a checkpoint 58 | best_val_acc1 = 0.0 59 | best_val_acc5 = 0.0 60 | best_train_acc1 = 0.0 61 | best_train_acc5 = 0.0 62 | 63 | if cfg.resume: 64 | best_val_acc1 = resume(cfg, model, optimizer) 65 | 66 | # Data loading code 67 | # if cfg.evaluate: 68 | # last_val_acc1, last_val_acc5 = validate( 69 | # dataset.val_loader, model, criterion, cfg, writer=None, epoch=cfg.start_epoch 70 | # ) 71 | # 72 | # return 73 | 74 | run_base_dir, ckpt_base_dir, log_base_dir = path_utils.get_directories(cfg,generation) 75 | cfg.ckpt_base_dir = ckpt_base_dir 76 | 77 | writer = SummaryWriter(log_dir=log_base_dir) 78 | epoch_time = AverageMeter("epoch_time", ":.4f", write_avg=False) 79 | validation_time = AverageMeter("validation_time", ":.4f", write_avg=False) 80 | train_time = AverageMeter("train_time", ":.4f", write_avg=False) 81 | progress_overall = ProgressMeter( 82 | 1, [epoch_time, validation_time, train_time],cfg, prefix="Overall Timing" 83 | ) 84 | 85 | end_epoch = time.time() 86 | cfg.start_epoch = cfg.start_epoch or 0 87 | last_val_acc1 = None 88 | 89 | 90 | start_time = time.time() 91 | gpu_info = gpu_utils.GPU_Utils(gpu_index=cfg.gpu) 92 | 93 | 94 | # Start training 95 | for epoch in range(cfg.start_epoch, cfg.epochs): 96 | lr_policy(epoch, iteration=None) 97 | # if epoch == cfg.start_epoch: 98 | # modifier(cfg, epoch, model) 99 | 100 | cur_lr = net_utils.get_lr(optimizer) 101 | # print(cur_lr) 102 | # train for one epoch 103 | start_train = time.time() 104 | train_acc1, train_acc5 = train( 105 | dataset.train_loader, model, criterion, optimizer, epoch, cfg, writer=writer 106 | ) 107 | train_time.update((time.time() - start_train) / 60) 108 | 109 | if (epoch+1) % cfg.test_interval == 0: 110 | # evaluate on validation set 111 | start_validation = time.time() 112 | last_val_acc1, last_val_acc5 = validate(dataset.val_loader, model, criterion, cfg, writer, epoch) 113 | validation_time.update((time.time() - start_validation) / 60) 114 | 115 | # remember best acc@1 and save checkpoint 116 | is_best = last_val_acc1 > best_val_acc1 117 | best_val_acc1 = max(last_val_acc1, best_val_acc1) 118 | best_val_acc5 = max(last_val_acc5, best_val_acc5) 119 | best_train_acc1 = max(train_acc1, best_train_acc1) 120 | best_train_acc5 = max(train_acc5, best_train_acc5) 121 | 122 | save = ((epoch % cfg.save_every) == 0) and cfg.save_every > 0 123 | 124 | elapsed_time = time.time() - start_time 125 | seconds_todo = (cfg.epochs - epoch) * (elapsed_time/cfg.test_interval) 126 | estimated_time_complete = timedelta(seconds=int(seconds_todo)) 127 | start_time = time.time() 128 | cfg.logger.info(f"==> ETA: {estimated_time_complete}\tGPU-M: {gpu_info.gpu_mem_usage()}\tGPU-U: {gpu_info.gpu_utilization()}") 129 | if is_best or save or epoch == cfg.epochs - 1: 130 | if is_best: 131 | cfg.logger.info(f"==> best {last_val_acc1:.02f} saving at {ckpt_base_dir / 'model_best.pth'}") 132 | 133 | net_utils.save_checkpoint( 134 | { 135 | "epoch": epoch + 1, 136 | "arch": cfg.arch, 137 | "state_dict": model.state_dict(), 138 | "best_acc1": best_val_acc1, 139 | "best_acc5": best_val_acc5, 140 | "best_train_acc1": best_train_acc1, 141 | "best_train_acc5": best_train_acc5, 142 | "optimizer": optimizer.state_dict(), 143 | "curr_acc1": last_val_acc1, 144 | "curr_acc5": last_val_acc5, 145 | }, 146 | is_best, 147 | filename=ckpt_base_dir / f"epoch_{epoch}.state", 148 | save=save or epoch == cfg.epochs - 1, 149 | ) 150 | 151 | epoch_time.update((time.time() - end_epoch) / 60) 152 | progress_overall.display(epoch) 153 | progress_overall.write_to_tensorboard( 154 | writer, prefix="diagnostics", global_step=epoch 155 | ) 156 | 157 | 158 | writer.add_scalar("test/lr", cur_lr, epoch) 159 | end_epoch = time.time() 160 | 161 | 162 | if cfg.eval_tst: 163 | last_tst_acc1, last_tst_acc5 = validate(dataset.tst_loader, model, criterion, cfg, writer, 0) 164 | best_tst_acc1 = 0 165 | best_tst_acc5 = 0 166 | # net_utils.load_pretrained(ckpt_base_dir / 'model_best.pth',cfg.multigpu[0],model) 167 | # best_tst_acc1, best_tst_acc5 = validate(dataset.tst_loader, model, criterion, cfg, writer, 0) 168 | else: 169 | last_tst_acc1 = 0 170 | last_tst_acc5 = 0 171 | best_tst_acc1 = 0 172 | best_tst_acc5 = 0 173 | 174 | 175 | csv_utils.write_cls_result_to_csv( 176 | ## Validation 177 | curr_acc1=last_val_acc1, 178 | curr_acc5=last_val_acc5, 179 | best_acc1=best_val_acc1, 180 | best_acc5=best_val_acc5, 181 | 182 | ## Test 183 | last_tst_acc1=last_tst_acc1, 184 | last_tst_acc5=last_tst_acc5, 185 | best_tst_acc1=best_tst_acc1, 186 | best_tst_acc5=best_tst_acc5, 187 | 188 | ## Train 189 | best_train_acc1=best_train_acc1, 190 | best_train_acc5=best_train_acc5, 191 | 192 | 193 | split_rate=cfg.split_rate, 194 | bias_split_rate=cfg.bias_split_rate, 195 | 196 | base_config=cfg.name, 197 | name=cfg.name, 198 | ) 199 | 200 | cfg.logger.info(f"==> Final Best {best_val_acc1:.02f}, saving at {ckpt_base_dir / 'model_best.pth'}") 201 | return ckpt_base_dir ## Do Not return the model because you just reloaded the best one 202 | 203 | 204 | def get_trainer(args): 205 | args.logger.info(f"=> Using trainer from trainers.{args.trainer}") 206 | trainer = importlib.import_module(f"trainers.{args.trainer}") 207 | 208 | return trainer.train, trainer.validate #, trainer.modifier 209 | 210 | 211 | def resume(args, model, optimizer): 212 | if os.path.isfile(args.resume): 213 | args.logger.info(f"=> Loading checkpoint '{args.resume}'") 214 | 215 | checkpoint = torch.load(args.resume, map_location=f"cuda:{args.gpu}") 216 | if args.start_epoch is None: 217 | args.logger.info(f"=> Setting new start epoch at {checkpoint['epoch']}") 218 | args.start_epoch = checkpoint["epoch"] 219 | 220 | best_acc1 = checkpoint["best_acc1"] 221 | 222 | model.load_state_dict(checkpoint["state_dict"]) 223 | 224 | optimizer.load_state_dict(checkpoint["optimizer"]) 225 | 226 | args.logger.info(f"=> Loaded checkpoint '{args.resume}' (epoch {checkpoint['epoch']})") 227 | 228 | return best_acc1 229 | else: 230 | args.logger.info(f"=> No checkpoint found at '{args.resume}'") 231 | 232 | 233 | 234 | 235 | def get_optimizer(args, model,fine_tune=False,criterion=None): 236 | for n, v in model.named_parameters(): 237 | if v.requires_grad: 238 | args.logger.info(" gradient to {}".format(n)) 239 | 240 | if not v.requires_grad: 241 | args.logger.info(" no gradient to {}".format(n)) 242 | 243 | param_groups = model.parameters() 244 | if fine_tune: 245 | # Train Parameters 246 | param_groups = [ 247 | {'params': list( 248 | set(model.parameters()).difference(set(model.model.embedding.parameters()))) if args.gpu != -1 else 249 | list(set(model.module.parameters()).difference(set(model.module.model.embedding.parameters())))}, 250 | { 251 | 'params': model.model.embedding.parameters() if args.gpu != -1 else model.module.model.embedding.parameters(), 252 | 'lr': float(args.lr) * 1}, 253 | ] 254 | if args.ml_loss == 'Proxy_Anchor': 255 | param_groups.append({'params': criterion.proxies, 'lr': float(args.lr) * 100}) 256 | 257 | if args.optimizer == "sgd": 258 | optimizer = torch.optim.SGD(param_groups, lr=args.lr, 259 | momentum=args.momentum, weight_decay=args.weight_decay) 260 | # parameters = list(model.named_parameters()) 261 | # bn_params = [v for n, v in parameters if ("bn" in n) and v.requires_grad] 262 | # rest_params = [v for n, v in parameters if ("bn" not in n) and v.requires_grad] 263 | # optimizer = torch.optim.SGD( 264 | # [ 265 | # { 266 | # "params": bn_params, 267 | # "weight_decay": 0 if args.no_bn_decay else args.weight_decay, 268 | # }, 269 | # {"params": rest_params, "weight_decay": args.weight_decay}, 270 | # ], 271 | # args.lr, 272 | # momentum=args.momentum, 273 | # weight_decay=args.weight_decay, 274 | # nesterov=args.nesterov, 275 | # ) 276 | elif args.optimizer == "adam": 277 | optimizer = torch.optim.Adam( 278 | filter(lambda p: p.requires_grad, param_groups), lr=args.lr 279 | ) 280 | elif args.optimizer == 'rmsprop': 281 | optimizer = torch.optim.RMSprop(param_groups, lr=args.lr, alpha=0.9, weight_decay = args.weight_decay, momentum = 0.9) 282 | elif args.optimizer == 'adamw': 283 | optimizer = torch.optim.AdamW(param_groups, lr=args.lr, weight_decay = args.weight_decay) 284 | else: 285 | raise NotImplemented('Invalid Optimizer {}'.format(args.optimizer)) 286 | 287 | return optimizer 288 | 289 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # knowledge_evolution 2 | The official PyTorch implementation of [Knowledge Evolution in Neural Networks](https://arxiv.org/abs/2103.05152) -- CVPR 2021 Oral. 3 | 4 | ### TL;DR 5 | We subclass neural layers and define the mask inside the subclass. When creating a new network, we simply use our [SplitConv](https://github.com/ahmdtaha/knowledge_evolution/blob/main/layers/conv_type.py) and [SplitLinear](https://github.com/ahmdtaha/knowledge_evolution/blob/main/layers/linear_type.py) instead of the standard nn.Conv2d and nn.Linear. 6 | 7 | * To extract a slim model, I create a dummy slim network with `slim_factor=split_rate` ([such as](https://github.com/ahmdtaha/knowledge_evolution/blob/b389041ff4dd308dd8e35ebb2b01e5863d7a6924/train_KE_cls.py#L80)), then I call `utils.nets_utils.extract_slim` ([such as](https://github.com/ahmdtaha/knowledge_evolution/blob/b389041ff4dd308dd8e35ebb2b01e5863d7a6924/train_KE_cls.py#L90)) to copy the fit-hypothesis weights, from the dense network, into the slim network. 8 | 9 | ![Knowledge Evolution GIF](./imgs/cvpr_2021.gif) 10 | 11 | ## Requirements 12 | 13 | * Python 3+ [Tested on 3.7] 14 | * PyTorch 1.X [Tested on torch 1.6.0 and torchvision 0.6.0] 15 | 16 | [//]: # "## ImageNet Pretrained Models" 17 | 18 | 19 | 20 | ## Usage example 21 | 22 | First update `constants.py` with your dataset dir and checkpoints dir 23 | 24 | To train a model `python train_KE_cls.py` 25 | 26 | The default hyperparameters are already hard coded in the python script. However these hyperparameters can be overridden by providing at least one parameter when running the script (e.g., `python train_KE_cls.py --name exp_name`) 27 | 28 | 29 | The Flower102Pytorch loader (`data>flower.py`) works directly with this [Flower102 dataset](https://drive.google.com/file/d/1KU0SWPYRFk8SAY1IF-8JvkBnzC_PBUrw/view?usp=sharing). This is the original flower102, but with an extra `list` directory that contains csv files for trn, val and tst splits. Feel free to download the flower dataset from [oxford website](https://www.robots.ox.ac.uk/~vgg/data/flowers/102/), just update `data>flower.py` accordingly. 30 | 31 | The following table shows knowledge evolution in both the dense (even rows) and slim (odd rows) using Flower102 on ResNet18. 32 | As the number of generation increases, both the dense and slim networks' performance increases. 33 | 34 | ![Our implementation performance](./imgs/dense_slim.jpg) 35 | 36 | ### TODO LIST 37 | * Document the important flags 38 | * Add TL;DR 39 | 40 | Contributor list 41 | ---------------- 42 | 1. [Ahmed Taha](http://www.cs.umd.edu/~ahmdtaha/) 43 | 44 | I want to give credit to [Vivek Ramanujan and Mitchell Wortsman's repos](https://github.com/allenai/hidden-networks). My implementation uses a lot of utils and functions from their code 45 | 46 | ### Further Contributions 47 | 1. It would be great if someone re-implement this in Tensorflow. Let me know and I will add a link to your Tensorflow implementation here 48 | 2. Pull requests to support more architectures (e.g., WideResNet, MobileNet, etc) are welcomed. 49 | 3. I accept coding-tips to improve the code quality. 50 | 4. To the contribution list, I will add the names/links of those who contribute to this repository. 51 | 52 | 53 | ### MISC Notes 54 | * This repository delivers a knowledge evolution implementation in a _**minimal**_ form. Accordingly, I disabled CS_KD baseline because it requires a specific sampling implementation. 55 | * This implementation _looks_ complex because it supports concatenation layers in DenseNet and GoogLeNet. For those interested in ResNets only, please use algorithm #1 presented in the paper. 56 | * This repository is a "clean" version of my original implementation. This repos focuses on the core idea and omits any research exploration snippets. If a function is missing, please let me know. 57 | 58 | 59 | ## Release History 60 | * 1.0.0 61 | * First code commit on 10 Dec 2020 62 | * Submit/Test Split_googlenet code commit on 19 Dec 2020 63 | * Submit/Test Split_densenet code commit on 20 Dec 2020 64 | * Repository made public on 1 Mar 2021 65 | * Add Aircarfts100 dataset loader 31 Jul 2021 -- the lists dir is available [here](https://drive.google.com/file/d/1lyaro579PRFAIyBZoPpLBjCzKdvsfEH2/view?usp=sharing) 66 | 67 | 68 | ### Citation 69 | ``` 70 | @inproceedings{taha2021knowledge, 71 | title={Knowledge Evolution in Neural Networks}, 72 | author={Taha, Ahmed and Shrivastava, Abhinav and Davis, Larry}, 73 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 74 | year={2021} 75 | } 76 | ``` 77 | -------------------------------------------------------------------------------- /configs/base_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import yaml 4 | import argparse 5 | import os.path as osp 6 | import logging.config 7 | from utils import os_utils 8 | from utils import log_utils 9 | from utils import path_utils 10 | # from configs import parser as _parser 11 | 12 | args = None 13 | 14 | class Config: 15 | def __init__(self): 16 | parser = argparse.ArgumentParser(description="Knowledge Evolution Training Approach") 17 | 18 | # General Config 19 | parser.add_argument( 20 | "--data", help="path to dataset base directory", default="/mnt/disk1/datasets" 21 | ) 22 | 23 | parser.add_argument("--optimizer", help="Which optimizer to use", default="sgd") 24 | parser.add_argument("--set", help="only Flower102Pytorch is currently supported", 25 | type=str, default="Flower102Pytorch", 26 | choices=['Flower102Pytorch','Aircraft100Pytorch']) 27 | 28 | parser.add_argument( 29 | "-a", "--arch", metavar="ARCH", default="Split_ResNet18", help="model architecture", 30 | choices=['Split_ResNet18','Split_ResNet34','Split_ResNet50','Split_ResNet101', 31 | 'Split_googlenet', 32 | 'Split_densenet121', 'Split_densenet161', 'Split_densenet169', 'Split_densenet201', 33 | ] 34 | ) 35 | parser.add_argument( 36 | "--config_file", help="Config file to use (see configs dir)", default=None 37 | ) 38 | parser.add_argument( 39 | "--log-dir", help="Where to save the runs. If None use ./runs", default=None 40 | ) 41 | 42 | parser.add_argument( 43 | '--evolve_mode', default='rand', choices=['rand'], 44 | help='How to initialize the reset-hypothesis.') 45 | 46 | parser.add_argument( 47 | "-t", 48 | "--num_threads", 49 | default=8, 50 | type=int, 51 | metavar="N", 52 | help="number of data loading workers (default: 20)", 53 | ) 54 | parser.add_argument( 55 | "--epochs", 56 | default=90, 57 | type=int, 58 | metavar="N", 59 | help="number of total epochs to run", 60 | ) 61 | parser.add_argument( 62 | "--start-epoch", 63 | default=None, 64 | type=int, 65 | metavar="N", 66 | help="manual epoch number (useful on restarts)", 67 | ) 68 | parser.add_argument( 69 | "-b", 70 | "--batch_size", 71 | default=256, 72 | type=int, 73 | metavar="N", 74 | help="mini-batch size (default: 256), this is the total " 75 | "batch size of all GPUs on the current node when " 76 | "using Data Parallel or Distributed Data Parallel", 77 | ) 78 | parser.add_argument( 79 | "--lr", 80 | "--learning-rate", 81 | default=0.1, 82 | type=float, 83 | metavar="LR", 84 | help="initial learning rate", 85 | dest="lr", 86 | ) 87 | parser.add_argument( 88 | "--warmup_length", default=0, type=int, help="Number of warmup iterations" 89 | ) 90 | parser.add_argument( 91 | "--momentum", default=0.9, type=float, metavar="M", help="momentum" 92 | ) 93 | parser.add_argument( 94 | "--wd", 95 | "--weight_decay", 96 | default=1e-4, 97 | type=float, 98 | metavar="W", 99 | help="weight decay (default: 1e-4)", 100 | dest="weight_decay", 101 | ) 102 | parser.add_argument( 103 | "-p", 104 | "--print-freq", 105 | default=10000, 106 | type=int, 107 | metavar="N", 108 | help="print frequency (default: 10)", 109 | ) 110 | parser.add_argument('--bn_freeze', default=1, type=int, 111 | help='Batch normalization parameter freeze' 112 | ) 113 | parser.add_argument('--samples_per_class', default=1, type=int, 114 | help='Number of samples per class inside a mini-batch.' 115 | ) 116 | parser.add_argument('--alpha', default=32, type=float, 117 | help='Scaling Parameter setting' 118 | ) 119 | parser.add_argument('--warm', default=1, type=int, 120 | help='Warmup training epochs' 121 | ) 122 | parser.add_argument( 123 | "--resume", 124 | default="", 125 | type=str, 126 | metavar="PATH", 127 | help="path to latest checkpoint (default: none)", 128 | ) 129 | 130 | parser.add_argument( 131 | "--pretrained", 132 | dest="pretrained", 133 | default=None, 134 | type=str, 135 | help="use pre-trained model", 136 | ) 137 | parser.add_argument( 138 | "--seed", default=None, type=int, help="seed for initializing training. " 139 | ) 140 | 141 | 142 | parser.add_argument( 143 | "--world_size", 144 | default=1, 145 | type=int, 146 | help="Pytorch DDP world size", 147 | ) 148 | 149 | 150 | parser.add_argument( 151 | "--gpu", 152 | default='0', 153 | type=int, 154 | help="Which GPUs to use?", 155 | ) 156 | parser.add_argument( 157 | "--test_interval", default=10, type=int, help="Eval on tst/val split every ? epochs" 158 | ) 159 | 160 | # Learning Rate Policy Specific 161 | parser.add_argument( 162 | "--lr_policy", default="constant_lr", help="Policy for the learning rate." 163 | ) 164 | parser.add_argument( 165 | "--multistep-lr-adjust", default=30, type=int, help="Interval to drop lr" 166 | ) 167 | parser.add_argument("--multistep-lr-gamma", default=0.1, type=int, help="Multistep multiplier") 168 | parser.add_argument( 169 | "--name", default=None, type=str, help="Experiment name to append to filepath" 170 | ) 171 | parser.add_argument( 172 | "--log_file", default='train_log.txt', type=str, help="Experiment name to append to filepath" 173 | ) 174 | parser.add_argument( 175 | "--save_every", default=-1, type=int, help="Save every ___ epochs" 176 | ) 177 | parser.add_argument( 178 | "--num_generations", default=100, type=int, help="Task Mask number of generations" 179 | ) 180 | parser.add_argument('--lr-decay-step', default=10, type=int,help='Learning decay step setting') 181 | parser.add_argument('--lr-decay-gamma', default=0.5, type=float,help='Learning decay gamma setting') 182 | parser.add_argument( 183 | "--split_rate", 184 | default=1.0, 185 | help="What is the split-rate for the split-network weights?", 186 | type=float, 187 | ) 188 | parser.add_argument( 189 | "--bias_split_rate", 190 | default=1.0, 191 | help="What is the bias split-rate for the split-network weights?", 192 | type=float, 193 | ) 194 | 195 | parser.add_argument( 196 | "--slim_factor", 197 | default=1.0, 198 | help="This variable is used to extract a slim network from a dense network. " 199 | "It is initialized using the split_rate of the trained dense network.", 200 | type=float, 201 | ) 202 | parser.add_argument( 203 | "--split_mode", 204 | default="kels", 205 | choices=['kels','wels'], 206 | help="how to split the binary mask", 207 | ) 208 | parser.add_argument( 209 | "--conv_type", type=str, default='SplitConv', help="SplitConv | DenseConv" 210 | ) 211 | parser.add_argument( 212 | "--linear_type", type=str, default='SplitLinear', help="SplitLinear | DenseLinear" 213 | ) 214 | parser.add_argument("--mode", default="fan_in", help="Weight initialization mode") 215 | parser.add_argument( 216 | "--nonlinearity", default="relu", help="Nonlinearity used by initialization" 217 | ) 218 | parser.add_argument("--bn_type", default='SplitBatchNorm', help="BatchNorm type", 219 | choices=['NormalBatchNorm','NonAffineBatchNorm','SplitBatchNorm']) 220 | parser.add_argument( 221 | "--init", default="kaiming_normal", help="Weight initialization modifications" 222 | ) 223 | parser.add_argument( 224 | "--no-bn-decay", action="store_true", default=False, help="No batchnorm decay" 225 | ) 226 | parser.add_argument( 227 | "--scale-fan", action="store_true", default=False, help="scale fan" 228 | ) 229 | 230 | parser.add_argument("--cs_kd", action="store_true", default=False, help="Enable Cls_KD") 231 | parser.add_argument("--reset_mask", action="store_true", default=False, help="Reset mask?") 232 | parser.add_argument("--reset_hypothesis", action="store_true", default=False, help="Reset hypothesis across generations") 233 | 234 | parser.add_argument( 235 | "--label_smoothing", 236 | type=float, 237 | help="Label smoothing to use, default 0.0", 238 | default=None, 239 | ) 240 | 241 | parser.add_argument( 242 | "--trainer", type=str, default="default", help="cs, ss, or standard training" 243 | ) 244 | 245 | 246 | self.parser = parser 247 | 248 | def parse(self,args): 249 | self.cfg = self.parser.parse_args(args) 250 | 251 | # Allow for use from notebook without config file 252 | # self.read_config_file() 253 | # self.read_cmd_args() 254 | 255 | if self.cfg.set == 'Flower102' or self.cfg.set == 'Flower102Pytorch': 256 | self.cfg.num_cls = 102 257 | self.cfg.eval_tst = True 258 | elif self.cfg.set == 'CUB200': 259 | self.cfg.num_cls = 200 260 | self.cfg.eval_tst = False 261 | elif self.cfg.set == 'ImageNet': 262 | self.cfg.num_cls = 1000 263 | self.cfg.eval_tst = False 264 | elif self.cfg.set == 'FCAMD': 265 | self.cfg.num_cls = 250 266 | self.cfg.eval_tst = False 267 | elif self.cfg.set == 'CUB200_RET': 268 | self.cfg.num_cls = self.cfg.emb_dim 269 | self.cfg.eval_tst = False 270 | elif self.cfg.set == 'CARS_RET': 271 | self.cfg.num_cls = self.cfg.emb_dim 272 | self.cfg.eval_tst = False 273 | elif self.cfg.set == 'Dog120': 274 | self.cfg.num_cls = 120 275 | self.cfg.eval_tst = False 276 | elif self.cfg.set in ['MIT67']: 277 | self.cfg.num_cls = 67 278 | self.cfg.eval_tst = False 279 | elif self.cfg.set == 'Aircraft100' or self.cfg.set == 'Aircraft100Pytorch': 280 | self.cfg.num_cls = 100 281 | self.cfg.eval_tst = True 282 | else: 283 | raise NotImplementedError('Invalid dataset {}'.format(self.cfg.set)) 284 | 285 | if self.cfg.cs_kd: 286 | self.cfg.samples_per_class = 2 287 | 288 | self.cfg.exp_dir = osp.join(path_utils.get_checkpoint_dir() , self.cfg.name) 289 | 290 | os_utils.touch_dir(self.cfg.exp_dir) 291 | log_file = os.path.join(self.cfg.exp_dir, self.cfg.log_file) 292 | logging.config.dictConfig(log_utils.get_logging_dict(log_file)) 293 | self.cfg.logger = logging.getLogger('KE') 294 | 295 | return self.cfg 296 | 297 | -------------------------------------------------------------------------------- /constants.py: -------------------------------------------------------------------------------- 1 | datasets_dir = '/mnt/data/datasets' 2 | checkpoints_dir = '/mnt/data/checkpoints' -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from data.flower import Flower102Pytorch 2 | from data.aircraft import Aircraft100Pytorch 3 | -------------------------------------------------------------------------------- /data/aircraft.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import numpy as np 4 | import pandas as pd 5 | import os.path as osp 6 | from utils import path_utils 7 | from data.custom_dataset import CustomDataset 8 | 9 | 10 | class Aircraft100Pytorch: 11 | 12 | def __init__(self, cfg): 13 | 14 | db_path = path_utils.get_datasets_dir(cfg.set) 15 | self.img_path = db_path + '/images/' 16 | 17 | csv_file = '/lists/trn.csv' 18 | trn_data_df = pd.read_csv(db_path + csv_file) 19 | 20 | lbls = trn_data_df['label'] 21 | lbl2idx = np.sort(np.unique(lbls)) 22 | self.lbl2idx_dict = {k: v for v, k in enumerate(lbl2idx)} 23 | self.final_lbls = [self.lbl2idx_dict[x] for x in list(lbls.values)] 24 | 25 | self.num_classes = len(self.lbl2idx_dict.keys()) 26 | 27 | self.train_loader = self.create_loader(csv_file, cfg, is_training=True) 28 | 29 | csv_file = '/lists/tst.csv' 30 | self.tst_loader = self.create_loader(csv_file, cfg, is_training=False) 31 | 32 | csv_file = '/lists/val.csv' 33 | self.val_loader = self.create_loader(csv_file, cfg, is_training=False) 34 | 35 | def create_loader(self, imgs_lst, cfg, is_training): 36 | db_path = path_utils.get_datasets_dir(cfg.set) 37 | if osp.exists(db_path + imgs_lst): 38 | data_df = pd.read_csv(db_path + imgs_lst) 39 | imgs, lbls = self.imgs_and_lbls(data_df) 40 | epoch_size = len(imgs) 41 | loader = torch.utils.data.DataLoader(CustomDataset(imgs, lbls, is_training=is_training), 42 | batch_size=cfg.batch_size, shuffle=is_training, 43 | num_workers=cfg.num_threads) 44 | 45 | loader.num_batches = math.ceil(epoch_size / cfg.batch_size) 46 | loader.num_files = epoch_size 47 | else: 48 | loader = None 49 | 50 | return loader 51 | 52 | def imgs_and_lbls(self, data_df): 53 | """ 54 | Load images' paths and int32 labels 55 | :param repeat: This is similar to TF.data.Dataset repeat. I use TF dataset repeat and no longer user this params. 56 | So its default is False 57 | 58 | :return: a list of images' paths and their corresponding int32 labels 59 | """ 60 | 61 | imgs = data_df 62 | ## Faster way to read data 63 | images = imgs['file_name'].tolist() 64 | lbls = imgs['label'].tolist() 65 | for img_idx in range(imgs.shape[0]): 66 | images[img_idx] = self.img_path + images[img_idx] 67 | lbls[img_idx] = self.lbl2idx_dict[lbls[img_idx]] 68 | 69 | return images, lbls -------------------------------------------------------------------------------- /data/custom_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from PIL import Image 3 | from torchvision import transforms 4 | from torch.utils.data import Dataset 5 | 6 | 7 | class CustomDataset(Dataset): 8 | """Face Landmarks dataset.""" 9 | 10 | def __init__(self, imgs_path, lbls,is_training): 11 | self.imgs_path = imgs_path 12 | self.lbls = lbls 13 | #self.idx = list(range(0,len(lbls))) 14 | if is_training: 15 | self.transform = transforms.Compose([ 16 | # transforms.Resize(256), 17 | transforms.RandomResizedCrop(224), 18 | transforms.RandomHorizontalFlip(), 19 | transforms.ToTensor(), 20 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 21 | ]) 22 | else: 23 | self.transform = transforms.Compose([ 24 | transforms.Resize(256), 25 | transforms.CenterCrop(224), 26 | transforms.ToTensor(), 27 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 28 | ]) 29 | 30 | 31 | def __len__(self): 32 | return len(self.imgs_path) 33 | 34 | def __getitem__(self, idx): 35 | if torch.is_tensor(idx): 36 | idx = idx.tolist() 37 | if self.transform: 38 | img = Image.open(self.imgs_path[idx]) 39 | if len(img.mode) != 3 or len(img.getbands())!=3: 40 | # print(len(img.mode),len(img.getbands())) 41 | # print(self.imgs_path[idx]) 42 | img = img.convert('RGB') 43 | 44 | img = self.transform(img) 45 | # sample = {'image': self.transform(img), 'label': self.lbls[idx],'index': idx} 46 | else: 47 | img = Image.open(self.imgs_path[idx]) 48 | # sample = {'image': Image.open(self.imgs_path[idx]), 'label': self.lbls[idx],'index': idx} 49 | 50 | return img,self.lbls[idx] 51 | -------------------------------------------------------------------------------- /data/flower.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import numpy as np 4 | import pandas as pd 5 | import os.path as osp 6 | from utils import path_utils 7 | from data.custom_dataset import CustomDataset 8 | 9 | 10 | class Flower102Pytorch: 11 | 12 | 13 | def __init__(self, cfg): 14 | 15 | db_path = path_utils.get_datasets_dir(cfg.set) 16 | self.img_path = db_path + '/jpg/' 17 | 18 | csv_file = '/lists/trn.csv' 19 | trn_data_df = pd.read_csv(db_path + csv_file) 20 | 21 | lbls = trn_data_df['label'] 22 | lbl2idx = np.sort(np.unique(lbls)) 23 | self.lbl2idx_dict = {k: v for v, k in enumerate(lbl2idx)} 24 | self.final_lbls = [self.lbl2idx_dict[x] for x in list(lbls.values)] 25 | 26 | self.num_classes = len(self.lbl2idx_dict.keys()) 27 | 28 | 29 | 30 | self.train_loader = self.create_loader(csv_file, cfg,is_training=True) 31 | 32 | csv_file = '/lists/tst.csv' 33 | self.tst_loader = self.create_loader(csv_file,cfg,is_training=False) 34 | 35 | csv_file = '/lists/val.csv' 36 | self.val_loader = self.create_loader(csv_file,cfg,is_training=False) 37 | 38 | 39 | def create_loader(self,imgs_lst,cfg,is_training): 40 | db_path = path_utils.get_datasets_dir(cfg.set) 41 | if osp.exists(db_path + imgs_lst): 42 | data_df = pd.read_csv(db_path + imgs_lst) 43 | imgs, lbls = self.imgs_and_lbls(data_df) 44 | epoch_size = len(imgs) 45 | loader = torch.utils.data.DataLoader(CustomDataset(imgs, lbls, is_training=is_training), 46 | batch_size=cfg.batch_size, shuffle=is_training, 47 | num_workers=cfg.num_threads) 48 | 49 | loader.num_batches = math.ceil(epoch_size / cfg.batch_size) 50 | loader.num_files = epoch_size 51 | else: 52 | loader = None 53 | 54 | return loader 55 | 56 | def imgs_and_lbls(self,data_df): 57 | """ 58 | Load images' paths and int32 labels 59 | :param repeat: This is similar to TF.data.Dataset repeat. I use TF dataset repeat and no longer user this params. 60 | So its default is False 61 | 62 | :return: a list of images' paths and their corresponding int32 labels 63 | """ 64 | 65 | imgs = data_df 66 | ## Faster way to read data 67 | images = imgs['file_name'].tolist() 68 | lbls = imgs['label'].tolist() 69 | for img_idx in range(imgs.shape[0]): 70 | images[img_idx] = self.img_path + images[img_idx] 71 | lbls[img_idx] = self.lbl2idx_dict[lbls[img_idx]] 72 | 73 | 74 | return images, lbls 75 | 76 | -------------------------------------------------------------------------------- /data/py_transform.py: -------------------------------------------------------------------------------- 1 | from PIL import Image, ImageOps 2 | import torchvision.transforms as T 3 | 4 | class Transform_single(): 5 | def __init__(self, cfg ,image_size, train, mean_std): 6 | if train == True: 7 | self.transform = T.Compose([ 8 | T.RandomResizedCrop(image_size, scale=(0.2, 1.0)), 9 | T.RandomHorizontalFlip(), 10 | T.ToTensor(), 11 | T.Normalize(*mean_std) 12 | ]) 13 | else: 14 | self.transform = T.Compose([ 15 | T.Resize(int(image_size*(8/7)), interpolation=Image.BICUBIC), # 224 -> 256 16 | T.Resize(int(image_size*(8/7))), # 224 -> 256 17 | T.CenterCrop(image_size), 18 | T.ToTensor(), 19 | T.Normalize(*mean_std) 20 | ]) 21 | def __call__(self, x): 22 | return self.transform(x) -------------------------------------------------------------------------------- /imgs/cvpr_2021.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ahmdtaha/knowledge_evolution/a3f2eb2eed7accb86ad1af2a15c13e4a9654fe16/imgs/cvpr_2021.gif -------------------------------------------------------------------------------- /imgs/dense_slim.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ahmdtaha/knowledge_evolution/a3f2eb2eed7accb86ad1af2a15c13e4a9654fe16/imgs/dense_slim.jpg -------------------------------------------------------------------------------- /layers/CS_KD.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class KDLoss(nn.Module): 5 | def __init__(self, temp_factor): 6 | super(KDLoss, self).__init__() 7 | self.temp_factor = temp_factor 8 | self.kl_div = nn.KLDivLoss(reduction="sum") 9 | 10 | def forward(self, input, target): 11 | log_p = torch.log_softmax(input/self.temp_factor, dim=1) 12 | q = torch.softmax(target/self.temp_factor, dim=1) 13 | loss = self.kl_div(log_p, q)*(self.temp_factor**2)/input.size(0) 14 | return loss -------------------------------------------------------------------------------- /layers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ahmdtaha/knowledge_evolution/a3f2eb2eed7accb86ad1af2a15c13e4a9654fe16/layers/__init__.py -------------------------------------------------------------------------------- /layers/bn_type.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import numpy as np 4 | import torch.nn as nn 5 | 6 | NormalBatchNorm = nn.BatchNorm2d 7 | 8 | 9 | class NonAffineBatchNorm(nn.BatchNorm2d): 10 | def __init__(self, dim,**kwargs): 11 | super(NonAffineBatchNorm, self).__init__(dim, affine=False,**kwargs) 12 | 13 | class SplitBatchNorm(nn.BatchNorm2d): 14 | def __init__(self, dim,**kwargs): 15 | self.in_channels_order = kwargs.pop('in_channels_order', None) 16 | split_rate = kwargs.pop('split_rate', None) 17 | super(SplitBatchNorm, self).__init__(dim, affine=True, **kwargs) 18 | 19 | if self.in_channels_order is not None: 20 | assert split_rate is not None, 'Should not be none if in_channels_order is not None' 21 | mask = np.zeros(self.weight.size()[0]) 22 | conv_concat = self.in_channels_order.split(',') 23 | start_ch = 0 24 | for conv in conv_concat: 25 | mask[start_ch:start_ch + math.ceil(int(conv) * split_rate)] = 1 26 | start_ch += int(conv) 27 | 28 | self.bn_mask = nn.Parameter(torch.Tensor(mask), requires_grad=False) 29 | 30 | def extract_slim(self,dst_m): 31 | c_out = self.weight.size()[0] 32 | d_out = dst_m.weight.size()[0] 33 | if self.in_channels_order is None: 34 | assert dst_m.weight.shape == self.weight[:d_out].shape 35 | dst_m.weight.data = self.weight[:d_out] 36 | dst_m.bias.data = self.bias[:d_out] 37 | dst_m.running_mean.data = self.running_mean[:d_out] 38 | dst_m.running_var.data = self.running_var[:d_out] 39 | else: 40 | assert dst_m.weight.shape == self.weight[self.bn_mask == 1].shape 41 | dst_m.weight.data = self.weight[self.bn_mask == 1] 42 | dst_m.bias.data = self.bias.data[self.bn_mask == 1] 43 | dst_m.running_mean.data = self.running_mean[self.bn_mask == 1] 44 | dst_m.running_var.data = self.running_var[self.bn_mask == 1] 45 | -------------------------------------------------------------------------------- /layers/conv_type.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import random 4 | import numpy as np 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.autograd as autograd 8 | from configs.base_config import args as parser_args 9 | 10 | 11 | DenseConv = nn.Conv2d 12 | 13 | 14 | # Not learning weights, finding subnet 15 | class SplitConv(nn.Conv2d): 16 | def __init__(self, *args, **kwargs): 17 | self.split_mode = kwargs.pop('split_mode', None) 18 | self.split_rate = kwargs.pop('split_rate', None) 19 | self.in_channels_order = kwargs.pop('in_channels_order', None) 20 | # self.keep_rate = keep_rate 21 | super().__init__(*args, **kwargs) 22 | 23 | if self.split_mode == 'kels': 24 | if self.in_channels_order is None: 25 | mask = np.zeros((self.weight.size())) 26 | if self.weight.size()[1] == 3: ## This is the first conv 27 | mask[:math.ceil(self.weight.size()[0] * self.split_rate), :, :, :] = 1 28 | else: 29 | mask[:math.ceil(self.weight.size()[0] * self.split_rate), :math.ceil(self.weight.size()[1] * self.split_rate), :, :] = 1 30 | else: 31 | 32 | mask = np.zeros((self.weight.size())) 33 | conv_concat = [int(chs) for chs in self.in_channels_order.split(',')] 34 | # assert sum(conv_concat) == self.weight.size()[1],'In channels {} should be equal to sum(concat) {}'.format(self.weight.size()[1],conv_concat) 35 | start_ch = 0 36 | for conv in conv_concat: 37 | mask[:math.ceil(self.weight.size()[0] * self.split_rate), start_ch:start_ch + math.ceil(conv * self.split_rate), 38 | :, :] = 1 39 | start_ch += conv 40 | 41 | elif self.split_mode == 'wels': 42 | mask = np.random.rand(*list(self.weight.shape)) 43 | # threshold = np.percentile(scores, (1-self.keep_rate)*100) 44 | threshold = 1 - self.split_rate 45 | mask[mask < threshold] = 0 46 | mask[mask >= threshold] = 1 47 | if self.split_rate != 1: 48 | assert len(np.unique(mask)) == 2,'Something is wrong with the mask {}'.format(np.unique(mask)) 49 | else: 50 | raise NotImplemented('Invalid split_mode {}'.format(self.split_mode)) 51 | 52 | self.mask = nn.Parameter(torch.Tensor(mask), requires_grad=False) 53 | 54 | def extract_slim(self,dst_m,src_name,dst_name): 55 | c_out, c_in, _, _, = self.weight.size() 56 | d_out, d_in, _, _ = dst_m.weight.size() 57 | if self.in_channels_order is None: 58 | if c_in == 3: 59 | selected_convs = self.weight[:d_out] 60 | # is_first_conv = False 61 | else: 62 | selected_convs = self.weight[:d_out][:, :d_in, :, :] 63 | 64 | assert selected_convs.shape == dst_m.weight.shape 65 | dst_m.weight.data = selected_convs 66 | else: 67 | selected_convs = self.weight[:d_out, self.mask[0, :, 0, 0] == 1, :, :] 68 | assert selected_convs.shape == dst_m.weight.shape, '{} {} {} {}'.format(dst_name, src_name, dst_m.weight.shape, 69 | selected_convs.shape) 70 | dst_m.weight.data = selected_convs 71 | 72 | # def reset_scores(self): 73 | # if self.split_mode == 'wels': 74 | # mask = np.random.rand(*list(self.weight.shape)) 75 | # threshold = 1 - self.split_rate 76 | # mask[mask < threshold] = 0 77 | # mask[mask >= threshold] = 1 78 | # if self.split_rate != 1: 79 | # assert len(np.unique(mask)) == 2,'Something is wrong with the score {}'.format(np.unique(mask)) 80 | # else: 81 | # raise NotImplemented('Reset score randomly only with WELS. The current mode is '.format(self.split_mode)) 82 | # # scores = np.zeros((self.weight.size())) 83 | # # rand_sub = random.randint(0, self.weight.size()[0] - math.ceil(self.weight.size()[0] * self.keep_rate)) 84 | # # if self.weight.size()[1] == 3: ## This is the first conv 85 | # # scores[rand_sub:rand_sub+math.ceil(self.weight.size()[0] * self.keep_rate), :, :, :] = 1 86 | # # else: 87 | # # scores[rand_sub:rand_sub+math.ceil(self.weight.size()[0] * self.keep_rate), :math.ceil(self.weight.size()[1] * self.keep_rate), :, 88 | # # :] = 1 89 | # 90 | # self.mask.data = torch.Tensor(mask).cuda() 91 | # # raise NotImplemented('Not implemented yet') 92 | # # nn.init.kaiming_uniform_(self.scores, a=math.sqrt(5)) 93 | 94 | # def reset_bias_scores(self): 95 | # pass 96 | 97 | # def set_split_rate(self, split_rate, bias_split_rate): 98 | # self.split_rate = split_rate 99 | # if self.bias is not None: 100 | # self.bias_split_rate = bias_split_rate 101 | # else: 102 | # self.bias_split_rate = 1.0 103 | 104 | def split_reinitialize(self,cfg): 105 | if cfg.evolve_mode == 'rand': 106 | rand_tensor = torch.zeros_like(self.weight).cuda() 107 | nn.init.kaiming_uniform_(rand_tensor, a=math.sqrt(5)) 108 | self.weight.data = torch.where(self.mask.type(torch.bool), self.weight.data, rand_tensor) 109 | else: 110 | raise NotImplemented('Invalid KE mode {}'.format(cfg.evolve_mode)) 111 | 112 | if hasattr(self, "bias") and self.bias is not None and self.bias_split_rate < 1.0: 113 | bias_mask = self.mask[:, 0, 0, 0] ## Same conv mask is used for bias terms 114 | if cfg.evolve_mode == 'rand': 115 | rand_tensor = torch.zeros_like(self.bias) 116 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(m.weight) 117 | bound = 1 / math.sqrt(fan_in) 118 | nn.init.uniform_(rand_tensor, -bound, bound) 119 | self.bias.data = torch.where(bias_mask.type(torch.bool), self.bias.data, rand_tensor) 120 | else: 121 | raise NotImplemented('Invalid KE mode {}'.format(cfg.evolve_mode)) 122 | 123 | def forward(self, x): 124 | ## Debugging reasons only 125 | # if self.split_rate < 1: 126 | # w = self.mask * self.weight 127 | # if self.bias_split_rate < 1: 128 | # # bias_subnet = GetSubnet.apply(self.clamped_bias_scores, self.bias_keep_rate) 129 | # b = self.bias * self.mask[:, 0, 0, 0] 130 | # else: 131 | # b = self.bias 132 | # else: 133 | # w = self.weight 134 | # b = self.bias 135 | 136 | w = self.weight 137 | b = self.bias 138 | x = F.conv2d( 139 | x, w, b, self.stride, self.padding, self.dilation, self.groups 140 | ) 141 | return x 142 | 143 | -------------------------------------------------------------------------------- /layers/linear_type.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import numpy as np 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.autograd as autograd 7 | 8 | DenseLinear = nn.Linear 9 | 10 | 11 | 12 | class SplitLinear(nn.Linear): 13 | def __init__(self, *args, **kwargs): 14 | self.split_mode = kwargs.pop('split_mode', None) 15 | split_rate = kwargs.pop('split_rate', None) 16 | last_layer = kwargs.pop('last_layer', None) 17 | self.in_channels_order = kwargs.pop('in_channels_order', None) 18 | 19 | self.split_rate = split_rate 20 | self.bias_split_rate = self.split_rate 21 | super().__init__(*args, **kwargs) 22 | 23 | ## AT : I am assuming a single FC layer in the network. Typical for most CNNs 24 | if self.split_mode == 'kels': 25 | if self.in_channels_order is None: 26 | if last_layer: 27 | active_in_dim = math.ceil(self.weight.size()[1] * split_rate) 28 | mask = np.zeros((self.weight.size()[0],self.weight.size()[1])) 29 | mask[:,:active_in_dim] = 1 30 | else: 31 | active_in_dim = math.ceil(self.weight.size()[1] * split_rate) 32 | active_out_dim = math.ceil(self.weight.size()[0] * split_rate) 33 | mask = np.zeros((self.weight.size()[0], self.weight.size()[1])) 34 | mask[:active_out_dim, :active_in_dim] = 1 35 | else: 36 | mask = np.zeros((self.weight.size()[0], self.weight.size()[1])) 37 | conv_concat = self.in_channels_order.split(',') 38 | start_ch = 0 39 | for conv in conv_concat: 40 | mask[:,start_ch:start_ch + math.ceil(int(conv) * split_rate)] = 1 41 | start_ch += int(conv) 42 | 43 | elif self.split_mode == 'wels': 44 | mask = np.random.rand(*list(self.weight.shape)) 45 | # threshold = np.percentile(scores, (1 - self.keep_rate) * 100) 46 | threshold = 1 - self.split_rate 47 | mask[mask < threshold] = 0 48 | mask[mask >= threshold] = 1 49 | if self.split_rate != 1: 50 | assert len(np.unique(mask)) == 2, 'Something is wrong with the mask {}'.format(np.unique(mask)) 51 | else: 52 | raise NotImplemented('Invalid split_mode {}'.format(self.split_mode)) 53 | 54 | self.mask = nn.Parameter(torch.Tensor(mask), requires_grad=False) 55 | 56 | # self.reset_scores() 57 | 58 | # def set_keep_rate(self, keep_rate, bias_keep_rate): 59 | # self.split_rate = keep_rate 60 | # self.bias_keep_rate = bias_keep_rate 61 | 62 | # def reset_scores(self): 63 | # if self.split_mode == 'wels': 64 | # scores = np.random.rand(*list(self.weight.shape)) 65 | # # threshold = np.percentile(scores, (1 - self.keep_rate) * 100) 66 | # threshold = 1 - self.split_rate 67 | # scores[scores < threshold] = 0 68 | # scores[scores >= threshold] = 1 69 | # if self.split_rate != 1: 70 | # assert len(np.unique(scores)) == 2, 'Something is wrong with the score {}'.format(np.unique(scores)) 71 | # else: 72 | # raise NotImplemented('Reset score randomly only with WELS. The current mode is '.format(self.split_mode)) 73 | # # active_in_dim = math.ceil(self.weight.size()[1] * self.keep_rate) 74 | # # rand_sub = random.randint(0, self.weight.size()[1] - active_in_dim) 75 | # # scores = np.zeros((self.weight.size()[0], self.weight.size()[1])) 76 | # # scores[:, rand_sub:rand_sub+active_in_dim] = 1 77 | # self.scores.data = torch.Tensor(scores).cuda() 78 | 79 | 80 | # def reset_bias_scores(self): 81 | # pass 82 | 83 | 84 | def extract_slim(self,dst_m,src_name,dst_name): 85 | c_out, c_in = self.weight.size() 86 | d_out, d_in = dst_m.weight.size() 87 | 88 | if self.in_channels_order is None: 89 | assert dst_m.weight.shape == self.weight[:d_out, :d_in].shape 90 | dst_m.weight.data = self.weight.data[:d_out, :d_in] 91 | assert dst_m.bias.data.shape == self.bias.data[:d_out].shape 92 | dst_m.bias.data = self.bias.data[:d_out] 93 | else: 94 | dst_m.weight.data = self.weight[:d_out, self.mask[0, :] == 1] 95 | dst_m.bias.data = self.bias.data[:d_out] 96 | 97 | def split_reinitialize(self, cfg): 98 | if cfg.evolve_mode == 'rand': 99 | rand_tensor = torch.zeros_like(self.weight).cuda() 100 | nn.init.kaiming_uniform_(rand_tensor, a=math.sqrt(5)) 101 | self.weight.data = torch.where(self.mask.type(torch.bool), self.weight.data, rand_tensor) 102 | else: 103 | raise NotImplemented('Invalid KE mode {}'.format(cfg.evolve_mode)) 104 | 105 | def forward(self, x): 106 | ## Debugging purpose 107 | # if self.split_rate < 1: 108 | # # subnet = GetSubnet.apply(self.clamped_scores, self.keep_rate) 109 | # w = self.weight * self.scores 110 | # else: 111 | # w = self.weight 112 | 113 | w = self.weight 114 | b = self.bias 115 | 116 | x = F.linear(x, w, b) 117 | 118 | return x 119 | 120 | 121 | 122 | 123 | -------------------------------------------------------------------------------- /layers/normalize_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def l2_norm(input): 4 | input_size = input.size() 5 | buffer = torch.pow(input, 2) 6 | normp = torch.sum(buffer, 1).add_(1e-5) 7 | norm = torch.sqrt(normp) 8 | _output = torch.div(input, norm.view(-1, 1).expand_as(input)) 9 | output = _output.view(input_size) 10 | 11 | return output -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from models.split_googlenet import Split_googlenet 2 | from models.split_resnet import Split_ResNet18,Split_ResNet34,Split_ResNet50,Split_ResNet101 3 | from models.split_densenet import Split_densenet121,Split_densenet161,Split_densenet169,Split_densenet201 4 | 5 | 6 | __all__ = [ 7 | 8 | "Split_ResNet18", 9 | "Split_ResNet34", 10 | "Split_ResNet50", 11 | "Split_ResNet101", 12 | 13 | "Split_googlenet", 14 | 15 | "Split_densenet121", 16 | "Split_densenet161", 17 | "Split_densenet169", 18 | "Split_densenet201", 19 | ] -------------------------------------------------------------------------------- /models/builder.py: -------------------------------------------------------------------------------- 1 | #from args import args 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | import layers.conv_type 6 | import layers.bn_type 7 | import layers.linear_type 8 | 9 | 10 | class Builder(object): 11 | def __init__(self, conv_layer, bn_layer,linear_layer,cfg=None): 12 | self.conv_layer = conv_layer 13 | self.bn_layer = bn_layer 14 | self.linear_layer = linear_layer 15 | self.first_layer = conv_layer 16 | self.cfg = cfg 17 | 18 | 19 | def linear(self,in_feat,out_feat,last_layer=False,in_channels_order=None): 20 | if self.linear_layer == nn.Linear: 21 | linear_layer = self.linear_layer(in_feat, out_feat) 22 | else: 23 | linear_layer = self.linear_layer(in_feat,out_feat,split_mode=self.cfg.split_mode, 24 | split_rate=self.cfg.split_rate,last_layer=last_layer,in_channels_order=in_channels_order) 25 | self._init_linear(linear_layer) 26 | return linear_layer 27 | 28 | def conv(self, kernel_size, in_planes, out_planes, stride=1, first_layer=False,bias=False,in_channels_order=None): 29 | conv_layer = self.first_layer if first_layer else self.conv_layer 30 | 31 | if first_layer: 32 | self.cfg.logger.info(f"==> Building first layer with {str(self.first_layer)}") 33 | 34 | if kernel_size == 3: 35 | if conv_layer == nn.Conv2d: 36 | conv = conv_layer( 37 | in_planes, 38 | out_planes, 39 | kernel_size=3, 40 | stride=stride, 41 | padding=1, 42 | bias=bias, 43 | ) 44 | else: 45 | conv = conv_layer( 46 | in_planes, 47 | out_planes, 48 | kernel_size=3, 49 | stride=stride, 50 | padding=1, 51 | bias=bias, 52 | split_mode=self.cfg.split_mode, 53 | split_rate=self.cfg.split_rate, 54 | in_channels_order=in_channels_order, 55 | ) 56 | elif kernel_size == 1: 57 | if conv_layer == nn.Conv2d: 58 | conv = conv_layer( 59 | in_planes, out_planes, kernel_size=1, stride=stride, bias=False, 60 | ) 61 | else: 62 | conv = conv_layer( 63 | in_planes, out_planes, kernel_size=1, stride=stride, bias=False, 64 | split_mode=self.cfg.split_mode, 65 | split_rate=self.cfg.split_rate, 66 | in_channels_order=in_channels_order, 67 | ) 68 | elif kernel_size == 5: 69 | if conv_layer == nn.Conv2d: 70 | conv = conv_layer( 71 | in_planes, 72 | out_planes, 73 | kernel_size=5, 74 | stride=stride, 75 | padding=2, 76 | bias=bias, 77 | ) 78 | else: 79 | conv = conv_layer( 80 | in_planes, 81 | out_planes, 82 | kernel_size=5, 83 | stride=stride, 84 | padding=2, 85 | bias=bias, 86 | split_mode=self.cfg.split_mode, 87 | split_rate=self.cfg.split_rate, 88 | in_channels_order=in_channels_order, 89 | ) 90 | elif kernel_size == 7: 91 | if conv_layer == nn.Conv2d: 92 | conv = conv_layer( 93 | in_planes, 94 | out_planes, 95 | kernel_size=7, 96 | stride=stride, 97 | padding=3, 98 | bias=bias, 99 | ) 100 | else: 101 | conv = conv_layer( 102 | in_planes, 103 | out_planes, 104 | kernel_size=7, 105 | stride=stride, 106 | padding=3, 107 | bias=bias, 108 | split_mode=self.cfg.split_mode, 109 | split_rate=self.cfg.split_rate, 110 | in_channels_order=in_channels_order, 111 | ) 112 | elif kernel_size == 11: 113 | if conv_layer == nn.Conv2d: 114 | conv = conv_layer( 115 | in_planes, 116 | out_planes, 117 | kernel_size=11, 118 | stride=stride, 119 | padding=2, 120 | bias=bias, 121 | ) 122 | else: 123 | conv = conv_layer( 124 | in_planes, 125 | out_planes, 126 | kernel_size=11, 127 | stride=stride, 128 | padding=2, 129 | bias=bias, 130 | split_mode=self.cfg.split_mode, 131 | split_rate=self.cfg.split_rate, 132 | in_channels_order=in_channels_order, 133 | ) 134 | else: 135 | return None 136 | 137 | self._init_conv(conv) 138 | 139 | return conv 140 | 141 | def conv3x3(self, in_planes, out_planes, stride=1, first_layer=False,bias=False,in_channels_order=None): 142 | """3x3 convolution with padding""" 143 | c = self.conv(3, in_planes, out_planes, stride=stride, first_layer=first_layer,bias=bias,in_channels_order=in_channels_order) 144 | return c 145 | 146 | def conv1x1(self, in_planes, out_planes, stride=1, first_layer=False,bias=False,in_channels_order=None): 147 | """1x1 convolution with padding""" 148 | c = self.conv(1, in_planes, out_planes, stride=stride, first_layer=first_layer,bias=bias,in_channels_order=in_channels_order) 149 | return c 150 | 151 | def conv7x7(self, in_planes, out_planes, stride=1, first_layer=False,bias=False,in_channels_order=None): 152 | """7x7 convolution with padding""" 153 | c = self.conv(7, in_planes, out_planes, stride=stride, first_layer=first_layer,bias=bias,in_channels_order=in_channels_order) 154 | return c 155 | 156 | def conv5x5(self, in_planes, out_planes, stride=1, first_layer=False,bias=False,in_channels_order=None): 157 | """5x5 convolution with padding""" 158 | c = self.conv(5, in_planes, out_planes, stride=stride, first_layer=first_layer,bias=bias,in_channels_order=in_channels_order) 159 | return c 160 | 161 | def conv11x11(self, in_planes, out_planes, stride=1, first_layer=False,bias=False,in_channels_order=None): 162 | """5x5 convolution with padding""" 163 | c = self.conv(11, in_planes, out_planes, stride=stride, first_layer=first_layer,bias=bias,in_channels_order=in_channels_order) 164 | return c 165 | 166 | def batchnorm(self, planes, last_bn=False, first_layer=False,in_channels_order=None,**kwargs): 167 | if self.bn_layer == nn.BatchNorm2d: 168 | return self.bn_layer(planes, **kwargs) 169 | else: 170 | return self.bn_layer(planes,in_channels_order=in_channels_order,split_rate=self.cfg.split_rate,**kwargs) 171 | 172 | def activation(self): 173 | if self.cfg.nonlinearity == "relu": 174 | return (lambda: nn.ReLU(inplace=True))() 175 | else: 176 | raise ValueError(f"{self.cfg.nonlinearity} is not an initialization option!") 177 | 178 | def _init_linear(self, linear): 179 | if self.cfg.init == "signed_constant": 180 | 181 | fan = nn.init._calculate_correct_fan(linear.weight, self.cfg.mode) 182 | if self.cfg.scale_fan: 183 | fan = fan * (1 - self.cfg.prune_rate) 184 | gain = nn.init.calculate_gain(self.cfg.nonlinearity) 185 | std = gain / math.sqrt(fan) 186 | linear.weight.data = linear.weight.data.sign() * std 187 | 188 | elif self.cfg.init == "unsigned_constant": 189 | 190 | fan = nn.init._calculate_correct_fan(linear.weight, self.cfg.mode) 191 | if self.cfg.scale_fan: 192 | fan = fan * (1 - self.cfg.prune_rate) 193 | 194 | gain = nn.init.calculate_gain(self.cfg.nonlinearity) 195 | std = gain / math.sqrt(fan) 196 | linear.weight.data = torch.ones_like(linear.weight.data) * std 197 | 198 | elif self.cfg.init == "kaiming_normal": 199 | 200 | if self.cfg.scale_fan: 201 | fan = nn.init._calculate_correct_fan(linear.weight, self.cfg.mode) 202 | fan = fan * (1 - self.cfg.prune_rate) 203 | gain = nn.init.calculate_gain(self.cfg.nonlinearity) 204 | std = gain / math.sqrt(fan) 205 | with torch.no_grad(): 206 | linear.weight.data.normal_(0, std) 207 | else: 208 | nn.init.kaiming_normal_( 209 | linear.weight, mode=self.cfg.mode, nonlinearity=self.cfg.nonlinearity 210 | ) 211 | 212 | elif self.cfg.init == "kaiming_uniform": 213 | nn.init.kaiming_uniform_( 214 | linear.weight, mode=self.cfg.mode, nonlinearity=self.cfg.nonlinearity 215 | ) 216 | elif self.cfg.init == "xavier_normal": 217 | nn.init.xavier_normal_(linear.weight) 218 | elif self.cfg.init == "xavier_constant": 219 | 220 | fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(linear.weight) 221 | std = math.sqrt(2.0 / float(fan_in + fan_out)) 222 | linear.weight.data = linear.weight.data.sign() * std 223 | 224 | elif self.cfg.init == "standard": 225 | nn.init.kaiming_uniform_(linear.weight, a=math.sqrt(5)) 226 | else: 227 | raise ValueError(f"{self.cfg.init} is not an initialization option!") 228 | 229 | 230 | def _init_conv(self, conv): 231 | if self.cfg.init == "signed_constant": 232 | 233 | fan = nn.init._calculate_correct_fan(conv.weight, self.cfg.mode) 234 | if self.cfg.scale_fan: 235 | fan = fan * (1 - self.cfg.prune_rate) 236 | gain = nn.init.calculate_gain(self.cfg.nonlinearity) 237 | std = gain / math.sqrt(fan) 238 | conv.weight.data = conv.weight.data.sign() * std 239 | 240 | elif self.cfg.init == "unsigned_constant": 241 | 242 | fan = nn.init._calculate_correct_fan(conv.weight, self.cfg.mode) 243 | if self.cfg.scale_fan: 244 | fan = fan * (1 - self.cfg.prune_rate) 245 | 246 | gain = nn.init.calculate_gain(self.cfg.nonlinearity) 247 | std = gain / math.sqrt(fan) 248 | conv.weight.data = torch.ones_like(conv.weight.data) * std 249 | 250 | elif self.cfg.init == "kaiming_normal": 251 | 252 | if self.cfg.scale_fan: 253 | fan = nn.init._calculate_correct_fan(conv.weight, self.cfg.mode) 254 | fan = fan * (1 - self.cfg.prune_rate) 255 | gain = nn.init.calculate_gain(self.cfg.nonlinearity) 256 | std = gain / math.sqrt(fan) 257 | with torch.no_grad(): 258 | conv.weight.data.normal_(0, std) 259 | else: 260 | nn.init.kaiming_normal_( 261 | conv.weight, mode=self.cfg.mode, nonlinearity=self.cfg.nonlinearity 262 | ) 263 | 264 | elif self.cfg.init == "kaiming_uniform": 265 | nn.init.kaiming_uniform_( 266 | conv.weight, mode=self.cfg.mode, nonlinearity=self.cfg.nonlinearity 267 | ) 268 | elif self.cfg.init == "xavier_normal": 269 | nn.init.xavier_normal_(conv.weight) 270 | elif self.cfg.init == "xavier_constant": 271 | 272 | fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(conv.weight) 273 | std = math.sqrt(2.0 / float(fan_in + fan_out)) 274 | conv.weight.data = conv.weight.data.sign() * std 275 | 276 | elif self.cfg.init == "standard": 277 | nn.init.kaiming_uniform_(conv.weight, a=math.sqrt(5)) 278 | else: 279 | raise ValueError(f"{self.cfg.init} is not an initialization option!") 280 | 281 | 282 | def get_builder(cfg): 283 | 284 | cfg.logger.info("==> Conv Type: {}".format(cfg.conv_type)) 285 | cfg.logger.info("==> BN Type: {}".format(cfg.bn_type)) 286 | 287 | conv_layer = getattr(layers.conv_type, cfg.conv_type) 288 | bn_layer = getattr(layers.bn_type, cfg.bn_type) 289 | linear_layer = getattr(layers.linear_type, cfg.linear_type) 290 | 291 | 292 | builder = Builder(conv_layer=conv_layer, bn_layer=bn_layer,linear_layer=linear_layer,cfg=cfg) 293 | 294 | return builder 295 | -------------------------------------------------------------------------------- /models/common.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict, namedtuple 2 | 3 | class _IncompatibleKeys(namedtuple('IncompatibleKeys', ['missing_keys', 'unexpected_keys'])): 4 | def __repr__(self): 5 | if not self.missing_keys and not self.unexpected_keys: 6 | return '' 7 | return super(_IncompatibleKeys, self).__repr__() 8 | 9 | __str__ = __repr__ 10 | 11 | def load_state_dict(model, state_dict, 12 | strict: bool = True): 13 | r"""Copies parameters and buffers from :attr:`state_dict` into 14 | this module and its descendants. If :attr:`strict` is ``True``, then 15 | the keys of :attr:`state_dict` must exactly match the keys returned 16 | by this module's :meth:`~torch.nn.Module.state_dict` function. 17 | Arguments: 18 | state_dict (dict): a dict containing parameters and 19 | persistent buffers. 20 | strict (bool, optional): whether to strictly enforce that the keys 21 | in :attr:`state_dict` match the keys returned by this module's 22 | :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` 23 | Returns: 24 | ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: 25 | * **missing_keys** is a list of str containing the missing keys 26 | * **unexpected_keys** is a list of str containing the unexpected keys 27 | """ 28 | missing_keys = [] 29 | unexpected_keys = [] 30 | error_msgs = [] 31 | 32 | # copy state_dict so _load_from_state_dict can modify it 33 | metadata = getattr(state_dict, '_metadata', None) 34 | state_dict = state_dict.copy() 35 | if metadata is not None: 36 | state_dict._metadata = metadata 37 | 38 | def load(module, prefix=''): 39 | local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) 40 | module._load_from_state_dict( 41 | state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) 42 | for name, child in module._modules.items(): 43 | if child is not None and not (child.__class__.__name__ == 'SplitLinear' or child.__class__.__name__ == 'Linear'): 44 | load(child, prefix + name + '.') 45 | 46 | load(model) 47 | load = None # break load->load reference cycle 48 | print('WARNING: num unexpected_keys {} , num missing keys {}'.format(len(unexpected_keys),len(missing_keys))) 49 | if strict: 50 | if len(unexpected_keys) > 0: 51 | error_msgs.insert( 52 | 0, 'Unexpected key(s) in state_dict: {}. '.format( 53 | ', '.join('"{}"'.format(k) for k in unexpected_keys))) 54 | if len(missing_keys) > 0: 55 | error_msgs.insert( 56 | 0, 'Missing key(s) in state_dict: {}. '.format( 57 | ', '.join('"{}"'.format(k) for k in missing_keys))) 58 | 59 | if len(error_msgs) > 0: 60 | raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( 61 | model.__class__.__name__, "\n\t".join(error_msgs))) 62 | return _IncompatibleKeys(missing_keys, unexpected_keys) 63 | -------------------------------------------------------------------------------- /models/split_densenet.py: -------------------------------------------------------------------------------- 1 | import re 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | from torch import Tensor 6 | from models import common 7 | import torch.nn.functional as F 8 | import torch.utils.checkpoint as cp 9 | from collections import OrderedDict 10 | from torch.jit.annotations import List 11 | from models.builder import get_builder 12 | 13 | try: 14 | from torch.hub import load_state_dict_from_url 15 | except ImportError: 16 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 17 | 18 | __all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161'] 19 | 20 | model_urls = { 21 | 'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth', 22 | 'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth', 23 | 'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth', 24 | 'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth', 25 | } 26 | 27 | 28 | class _DenseLayer(nn.Module): 29 | def __init__(self,builder, num_input_features, growth_rate, bn_size, drop_rate, memory_efficient=False,in_channels_order=None): 30 | super(_DenseLayer, self).__init__() 31 | # self.add_module('norm1', nn.BatchNorm2d(num_input_features)), 32 | # self.add_module('relu1', nn.ReLU(inplace=True)), 33 | # self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * 34 | # growth_rate, kernel_size=1, stride=1, 35 | # bias=False)), 36 | # self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)), 37 | # self.add_module('relu2', nn.ReLU(inplace=True)), 38 | # self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate, 39 | # kernel_size=3, stride=1, padding=1, 40 | # bias=False)), 41 | 42 | self.add_module('norm1', builder.batchnorm(num_input_features,in_channels_order=in_channels_order)), 43 | self.add_module('relu1', nn.ReLU(inplace=True)), 44 | self.add_module('conv1', builder.conv1x1(num_input_features, bn_size * 45 | growth_rate, stride=1,in_channels_order=in_channels_order)), 46 | self.add_module('norm2', builder.batchnorm(bn_size * growth_rate)), 47 | self.add_module('relu2', nn.ReLU(inplace=True)), 48 | self.add_module('conv2', builder.conv3x3(bn_size * growth_rate, growth_rate,stride=1)), 49 | 50 | self.drop_rate = float(drop_rate) 51 | self.memory_efficient = memory_efficient 52 | 53 | def bn_function(self, inputs): 54 | # type: (List[Tensor]) -> Tensor 55 | concated_features = torch.cat(inputs, 1) 56 | bottleneck_output = self.conv1(self.relu1(self.norm1(concated_features))) # noqa: T484 57 | return bottleneck_output 58 | 59 | # todo: rewrite when torchscript supports any 60 | def any_requires_grad(self, input): 61 | # type: (List[Tensor]) -> bool 62 | for tensor in input: 63 | if tensor.requires_grad: 64 | return True 65 | return False 66 | 67 | @torch.jit.unused # noqa: T484 68 | def call_checkpoint_bottleneck(self, input): 69 | # type: (List[Tensor]) -> Tensor 70 | def closure(*inputs): 71 | return self.bn_function(inputs) 72 | 73 | return cp.checkpoint(closure, *input) 74 | 75 | @torch.jit._overload_method # noqa: F811 76 | def forward(self, input): 77 | # type: (List[Tensor]) -> (Tensor) 78 | pass 79 | 80 | @torch.jit._overload_method # noqa: F811 81 | def forward(self, input): 82 | # type: (Tensor) -> (Tensor) 83 | pass 84 | 85 | # torchscript does not yet support *args, so we overload method 86 | # allowing it to take either a List[Tensor] or single Tensor 87 | def forward(self, input): # noqa: F811 88 | if isinstance(input, Tensor): 89 | prev_features = [input] 90 | else: 91 | prev_features = input 92 | 93 | if self.memory_efficient and self.any_requires_grad(prev_features): 94 | if torch.jit.is_scripting(): 95 | raise Exception("Memory Efficient not supported in JIT") 96 | 97 | bottleneck_output = self.call_checkpoint_bottleneck(prev_features) 98 | else: 99 | bottleneck_output = self.bn_function(prev_features) 100 | 101 | new_features = self.conv2(self.relu2(self.norm2(bottleneck_output))) 102 | if self.drop_rate > 0: 103 | new_features = F.dropout(new_features, p=self.drop_rate, 104 | training=self.training) 105 | return new_features 106 | 107 | 108 | class _DenseBlock(nn.ModuleDict): 109 | _version = 2 110 | 111 | def __init__(self,builder, num_layers, num_input_features, bn_size, growth_rate, drop_rate, memory_efficient=False,in_channels_order=None): 112 | super(_DenseBlock, self).__init__() 113 | for i in range(num_layers): 114 | layer = _DenseLayer(builder, 115 | num_input_features + i * growth_rate, 116 | growth_rate=growth_rate, 117 | bn_size=bn_size, 118 | drop_rate=drop_rate, 119 | memory_efficient=memory_efficient, 120 | in_channels_order=in_channels_order 121 | ) 122 | in_channels_order += ',{}'.format(growth_rate) 123 | self.add_module('denselayer%d' % (i + 1), layer) 124 | self.out_channels_order = in_channels_order 125 | 126 | def forward(self, init_features): 127 | features = [init_features] 128 | for name, layer in self.items(): 129 | new_features = layer(features) 130 | features.append(new_features) 131 | return torch.cat(features, 1) 132 | 133 | 134 | class _Transition(nn.Sequential): 135 | def __init__(self,builder, num_input_features, num_output_features,in_channels_order=None): 136 | super(_Transition, self).__init__() 137 | # self.add_module('norm', nn.BatchNorm2d(num_input_features)) 138 | # self.add_module('relu', nn.ReLU(inplace=True)) 139 | # self.add_module('conv', nn.Conv2d(num_input_features, num_output_features, 140 | # kernel_size=1, stride=1, bias=False)) 141 | # self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) 142 | # 143 | self.add_module('norm', builder.batchnorm(num_input_features,in_channels_order=in_channels_order)) 144 | self.add_module('relu', nn.ReLU(inplace=True)) 145 | self.add_module('conv', builder.conv1x1(num_input_features, num_output_features,stride=1,in_channels_order=in_channels_order)) 146 | self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) 147 | 148 | 149 | class DenseNet(nn.Module): 150 | r"""Densenet-BC model class, based on 151 | `"Densely Connected Convolutional Networks" `_ 152 | 153 | Args: 154 | growth_rate (int) - how many filters to add each layer (`k` in paper) 155 | block_config (list of 4 ints) - how many layers in each pooling block 156 | num_init_features (int) - the number of filters to learn in the first convolution layer 157 | bn_size (int) - multiplicative factor for number of bottle neck layers 158 | (i.e. bn_size * k features in the bottleneck layer) 159 | drop_rate (float) - dropout rate after each dense layer 160 | num_classes (int) - number of classification classes 161 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, 162 | but slower. Default: *False*. See `"paper" `_ 163 | """ 164 | 165 | def __init__(self,cfg,builder, growth_rate=32, block_config=(6, 12, 24, 16), 166 | num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000, memory_efficient=False): 167 | 168 | super(DenseNet, self).__init__() 169 | slim_factor = cfg.slim_factor 170 | if slim_factor < 1: 171 | cfg.logger.info('WARNING: You are using a slim network') 172 | 173 | # num_init_features = math.ceil(num_init_features * slim_factor) 174 | # growth_rate = math.ceil(growth_rate * slim_factor) 175 | 176 | slim = lambda x: math.ceil(x * slim_factor) 177 | 178 | self.features = nn.Sequential(OrderedDict([ 179 | ('conv0', builder.conv7x7(3, slim(num_init_features), stride=2)), 180 | ('norm0', builder.batchnorm(slim(num_init_features))), 181 | ('relu0', nn.ReLU(inplace=True)), 182 | ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), 183 | ])) 184 | 185 | # Each denseblock 186 | num_features = num_init_features 187 | # in_channels_order = '{}'.format(num_features) 188 | for i, num_layers in enumerate(block_config): 189 | in_channels_order = '{}'.format(num_features) 190 | block = _DenseBlock(builder, 191 | num_layers=num_layers, 192 | num_input_features=slim(num_features), 193 | bn_size=bn_size, 194 | growth_rate=slim(growth_rate), 195 | drop_rate=drop_rate, 196 | memory_efficient=memory_efficient, 197 | in_channels_order=in_channels_order, 198 | ) 199 | self.features.add_module('denseblock%d' % (i + 1), block) 200 | pre_num_features = num_features 201 | 202 | if i != len(block_config) - 1: 203 | num_features = num_features + num_layers * growth_rate 204 | trans = _Transition(builder,num_input_features=slim(pre_num_features) + num_layers * slim(growth_rate), 205 | num_output_features=slim(num_features // 2),in_channels_order=block.out_channels_order) 206 | self.features.add_module('transition%d' % (i + 1), trans) 207 | num_features = num_features // 2 208 | else: 209 | num_features = slim(num_features) + num_layers * slim(growth_rate) 210 | 211 | 212 | # else: 213 | # num_features = slim(num_features) + num_layers * slim(growth_rate) 214 | 215 | 216 | 217 | # Final batch norm 218 | # self.features.add_module('norm5', nn.BatchNorm2d(num_features)) 219 | self.features.add_module('norm5', builder.batchnorm(num_features,in_channels_order=block.out_channels_order)) 220 | 221 | # Linear layer 222 | # self.classifier = nn.Linear(num_features, num_classes) 223 | self.classifier = builder.linear(num_features, cfg.num_cls, last_layer=True, 224 | in_channels_order=block.out_channels_order) 225 | 226 | # Official init from torch repo. 227 | for m in self.modules(): 228 | if isinstance(m, nn.Conv2d): 229 | nn.init.kaiming_normal_(m.weight) 230 | elif isinstance(m, nn.BatchNorm2d): 231 | nn.init.constant_(m.weight, 1) 232 | nn.init.constant_(m.bias, 0) 233 | elif isinstance(m, nn.Linear): 234 | nn.init.constant_(m.bias, 0) 235 | 236 | def forward(self, x): 237 | features = self.features(x) 238 | out = F.relu(features, inplace=True) 239 | out = F.adaptive_avg_pool2d(out, (1, 1)) 240 | out = torch.flatten(out, 1) 241 | out = self.classifier(out) 242 | return out 243 | 244 | 245 | def _load_state_dict(model, model_url, progress): 246 | # '.'s are no longer allowed in module names, but previous _DenseLayer 247 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 248 | # They are also in the checkpoints in model_urls. This pattern is used 249 | # to find such keys. 250 | pattern = re.compile( 251 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 252 | 253 | state_dict = load_state_dict_from_url(model_url, progress=progress) 254 | for key in list(state_dict.keys()): 255 | res = pattern.match(key) 256 | if res: 257 | new_key = res.group(1) + res.group(2) 258 | state_dict[new_key] = state_dict[key] 259 | del state_dict[key] 260 | 261 | common.load_state_dict(model, state_dict, strict=False) 262 | # model.load_state_dict(state_dict) 263 | 264 | 265 | def _densenet(cfg,builder,arch, growth_rate, block_config, num_init_features, pretrained, progress, 266 | **kwargs): 267 | model = DenseNet(cfg,builder,growth_rate, block_config, num_init_features, **kwargs) 268 | if cfg.pretrained == 'imagenet': 269 | _load_state_dict(model, model_urls[arch], progress) 270 | return model 271 | 272 | 273 | def Split_densenet121(cfg,pretrained=False, progress=True, **kwargs): 274 | r"""Densenet-121 model from 275 | `"Densely Connected Convolutional Networks" `_ 276 | 277 | Args: 278 | pretrained (bool): If True, returns a model pre-trained on ImageNet 279 | progress (bool): If True, displays a progress bar of the download to stderr 280 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, 281 | but slower. Default: *False*. See `"paper" `_ 282 | """ 283 | return _densenet(cfg,get_builder(cfg),'densenet121', 32, (6, 12, 24, 16), 64, pretrained, progress, 284 | **kwargs) 285 | 286 | 287 | def Split_densenet161(cfg,pretrained=False, progress=True, **kwargs): 288 | r"""Densenet-161 model from 289 | `"Densely Connected Convolutional Networks" `_ 290 | 291 | Args: 292 | pretrained (bool): If True, returns a model pre-trained on ImageNet 293 | progress (bool): If True, displays a progress bar of the download to stderr 294 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, 295 | but slower. Default: *False*. See `"paper" `_ 296 | """ 297 | return _densenet(cfg,get_builder(cfg),'densenet161', 48, (6, 12, 36, 24), 96, pretrained, progress, 298 | **kwargs) 299 | 300 | 301 | def Split_densenet169(cfg,pretrained=False, progress=True, **kwargs): 302 | r"""Densenet-169 model from 303 | `"Densely Connected Convolutional Networks" `_ 304 | 305 | Args: 306 | pretrained (bool): If True, returns a model pre-trained on ImageNet 307 | progress (bool): If True, displays a progress bar of the download to stderr 308 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, 309 | but slower. Default: *False*. See `"paper" `_ 310 | """ 311 | return _densenet(cfg,get_builder(cfg),'densenet169', 32, (6, 12, 32, 32), 64, pretrained, progress, 312 | **kwargs) 313 | 314 | 315 | def Split_densenet201(cfg,pretrained=False, progress=True, **kwargs): 316 | r"""Densenet-201 model from 317 | `"Densely Connected Convolutional Networks" `_ 318 | 319 | Args: 320 | pretrained (bool): If True, returns a model pre-trained on ImageNet 321 | progress (bool): If True, displays a progress bar of the download to stderr 322 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, 323 | but slower. Default: *False*. See `"paper" `_ 324 | """ 325 | return _densenet(cfg,get_builder(cfg),'densenet201', 32, (6, 12, 48, 32), 64, pretrained, progress, 326 | **kwargs) 327 | -------------------------------------------------------------------------------- /models/split_googlenet.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from collections import namedtuple 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.jit.annotations import Optional, Tuple 8 | from torch import Tensor 9 | from models.builder import get_builder 10 | 11 | try: 12 | from torch.hub import load_state_dict_from_url 13 | except ImportError: 14 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 15 | 16 | __all__ = ['GoogLeNet', 'googlenet', "GoogLeNetOutputs", "_GoogLeNetOutputs"] 17 | 18 | model_urls = { 19 | # GoogLeNet ported from TensorFlow 20 | 'googlenet': 'https://download.pytorch.org/models/googlenet-1378be20.pth', 21 | } 22 | 23 | GoogLeNetOutputs = namedtuple('GoogLeNetOutputs', ['logits', 'aux_logits2', 'aux_logits1']) 24 | GoogLeNetOutputs.__annotations__ = {'logits': Tensor, 'aux_logits2': Optional[Tensor], 25 | 'aux_logits1': Optional[Tensor]} 26 | 27 | # Script annotations failed with _GoogleNetOutputs = namedtuple ... 28 | # _GoogLeNetOutputs set here for backwards compat 29 | _GoogLeNetOutputs = GoogLeNetOutputs 30 | 31 | 32 | def Split_googlenet(cfg, pretrained=False, progress=True, **kwargs): 33 | r"""GoogLeNet (Inception v1) model architecture from 34 | `"Going Deeper with Convolutions" `_. 35 | 36 | Args: 37 | pretrained (bool): If True, returns a model pre-trained on ImageNet 38 | progress (bool): If True, displays a progress bar of the download to stderr 39 | aux_logits (bool): If True, adds two auxiliary branches that can improve training. 40 | Default: *False* when pretrained is True otherwise *True* 41 | transform_input (bool): If True, preprocesses the input according to the method with which it 42 | was trained on ImageNet. Default: *False* 43 | """ 44 | if pretrained: 45 | if 'transform_input' not in kwargs: 46 | kwargs['transform_input'] = True 47 | if 'aux_logits' not in kwargs: 48 | kwargs['aux_logits'] = False 49 | if kwargs['aux_logits']: 50 | warnings.warn('auxiliary heads in the pretrained googlenet model are NOT pretrained, ' 51 | 'so make sure to train them') 52 | original_aux_logits = kwargs['aux_logits'] 53 | kwargs['aux_logits'] = True 54 | kwargs['init_weights'] = False 55 | model = GoogLeNet(**kwargs) 56 | state_dict = load_state_dict_from_url(model_urls['googlenet'], 57 | progress=progress) 58 | model.load_state_dict(state_dict) 59 | if not original_aux_logits: 60 | model.aux_logits = False 61 | model.aux1 = None 62 | model.aux2 = None 63 | return model 64 | 65 | return GoogLeNet(cfg,num_classes=cfg.num_cls,**kwargs) 66 | 67 | 68 | class GoogLeNet(nn.Module): 69 | __constants__ = ['aux_logits', 'transform_input'] 70 | 71 | def __init__(self,cfg, num_classes=1000, aux_logits=False, transform_input=False, init_weights=None, 72 | blocks=None): # AT : I disabled the aux_logits 73 | super(GoogLeNet, self).__init__() 74 | if blocks is None: 75 | blocks = [BasicConv2d, Inception, InceptionAux] 76 | if init_weights is None: 77 | warnings.warn('The default weight initialization of GoogleNet will be changed in future releases of ' 78 | 'torchvision. If you wish to keep the old behavior (which leads to long initialization times' 79 | ' due to scipy/scipy#11299), please set init_weights=True.', FutureWarning) 80 | init_weights = True 81 | assert len(blocks) == 3 82 | 83 | builder = get_builder(cfg) 84 | slim_factor = cfg.slim_factor 85 | if slim_factor < 1: 86 | cfg.logger.info('WARNING: You are using a slim network') 87 | 88 | conv_block = blocks[0] 89 | inception_block = blocks[1] 90 | inception_aux_block = blocks[2] 91 | 92 | self.aux_logits = aux_logits 93 | self.transform_input = transform_input 94 | 95 | slim = lambda x: math.ceil(x * slim_factor) 96 | self.conv1 = conv_block(builder,3, slim(64), kernel_size=7, stride=2) # , padding=3 97 | self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True) 98 | self.conv2 = conv_block(builder,slim(64), slim(64), kernel_size=1) 99 | self.conv3 = conv_block(builder,slim(64), slim(192), kernel_size=3) # padding=1 100 | self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True) 101 | 102 | self.inception3a = inception_block(builder, slim(192), slim(64) 103 | , slim(96), slim(128), slim(16), slim(32), slim(32)) 104 | prev_out_channels = self.inception3a.out_channels # This is 256 105 | concat_order = self.inception3a.concat_order 106 | 107 | self.inception3b = inception_block(builder,prev_out_channels, slim(128), slim(128), slim(192), slim(32), slim(96), slim(64),in_channels_order=concat_order) 108 | prev_out_channels = self.inception3b.out_channels # This is 480 109 | concat_order = self.inception3b.concat_order 110 | 111 | self.maxpool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True) 112 | 113 | self.inception4a = inception_block(builder,prev_out_channels, slim(192), slim(96), slim(208), slim(16), slim(48), slim(64),in_channels_order=concat_order) 114 | prev_out_channels = self.inception4a.out_channels # This is 512 115 | concat_order = self.inception4a.concat_order 116 | 117 | self.inception4b = inception_block(builder,prev_out_channels, slim(160), slim(112), slim(224), slim(24), slim(64), slim(64),in_channels_order=concat_order) 118 | prev_out_channels = self.inception4b.out_channels # This is 512 119 | concat_order = self.inception4b.concat_order 120 | 121 | self.inception4c = inception_block(builder,prev_out_channels, slim(128), slim(128), slim(256), slim(24), slim(64), slim(64),in_channels_order=concat_order) 122 | prev_out_channels = self.inception4c.out_channels # This is 512 123 | concat_order = self.inception4c.concat_order 124 | 125 | self.inception4d = inception_block(builder,prev_out_channels, slim(112), slim(144), slim(288), slim(32), slim(64), slim(64),in_channels_order=concat_order) 126 | prev_out_channels = self.inception4d.out_channels # This is 528 127 | concat_order = self.inception4d.concat_order 128 | 129 | self.inception4e = inception_block(builder,prev_out_channels, slim(256), slim(160), slim(320), slim(32), slim(128), slim(128),in_channels_order=concat_order) 130 | prev_out_channels = self.inception4e.out_channels # This is 832 131 | concat_order = self.inception4e.concat_order 132 | 133 | self.maxpool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 134 | 135 | self.inception5a = inception_block(builder,prev_out_channels, slim(256), slim(160), slim(320), slim(32), slim(128), slim(128),in_channels_order=concat_order) 136 | prev_out_channels = self.inception5a.out_channels # This is 832 137 | concat_order = self.inception5a.concat_order 138 | 139 | self.inception5b = inception_block(builder,prev_out_channels, slim(384), slim(192), slim(384), slim(48), slim(128), slim(128),in_channels_order=concat_order) 140 | prev_out_channels = self.inception5b.out_channels # This is 1024 141 | concat_order = self.inception5b.concat_order 142 | 143 | if aux_logits: 144 | self.aux1 = inception_aux_block(builder,512, num_classes) 145 | self.aux2 = inception_aux_block(builder,528, num_classes) 146 | else: 147 | self.aux1 = None 148 | self.aux2 = None 149 | 150 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 151 | self.dropout = nn.Dropout(0.2) 152 | # self.fc = nn.Linear(1024, num_classes) 153 | self.fc = builder.linear(prev_out_channels, num_classes,last_layer=True,in_channels_order=concat_order) 154 | 155 | if init_weights: 156 | self._initialize_weights() 157 | 158 | def _initialize_weights(self): 159 | for m in self.modules(): 160 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 161 | import scipy.stats as stats 162 | X = stats.truncnorm(-2, 2, scale=0.01) 163 | values = torch.as_tensor(X.rvs(m.weight.numel()), dtype=m.weight.dtype) 164 | values = values.view(m.weight.size()) 165 | with torch.no_grad(): 166 | m.weight.copy_(values) 167 | elif isinstance(m, nn.BatchNorm2d): 168 | nn.init.constant_(m.weight, 1) 169 | nn.init.constant_(m.bias, 0) 170 | 171 | def _transform_input(self, x): 172 | # type: (Tensor) -> Tensor 173 | if self.transform_input: 174 | x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 175 | x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 176 | x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 177 | x = torch.cat((x_ch0, x_ch1, x_ch2), 1) 178 | return x 179 | 180 | def _forward(self, x): 181 | # type: (Tensor) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]] 182 | # N x 3 x 224 x 224 183 | x = self.conv1(x) 184 | # N x 64 x 112 x 112 185 | x = self.maxpool1(x) 186 | # N x 64 x 56 x 56 187 | x = self.conv2(x) 188 | # N x 64 x 56 x 56 189 | x = self.conv3(x) 190 | # N x 192 x 56 x 56 191 | x = self.maxpool2(x) 192 | 193 | # N x 192 x 28 x 28 194 | x = self.inception3a(x) 195 | 196 | # N x 256 x 28 x 28 197 | x = self.inception3b(x) 198 | # N x 480 x 28 x 28 199 | x = self.maxpool3(x) 200 | # N x 480 x 14 x 14 201 | x = self.inception4a(x) 202 | # N x 512 x 14 x 14 203 | aux1 = torch.jit.annotate(Optional[Tensor], None) 204 | if self.aux1 is not None: 205 | if self.training: 206 | aux1 = self.aux1(x) 207 | 208 | x = self.inception4b(x) 209 | # N x 512 x 14 x 14 210 | x = self.inception4c(x) 211 | # N x 512 x 14 x 14 212 | x = self.inception4d(x) 213 | # N x 528 x 14 x 14 214 | aux2 = torch.jit.annotate(Optional[Tensor], None) 215 | if self.aux2 is not None: 216 | if self.training: 217 | aux2 = self.aux2(x) 218 | 219 | x = self.inception4e(x) 220 | # N x 832 x 14 x 14 221 | x = self.maxpool4(x) 222 | # N x 832 x 7 x 7 223 | x = self.inception5a(x) 224 | # N x 832 x 7 x 7 225 | x = self.inception5b(x) 226 | # N x 1024 x 7 x 7 227 | 228 | x = self.avgpool(x) 229 | # N x 1024 x 1 x 1 230 | x = torch.flatten(x, 1) 231 | # N x 1024 232 | x = self.dropout(x) 233 | x = self.fc(x) 234 | # N x 1000 (num_classes) 235 | 236 | return x, aux2, aux1 237 | # return x 238 | 239 | @torch.jit.unused 240 | def eager_outputs(self, x: Tensor, aux2: Tensor, aux1: Optional[Tensor]) -> GoogLeNetOutputs: 241 | if self.training and self.aux_logits: 242 | return _GoogLeNetOutputs(x, aux2, aux1) 243 | else: 244 | return x # type: ignore[return-value] 245 | 246 | def forward(self, x): 247 | # type: (Tensor) -> GoogLeNetOutputs 248 | x = self._transform_input(x) 249 | 250 | # x = self._forward(x) 251 | # return x 252 | 253 | x, aux1, aux2 = self._forward(x) 254 | aux_defined = self.training and self.aux_logits 255 | if torch.jit.is_scripting(): 256 | if not aux_defined: 257 | warnings.warn("Scripted GoogleNet always returns GoogleNetOutputs Tuple") 258 | return GoogLeNetOutputs(x, aux2, aux1) 259 | else: 260 | return self.eager_outputs(x, aux2, aux1) 261 | 262 | 263 | class Inception(nn.Module): 264 | 265 | def __init__(self,builder ,in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj, 266 | in_channels_order=None, 267 | conv_block=None): 268 | super(Inception, self).__init__() 269 | if conv_block is None: 270 | conv_block = BasicConv2d 271 | self.branch1 = conv_block(builder,in_channels, ch1x1,in_channels_order=in_channels_order, kernel_size=1) 272 | 273 | self.branch2 = nn.Sequential( 274 | conv_block(builder, in_channels, ch3x3red,in_channels_order=in_channels_order, kernel_size=1), 275 | conv_block(builder, ch3x3red, ch3x3, kernel_size=3) # padding=1 276 | ) 277 | 278 | self.branch3 = nn.Sequential( 279 | conv_block(builder, in_channels, ch5x5red,in_channels_order=in_channels_order, kernel_size=1), 280 | # Here, kernel_size=3 instead of kernel_size=5 is a known bug. 281 | # Please see https://github.com/pytorch/vision/issues/906 for details. 282 | conv_block(builder, ch5x5red, ch5x5, kernel_size=3) # padding=1 283 | ) 284 | 285 | self.branch4 = nn.Sequential( 286 | nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=True), 287 | conv_block(builder, in_channels, pool_proj,in_channels_order=in_channels_order, kernel_size=1) 288 | ) 289 | self.out_channels = ch1x1 + ch3x3 + ch5x5 + pool_proj 290 | self.concat_order = '{},{},{},{}'.format(ch1x1,ch3x3,ch5x5,pool_proj) 291 | 292 | def _forward(self, x): 293 | branch1 = self.branch1(x) 294 | branch2 = self.branch2(x) 295 | branch3 = self.branch3(x) 296 | branch4 = self.branch4(x) 297 | 298 | outputs = [branch1, branch2, branch3, branch4] 299 | return outputs 300 | 301 | def forward(self, x): 302 | outputs = self._forward(x) 303 | _res = torch.cat(outputs, 1) 304 | # print(_res.shape) 305 | return _res 306 | 307 | 308 | class InceptionAux(nn.Module): 309 | 310 | def __init__(self, builder, in_channels, num_classes, slim_factor, conv_block=None): 311 | super(InceptionAux, self).__init__() 312 | if conv_block is None: 313 | conv_block = BasicConv2d 314 | self.conv = conv_block(builder,in_channels, 128, kernel_size=1) 315 | 316 | self.fc1 = nn.Linear(2048, 1024) 317 | self.fc2 = nn.Linear(1024, num_classes) 318 | 319 | def forward(self, x): 320 | # aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14 321 | x = F.adaptive_avg_pool2d(x, (4, 4)) 322 | # aux1: N x 512 x 4 x 4, aux2: N x 528 x 4 x 4 323 | x = self.conv(x) 324 | # N x 128 x 4 x 4 325 | x = torch.flatten(x, 1) 326 | # N x 2048 327 | x = F.relu(self.fc1(x), inplace=True) 328 | # N x 1024 329 | x = F.dropout(x, 0.7, training=self.training) 330 | # N x 1024 331 | x = self.fc2(x) 332 | # N x 1000 (num_classes) 333 | 334 | return x 335 | 336 | 337 | class BasicConv2d(nn.Module): 338 | 339 | def __init__(self,builder, in_channels, out_channels,in_channels_order=None, **kwargs): 340 | super(BasicConv2d, self).__init__() 341 | kernel_size = kwargs.pop('kernel_size', None) 342 | # self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) 343 | # self.bn = nn.BatchNorm2d(out_channels, eps=0.001) 344 | if kernel_size == 3: 345 | self.conv = builder.conv3x3( in_channels, out_channels , bias=False,in_channels_order=in_channels_order, **kwargs) 346 | elif kernel_size == 1: 347 | self.conv = builder.conv1x1(in_channels, out_channels, bias=False,in_channels_order=in_channels_order, **kwargs) 348 | elif kernel_size == 7: 349 | self.conv = builder.conv7x7(in_channels, out_channels, bias=False,in_channels_order=in_channels_order, **kwargs) 350 | elif kernel_size == 11: 351 | self.conv = builder.conv11x11(in_channels, out_channels, bias=False,in_channels_order=in_channels_order, **kwargs) 352 | elif kernel_size == 5: 353 | self.conv = builder.conv5x5(in_channels, out_channels, bias=False,in_channels_order=in_channels_order, **kwargs) 354 | else: 355 | raise NotImplemented("Invalid kernel size {}".format(kernel_size)) 356 | 357 | self.bn = builder.batchnorm(out_channels , eps=0.001) 358 | 359 | def forward(self, x): 360 | x = self.conv(x) 361 | x = self.bn(x) 362 | return F.relu(x, inplace=True) 363 | -------------------------------------------------------------------------------- /models/split_resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from layers import conv_type 5 | from models.builder import get_builder 6 | 7 | try: 8 | from torch.hub import load_state_dict_from_url 9 | except ImportError: 10 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 11 | 12 | model_urls = { 13 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 14 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 15 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 16 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 17 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 18 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 19 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 20 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 21 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 22 | } 23 | 24 | # BasicBlock {{{ 25 | class BasicBlock(nn.Module): 26 | M = 2 27 | expansion = 1 28 | 29 | def __init__(self, builder, inplanes, planes, stride=1, downsample=None, base_width=64, slim_factor=1): 30 | super(BasicBlock, self).__init__() 31 | if base_width / 64 > 1: 32 | raise ValueError("Base width >64 does not work for BasicBlock") 33 | 34 | self.conv1 = builder.conv3x3(math.ceil(inplanes * slim_factor), math.ceil(planes * slim_factor), stride) ## Avoid residual links 35 | self.bn1 = builder.batchnorm(math.ceil(planes * slim_factor)) 36 | self.relu = builder.activation() 37 | 38 | self.conv2 = builder.conv3x3(math.ceil(planes * slim_factor), 39 | math.ceil(planes * slim_factor)) ## Avoid residual links 40 | self.bn2 = builder.batchnorm(math.ceil(planes * slim_factor), last_bn=True) ## Avoid residual links 41 | 42 | self.downsample = downsample 43 | self.stride = stride 44 | 45 | def forward(self, x): 46 | 47 | residual = x 48 | # print('1: ',torch.norm(residual[:,:residual.shape[1]//self.split])) 49 | out = self.conv1(x) 50 | # print('1.5 conv ',self.conv1.weight.shape) 51 | # print('1.5 conv ',torch.norm(self.conv1.weight[:self.conv1.weight.shape[0]//self.split,:self.conv1.weight.shape[1]//self.split])) 52 | # print('1.5: ', torch.norm(out[:, :out.shape[1] // self.split])) 53 | if self.bn1 is not None: 54 | out = self.bn1(out) 55 | # print('2: ', torch.norm(out[:,:out.shape[1]//self.split])) 56 | out = self.relu(out) 57 | out = self.conv2(out) 58 | # print('3: ', torch.norm(out[:,:out.shape[1]//self.split])) 59 | if self.bn2 is not None: 60 | out = self.bn2(out) 61 | if self.downsample is not None: 62 | residual = self.downsample(x) 63 | # print('4: ', torch.norm(residual[:,:residual.shape[1]//self.split])) 64 | out += residual 65 | out = self.relu(out) 66 | # print('5: ', torch.norm(out[:,:out.shape[1]//self.split])) 67 | return out 68 | 69 | 70 | # BasicBlock }}} 71 | 72 | # Bottleneck {{{ 73 | class Bottleneck(nn.Module): 74 | M = 3 75 | expansion = 4 76 | 77 | def __init__(self, builder, inplanes, planes, stride=1, downsample=None, base_width=64, slim_factor=1, is_last_conv=False): 78 | super(Bottleneck, self).__init__() 79 | width = int(planes * base_width / 64) 80 | self.conv1 = builder.conv1x1(math.ceil(inplanes * slim_factor), math.ceil(width * slim_factor)) 81 | self.bn1 = builder.batchnorm(math.ceil(width * slim_factor)) 82 | self.conv2 = builder.conv3x3(math.ceil(width * slim_factor), math.ceil(width * slim_factor), stride=stride) 83 | self.bn2 = builder.batchnorm(math.ceil(width * slim_factor)) 84 | self.conv3 = builder.conv1x1(math.ceil(width * slim_factor), math.ceil(planes * self.expansion * slim_factor)) 85 | self.bn3 = builder.batchnorm(math.ceil(planes * self.expansion * slim_factor)) 86 | self.relu = builder.activation() 87 | self.downsample = downsample 88 | self.stride = stride 89 | 90 | def forward(self, x): 91 | 92 | residual = x 93 | 94 | out = self.conv1(x) 95 | out = self.bn1(out) 96 | out = self.relu(out) 97 | 98 | out = self.conv2(out) 99 | out = self.bn2(out) 100 | out = self.relu(out) 101 | 102 | out = self.conv3(out) 103 | out = self.bn3(out) 104 | 105 | if self.downsample is not None: 106 | residual = self.downsample(x) 107 | 108 | out += residual 109 | 110 | out = self.relu(out) 111 | 112 | return out 113 | 114 | 115 | # Bottleneck }}} 116 | 117 | # ResNet {{{ 118 | class ResNet(nn.Module): 119 | def __init__(self,cfg, builder, block, layers, base_width=64): 120 | 121 | super(ResNet, self).__init__() 122 | self.inplanes = 64 123 | slim_factor = cfg.slim_factor 124 | if slim_factor < 1: 125 | cfg.logger.info('WARNING: You are using a slim network') 126 | 127 | self.base_width = base_width 128 | if self.base_width // 64 > 1: 129 | print(f"==> Using {self.base_width // 64}x wide model") 130 | 131 | 132 | self.conv1 = builder.conv7x7(3, math.ceil(64*slim_factor), stride=2, first_layer=True) 133 | 134 | self.bn1 = builder.batchnorm(math.ceil(64*slim_factor)) 135 | self.relu = builder.activation() 136 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 137 | self.layer1 = self._make_layer(builder, block, 64, layers[0], slim_factor=slim_factor) 138 | 139 | self.layer2 = self._make_layer(builder, block, 128, layers[1], stride=2, slim_factor=slim_factor) 140 | 141 | self.layer3 = self._make_layer(builder, block, 256, layers[2], stride=2, slim_factor=slim_factor) 142 | 143 | self.layer4 = self._make_layer(builder, block, 512, layers[3], stride=2, slim_factor=slim_factor) 144 | 145 | self.avgpool = nn.AdaptiveAvgPool2d(1) 146 | 147 | self.fc = builder.linear(math.ceil(512 * block.expansion * slim_factor), cfg.num_cls, last_layer=True) 148 | 149 | 150 | def _make_layer(self, builder, block, planes, blocks, stride=1, slim_factor=1): 151 | downsample = None 152 | if stride != 1 or self.inplanes != planes * block.expansion: 153 | dconv = builder.conv1x1(math.ceil(self.inplanes * slim_factor), 154 | math.ceil(planes * block.expansion * slim_factor), stride=stride) ## Going into a residual link 155 | dbn = builder.batchnorm(math.ceil(planes * block.expansion * slim_factor)) 156 | if dbn is not None: 157 | downsample = nn.Sequential(dconv, dbn) 158 | else: 159 | downsample = dconv 160 | 161 | layers = [] 162 | layers.append(block(builder, self.inplanes, planes, stride, downsample, base_width=self.base_width, slim_factor=slim_factor)) 163 | self.inplanes = planes * block.expansion 164 | for i in range(1, blocks): 165 | layers.append(block(builder, self.inplanes, planes, base_width=self.base_width, 166 | slim_factor=slim_factor)) 167 | 168 | return nn.Sequential(*layers) 169 | 170 | def forward(self, x): 171 | # features = [] 172 | x = self.conv1(x) 173 | # features.append(x) 174 | 175 | if self.bn1 is not None: 176 | x = self.bn1(x) 177 | x = self.relu(x) 178 | x = self.maxpool(x) 179 | # features.append(x) 180 | # print('resnet 1 ',torch.norm(x[:,:x.shape[1]//split])) 181 | # self.layer1[0].split = split 182 | # self.layer1[1].split = split 183 | x = self.layer1(x) 184 | # features.append(x) 185 | 186 | x = self.layer2(x) 187 | # features.append(x) 188 | x = self.layer3(x) 189 | # features.append(x) 190 | x = self.layer4(x) 191 | # features.append(x) 192 | x = self.avgpool(x) 193 | x = torch.flatten(x, 1) 194 | x = self.fc(x) 195 | x = x.view(x.size(0), -1) 196 | # features.append(x) 197 | return x 198 | 199 | from collections import OrderedDict, namedtuple 200 | 201 | class _IncompatibleKeys(namedtuple('IncompatibleKeys', ['missing_keys', 'unexpected_keys'])): 202 | def __repr__(self): 203 | if not self.missing_keys and not self.unexpected_keys: 204 | return '' 205 | return super(_IncompatibleKeys, self).__repr__() 206 | 207 | __str__ = __repr__ 208 | 209 | def load_state_dict(model, state_dict, 210 | strict: bool = True): 211 | r"""Copies parameters and buffers from :attr:`state_dict` into 212 | this module and its descendants. If :attr:`strict` is ``True``, then 213 | the keys of :attr:`state_dict` must exactly match the keys returned 214 | by this module's :meth:`~torch.nn.Module.state_dict` function. 215 | Arguments: 216 | state_dict (dict): a dict containing parameters and 217 | persistent buffers. 218 | strict (bool, optional): whether to strictly enforce that the keys 219 | in :attr:`state_dict` match the keys returned by this module's 220 | :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` 221 | Returns: 222 | ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: 223 | * **missing_keys** is a list of str containing the missing keys 224 | * **unexpected_keys** is a list of str containing the unexpected keys 225 | """ 226 | missing_keys = [] 227 | unexpected_keys = [] 228 | error_msgs = [] 229 | 230 | # copy state_dict so _load_from_state_dict can modify it 231 | metadata = getattr(state_dict, '_metadata', None) 232 | state_dict = state_dict.copy() 233 | if metadata is not None: 234 | state_dict._metadata = metadata 235 | 236 | def load(module, prefix=''): 237 | local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) 238 | module._load_from_state_dict( 239 | state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) 240 | for name, child in module._modules.items(): 241 | if child is not None and not (child.__class__.__name__ == 'SplitLinear' or child.__class__.__name__ == 'Linear'): 242 | load(child, prefix + name + '.') 243 | 244 | load(model) 245 | load = None # break load->load reference cycle 246 | 247 | if strict: 248 | if len(unexpected_keys) > 0: 249 | error_msgs.insert( 250 | 0, 'Unexpected key(s) in state_dict: {}. '.format( 251 | ', '.join('"{}"'.format(k) for k in unexpected_keys))) 252 | if len(missing_keys) > 0: 253 | error_msgs.insert( 254 | 0, 'Missing key(s) in state_dict: {}. '.format( 255 | ', '.join('"{}"'.format(k) for k in missing_keys))) 256 | 257 | if len(error_msgs) > 0: 258 | raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( 259 | model.__class__.__name__, "\n\t".join(error_msgs))) 260 | return _IncompatibleKeys(missing_keys, unexpected_keys) 261 | 262 | # ResNet }}} 263 | def Split_ResNet18(cfg, progress=True): 264 | model = ResNet(cfg,get_builder(cfg), BasicBlock, [2, 2, 2, 2]) 265 | if cfg.pretrained == 'imagenet': 266 | arch = 'resnet18' 267 | state_dict = load_state_dict_from_url(model_urls[arch], 268 | progress=progress) 269 | load_state_dict(model,state_dict,strict=False) 270 | return model 271 | 272 | def Split_ResNet34(cfg, progress=True): 273 | model = ResNet(cfg,get_builder(cfg), BasicBlock, [3, 4, 6, 3]) 274 | if cfg.pretrained == 'imagenet': 275 | arch = 'resnet34' 276 | state_dict = load_state_dict_from_url(model_urls[arch], 277 | progress=progress) 278 | load_state_dict(model,state_dict,strict=False) 279 | return model 280 | 281 | def Split_ResNet50(cfg,progress=True): 282 | model = ResNet(cfg,get_builder(cfg), Bottleneck, [3, 4, 6, 3]) 283 | if cfg.pretrained == 'imagenet': 284 | arch = 'resnet50' 285 | state_dict = load_state_dict_from_url(model_urls[arch], 286 | progress=progress) 287 | load_state_dict(model,state_dict,strict=False) 288 | return model 289 | 290 | 291 | def Split_ResNet101(cfg,progress=True): 292 | model = ResNet(cfg,get_builder(cfg), Bottleneck, [3, 4, 23, 3]) 293 | if cfg.pretrained == 'imagenet': 294 | arch = 'resnet101' 295 | state_dict = load_state_dict_from_url(model_urls[arch], 296 | progress=progress) 297 | load_state_dict(model,state_dict,strict=False) 298 | return model 299 | 300 | 301 | # def WideResNet50_2(cfg,pretrained=False): 302 | # return ResNet(cfg, 303 | # get_builder(cfg), Bottleneck, [3, 4, 6, 3], base_width=64 * 2 304 | # ) 305 | # 306 | # 307 | # def WideResNet101_2(cfg,pretrained=False): 308 | # return ResNet(cfg, 309 | # get_builder(cfg), Bottleneck, [3, 4, 23, 3], base_width=64 * 2 310 | # ) 311 | 312 | -------------------------------------------------------------------------------- /sample_runs/aircraft/SPLT_CLS_Aircraft100Pytorch_Split_ResNet18_cskdFalse_smthTrue_k0.5_G5_e200_evrand_hResetFalse_smkels_single_gpu_test3.csv: -------------------------------------------------------------------------------- 1 | Date Finished, Name, Split Rate, Bias Split Rate, Current Val Top 1, Current Val Top 5, Best Val Top 1, Best Val Top 5, Current Tst Top 1, Current Tst Top 5, Best Tst Top 1, Best Tst Top 5, Best Trn Top 1, Best Trn Top 5 2 | 07-31-21_15:12:03, SPLT_CLS_Aircraft100Pytorch_Split_ResNet18_cskdFalse_smthTrue_k0.5_G5_e200_evrand_hResetFalse_smkels_single_gpu_test3/, 0.5, 0.5, 56.23, 79.72, 56.65, 79.81, 56.95, 79.75, 0.00, 0.00, 94.93, 97.90 3 | 07-31-21_15:12:16, SPLT_CLS_Aircraft100Pytorch_Split_ResNet18_cskdFalse_smthTrue_k0.5_G5_e200_evrand_hResetFalse_smkels_single_gpu_test3/, slim, slim, 0.00, 0.00, 0.00, 0.00, 1.08, 5.04, 0.00, 0.00, 0.00, 0.00 4 | 07-31-21_16:18:12, SPLT_CLS_Aircraft100Pytorch_Split_ResNet18_cskdFalse_smthTrue_k0.5_G5_e200_evrand_hResetFalse_smkels_single_gpu_test3/, 0.5, 0.5, 58.00, 80.32, 58.06, 80.80, 57.37, 80.56, 0.00, 0.00, 95.11, 98.08 5 | 07-31-21_16:18:25, SPLT_CLS_Aircraft100Pytorch_Split_ResNet18_cskdFalse_smthTrue_k0.5_G5_e200_evrand_hResetFalse_smkels_single_gpu_test3/, slim, slim, 0.00, 0.00, 0.00, 0.00, 12.27, 29.70, 0.00, 0.00, 0.00, 0.00 6 | 07-31-21_17:21:53, SPLT_CLS_Aircraft100Pytorch_Split_ResNet18_cskdFalse_smthTrue_k0.5_G5_e200_evrand_hResetFalse_smkels_single_gpu_test3/, 0.5, 0.5, 57.37, 79.66, 57.37, 80.08, 57.85, 79.87, 0.00, 0.00, 94.78, 97.90 7 | 07-31-21_17:22:05, SPLT_CLS_Aircraft100Pytorch_Split_ResNet18_cskdFalse_smthTrue_k0.5_G5_e200_evrand_hResetFalse_smkels_single_gpu_test3/, slim, slim, 0.00, 0.00, 0.00, 0.00, 46.20, 72.25, 0.00, 0.00, 0.00, 0.00 8 | 07-31-21_18:25:55, SPLT_CLS_Aircraft100Pytorch_Split_ResNet18_cskdFalse_smthTrue_k0.5_G5_e200_evrand_hResetFalse_smkels_single_gpu_test3/, 0.5, 0.5, 57.61, 79.27, 57.61, 79.45, 57.70, 80.62, 0.00, 0.00, 95.47, 97.96 9 | 07-31-21_18:26:07, SPLT_CLS_Aircraft100Pytorch_Split_ResNet18_cskdFalse_smthTrue_k0.5_G5_e200_evrand_hResetFalse_smkels_single_gpu_test3/, slim, slim, 0.00, 0.00, 0.00, 0.00, 56.14, 80.29, 0.00, 0.00, 0.00, 0.00 10 | 07-31-21_19:29:54, SPLT_CLS_Aircraft100Pytorch_Split_ResNet18_cskdFalse_smthTrue_k0.5_G5_e200_evrand_hResetFalse_smkels_single_gpu_test3/, 0.5, 0.5, 57.40, 79.84, 57.70, 79.84, 57.91, 80.23, 0.00, 0.00, 94.96, 97.90 11 | 07-31-21_19:30:08, SPLT_CLS_Aircraft100Pytorch_Split_ResNet18_cskdFalse_smthTrue_k0.5_G5_e200_evrand_hResetFalse_smkels_single_gpu_test3/, slim, slim, 0.00, 0.00, 0.00, 0.00, 58.18, 81.40, 0.00, 0.00, 0.00, 0.00 12 | -------------------------------------------------------------------------------- /train_KE_cls.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import data 4 | import torch 5 | import getpass 6 | import KE_model 7 | import importlib 8 | import os.path as osp 9 | import torch.nn as nn 10 | from utils import os_utils 11 | from utils import net_utils 12 | from utils import csv_utils 13 | from layers import conv_type 14 | from utils import path_utils 15 | from utils import model_profile 16 | from configs.base_config import Config 17 | 18 | 19 | def get_trainer(args): 20 | print(f"=> Using trainer from trainers.{args.trainer}") 21 | trainer = importlib.import_module(f"trainers.{args.trainer}") 22 | 23 | return trainer.train, trainer.validate 24 | 25 | 26 | def train_dense(cfg, generation): 27 | 28 | model = net_utils.get_model(cfg) 29 | 30 | if cfg.pretrained and cfg.pretrained != "imagenet": 31 | net_utils.load_pretrained(cfg.pretrained, cfg.gpu, model, cfg) 32 | model = net_utils.move_model_to_gpu(cfg, model) 33 | net_utils.split_reinitialize(cfg, model, reset_hypothesis=cfg.reset_hypothesis) 34 | else: 35 | model = net_utils.move_model_to_gpu(cfg, model) 36 | 37 | cfg.trainer = "default_cls" 38 | cfg.pretrained = None 39 | ckpt_path = KE_model.ke_cls_train(cfg, model, generation) 40 | 41 | return ckpt_path 42 | 43 | 44 | def eval_slim(cfg, generation): 45 | original_num_epos = cfg.epochs 46 | # cfg.epochs = 0 47 | softmax_criterion = nn.CrossEntropyLoss().cuda() 48 | epoch = 1 49 | writer = None 50 | model = net_utils.get_model(cfg) 51 | net_utils.load_pretrained(cfg.pretrained, cfg.gpu, model, cfg) 52 | # if cfg.reset_mask: 53 | # net_utils.reset_mask(cfg, model) 54 | model = net_utils.move_model_to_gpu(cfg, model) 55 | 56 | save_filter_stats = cfg.arch in ["split_alexnet", "split_vgg11_bn"] 57 | if save_filter_stats: 58 | for n, m in model.named_modules(): 59 | if hasattr(m, "weight") and m.weight is not None: 60 | if hasattr(m, "mask"): 61 | layer_mask = m.mask 62 | if m.__class__ == conv_type.SplitConv: 63 | # filter_state = [''.join(map(str, ((score_mask == True).type(torch.int).squeeze().tolist())))] 64 | filter_mag = [ 65 | "{},{}".format( 66 | float( 67 | torch.mean( 68 | torch.abs(m.weight[layer_mask.type(torch.bool)]) 69 | ) 70 | ), 71 | float( 72 | torch.mean( 73 | torch.abs( 74 | m.weight[(1 - layer_mask).type(torch.bool)] 75 | ) 76 | ) 77 | ), 78 | ) 79 | ] 80 | os_utils.txt_write( 81 | osp.join( 82 | cfg.exp_dir, n.replace(".", "_") + "_mean_magnitude.txt" 83 | ), 84 | filter_mag, 85 | mode="a+", 86 | ) 87 | 88 | dummy_input_tensor = torch.zeros((1, 3, 224, 224)).cuda() 89 | total_ops, total_params = model_profile.profile(model, dummy_input_tensor) 90 | cfg.logger.info("Dense #Ops: %f GOps" % (total_ops / 1e9)) 91 | cfg.logger.info( 92 | "Dense #Parameters: %f M (Split-Mask included)" % (total_params / 1e6) 93 | ) 94 | 95 | original_split_rate = cfg.split_rate 96 | original_bias_split_rate = cfg.bias_split_rate 97 | 98 | if cfg.split_mode == "kels": 99 | cfg.slim_factor = cfg.split_rate 100 | cfg.split_rate = 1.0 101 | cfg.bias_split_rate = 1.0 102 | split_model = net_utils.get_model(cfg) 103 | split_model = net_utils.move_model_to_gpu(cfg, split_model) 104 | 105 | total_ops, total_params = model_profile.profile(split_model, dummy_input_tensor) 106 | cfg.logger.info("Split #Ops: %f GOps" % (total_ops / 1e9)) 107 | cfg.logger.info( 108 | "Split #Parameters: %f M (Split-Mask included)" % (total_params / 1e6) 109 | ) 110 | 111 | net_utils.extract_slim(split_model, model) 112 | dataset = getattr(data, cfg.set)(cfg) 113 | train, validate = get_trainer(cfg) 114 | last_val_acc1, last_val_acc5 = validate( 115 | dataset.tst_loader, split_model, softmax_criterion, cfg, writer, epoch 116 | ) 117 | cfg.logger.info("Split Model : {} , {}".format(last_val_acc1, last_val_acc5)) 118 | else: 119 | last_val_acc1 = 0 120 | last_val_acc5 = 0 121 | 122 | csv_utils.write_cls_result_to_csv( 123 | ## Validation 124 | curr_acc1=0, 125 | curr_acc5=0, 126 | best_acc1=0, 127 | best_acc5=0, 128 | ## Test 129 | last_tst_acc1=last_val_acc1, 130 | last_tst_acc5=last_val_acc5, 131 | best_tst_acc1=0, 132 | best_tst_acc5=0, 133 | ## Train 134 | best_train_acc1=0, 135 | best_train_acc5=0, 136 | split_rate="slim", 137 | bias_split_rate="slim", 138 | base_config=cfg.name, 139 | name=cfg.name, 140 | ) 141 | 142 | cfg.epochs = original_num_epos 143 | 144 | cfg.slim_factor = 1 145 | cfg.split_rate = original_split_rate 146 | cfg.bias_split_rate = original_bias_split_rate 147 | 148 | 149 | def clean_dir(ckpt_dir, num_epochs): 150 | # print(ckpt_dir) 151 | if "0000" in str( 152 | ckpt_dir 153 | ): ## Always keep the first model -- Help reproduce results 154 | return 155 | rm_path = ckpt_dir / "model_best.pth" 156 | if rm_path.exists(): 157 | os.remove(rm_path) 158 | 159 | rm_path = ckpt_dir / "epoch_{}.state".format(num_epochs - 1) 160 | if rm_path.exists(): 161 | os.remove(rm_path) 162 | 163 | rm_path = ckpt_dir / "initial.state" 164 | if rm_path.exists(): 165 | os.remove(rm_path) 166 | 167 | 168 | def start_KE(cfg): 169 | # assert cfg.epochs % 10 == 0 or 'debug' in cfg.name, 'Epoch should be divisible by 10' 170 | assert ( 171 | cfg.cs_kd == False 172 | ), "CS-KD requires a different data loader, not available in this repos" 173 | 174 | ckpt_queue = [] 175 | 176 | for gen in range(cfg.num_generations): 177 | cfg.start_epoch = 0 178 | 179 | # cfg.name = original_name + 'task' 180 | task_ckpt = train_dense(cfg, gen) 181 | ckpt_queue.append(task_ckpt) 182 | 183 | # cfg.name = original_name + 'mask' 184 | 185 | cfg.pretrained = task_ckpt / "epoch_{}.state".format(cfg.epochs - 1) 186 | 187 | if cfg.num_generations == 1: 188 | break 189 | 190 | eval_slim(cfg, gen) 191 | 192 | cfg.pretrained = task_ckpt / "epoch_{}.state".format(cfg.epochs - 1) 193 | 194 | if len(ckpt_queue) > 4: 195 | oldest_ckpt = ckpt_queue.pop(0) 196 | clean_dir(oldest_ckpt, cfg.epochs) 197 | 198 | 199 | def main(arg_num_threads=16): 200 | print("Starting with {} threads".format(arg_num_threads)) 201 | # arg_dataset = 'CUB200' # Flower102, CUB200,HAM,Dog120,MIT67,Aircraft100,MINI_MIT67,FCAM 202 | for arg_dataset in ["Flower102Pytorch"]: 203 | arg_epochs = str(200) 204 | arg_evolve_mode = "rand" 205 | arg_reset_hypothesis = False 206 | arg_enable_cs_kd = False 207 | arg_enable_label_smoothing = True 208 | arg_arch = "Split_ResNet18" # Split_ResNet18,Split_ResNet34,Split_ResNet50,split_googlenet,split_densenet169,split_vgg11_bn,split_densenet121 209 | arg_split_top = "0.5" 210 | arg_bias_split_top = arg_split_top 211 | arg_num_generations = "5" 212 | arg_split_mode = "kels" # wels , kels 213 | 214 | exp_name_suffix = "single_gpu_test3" 215 | arg_exp_name = ( 216 | "SPLT_CLS_{}_{}_cskd{}_smth{}_k{}_G{}_e{}_ev{}_hReset{}_sm{}_{}/".format( 217 | arg_dataset, 218 | arg_arch, 219 | arg_enable_cs_kd, 220 | arg_enable_label_smoothing, 221 | arg_split_top, 222 | arg_num_generations, 223 | arg_epochs, 224 | arg_evolve_mode, 225 | arg_reset_hypothesis, 226 | arg_split_mode, 227 | exp_name_suffix, 228 | ) 229 | ) 230 | 231 | if arg_arch in ["split_alexnet", "split_vgg11", "split_vgg11_bn"]: 232 | arg_weight_decay = "5e-4" 233 | arg_init = "kaiming_normal" 234 | else: 235 | arg_weight_decay = "1e-4" 236 | arg_init = "kaiming_normal" 237 | 238 | argv = [ 239 | "--name", 240 | arg_exp_name, 241 | "--evolve_mode", 242 | arg_evolve_mode, 243 | "--num_threads", 244 | "16", 245 | "--gpu", 246 | "0", 247 | "--epochs", 248 | arg_epochs, 249 | "--arch", 250 | arg_arch, 251 | # '--trainer', 'default', #'default', #lottery, # supermask 252 | "--data", 253 | "/mnt/data/datasets/", 254 | "--set", 255 | arg_dataset, # Flower102, CUB200 256 | "--optimizer", 257 | "sgd", 258 | # '--lr', '0.1', 259 | # '--lr_policy', 'step_lr', 260 | # '--warmup_length', '5', 261 | "--lr_policy", 262 | "cosine_lr", 263 | "--warmup_length", 264 | "5", 265 | "--weight_decay", 266 | arg_weight_decay, 267 | "--momentum", 268 | "0.9", 269 | "--batch_size", 270 | "32", 271 | "--conv_type", 272 | "SplitConv", # 'SubnetConv','StrictSubnetConv 273 | "--bn_type", 274 | "SplitBatchNorm", 275 | "--linear_type", 276 | "SplitLinear", 277 | "--split_rate", 278 | arg_split_top, 279 | "--bias_split_rate", 280 | arg_bias_split_top, 281 | "--init", 282 | arg_init, # xavier_normal, kaiming_normal 283 | "--mode", 284 | "fan_in", 285 | "--nonlinearity", 286 | "relu", 287 | "--num_generations", 288 | arg_num_generations, 289 | "--split_mode", 290 | arg_split_mode, 291 | ] 292 | 293 | if arg_enable_cs_kd: 294 | argv.extend(["--cs_kd"]) 295 | 296 | if arg_enable_label_smoothing: 297 | argv.extend(["--label_smoothing", "0.1"]) 298 | 299 | argv.extend(["--lr", "0.256"]) 300 | 301 | if arg_reset_hypothesis: 302 | argv.extend(["--reset_hypothesis"]) 303 | 304 | cfg = Config().parse(argv) 305 | 306 | start_KE(cfg) 307 | 308 | 309 | if __name__ == "__main__": 310 | if len(sys.argv) == 1: 311 | main() 312 | else: 313 | cfg = Config().parse(None) 314 | # print(cfg.name) 315 | start_KE(cfg) 316 | -------------------------------------------------------------------------------- /trainers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ahmdtaha/knowledge_evolution/a3f2eb2eed7accb86ad1af2a15c13e4a9654fe16/trainers/__init__.py -------------------------------------------------------------------------------- /trainers/default_cls.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import numpy as np 4 | import torch.nn as nn 5 | from utils import net_utils 6 | from layers.CS_KD import KDLoss 7 | from utils.eval_utils import accuracy 8 | from utils.logging import AverageMeter, ProgressMeter 9 | 10 | 11 | 12 | __all__ = ["train", "validate"] 13 | 14 | 15 | 16 | kdloss = KDLoss(4).cuda() 17 | 18 | def train(train_loader, model, criterion, optimizer, epoch, cfg, writer): 19 | batch_time = AverageMeter("Time", ":6.3f") 20 | data_time = AverageMeter("Data", ":6.3f") 21 | losses = AverageMeter("Loss", ":.3f") 22 | top1 = AverageMeter("Acc@1", ":6.2f") 23 | top5 = AverageMeter("Acc@5", ":6.2f") 24 | progress = ProgressMeter( 25 | train_loader.num_batches, 26 | [batch_time, data_time, losses, top1, top5],cfg, 27 | prefix=f"Epoch: [{epoch}]", 28 | ) 29 | 30 | # switch to train mode 31 | model.train() 32 | 33 | batch_size = train_loader.batch_size 34 | num_batches = train_loader.num_batches 35 | end = time.time() 36 | 37 | for i , data in enumerate(train_loader): 38 | # images, target = data[0]['data'],data[0]['label'].long().squeeze() 39 | images, target = data[0].cuda(),data[1].long().squeeze().cuda() 40 | # measure data loading time 41 | data_time.update(time.time() - end) 42 | 43 | if cfg.cs_kd: 44 | 45 | batch_size = images.size(0) 46 | loss_batch_size = batch_size // 2 47 | targets_ = target[:batch_size // 2] 48 | outputs = model(images[:batch_size // 2]) 49 | loss = torch.mean(criterion(outputs, targets_)) 50 | # loss += loss.item() 51 | 52 | with torch.no_grad(): 53 | outputs_cls = model(images[batch_size // 2:]) 54 | cls_loss = kdloss(outputs[:batch_size // 2], outputs_cls.detach()) 55 | lamda = 3 56 | loss += lamda * cls_loss 57 | acc1, acc5 = accuracy(outputs, targets_, topk=(1, 5)) 58 | else: 59 | batch_size = images.size(0) 60 | loss_batch_size = batch_size 61 | #compute output 62 | output = model(images) 63 | loss = criterion(output, target) 64 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 65 | 66 | # print(i, batch_size, loss) 67 | 68 | # measure accuracy and record loss 69 | 70 | losses.update(loss.item(), loss_batch_size) 71 | top1.update(acc1.item(), loss_batch_size) 72 | top5.update(acc5.item(), loss_batch_size) 73 | 74 | # compute gradient and do SGD step 75 | optimizer.zero_grad() 76 | loss.backward() 77 | optimizer.step() 78 | 79 | # measure elapsed time 80 | batch_time.update(time.time() - end) 81 | end = time.time() 82 | 83 | 84 | if i % cfg.print_freq == 0 or i == num_batches-1: 85 | t = (num_batches * epoch + i) * batch_size 86 | progress.display(i) 87 | progress.write_to_tensorboard(writer, prefix="train", global_step=t) 88 | 89 | # train_loader.reset() 90 | # print(top1.count) 91 | return top1.avg, top5.avg 92 | 93 | 94 | def validate(val_loader, model, criterion, args, writer, epoch): 95 | batch_time = AverageMeter("Time", ":6.3f", write_val=False) 96 | losses = AverageMeter("Loss", ":.3f", write_val=False) 97 | top1 = AverageMeter("Acc@1", ":6.2f", write_val=False) 98 | top5 = AverageMeter("Acc@5", ":6.2f", write_val=False) 99 | progress = ProgressMeter( 100 | val_loader.num_batches, [batch_time, losses, top1, top5],args, prefix="Test: " 101 | ) 102 | 103 | # switch to evaluate mode 104 | model.eval() 105 | 106 | with torch.no_grad(): 107 | end = time.time() 108 | 109 | # confusion_matrix = torch.zeros(args.num_cls,args.num_cls) 110 | for i, data in enumerate(val_loader): 111 | # images, target = data[0]['data'], data[0]['label'].long().squeeze() 112 | images, target = data[0].cuda(), data[1].long().squeeze().cuda() 113 | 114 | 115 | output = model(images) 116 | loss = criterion(output, target) 117 | 118 | # measure accuracy and record loss 119 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 120 | # print(target,torch.mean(images),acc1,acc5,loss,torch.mean(output)) 121 | losses.update(loss.item(), images.size(0)) 122 | top1.update(acc1.item(), images.size(0)) 123 | top5.update(acc5.item(), images.size(0)) 124 | 125 | 126 | # _, preds = torch.max(output, 1) 127 | # for t, p in zip(target.view(-1), preds.view(-1)): 128 | # confusion_matrix[t.long(), p.long()] += 1 129 | 130 | # measure elapsed time 131 | batch_time.update(time.time() - end) 132 | end = time.time() 133 | 134 | 135 | 136 | if i % args.print_freq == 0: 137 | progress.display(i) 138 | 139 | progress.display(val_loader.num_batches) 140 | 141 | if writer is not None: 142 | progress.write_to_tensorboard(writer, prefix="test", global_step=epoch) 143 | 144 | # torch.save(confusion_matrix,'./conf_mat.pt') 145 | # print(top1.count) 146 | return top1.avg, top5.avg 147 | 148 | 149 | 150 | 151 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ahmdtaha/knowledge_evolution/a3f2eb2eed7accb86ad1af2a15c13e4a9654fe16/utils/__init__.py -------------------------------------------------------------------------------- /utils/csv_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import pathlib 4 | from utils import path_utils 5 | 6 | def write_cls_result_to_csv(**kwargs): 7 | name = kwargs.get('name') 8 | if '/' in name: 9 | exp_name = name.split('/')[0] 10 | results = pathlib.Path(os.path.join(path_utils.get_checkpoint_dir(),exp_name, "{}.csv".format(exp_name))) 11 | else: 12 | results = pathlib.Path(os.path.join(path_utils.get_checkpoint_dir(), "{}.csv".format(name) )) 13 | 14 | if not results.exists(): 15 | results.write_text( 16 | "Date Finished, " 17 | # "Base Config, " 18 | "Name, " 19 | "Split Rate, " 20 | "Bias Split Rate, " 21 | "Current Val Top 1, " 22 | "Current Val Top 5, " 23 | "Best Val Top 1, " 24 | "Best Val Top 5, " 25 | 26 | "Current Tst Top 1, " 27 | "Current Tst Top 5, " 28 | "Best Tst Top 1, " 29 | "Best Tst Top 5, " 30 | 31 | "Best Trn Top 1, " 32 | "Best Trn Top 5\n" 33 | ) 34 | 35 | now = time.strftime("%m-%d-%y_%H:%M:%S") 36 | 37 | with open(results, "a+") as f: 38 | f.write( 39 | ( 40 | "{now}, " 41 | # "{base_config}, " 42 | "{name}, " 43 | "{split_rate}, " 44 | "{bias_split_rate}, " 45 | "{curr_acc1:.02f}, " 46 | "{curr_acc5:.02f}, " 47 | "{best_acc1:.02f}, " 48 | "{best_acc5:.02f}, " 49 | 50 | "{last_tst_acc1:.02f}, " 51 | "{last_tst_acc5:.02f}, " 52 | "{best_tst_acc1:.02f}, " 53 | "{best_tst_acc5:.02f}, " 54 | 55 | "{best_train_acc1:.02f}, " 56 | "{best_train_acc5:.02f}\n" 57 | ).format(now=now, **kwargs) 58 | ) 59 | 60 | 61 | def write_ret_result_to_csv(**kwargs): 62 | name = kwargs.get('name') 63 | name_prefix = kwargs.get('name_prefix') 64 | 65 | if '/' in name: 66 | exp_name = name.split('/')[0] 67 | if name_prefix is None: 68 | results = pathlib.Path(os.path.join(path_utils.get_checkpoint_dir(), exp_name, "{}.csv".format(exp_name))) 69 | else: 70 | results = pathlib.Path(os.path.join(path_utils.get_checkpoint_dir(), exp_name, "{}_{}.csv".format(name_prefix,exp_name))) 71 | else: 72 | results = pathlib.Path(os.path.join(path_utils.get_checkpoint_dir(), "{}.csv".format(name))) 73 | 74 | if not results.exists(): 75 | results.write_text( 76 | "Date Finished, " 77 | # "Base Config, " 78 | "Name, " 79 | "Split Rate, " 80 | "Bias Split Rate, " 81 | "NMI," 82 | "R@1," 83 | "R@2," 84 | "R@4," 85 | "R@8," 86 | "R@16," 87 | "R@32\n" 88 | ) 89 | 90 | now = time.strftime("%m-%d-%y_%H:%M:%S") 91 | 92 | with open(results, "a+") as f: 93 | f.write( 94 | ( 95 | "{now}, " 96 | # "{base_config}, " 97 | "{name}, " 98 | "{split_rate}, " 99 | "{bias_split_rate}, " 100 | "{NMI:.03f}, " 101 | "{R_1:.02f}, " 102 | "{R_2:.02f}, " 103 | "{R_4:.02f}, " 104 | "{R_8:.02f}, " 105 | "{R_16:.02f}, " 106 | "{R_32:.02f}\n" 107 | ).format(now=now, **kwargs) 108 | ) 109 | -------------------------------------------------------------------------------- /utils/eval_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def accuracy(output, target, topk=(1,)): 5 | """Computes the accuracy over the k top predictions for the specified values of k""" 6 | with torch.no_grad(): 7 | maxk = max(topk) 8 | batch_size = target.size(0) 9 | 10 | _, pred = output.topk(maxk, 1, True, True) 11 | pred = pred.t() 12 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 13 | 14 | res = [] 15 | for k in topk: 16 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 17 | res.append(correct_k.mul_(100.0 / batch_size)) 18 | return res 19 | -------------------------------------------------------------------------------- /utils/gpu_utils.py: -------------------------------------------------------------------------------- 1 | import nvidia_smi 2 | # pip install nvidia-ml-py3 3 | class GPU_Utils: 4 | def __init__(self,gpu_index=0): 5 | nvidia_smi.nvmlInit() 6 | self.handle = nvidia_smi.nvmlDeviceGetHandleByIndex(gpu_index) 7 | 8 | def gpu_mem_usage(self): 9 | mem_res = nvidia_smi.nvmlDeviceGetMemoryInfo(self.handle) 10 | return mem_res.used / (1024**2) 11 | 12 | def gpu_utilization(self): 13 | gpu_util = nvidia_smi.nvmlDeviceGetUtilizationRates(self.handle) 14 | return gpu_util.gpu 15 | -------------------------------------------------------------------------------- /utils/log_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os.path as osp 3 | # from utils import os_utils 4 | 5 | # tensorflow_logger = logging.getLogger('tensorflow') 6 | # tensorflow_logger.setLevel(logging.DEBUG) 7 | # 8 | # # create formatter and add it to the handlers 9 | # formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 10 | # # create file handler which logs even debug messages 11 | # fh = logging.FileHandler(os.path.join(config.model_save_path ,'tensorflow.txt')) 12 | # fh.setLevel(logging.DEBUG) 13 | # fh.setFormatter(formatter) 14 | # tensorflow_logger.addHandler(fh) 15 | 16 | 17 | # Copyright (c) 2014 Markus Pointner 18 | # 19 | # Permission is hereby granted, free of charge, to any person obtaining a copy 20 | # of this software and associated documentation files (the "Software"), to deal 21 | # in the Software without restriction, including without limitation the rights 22 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 23 | # copies of the Software, and to permit persons to whom the Software is 24 | # furnished to do so, subject to the following conditions: 25 | # 26 | # The above copyright notice and this permission notice shall be included in 27 | # all copies or substantial portions of the Software. 28 | # 29 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 30 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 31 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 32 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 33 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 34 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 35 | # THE SOFTWARE. 36 | 37 | class _AnsiColorStreamHandler(logging.StreamHandler): 38 | DEFAULT = '\x1b[0m' 39 | RED = '\x1b[31m' 40 | GREEN = '\x1b[32m' 41 | YELLOW = '\x1b[33m' 42 | CYAN = '\x1b[36m' 43 | 44 | CRITICAL = RED 45 | ERROR = RED 46 | WARNING = YELLOW 47 | INFO = DEFAULT # GREEN 48 | DEBUG = CYAN 49 | 50 | @classmethod 51 | def _get_color(cls, level): 52 | if level >= logging.CRITICAL: return cls.CRITICAL 53 | elif level >= logging.ERROR: return cls.ERROR 54 | elif level >= logging.WARNING: return cls.WARNING 55 | elif level >= logging.INFO: return cls.INFO 56 | elif level >= logging.DEBUG: return cls.DEBUG 57 | else: return cls.DEFAULT 58 | 59 | def __init__(self, stream=None): 60 | logging.StreamHandler.__init__(self, stream) 61 | 62 | def format(self, record): 63 | text = logging.StreamHandler.format(self, record) 64 | color = self._get_color(record.levelno) 65 | return (color + text + self.DEFAULT) if self.is_tty() else text 66 | 67 | def is_tty(self): 68 | isatty = getattr(self.stream, 'isatty', None) 69 | return isatty and isatty() 70 | 71 | 72 | class _WinColorStreamHandler(logging.StreamHandler): 73 | # wincon.h 74 | FOREGROUND_BLACK = 0x0000 75 | FOREGROUND_BLUE = 0x0001 76 | FOREGROUND_GREEN = 0x0002 77 | FOREGROUND_CYAN = 0x0003 78 | FOREGROUND_RED = 0x0004 79 | FOREGROUND_MAGENTA = 0x0005 80 | FOREGROUND_YELLOW = 0x0006 81 | FOREGROUND_GREY = 0x0007 82 | FOREGROUND_INTENSITY = 0x0008 # foreground color is intensified. 83 | FOREGROUND_WHITE = FOREGROUND_BLUE | FOREGROUND_GREEN | FOREGROUND_RED 84 | 85 | BACKGROUND_BLACK = 0x0000 86 | BACKGROUND_BLUE = 0x0010 87 | BACKGROUND_GREEN = 0x0020 88 | BACKGROUND_CYAN = 0x0030 89 | BACKGROUND_RED = 0x0040 90 | BACKGROUND_MAGENTA = 0x0050 91 | BACKGROUND_YELLOW = 0x0060 92 | BACKGROUND_GREY = 0x0070 93 | BACKGROUND_INTENSITY = 0x0080 # background color is intensified. 94 | 95 | DEFAULT = FOREGROUND_WHITE 96 | CRITICAL = BACKGROUND_YELLOW | FOREGROUND_RED | FOREGROUND_INTENSITY | BACKGROUND_INTENSITY 97 | ERROR = FOREGROUND_RED | FOREGROUND_INTENSITY 98 | WARNING = FOREGROUND_YELLOW | FOREGROUND_INTENSITY 99 | INFO = FOREGROUND_GREEN 100 | DEBUG = FOREGROUND_CYAN 101 | 102 | @classmethod 103 | def _get_color(cls, level): 104 | if level >= logging.CRITICAL: return cls.CRITICAL 105 | elif level >= logging.ERROR: return cls.ERROR 106 | elif level >= logging.WARNING: return cls.WARNING 107 | elif level >= logging.INFO: return cls.INFO 108 | elif level >= logging.DEBUG: return cls.DEBUG 109 | else: return cls.DEFAULT 110 | 111 | def _set_color(self, code): 112 | import ctypes 113 | ctypes.windll.kernel32.SetConsoleTextAttribute(self._outhdl, code) 114 | 115 | def __init__(self, stream=None): 116 | logging.StreamHandler.__init__(self, stream) 117 | # get file handle for the stream 118 | import ctypes, ctypes.util 119 | # for some reason find_msvcrt() sometimes doesn't find msvcrt.dll on my system? 120 | crtname = ctypes.util.find_msvcrt() 121 | if not crtname: 122 | crtname = ctypes.util.find_library("msvcrt") 123 | crtlib = ctypes.cdll.LoadLibrary(crtname) 124 | self._outhdl = crtlib._get_osfhandle(self.stream.fileno()) 125 | 126 | def emit(self, record): 127 | color = self._get_color(record.levelno) 128 | self._set_color(color) 129 | logging.StreamHandler.emit(self, record) 130 | self._set_color(self.FOREGROUND_WHITE) 131 | 132 | # select ColorStreamHandler based on platform 133 | import platform 134 | if platform.system() == 'Windows': 135 | ColorStreamHandler = _WinColorStreamHandler 136 | else: 137 | ColorStreamHandler = _AnsiColorStreamHandler 138 | 139 | def get_logging_dict(name,mode='w'): 140 | return { 141 | 'version': 1, 142 | 'disable_existing_loggers': False, 143 | 'formatters': { 144 | 'standard': { 145 | 'format': '%(asctime)s [%(levelname)s] %(name)s: %(message)s' 146 | }, 147 | }, 148 | 'handlers': { 149 | 'stderr': { 150 | 'level': 'INFO', 151 | 'formatter': 'standard', 152 | 'class': 'utils.log_utils.ColorStreamHandler', 153 | 'stream': 'ext://sys.stderr', 154 | }, 155 | 'logfile': { 156 | 'level': 'DEBUG', 157 | 'formatter': 'standard', 158 | 'class': 'logging.FileHandler', 159 | 'filename': name, 160 | 'mode': mode, 161 | } 162 | }, 163 | 'loggers': { 164 | '': { 165 | 'handlers': ['stderr', 'logfile'], 166 | 'level': 'DEBUG', 167 | 'propagate': True 168 | }, 169 | 170 | # extra ones to shut up. 171 | 'tensorflow': { 172 | 'handlers': ['stderr', 'logfile'], 173 | 'level': 'INFO', 174 | }, 175 | } 176 | } 177 | 178 | def create_logger(log_file): 179 | import logging.config 180 | logging.config.dictConfig(get_logging_dict(log_file)) 181 | filename = osp.basename(log_file) 182 | logger = logging.getLogger(filename) 183 | return logger 184 | 185 | def get_logger_by_tag(tag): 186 | logger = logging.getLogger(tag) 187 | return logger 188 | 189 | 190 | -------------------------------------------------------------------------------- /utils/logging.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import tqdm 3 | 4 | from torch.utils.tensorboard import SummaryWriter 5 | 6 | 7 | class ProgressMeter(object): 8 | def __init__(self, num_batches, meters,cfg,prefix=""): 9 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 10 | self.meters = meters 11 | self.prefix = prefix 12 | self.cfg = cfg 13 | 14 | def display(self, batch, tqdm_writer=False): 15 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 16 | entries += [str(meter) for meter in self.meters] 17 | if not tqdm_writer: 18 | self.cfg.logger.info("\t".join(entries)) 19 | else: 20 | tqdm.tqdm.write("\t".join(entries)) 21 | 22 | def write_to_tensorboard( 23 | self, writer: SummaryWriter, prefix="train", global_step=None 24 | ): 25 | for meter in self.meters: 26 | avg = meter.avg 27 | val = meter.val 28 | if meter.write_val: 29 | writer.add_scalar( 30 | f"{prefix}/{meter.name}_val", val, global_step=global_step 31 | ) 32 | 33 | if meter.write_avg: 34 | writer.add_scalar( 35 | f"{prefix}/{meter.name}_avg", avg, global_step=global_step 36 | ) 37 | 38 | def _get_batch_fmtstr(self, num_batches): 39 | num_digits = len(str(num_batches // 1)) 40 | fmt = "{:" + str(num_digits) + "d}" 41 | return "[" + fmt + "/" + fmt.format(num_batches) + "]" 42 | 43 | 44 | class Meter(object): 45 | @abc.abstractmethod 46 | def __init__(self, name, fmt=":f"): 47 | pass 48 | 49 | @abc.abstractmethod 50 | def reset(self): 51 | pass 52 | 53 | @abc.abstractmethod 54 | def update(self, val, n=1): 55 | pass 56 | 57 | @abc.abstractmethod 58 | def __str__(self): 59 | pass 60 | 61 | 62 | class AverageMeter(Meter): 63 | """ Computes and stores the average and current value """ 64 | 65 | def __init__(self, name, fmt=":f", write_val=True, write_avg=True): 66 | self.name = name 67 | self.fmt = fmt 68 | self.reset() 69 | 70 | self.write_val = write_val 71 | self.write_avg = write_avg 72 | 73 | def reset(self): 74 | self.val = 0 75 | self.avg = 0 76 | self.sum = 0 77 | self.count = 0 78 | 79 | def update(self, val, n=1): 80 | self.val = val 81 | self.sum += val * n 82 | self.count += n 83 | self.avg = self.sum / self.count 84 | 85 | def __str__(self): 86 | fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" 87 | return fmtstr.format(**self.__dict__) 88 | 89 | 90 | class VarianceMeter(Meter): 91 | def __init__(self, name, fmt=":f", write_val=False): 92 | self.name = name 93 | self._ex_sq = AverageMeter(name="_subvariance_1", fmt=":.02f") 94 | self._sq_ex = AverageMeter(name="_subvariance_2", fmt=":.02f") 95 | self.fmt = fmt 96 | self.reset() 97 | self.write_val = False 98 | self.write_avg = True 99 | 100 | @property 101 | def val(self): 102 | return self._ex_sq.val - self._sq_ex.val ** 2 103 | 104 | @property 105 | def avg(self): 106 | return self._ex_sq.avg - self._sq_ex.avg ** 2 107 | 108 | def reset(self): 109 | self._ex_sq.reset() 110 | self._sq_ex.reset() 111 | 112 | def update(self, val, n=1): 113 | self._ex_sq.update(val ** 2, n=n) 114 | self._sq_ex.update(val, n=n) 115 | 116 | def __str__(self): 117 | return ("{name} (var {avg" + self.fmt + "})").format( 118 | name=self.name, avg=self.avg 119 | ) 120 | -------------------------------------------------------------------------------- /utils/model_profile.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torchvision.models as models 6 | 7 | def count_conv2d(m, x, y): 8 | x = x[0] 9 | 10 | cin = m.in_channels // m.groups 11 | cout = m.out_channels // m.groups 12 | kh, kw = m.kernel_size 13 | batch_size = x.size()[0] 14 | 15 | # ops per output element 16 | kernel_mul = kh * kw * cin 17 | kernel_add = kh * kw * cin - 1 18 | bias_ops = 1 if m.bias is not None else 0 19 | ops = kernel_mul + kernel_add + bias_ops 20 | 21 | # total ops 22 | num_out_elements = y.numel() 23 | total_ops = num_out_elements * ops 24 | 25 | # incase same conv is used multiple times 26 | m.total_ops += torch.Tensor([int(total_ops)]) 27 | 28 | def count_bn2d(m, x, y): 29 | x = x[0] 30 | 31 | nelements = x.numel() 32 | total_sub = nelements 33 | total_div = nelements 34 | total_ops = total_sub + total_div 35 | 36 | m.total_ops += torch.Tensor([int(total_ops)]) 37 | 38 | def count_relu(m, x, y): 39 | x = x[0] 40 | 41 | nelements = x.numel() 42 | total_ops = nelements 43 | 44 | m.total_ops += torch.Tensor([int(total_ops)]) 45 | 46 | def count_softmax(m, x, y): 47 | x = x[0] 48 | 49 | batch_size, nfeatures = x.size() 50 | 51 | total_exp = nfeatures 52 | total_add = nfeatures - 1 53 | total_div = nfeatures 54 | total_ops = batch_size * (total_exp + total_add + total_div) 55 | 56 | m.total_ops += torch.Tensor([int(total_ops)]) 57 | 58 | def count_maxpool(m, x, y): 59 | kernel_ops = torch.prod(torch.Tensor([m.kernel_size])) - 1 60 | num_elements = y.numel() 61 | total_ops = kernel_ops * num_elements 62 | 63 | m.total_ops += torch.Tensor([int(total_ops)]) 64 | 65 | def count_avgpool(m, x, y): 66 | total_add = torch.prod(torch.Tensor([m.kernel_size])) - 1 67 | total_div = 1 68 | kernel_ops = total_add + total_div 69 | num_elements = y.numel() 70 | total_ops = kernel_ops * num_elements 71 | 72 | m.total_ops += torch.Tensor([int(total_ops)]) 73 | 74 | def count_linear(m, x, y): 75 | # per output element 76 | total_mul = m.in_features 77 | total_add = m.in_features - 1 78 | num_elements = y.numel() 79 | total_ops = (total_mul + total_add) * num_elements 80 | 81 | m.total_ops += torch.Tensor([int(total_ops)]) 82 | 83 | def profile(model, input_zero_tensor, custom_ops = {}): 84 | 85 | model.eval() 86 | 87 | def add_hooks(m): 88 | if len(list(m.children())) > 0: return 89 | m.register_buffer('total_ops', torch.zeros(1)) 90 | m.register_buffer('total_params', torch.zeros(1)) 91 | 92 | for p in m.parameters(): 93 | m.total_params += torch.Tensor([p.numel()]) 94 | 95 | if isinstance(m, nn.Conv2d): 96 | m.register_forward_hook(count_conv2d) 97 | elif isinstance(m, nn.BatchNorm2d): 98 | m.register_forward_hook(count_bn2d) 99 | elif isinstance(m, nn.ReLU): 100 | m.register_forward_hook(count_relu) 101 | elif isinstance(m, (nn.MaxPool1d, nn.MaxPool2d, nn.MaxPool3d)): 102 | m.register_forward_hook(count_maxpool) 103 | elif isinstance(m, (nn.AvgPool1d, nn.AvgPool2d, nn.AvgPool3d)): 104 | m.register_forward_hook(count_avgpool) 105 | elif isinstance(m, nn.Linear): 106 | m.register_forward_hook(count_linear) 107 | elif isinstance(m, (nn.Dropout, nn.Dropout2d, nn.Dropout3d)): 108 | pass 109 | else: 110 | print("Not implemented for ", m) 111 | 112 | model.apply(add_hooks) 113 | 114 | 115 | model(input_zero_tensor) 116 | 117 | total_ops = 0 118 | total_params = 0 119 | for m in model.modules(): 120 | if len(list(m.children())) > 0: continue 121 | # print(m,m.total_ops) 122 | total_ops += m.total_ops 123 | total_params += m.total_params 124 | total_ops = total_ops 125 | total_params = total_params 126 | 127 | return total_ops, total_params 128 | 129 | def main(args): 130 | # model = torch.load(args.model) 131 | # model = models.resnet18(pretrained=False) 132 | # model = models.vgg19(pretrained=False) 133 | model = models.alexnet(pretrained=False) 134 | dummy_input_tensor = torch.zeros((1, 3, 224, 224)) 135 | total_ops, total_params = profile(model, dummy_input_tensor) 136 | print("#Ops: %f GOps"%(total_ops/1e9)) 137 | print("#Parameters: %f M"%(total_params/1e6)) 138 | 139 | if __name__ == "__main__": 140 | parser = argparse.ArgumentParser(description="pytorch model profiler") 141 | parser.add_argument("--model",type=str, help="model to profile") 142 | # parser.add_argument("input_size", nargs='+', type=int, 143 | # help="input size to the network") 144 | argv = [ 145 | '--model','resnet18', 146 | # 'input_size','244','244','3', 147 | ] 148 | args = parser.parse_args(argv) 149 | main(args) -------------------------------------------------------------------------------- /utils/net_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import torch 4 | import shutil 5 | import models 6 | import pathlib 7 | import numpy as np 8 | import torch.nn as nn 9 | from layers import bn_type 10 | from layers import conv_type 11 | from layers import linear_type 12 | import torch.backends.cudnn as cudnn 13 | # from layers.conv_type import FixedSubnetConv 14 | 15 | 16 | def get_model(args): 17 | 18 | args.logger.info("=> Creating model '{}'".format(args.arch)) 19 | model = models.__dict__[args.arch](args) 20 | 21 | return model 22 | 23 | 24 | 25 | def move_model_to_gpu(args, model): 26 | assert torch.cuda.is_available(), "CPU-only experiments currently unsupported" 27 | # print('{}'.format(args.gpu)) 28 | if args.gpu is not None: 29 | torch.cuda.set_device(args.gpu) 30 | model = model.cuda(args.gpu) 31 | elif args.multigpu is None: 32 | device = torch.device("cpu") 33 | else: 34 | # DataParallel will divide and allocate batch_size to all available GPUs 35 | args.logger.info(f"=> Parallelizing on {args.multigpu} gpus") 36 | torch.cuda.set_device(args.multigpu[0]) 37 | args.gpu = args.multigpu[0] 38 | model = torch.nn.DataParallel(model, device_ids=args.multigpu).cuda( 39 | args.multigpu[0] 40 | ) 41 | 42 | cudnn.benchmark = True 43 | 44 | return model 45 | 46 | def save_checkpoint(state, is_best, filename="checkpoint.pth", save=False): 47 | filename = pathlib.Path(filename) 48 | 49 | if not filename.parent.exists(): 50 | os.makedirs(filename.parent,exist_ok=True) 51 | 52 | torch.save(state, filename) 53 | 54 | if is_best: 55 | shutil.copyfile(filename, str(filename.parent / "model_best.pth")) 56 | 57 | if not save: 58 | os.remove(filename) 59 | 60 | 61 | def get_lr(optimizer): 62 | return optimizer.param_groups[0]["lr"] 63 | 64 | 65 | def extract_slim(split_model,model): 66 | for (dst_n, dst_m), (src_n, src_m) in zip(split_model.named_modules(), model.named_modules()): 67 | if hasattr(src_m, "weight") and src_m.weight is not None: 68 | if hasattr(src_m, "mask"): 69 | src_m.extract_slim(dst_m,src_n,dst_n) 70 | # if src_m.__class__ == conv_type.SplitConv: 71 | # elif src_m.__class__ == linear_type.SplitLinear: 72 | elif src_m.__class__ == bn_type.SplitBatchNorm: ## BatchNorm has bn_maks not mask 73 | src_m.extract_slim(dst_m) 74 | 75 | 76 | 77 | def split_reinitialize(cfg,model,reset_hypothesis=False): 78 | cfg.logger.info('split_reinitialize') 79 | # zero_reset = True 80 | if cfg.evolve_mode == 'zero': 81 | cfg.logger.info('WARNING: ZERO RESET is not optimal') 82 | for n, m in model.named_modules(): 83 | if hasattr(m, "weight") and m.weight is not None: 84 | if hasattr(m, "mask"): ## Conv and Linear but not BN 85 | assert m.split_rate < 1.0 86 | 87 | if reset_hypothesis and (m.__class__ == conv_type.SplitConv or m.__class__ == linear_type.SplitLinear): 88 | before_sum = torch.sum(m.mask) 89 | m.reset_mask() 90 | cfg.logger.info('reset_hypothesis : True {} : {} -> {}'.format(n,before_sum,torch.sum(m.mask))) 91 | else: 92 | cfg.logger.info('reset_hypothesis : False {} : {}'.format(n, torch.sum(m.mask))) 93 | 94 | if m.__class__ == conv_type.SplitConv or m.__class__ == linear_type.SplitLinear: 95 | m.split_reinitialize(cfg) 96 | else: 97 | raise NotImplemented('Invalid layer {}'.format(m.__class__)) 98 | 99 | 100 | 101 | 102 | class LabelSmoothing(nn.Module): 103 | """ 104 | NLL loss with label smoothing. 105 | """ 106 | 107 | def __init__(self, smoothing=0.0): 108 | """ 109 | Constructor for the LabelSmoothing module. 110 | 111 | :param smoothing: label smoothing factor 112 | """ 113 | super(LabelSmoothing, self).__init__() 114 | self.confidence = 1.0 - smoothing 115 | self.smoothing = smoothing 116 | 117 | def forward(self, x, target): 118 | logprobs = torch.nn.functional.log_softmax(x, dim=-1) 119 | 120 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) 121 | nll_loss = nll_loss.squeeze(1) 122 | smooth_loss = -logprobs.mean(dim=-1) 123 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss 124 | return loss.mean() 125 | 126 | def reset_mask(cfg,model): 127 | cfg.logger.info("=> reseting model mask") 128 | 129 | for n, m in model.named_modules(): 130 | if hasattr(m, "mask"): 131 | cfg.logger.info(f"==> reset {n}.mask") 132 | # m.mask.requires_grad = True 133 | m.reset_mask() 134 | 135 | if hasattr(m, "bias_mask"): 136 | cfg.logger.info(f"==> reset {n}.bias_mask") 137 | m.reset_bias_mask() 138 | # m.bias_mask.requires_grad = True 139 | 140 | def load_pretrained(pretrained_path,gpus, model,cfg): 141 | if os.path.isfile(pretrained_path): 142 | cfg.logger.info("=> loading pretrained weights from '{}'".format(pretrained_path)) 143 | pretrained = torch.load( 144 | pretrained_path, 145 | map_location=torch.device("cuda:{}".format(gpus)), 146 | )["state_dict"] 147 | skip = ' ' 148 | # skip = 'mask' 149 | model_state_dict = model.state_dict() 150 | for k, v in pretrained.items(): 151 | # if k not in model_state_dict or v.size() != model_state_dict[k].size(): 152 | if k not in model_state_dict or v.size() != model_state_dict[k].size() or skip in k: 153 | cfg.logger.info("IGNORE: {}".format(k)) 154 | pretrained = { 155 | k: v 156 | for k, v in pretrained.items() 157 | if (k in model_state_dict and v.size() == model_state_dict[k].size() and skip not in k) 158 | } 159 | model_state_dict.update(pretrained) 160 | model.load_state_dict(model_state_dict) 161 | 162 | else: 163 | cfg.logger.info("=> no pretrained weights found at '{}'".format(pretrained_path)) 164 | 165 | 166 | 167 | 168 | -------------------------------------------------------------------------------- /utils/os_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import errno 4 | import pickle 5 | import numpy as np 6 | from shutil import copyfile 7 | 8 | 9 | def get_last_part(path): 10 | return os.path.basename(os.path.normpath(path)) 11 | 12 | def copy_file(f,dst,rename=None): 13 | touch_dir(dst) 14 | # for f_idx,f in enumerate(src_file_lst): 15 | if os.path.exists(f): 16 | # print(f) 17 | if rename ==None: 18 | copyfile(f, os.path.join(dst,get_last_part(f))) 19 | else: 20 | _,ext = get_file_name_ext(f) 21 | copyfile(f, os.path.join(dst, rename+ext )) 22 | else: 23 | raise Exception('File not found') 24 | 25 | def copy_files(src_file_lst,dst,rename=None): 26 | touch_dir(dst) 27 | for f_idx,f in enumerate(src_file_lst): 28 | if os.path.exists(f): 29 | # print(f) 30 | if rename ==None: 31 | copyfile(f, os.path.join(dst,get_last_part(f))) 32 | else: 33 | _,ext = get_file_name_ext(f) 34 | copyfile(f, os.path.join(dst, rename[f_idx]+ext )) 35 | else: 36 | raise Exception('File not found') 37 | 38 | def dataset_tuples(dataset_path): 39 | return dataset_path + '_tuples_class' 40 | 41 | 42 | def get_dirs(base_path): 43 | return sorted([f for f in os.listdir(base_path) if os.path.isdir(os.path.join(base_path, f))]) 44 | 45 | 46 | def get_files(base_path,extension,append_base=False): 47 | if (append_base): 48 | files =[os.path.join(base_path,f) for f in os.listdir(base_path) if (f.endswith(extension) and not f.startswith('.'))]; 49 | else: 50 | files = [f for f in os.listdir(base_path) if (f.endswith(extension) and not f.startswith('.'))]; 51 | return sorted(files); 52 | 53 | def csv_read(csv_file,has_header=False): 54 | rows = [] 55 | with open(csv_file, 'r') as csvfile: 56 | file_content = csv.reader(csvfile) 57 | if has_header: 58 | header = next(file_content, None) # skip the headers 59 | for row in file_content: 60 | rows.append(row) 61 | 62 | return rows 63 | 64 | def csv_write(csv_file,rows): 65 | with open(csv_file, mode='w') as file: 66 | rows_writer = csv.writer(file, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL) 67 | for row in rows: 68 | rows_writer.writerow(row) 69 | 70 | 71 | 72 | def txt_read(path): 73 | with open(path) as f: 74 | content = f.readlines() 75 | lines = [x.strip() for x in content] 76 | return lines; 77 | 78 | def txt_write(path,lines,mode='w'): 79 | out_file = open(path, mode) 80 | for line in lines: 81 | out_file.write(line) 82 | out_file.write('\n') 83 | out_file.close() 84 | 85 | def pkl_write(path,data): 86 | pickle.dump(data, open(path, "wb")) 87 | 88 | 89 | def hot_one_vector(y, max): 90 | 91 | labels_hot_vector = np.zeros((y.shape[0], max),dtype=np.int32) 92 | labels_hot_vector[np.arange(y.shape[0]), y] = 1 93 | return labels_hot_vector 94 | 95 | def pkl_read(path): 96 | if(not os.path.exists(path)): 97 | return None 98 | 99 | data = pickle.load(open(path, 'rb')) 100 | return data 101 | 102 | def touch_dir(path): 103 | if(not os.path.exists(path)): 104 | os.makedirs(path,exist_ok=True) 105 | 106 | def touch_file_dir(file_path): 107 | if not os.path.exists(os.path.dirname(file_path)): 108 | try: 109 | os.makedirs(os.path.dirname(file_path),exist_ok=True) 110 | except OSError as exc: # Guard against race condition 111 | if exc.errno != errno.EEXIST: 112 | raise 113 | 114 | 115 | 116 | 117 | def last_tuple_idx(path): 118 | files =[f for f in os.listdir(path) if (f.endswith('.jpg') and not f.startswith('.'))] 119 | return len(files) 120 | 121 | def get_file_name_ext(inputFilepath): 122 | filename_w_ext = os.path.basename(inputFilepath) 123 | filename, file_extension = os.path.splitext(filename_w_ext) 124 | return filename, file_extension 125 | 126 | def get_latest_file(path,extension=''): 127 | files = get_files(path,extension=extension,append_base=True) 128 | return max(files, key=os.path.getctime) 129 | 130 | def dir_empty(path): 131 | if os.listdir(path) == []: 132 | return True 133 | else: 134 | return False 135 | 136 | def chkpt_exists(path): 137 | files = [f for f in os.listdir(path) if (f.find('.ckpt') > 0 and not f.startswith('.'))] 138 | if len(files): 139 | return True 140 | return False 141 | 142 | def ask_yes_no_question(question): 143 | print(question+ ' [y/n] ') 144 | while True: 145 | answer = input() 146 | if answer.lower() in ['y','yes']: 147 | return True 148 | elif answer.lower() in ['n','no']: 149 | return False 150 | print('Please Enter a valid answer') 151 | 152 | def file_size(file): 153 | return os.path.getsize(file) -------------------------------------------------------------------------------- /utils/path_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | import constants 4 | import os.path as osp 5 | from datetime import datetime 6 | 7 | 8 | def get_checkpoint_dir(): 9 | 10 | project_name = osp.basename(osp.abspath('./')) 11 | ckpt_dir = constants.checkpoints_dir 12 | assert osp.exists(ckpt_dir),('{} does not exists'.format(ckpt_dir)) 13 | 14 | ckpt_dir = f'{ckpt_dir}/{project_name}' 15 | return ckpt_dir 16 | 17 | 18 | 19 | def get_datasets_dir(dataset_name): 20 | datasets_dir = constants.datasets_dir 21 | 22 | assert osp.exists(datasets_dir),('{} does not exists'.format(datasets_dir)) 23 | if dataset_name == 'CUB200' or dataset_name == 'CUB200_RET': 24 | dataset_dir = 'CUB_200_2011' 25 | elif dataset_name == 'CARS_RET': 26 | dataset_dir = 'stanford_cars' 27 | elif dataset_name == 'stanford': 28 | dataset_dir = 'Stanford_Online_Products' 29 | elif dataset_name == 'imagenet': 30 | dataset_dir = 'imagenet/ILSVRC/Data/CLS-LOC' 31 | elif dataset_name == 'market': 32 | dataset_dir = 'Market-1501-v15.09.15' 33 | elif dataset_name == 'Flower102' or dataset_name == 'Flower102Pytorch': 34 | dataset_dir = 'flower102' 35 | elif dataset_name == 'HAM': 36 | dataset_dir = 'HAM' 37 | elif dataset_name == 'FCAM': 38 | dataset_dir = 'FCAM' 39 | elif dataset_name == 'FCAMD': 40 | dataset_dir = 'FCAMD' 41 | elif dataset_name == 'Dog120': 42 | dataset_dir = 'stanford_dogs' 43 | elif dataset_name in ['MIT67','MINI_MIT67']: 44 | dataset_dir = 'mit67' 45 | elif dataset_name == 'Aircraft100' or dataset_name == 'Aircraft100Pytorch': 46 | dataset_dir = 'aircrafts' 47 | elif dataset_name == 'ImageNet': 48 | dataset_dir = 'imagenet/ILSVRC/Data/CLS-LOC' 49 | else: 50 | raise NotImplementedError('Invalid dataset name {}'.format(dataset_name)) 51 | 52 | datasets_dir = '{}/{}'.format(datasets_dir, dataset_dir) 53 | 54 | return datasets_dir 55 | 56 | 57 | 58 | def get_directories(args,generation): 59 | # if args.config_file is None or args.name is None: 60 | if args.config_file is None and args.name is None: 61 | raise ValueError("Must have name and config") 62 | 63 | # config = pathlib.Path(args.config_file).stem 64 | config = args.name 65 | if args.log_dir is None: 66 | run_base_dir = pathlib.Path( 67 | f"{get_checkpoint_dir()}/{args.name}/gen_{generation}/split_rate={args.split_rate}" 68 | ) 69 | else: 70 | run_base_dir = pathlib.Path( 71 | f"{args.log_dir}/{args.name}/gen_{generation}/split_rate={args.split_rate}" 72 | ) 73 | 74 | def _run_dir_exists(run_base_dir): 75 | log_base_dir = run_base_dir / "logs" 76 | ckpt_base_dir = run_base_dir / "checkpoints" 77 | 78 | return log_base_dir.exists() or ckpt_base_dir.exists() 79 | 80 | # if _run_dir_exists(run_base_dir): 81 | rep_count = 0 82 | while _run_dir_exists(run_base_dir / '{:04d}_g{:01d}'.format(rep_count,args.gpu)): 83 | rep_count += 1 84 | 85 | # date_time_int = int(datetime.now().strftime('%Y%m%d%H%M')) 86 | run_base_dir = run_base_dir / '{:04d}_g{:01d}'.format(rep_count,args.gpu) 87 | 88 | log_base_dir = run_base_dir / "logs" 89 | ckpt_base_dir = run_base_dir / "checkpoints" 90 | 91 | if not run_base_dir.exists(): 92 | os.makedirs(run_base_dir,exist_ok=True) 93 | 94 | (run_base_dir / "settings.txt").write_text(str(args)) 95 | 96 | return run_base_dir, ckpt_base_dir, log_base_dir 97 | 98 | 99 | if __name__ == '__main__': 100 | print(get_checkpoint_dir('test_exp')) 101 | print(get_datasets_dir('cub')) -------------------------------------------------------------------------------- /utils/schedulers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | __all__ = ["multistep_lr", "cosine_lr", "constant_lr", "get_policy",'step_lr'] 4 | 5 | 6 | def get_policy(name): 7 | if name is None: 8 | return constant_lr 9 | 10 | out_dict = { 11 | "constant_lr": constant_lr, 12 | "cosine_lr": cosine_lr, 13 | "multistep_lr": multistep_lr, 14 | "step_lr": step_lr, 15 | } 16 | 17 | return out_dict[name] 18 | 19 | 20 | def assign_learning_rate(optimizer, new_lr): 21 | for param_group in optimizer.param_groups: 22 | param_group["lr"] = new_lr 23 | 24 | 25 | def constant_lr(optimizer, args, **kwargs): 26 | def _lr_adjuster(epoch, iteration): 27 | if epoch < args.warmup_length: 28 | lr = _warmup_lr(args.lr, args.warmup_length, epoch) 29 | else: 30 | lr = args.lr 31 | 32 | assign_learning_rate(optimizer, lr) 33 | 34 | return lr 35 | 36 | return _lr_adjuster 37 | 38 | 39 | def cosine_lr(optimizer, args, **kwargs): 40 | def _lr_adjuster(epoch, iteration): 41 | if epoch < args.warmup_length: 42 | lr = _warmup_lr(args.lr, args.warmup_length, epoch) 43 | else: 44 | e = epoch - args.warmup_length 45 | es = args.epochs - args.warmup_length 46 | lr = 0.5 * (1 + np.cos(np.pi * e / es)) * args.lr 47 | 48 | assign_learning_rate(optimizer, lr) 49 | 50 | return lr 51 | 52 | return _lr_adjuster 53 | 54 | def step_lr(optimizer, cfg, **kwargs): 55 | def _lr_adjuster(epoch, iteration): 56 | lr = cfg.lr 57 | if epoch >= 0.5 * cfg.epochs: 58 | lr /= 10 59 | if epoch >= 0.75 * cfg.epochs: 60 | lr /= 10 61 | 62 | assign_learning_rate(optimizer, lr) 63 | 64 | return lr 65 | 66 | return _lr_adjuster 67 | 68 | 69 | def multistep_lr(optimizer, args, **kwargs): 70 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 71 | 72 | def _lr_adjuster(epoch, iteration): 73 | lr = args.lr * (args.lr_gamma ** (epoch // args.lr_adjust)) 74 | 75 | assign_learning_rate(optimizer, lr) 76 | 77 | return lr 78 | 79 | return _lr_adjuster 80 | 81 | 82 | def _warmup_lr(base_lr, warmup_length, epoch): 83 | return base_lr * (epoch + 1) / warmup_length 84 | --------------------------------------------------------------------------------