├── LICENSE ├── README.md ├── commands ├── __init__.py ├── parallel.py ├── test.py └── train.py ├── config.yaml ├── images └── iterand.png ├── main.py ├── models ├── __init__.py ├── networks │ ├── __init__.py │ ├── convs.py │ ├── resnet.py │ └── sparse_modules.py └── supervised_learning.py ├── requirements.txt └── utils ├── __init__.py ├── datasets.py ├── filelock.py ├── info.py ├── output_manager.py ├── pd_logger.py ├── schedulers.py ├── seed.py ├── subset_dataset.py ├── sync_jobs.py └── test_info.py /LICENSE: -------------------------------------------------------------------------------- 1 | SOFTWARE LICENSE AGREEMENT FOR EVALUATION 2 | 3 | This SOFTWARE EVALUATION LICENSE AGREEMENT (this "Agreement") is a legal 4 | contract between a person who uses or otherwise accesses or installs the 5 | Software (“User(s)”), and Nippon Telegraph and Telephone corporation 6 | ("NTT"). 7 | 8 | READ THE TERMS AND CONDITIONS OF THIS AGREEMENT CAREFULLY BEFORE 9 | INSTALLING OR OTHERWISE ACCESSING OR USING NTT'S PROPRIETARY SOFTWARE 10 | ACCOMPANIED BY THIS AGREEMENT (the "SOFTWARE"). THE SOFTWARE IS 11 | COPYRIGHTED AND IT IS LICENSED TO USER UNDER THIS AGREEMENT, NOT SOLD TO 12 | USER. BY INSTALLING OR OTHERWISE ACCESSING OR USING THE SOFTWARE, USER 13 | ACKNOWLEDGES THAT USER HAS READ THIS AGREEMENT, THAT USER UNDERSTANDS 14 | IT, AND THAT USER ACCEPTS AND AGREES TO BE BOUND BY ITS TERMS. IF AT ANY 15 | TIME USER IS NOT WILLING TO BE BOUND BY THE TERMS OF THIS AGREEMENT, 16 | USER SHOULD TERMINATE THE INSTALLATION PROCESS, IMMEDIATELY CEASE AND 17 | REFRAIN FROM ACCESSING OR USING THE SOFTWARE AND DELETE ANY COPIES USER 18 | MAY HAVE. THIS AGREEMENT REPRESENTS THE ENTIRE AGREEMENT BETWEEN USER 19 | AND NTT CONCERNING THE SOFTWARE. 20 | 21 | BACKGROUND 22 | 23 | A. NTT is the owner of all rights, including all patent rights, 24 | copyrights and trade secret rights, in and to the Software and 25 | related documentation listed in Exhibit A to this Agreement. 26 | 27 | B. User wishes to obtain a royalty free license to use the Software to 28 | enable User to evaluate, and NTT wishes to grant such a license to 29 | User, pursuant and subject to the terms and conditions of this 30 | Agreement. 31 | 32 | C. As a condition to NTT's provision of the Software to User, NTT has 33 | required User to execute this Agreement. 34 | 35 | In consideration of these premises, and the mutual promises and 36 | conditions in this Agreement, the parties hereby agree as follows: 37 | 38 | 39 | 40 | 1. GRANT OF EVALUATION LICENSE. NTT HEREBY GRANTS TO USER, AND USER HEREBY ACCEPTS, UNDER THE TERMS AND CONDITIONS OF THIS AGREEMENT, A ROYALTY FREE, NONTRANSFERABLE AND NONEXCLUSIVE LICENSE TO USE THE SOFTWARE INTERNALLY FOR THE PURPOSES OF TESTING, ANALYZING, AND EVALUATING THE METHODS OR MECHANISMS AS SHOWN IN THE RESEARCH PAPER SUBMITTED BY NTT TO A CERTAIN ACADEMY. USER MAY MAKE A REASONABLE NUMBER OF BACKUP COPIES OF THE SOFTWARE SOLELY FOR USER'S INTERNAL USE PURSUANT TO THE LICENSE GRANTED IN THIS SECTION 1. 41 | 42 | 43 | 2. Shipment and Installation. NTT will ship or deliver the Software by 44 | any method that NTT deems appropriate. User shall be solely responsible 45 | for proper installation of the Software. 46 | 47 | 48 | 49 | 3. TERM. THIS AGREEMENT IS EFFECTIVE WHICHEVER IS EARLIER (I) UPON USER’S ACCEPTANCE OF THE AGREEMENT, OR (II) UPON USER’S INSTALLING, ACCESSING, AND USING THE SOFTWARE, EVEN IF USER HAS NOT EXPRESSLY ACCEPTED THIS AGREEMENT. WITHOUT PREJUDICE TO ANY OTHER RIGHTS, NTT MAY TERMINATE THIS AGREEMENT WITHOUT NOTICE TO USER (I) IF USER BREACHES OR FAILS TO COMPLY WITH ANY OF THE LIMITATIONS OR OTHER REQUIREMENTS DESCRIBED HEREIN, AND (II) IF NTT RECEIVES A NOTICE FROM THE ACADEMY STATING THAT THE RESEARCH PAPER WOULD NOT BE PUBLISHED, AND IN ANY SUCH CASE USER AGREES THAT NTT MAY, IN ADDITION TO ANY OTHER REMEDIES IT MAY HAVE AT LAW OR IN EQUITY, REMOTELY DISABLE THE SOFTWARE. USER MAY TERMINATE THIS AGREEMENT AT ANY TIME BY USER’S DECISION TO TERMINATE THE AGREEMENT TO NTT AND CEASING USE OF THE SOFTWARE. UPON ANY TERMINATION OR EXPIRATION OF THIS AGREEMENT FOR ANY REASON, USER AGREES TO UNINSTALL THE SOFTWARE AND EITHER RETURN TO NTT THE SOFTWARE AND ALL COPIES THEREOF, OR TO DESTROY ALL SUCH MATERIALS AND PROVIDE WRITTEN VERIFICATION OF SUCH DESTRUCTION TO NTT. 50 | 51 | 52 | 53 | 4. PROPRIETARY RIGHTS 54 | 55 | 56 | (a) The Software is the valuable, confidential, and proprietary property 57 | of NTT, and NTT shall retain exclusive title to this property both 58 | during the term and after the termination of this Agreement. Without 59 | limitation, User acknowledges that all patent rights, copyrights and 60 | trade secret rights in the Software shall remain the exclusive property 61 | of NTT at all times. User shall use not less than reasonable care in 62 | safeguarding the confidentiality of the Software. 63 | 64 | 65 | (b) USER SHALL NOT, IN WHOLE OR IN PART, AT ANY TIME DURING THE TERM OF OR AFTER THE TERMINATION OF THIS AGREEMENT: (i) SELL, ASSIGN, LEASE, DISTRIBUTE, OR OTHERWISE TRANSFER THE SOFTWARE TO ANY THIRD PARTY; (ii) EXCEPT AS OTHERWISE PROVIDED HEREIN, COPY OR REPRODUCE THE SOFTWARE IN ANY MANNER; (iii) DISCLOSE THE SOFTWARE TO ANY THIRD PARTY, EXCEPT TO USER'S EMPLOYEES WHO REQUIRE ACCESS TO THE SOFTWARE FOR THE PURPOSES OF THIS AGREEMENT; (iv) MODIFY, DISASSEMBLE, DECOMPILE, REVERSE ENGINEER OR TRANSLATE THE SOFTWARE; OR (v) ALLOW ANY PERSON OR ENTITY TO COMMIT ANY OF THE ACTIONS DESCRIBED IN (i) THROUGH (iv) ABOVE. 66 | 67 | 68 | (c) User shall take appropriate action, by instruction, agreement, or otherwise, with respect to its employees permitted under this Agreement to have access to the Software to ensure that all of User's obligations under this Section 4 shall be satisfied. 69 | 70 | 71 | 72 | 5.  INDEMNITY. USER SHALL DEFEND, INDEMNIFY AND HOLD HARMLESS NTT, ITS AGENTS AND EMPLOYEES, FROM ANY LOSS, DAMAGE, OR LIABILITY ARISING IN CONNECTION WITH USER'S IMPROPER OR UNAUTHORIZED USE OF THE SOFTWARE. NTT SHALL HAVE THE SOLE RIGHT TO CONDUCT DEFEND ANY ACTTION RELATING TO THE SOFTWARE. 73 | 74 | 75 | 76 | 6. DISCLAIMER. THE SOFTWARE IS LICENSED TO USER "AS IS," WITHOUT ANY TRAINING, MAINTENANCE, OR SERVICE OBLIGATIONS WHATSOEVER ON THE PART OF NTT. NTT MAKES NO EXPRESS OR IMPLIED WARRANTIES OF ANY TYPE WHATSOEVER, INCLUDING WITHOUT LIMITATION THE IMPLIED WARRANTIES OF MERCHANTABILITY, OF FITNESS FOR A PARTICULAR PURPOSE AND OF NON-INFRINGEMENT ON COPYRIGHT OR ANY OTHER RIGHT OF THIRD PARTIES. USER ASSUMES ALL RISKS ASSOCIATED WITH ITS USE OF THE SOFTWARE, INCLUDING WITHOUT LIMITATION RISKS RELATING TO QUALITY, PERFORMANCE, DATA LOSS, AND UTILITY IN A PRODUCTION ENVIRONMENT. 77 | 78 | 79 | 80 | 7. LIMITATION OF LIABILITY. IN NO EVENT SHALL NTT BE LIABLE TO USER OR TO ANY THIRD PARTY FOR ANY INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING BUT NOT LIMITED TO DAMAGES FOR PERSONAL INJURY, PROPERTY DAMAGE, LOST PROFITS, OR OTHER ECONOMIC LOSS, ARISING IN CONNECTION WITH USER'S USE OF OR INABILITY TO USE THE SOFTWARE, IN CONNECTION WITH NTT'S PROVISION OF OR FAILURE TO PROVIDE SERVICES PERTAINING TO THE SOFTWARE, OR AS A RESULT OF ANY DEFECT IN THE SOFTWARE. THIS DISCLAIMER OF LIABILITY SHALL APPLY REGARD­LESS OF THE FORM OF ACTION THAT MAY BE BROUGHT AGAINST NTT, WHETHER IN CONTRACT OR TORT, INCLUDING WITHOUT LIMITATION ANY ACTION FOR NEGLIGENCE. USER'S SOLE REMEDY IN THE EVENT OF ANY BREACH OF THIS AGREEMENT BY NTT SHALL BE TERMINATION PURSUANT TO SECTION 3. 81 | 82 | 83 | 84 | 8. NO ASSIGNMENT OR SUBLICENSE. NEITHER THIS AGREEMENT NOR ANY RIGHT OR LICENSE UNDER THIS AGREEMENT, NOR THE SOFTWARE, MAY BE SUBLICENSED, ASSIGNED, OR OTHERWISE TRANSFERRED BY USER WITHOUT NTT'S PRIOR WRITTEN CONSENT. 85 | 86 | 87 | 88 | 9. GENERAL 89 | 90 | 91 | 92 | (A) IF ANY PROVISION, OR PART OF A PROVISION, OF THIS AGREEMENT IS OR BECOMES ILLEGAL, UNENFORCEABLE, OR INVALIDATED, BY OPERATION OF LAW OR OTHERWISE, THAT PROVISION OR PART SHALL TO THAT EXTENT BE DEEMED OMITTED, AND THE REMAINDER OF THIS AGREEMENT SHALL REMAIN IN FULL FORCE AND EFFECT. 93 | 94 | 95 | (b) This Agreement is the complete and exclusive statement of the agreement between the parties with respect to the subject matter hereof, and supersedes all written and oral contracts, proposals, and other communications between the parties relating to that subject matter. 96 | 97 | 98 | (c) Subject to Section 8, this Agreement shall be binding on, and shall inure to the benefit of, the respective successors and assigns of NTT and User. 99 | 100 | 101 | (d) If either party to this Agreement initiates a legal action or proceeding to enforce or interpret any part of this Agreement, the prevailing party in such action shall be entitled to recover, as an element of the costs of such action and not as damages, its attorneys' fees and other costs associated with such action or proceeding. 102 | 103 | 104 | (e) This Agreement shall be governed by and interpreted under the laws of Japan, without reference to conflicts of law principles. All disputes arising out of or in connection with this Agreement shall be finally settled by arbitration in Tokyo in accordance with the Commercial Arbitration Rules of the Japan Commercial Arbitration Association. The arbitration shall be conducted by three (3) arbitrators and in Japanese. The award rendered by the arbitrators shall be final and binding upon the parties. Judgment upon the award may be entered in any court having jurisdiction thereof. 105 | 106 | (f)   NTT shall not be liable to the User or to any third party for 107 | any delay or failure to perform NTT’s obligation set forth under this 108 | Agreement due to any cause beyond NTT’s reasonable control. 109 | 110 | EXHIBIT A 111 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pruning Randomly Initialized Neural Networks with Iterative Randomization 2 | 3 | by Daiki Chijiwa\*, Shin’ya Yamaguchi, Yasutoshi Ida, Kenji Umakoshi, Tomohiro Inoue 4 | 5 | ArXiv: https://arxiv.org/abs/2106.09269 6 | 7 | ![iterand](images/iterand.png) 8 | 9 | ## Requirements 10 | 11 | To install requirements (for Python 3.7 & NVIDIA CUDA 10.2): 12 | 13 | ```setup 14 | pip install -r requirements.txt 15 | ``` 16 | 17 | ## Usage 18 | ``` 19 | python main.py 20 | ``` 21 | 22 | - `` is one of `train`, `test`, and `parallel`. `train` and `test` can be used to train/test single model, and `parallel` can be used to reproduce our experiments or to search hyperparameters. 23 | - `` is the filename of a YAML file. For this, we have `config.yaml`. 24 | - `` is one of the keys defined in the `` file. 25 | 26 | 27 | ## Train Single Model 28 | 29 | To train a network, we can simply execute `train` command with `` in `` file. 30 | For example, to train ResNet18 on CIFAR-10 with SGD/edge-popup/IteRand, run the following commands: 31 | ``` 32 | python main.py train config.yaml cifar10_resnet18_ku_sgd 33 | ``` 34 | ``` 35 | python main.py train config.yaml cifar10_resnet18_sc_edgepopup 36 | ``` 37 | ``` 38 | python main.py train config.yaml cifar10_resnet18_sc_iterand 39 | ``` 40 | 41 | In the end of experiments, the program automatically evaluate the model on test dataset. 42 | 43 | NOTE: We should not use/see this final result for searching hyperparameters. During our research, the final evaluation on test set was conducted only after fixing hyperparamters. 44 | 45 | 46 | 47 | ## Reproduce Experimental Results 48 | 49 | For each method used in figures in our paper, we provide the corresponding experimental setting as `figure_` in `config.yaml`. 50 | We can run the experiments by `parallel` command: 51 | ``` 52 | python main.py parallel config.yaml figure_ 53 | ``` 54 | 55 | For example, the results for Conv6 w/ SGD on CIFAR-10 (in Figure 2) is obtained by: 56 | ``` 57 | python main.py parallel config.yaml figure2_conv6_ku_sgd 58 | ``` 59 | 60 | In the end of experiments, we can check the final results by: 61 | ``` 62 | python utils/test_info.py __outputs__/figure2_conv6_ku_sgd/ --epoch=99 63 | ``` 64 | 65 | To specify some hyperparameters in `parallel_grid` option in `figure_` experiment, 66 | we can use `train` command with command line options like: 67 | ``` 68 | python main.py train config.yaml figure2_resnet18_sc_iterand --model.config_name=resnet18x0.5 --conv_sparsity=0.6 --rerand_freq=300 --rerand_lambda=0.1 --weight_decay=0.0005 --seed=1 69 | ``` 70 | 71 | -------------------------------------------------------------------------------- /commands/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | 4 | for module in os.listdir(os.path.dirname(__file__)): 5 | if module == '__init__.py' or module[-3:] != '.py': 6 | continue 7 | importlib.import_module('.' + module[:-3], package='commands') 8 | -------------------------------------------------------------------------------- /commands/parallel.py: -------------------------------------------------------------------------------- 1 | 2 | import random 3 | import time 4 | import json 5 | from os.path import join, abspath, exists 6 | from os import makedirs 7 | import pprint 8 | 9 | from commands.train import train 10 | from utils.sync_jobs import JobManager 11 | from utils.filelock import FileLock, Timeout 12 | 13 | pp = pprint.PrettyPrinter(indent=1) 14 | 15 | def _check_job_running(jobman, sync_dir, name): 16 | coordinate_path = abspath(join(sync_dir, 'coordinate.json')) 17 | with open(coordinate_path, 'r+') as f: 18 | body = f.read() 19 | lis = json.loads(body) 20 | for i, (job_name, status, job_id, dic) in enumerate(lis): 21 | if job_name == name: 22 | return jobman.check_alive(job_id) 23 | 24 | def _next_job(jobman, sync_dir, my_id): 25 | coordinate_path = abspath(join(sync_dir, 'coordinate.json')) 26 | lock_path = abspath(join(sync_dir, 'coordinate.json.lock')) 27 | lock = FileLock(lock_path, timeout=1) 28 | 29 | with lock: 30 | if not exists(coordinate_path): 31 | with open(coordinate_path, "w") as f: 32 | f.write('[]') 33 | return None, None 34 | 35 | with lock: 36 | with open(coordinate_path, 'r+') as f: 37 | body = f.read() 38 | lis = json.loads(body) 39 | print('\n[Parallel] All hyperparams:') 40 | pp.pprint(lis) 41 | print('') 42 | for i, (job_name, status, job_id, dic) in enumerate(lis): 43 | if status == "not completed": 44 | if not jobman.check_alive(job_id): 45 | lis[i] = (job_name, status, my_id, dic) 46 | f.seek(0) 47 | f.write(json.dumps(lis)) 48 | f.truncate() 49 | return job_name, dic 50 | elif status == "completed": 51 | pass 52 | else: 53 | raise NotImplementedError 54 | return None, None 55 | 56 | def _update_job(sync_dir, job_name, status, job_id, dic): 57 | coordinate_path = abspath(join(sync_dir, 'coordinate.json')) 58 | assert exists(coordinate_path) 59 | 60 | lock_path = abspath(join(sync_dir, 'coordinate.json.lock')) 61 | lock = FileLock(lock_path, timeout=1) 62 | with lock: 63 | with open(coordinate_path, 'r+') as f: 64 | lis = json.loads(f.read()) 65 | for i, (name, _, _, _) in enumerate(lis): 66 | if name == job_name: 67 | lis[i] = (job_name, status, job_id, dic) 68 | f.seek(0) 69 | f.write(json.dumps(lis)) 70 | f.truncate() 71 | return 72 | lis.append((job_name, status, job_id, dic)) 73 | f.seek(0) 74 | f.write(json.dumps(lis)) 75 | f.truncate() 76 | 77 | def _count_jobs(sync_dir): 78 | coordinate_path = abspath(join(sync_dir, 'coordinate.json')) 79 | assert exists(coordinate_path) 80 | 81 | with open(coordinate_path, 'r') as f: 82 | lis = json.loads(f.read()) 83 | return len(lis) 84 | 85 | def parallel(exp_name, cfg): 86 | sync_dir = cfg['sync_dir'] 87 | sync_dir = abspath(join(sync_dir, exp_name)) 88 | if not exists(sync_dir): 89 | try: 90 | makedirs(sync_dir) 91 | except Exception as e: 92 | print('[Parallel] Caught Exception:', e.args) 93 | 94 | now = time.time() 95 | job_id = str(now) 96 | search_space = cfg['parallel_grid'] 97 | assert search_space is not None 98 | 99 | jobman = JobManager(sync_dir, job_id, sync_script_path="utils/sync_jobs.py") 100 | jobman.clear() 101 | jobman.start() 102 | 103 | completed_dics = [] 104 | job_counter = 0 105 | no_job_counter = 0 106 | parallel_counter = 0 107 | max_jobs = 1 108 | max_no_job = 100 109 | max_parallel = 1 110 | for k, cands in search_space.items(): 111 | max_jobs *= len(cands) 112 | 113 | while True: 114 | if job_counter >= max_jobs: 115 | print("[Parallel] Stop searching because we execute all patterns") 116 | break 117 | if no_job_counter >= max_no_job: 118 | if parallel_counter >= max_parallel: 119 | print("[Parallel] Stop searching because searched for enough time") 120 | break 121 | else: 122 | print("[Parallel] Wait for 6 minutes...") 123 | time.sleep(60 * 6) 124 | 125 | no_job_counter = 0 126 | parallel_counter += 1 127 | 128 | # resume to train existing hyperparameters 129 | while True: 130 | resume_job_name, dic = _next_job(jobman, sync_dir, job_id) 131 | if resume_job_name is not None: 132 | for k in dic: 133 | cfg[k] = dic[k] 134 | print('[Parallel] Resume job:', resume_job_name) 135 | try: 136 | train(exp_name, cfg, prefix=resume_job_name) 137 | _update_job(sync_dir, resume_job_name, 'completed', job_id, dic) 138 | 139 | completed_dics.append(dic.copy()) 140 | job_counter += 1 141 | no_job_counter = 0 142 | parallel_counter = 0 143 | except: 144 | time.sleep(10) 145 | else: 146 | break 147 | 148 | # search for new hyperparameters 149 | print("[Parallel] Search new hyperparamter...") 150 | random.seed() 151 | sampled_dic = {} 152 | for k, cands in search_space.items(): 153 | sampled_dic[k] = random.sample(cands,k=1)[0] 154 | 155 | # train with new hyperparameters 156 | job_name = "" 157 | for k, v in sampled_dic.items(): 158 | job_name += f"{k}_{v}--" 159 | cfg[k] = v 160 | 161 | if _check_job_running(jobman, sync_dir, job_name): 162 | no_job_counter += 1 163 | elif sampled_dic not in completed_dics: 164 | _update_job(sync_dir, job_name, 'not completed', job_id, sampled_dic) 165 | print('[Parallel] Start job:', job_name) 166 | train(exp_name, cfg, prefix=job_name) 167 | _update_job(sync_dir, job_name, 'completed', job_id, sampled_dic) 168 | 169 | completed_dics.append(sampled_dic.copy()) 170 | job_counter += 1 171 | no_job_counter = 0 172 | parallel_counter = 0 173 | else: 174 | no_job_counter += 1 175 | 176 | jobman.stop() 177 | 178 | -------------------------------------------------------------------------------- /commands/test.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import time 4 | import datetime 5 | import json 6 | 7 | import torch 8 | from utils.seed import set_random_seed 9 | from utils.output_manager import OutputManager 10 | from utils.pd_logger import PDLogger 11 | from torch.nn import DataParallel 12 | 13 | from models.supervised_learning import SupervisedLearning 14 | 15 | def test(exp_name, cfg, prefix="", epoch=None, use_best=False): 16 | set_random_seed(cfg['seed']) 17 | device = torch.device('cuda:0' if cfg['use_cuda'] and torch.cuda.is_available() else 'cpu') 18 | outman = OutputManager(cfg['output_dir'], exp_name) 19 | 20 | outman.print('Number of available gpus: ', torch.cuda.device_count(), prefix=prefix) 21 | 22 | if cfg['learning_framework'] == 'SupervisedLearning': 23 | learner = SupervisedLearning(outman, cfg, device, cfg['data_parallel']) 24 | else: 25 | raise NotImplementedError 26 | 27 | 28 | if use_best: 29 | dump_path = outman.get_abspath(prefix=f"best.{prefix}", ext="pth") 30 | elif epoch is not None: 31 | dump_path = outman.get_abspath(prefix=f'epoch{epoch}.{prefix}', ext="pth") 32 | else: 33 | dump_path = outman.get_abspath(prefix=f"dump.{prefix}", ext="pth") 34 | 35 | outman.print(dump_path, prefix=prefix) 36 | if os.path.exists(dump_path): 37 | try: 38 | dump_dict = torch.load(dump_path) 39 | epoch = dump_dict['epoch'] 40 | if isinstance(learner.model, DataParallel): 41 | learner.model.module.load_state_dict(dump_dict['model_state_dict']) 42 | else: 43 | learner.model.load_state_dict(dump_dict['model_state_dict']) 44 | except Exception as e: 45 | print("[train.py] catched unexpected error in loading checkpoint:", str(e)) 46 | print("[train.py] start training from scratch") 47 | else: 48 | raise Exception 49 | 50 | outman.print('[', str(datetime.datetime.now()) , '] Evaluate on Test Dataset...' , prefix=prefix) 51 | 52 | # Test 53 | result = learner.evaluate(dataset_type='test') 54 | if use_best: 55 | outman.print('Test Accuracy (Best):', str(result['accuracy']), prefix=prefix) 56 | else: 57 | outman.print('Test Accuracy:', str(result['accuracy']), prefix=prefix) 58 | 59 | test_info_dict = { 60 | 'accuracy': result['accuracy'], 61 | 'epoch': epoch, 62 | 'loss': result['loss'], 63 | 'prefix': prefix, 64 | } 65 | 66 | if use_best: 67 | output_path = outman.get_abspath(prefix=f"test_best.{prefix}", ext="json") 68 | else: 69 | output_path = outman.get_abspath(prefix=f"test_epoch{epoch}.{prefix}", ext="json") 70 | 71 | with open(output_path, 'w') as f: 72 | json.dump(test_info_dict, f, indent=2) 73 | 74 | -------------------------------------------------------------------------------- /commands/train.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import time 4 | import datetime 5 | import json 6 | 7 | import torch 8 | from utils.seed import set_random_seed 9 | from utils.output_manager import OutputManager 10 | from utils.pd_logger import PDLogger 11 | from torch.nn import DataParallel 12 | from commands.test import test 13 | 14 | from models.supervised_learning import SupervisedLearning 15 | 16 | def count_params(model): 17 | count = 0 18 | count_not_score = 0 19 | count_reduced = 0 20 | for (n,p) in model.named_parameters(): 21 | count += p.flatten().size(0) 22 | if hasattr(p, 'is_score') and p.is_score: 23 | print(n+':', int(p.flatten().size(0) * (1.0 - p.sparsity)), '/', p.flatten().size(0), '(sparsity =', p.sparsity,')') 24 | count_reduced += int(p.flatten().size(0) * p.sparsity) 25 | else: 26 | print(n+':',p.flatten().size(0)) 27 | count_not_score += p.flatten().size(0) 28 | count_after_pruning = count_not_score - count_reduced 29 | total_sparsity = 1 - (count_after_pruning / count_not_score) 30 | print('Params after/before pruned:\t', count_after_pruning, '/', count_not_score, '(sparsity: ' + str(total_sparsity) +')') 31 | print('Total Params:\t', count) 32 | return { 33 | 'params_after_pruned': count_after_pruning, 34 | 'params_before_pruned': count_not_score, 35 | 'total_params': count, 36 | 'sparsity': total_sparsity, 37 | } 38 | 39 | def train(exp_name, cfg, prefix=""): 40 | if cfg['seed'] is not None: 41 | set_random_seed(cfg['seed']) 42 | elif cfg['seed_by_time']: 43 | set_random_seed(int(time.time() * 1000) % 1000000) 44 | else: 45 | raise Exception("Set seed value.") 46 | device = torch.device('cuda:0' if cfg['use_cuda'] and torch.cuda.is_available() else 'cpu') 47 | outman = OutputManager(cfg['output_dir'], exp_name) 48 | 49 | dump_path = outman.get_abspath(prefix=f"dump.{prefix}", ext="pth") 50 | 51 | outman.print('Number of available gpus: ', torch.cuda.device_count(), prefix=prefix) 52 | 53 | pd_logger = PDLogger() 54 | pd_logger.set_filename(outman.get_abspath(prefix=f"pd_log.{prefix}", ext="pickle")) 55 | if os.path.exists(pd_logger.filename) and not cfg['force_restart']: 56 | pd_logger.load() 57 | 58 | if cfg['learning_framework'] == 'SupervisedLearning': 59 | learner = SupervisedLearning(outman, cfg, device, cfg['data_parallel']) 60 | else: 61 | raise NotImplementedError 62 | 63 | params_info = count_params(learner.model) 64 | 65 | best_value = None 66 | best_epoch = 0 67 | start_epoch = 0 68 | total_iters = 0 69 | total_seconds = 0. 70 | 71 | outman.print(dump_path, prefix=prefix) 72 | if os.path.exists(dump_path) and not cfg['force_restart']: 73 | try: 74 | dump_dict = torch.load(dump_path) 75 | start_epoch = dump_dict['epoch'] + 1 76 | total_iters = dump_dict['total_iters'] 77 | best_value = dump_dict['best_val'] 78 | best_epoch = dump_dict['best_epoch'] if 'best_epoch' in dump_dict else 0 79 | total_seconds = dump_dict['total_seconds'] if 'total_seconds' in dump_dict else 0. 80 | if isinstance(learner.model, DataParallel): 81 | learner.model.module.load_state_dict(dump_dict['model_state_dict']) 82 | else: 83 | learner.model.load_state_dict(dump_dict['model_state_dict']) 84 | learner.optimizer.load_state_dict(dump_dict['optim_state_dict']) 85 | if 'sched_state_dict' in dump_dict: 86 | learner.scheduler.load_state_dict(dump_dict['sched_state_dict']) 87 | except Exception as e: 88 | print("[train.py] catched unexpected error in loading checkpoint:", str(e)) 89 | print("[train.py] start training from scratch") 90 | elif cfg['load_checkpoint_path'] is not None: 91 | assert not os.path.exists(dump_path) 92 | assert os.path.exists(cfg['load_checkpoint_path']) 93 | try: 94 | checkpoint_dict = torch.load(cfg['load_checkpoint_path']) 95 | if isinstance(learner.model, DataParallel): 96 | learner.model.module.load_state_dict(checkpoint_dict['model_state_dict']) 97 | else: 98 | learner.model.load_state_dict(checkpoint_dict['model_state_dict']) 99 | #learner.optimizer.load_state_dict(checkpoint_dict['optim_state_dict']) 100 | #if 'sched_state_dict' in checkpoint_dict: 101 | # learner.scheduler.load_state_dict(checkpoint_dict['sched_state_dict']) 102 | except Exception as e: 103 | print("[train.py] catched unexpected error in loading checkpoint:", str(e)) 104 | print("[train.py] start training from scratch") 105 | 106 | # Define re-randomize callback 107 | if cfg['rerand_mode'] is not None and cfg['rerand_freq'] > 0: 108 | if cfg['rerand_freq_unit'] == 'epoch': 109 | def rerand_callback(model, epoch, it, iters_per_epoch): 110 | real_model = model.module if isinstance(model, DataParallel) else model 111 | if (it + 1) % int(iters_per_epoch / cfg['rerand_freq']) == 0: 112 | outman.print(f'[Train] rerandomized@{it}', prefix=prefix) 113 | real_model.rerandomize(cfg['rerand_mode'], cfg['rerand_lambda'], cfg['rerand_mu']) 114 | else: 115 | pass 116 | elif cfg['rerand_freq_unit'] == 'iteration': 117 | def rerand_callback(model, epoch, it, iters_per_epoch): 118 | real_model = model.module if isinstance(model, DataParallel) else model 119 | if (it + 1) % int(cfg['rerand_freq']) == 0: 120 | real_model.rerandomize(cfg['rerand_mode'], cfg['rerand_lambda'], cfg['rerand_mu']) 121 | else: 122 | pass 123 | if (it + 1) % iters_per_epoch == 0: 124 | outman.print(f'[Train] rerandomized per', cfg['rerand_freq'], 'iterations') 125 | else: 126 | raise NotImplementedError 127 | else: 128 | def rerand_callback(model, epoch, it, iters_per_epoch): 129 | pass 130 | 131 | # Training loop 132 | for epoch in range(start_epoch, cfg['epoch']): 133 | start_sec = time.time() 134 | 135 | outman.print('[', str(datetime.datetime.now()) , '] Epoch: ', str(epoch), prefix=prefix) 136 | 137 | # Train 138 | results_train = learner.train(epoch, total_iters, before_callback=rerand_callback) 139 | train_accuracy = results_train['moving_accuracy'] 140 | results_per_iter = results_train['per_iteration'] 141 | new_total_iters = results_train['iterations'] 142 | total_loss_train = results_train['loss'] 143 | 144 | pd_logger.add('train_accs', [train_accuracy], index=[epoch], columns=['train-acc']) 145 | outman.print('Train Accuracy:', str(train_accuracy), prefix=prefix) 146 | if cfg['print_train_loss']: 147 | outman.print('Train Loss:', str(total_loss_train), prefix=prefix) 148 | 149 | # Evaluate 150 | results_eval = learner.evaluate() 151 | val_accuracy = results_eval['accuracy'] 152 | pd_logger.add('val_accs', [val_accuracy], index=[epoch], columns=['val-acc']) 153 | outman.print('Val Accuracy:', str(val_accuracy), prefix=prefix) 154 | 155 | # Save train losses per iteration 156 | losses = [res['mean_loss'] for res in results_per_iter] 157 | index = list(range(total_iters, new_total_iters)) 158 | pd_logger.add('train_losses', losses, index=index) 159 | # Update total_iters 160 | total_iters = new_total_iters 161 | 162 | # Flag if save best model 163 | if (best_value is None) or (best_value < val_accuracy): 164 | best_value = val_accuracy 165 | best_epoch = epoch 166 | save_best_model = True 167 | else: 168 | save_best_model = False 169 | 170 | end_sec = time.time() 171 | total_seconds += end_sec - start_sec 172 | if isinstance(learner.model, DataParallel): 173 | model_state_dict = learner.model.module.state_dict() 174 | else: 175 | model_state_dict = learner.model.state_dict() 176 | dump_dict = { 177 | 'epoch': epoch, 178 | 'model_state_dict': model_state_dict, 179 | 'optim_state_dict': learner.optimizer.state_dict(), 180 | 'sched_state_dict': learner.scheduler.state_dict(), 181 | 'best_val': best_value, 182 | 'best_epoch': best_epoch, 183 | 'total_iters': total_iters, 184 | 'total_seconds': total_seconds, 185 | } 186 | info_dict = { 187 | 'last_val': val_accuracy, 188 | 'epoch': epoch, 189 | 'best_val': best_value, 190 | 'best_epoch': best_epoch, 191 | 'loss_train': total_loss_train, 192 | 'acc_train': train_accuracy, 193 | 'total_time': str(datetime.timedelta(seconds=int(total_seconds))), 194 | 'total_seconds': total_seconds, 195 | 'prefix': prefix, 196 | 'params_info': params_info, 197 | } 198 | outman.save_dict(dump_dict, prefix=f"dump.{prefix}", ext="pth") 199 | with open(outman.get_abspath(prefix=f"info.{prefix}", ext="json"), 'w') as f: 200 | json.dump(info_dict, f, indent=2) 201 | if save_best_model and cfg['save_best_model']: 202 | outman.save_dict(dump_dict, prefix=f"best.{prefix}", ext="pth") 203 | if epoch in cfg['checkpoint_epochs']: 204 | outman.save_dict(dump_dict, prefix=f'epoch{epoch}.{prefix}', ext='pth') 205 | 206 | pd_logger.save() 207 | 208 | if start_epoch + 1 <= cfg['epoch']: 209 | test(exp_name, cfg, prefix=prefix) 210 | 211 | -------------------------------------------------------------------------------- /config.yaml: -------------------------------------------------------------------------------- 1 | 2 | ######################################################## 3 | # General settings 4 | ######################################################## 5 | 6 | # -- Definitions of datasets -- 7 | 8 | cifar10: 9 | class: 'CIFAR10' 10 | data_type: 'image' 11 | num_channels: 3 12 | image_size: 32 13 | num_classes: 10 14 | train_val_split: 0.1 15 | 16 | imagenet: 17 | class: 'ImageNet' 18 | data_type: 'image' 19 | num_channels: 3 20 | image_size: 224 21 | num_classes: 1000 22 | train_val_split: 0.01 23 | 24 | # -- Definitions of networks -- 25 | 26 | conv6: 27 | class: 'Conv6' 28 | factor: 1.0 29 | conv6x0.25: 30 | class: 'Conv6' 31 | factor: 0.25 32 | conv6x0.5: 33 | class: 'Conv6' 34 | factor: 0.5 35 | conv6x1.0: 36 | class: 'Conv6' 37 | factor: 1.0 38 | conv6x2.0: 39 | class: 'Conv6' 40 | factor: 2.0 41 | 42 | resnet18: &__resnet18__ 43 | class: 'ResNet' 44 | block_class: 'BasicBlock' 45 | num_blocks: [2, 2, 2, 2] 46 | factor: 1.0 47 | resnet18x0.25: 48 | <<: *__resnet18__ 49 | factor: 0.25 50 | resnet18x0.5: 51 | <<: *__resnet18__ 52 | factor: 0.5 53 | resnet18x1.0: 54 | <<: *__resnet18__ 55 | factor: 1.0 56 | resnet18x2.0: 57 | <<: *__resnet18__ 58 | factor: 2.0 59 | 60 | resnet34: &__resnet34__ 61 | class: 'ResNet' 62 | block_class: 'BasicBlock' 63 | num_blocks: [3, 4, 6, 3] 64 | factor: 1.0 65 | resnet34x0.25: 66 | <<: *__resnet34__ 67 | factor: 0.25 68 | resnet34x0.5: 69 | <<: *__resnet34__ 70 | factor: 0.5 71 | resnet34x1.0: 72 | <<: *__resnet34__ 73 | factor: 1.0 74 | resnet34x2.0: 75 | <<: *__resnet34__ 76 | factor: 2.0 77 | 78 | resnet50: 79 | class: 'ResNet' 80 | block_class: 'Bottleneck' 81 | num_blocks: [3, 4, 6, 3] 82 | factor: 1.0 83 | 84 | resnet101: 85 | class: 'ResNet' 86 | block_class: 'Bottleneck' 87 | num_blocks: [3, 4, 23, 3] 88 | factor: 1.0 89 | 90 | 91 | # -- all options -- 92 | 93 | __default__: &__default__ 94 | 95 | # General Setting 96 | num_workers: 4 97 | use_cuda: True 98 | output_dir: '__outputs__' 99 | dataset_dir: '__data__' 100 | sync_dir: '__sync__' 101 | checkpoint_epochs: [] 102 | seed: null 103 | seed_by_time: false 104 | dataset_download: true 105 | num_gpus: 1 106 | debug_max_iters: null 107 | load_checkpoint_path: null 108 | 109 | dataset.config_name: null 110 | model.config_name: null 111 | save_best_model: true 112 | print_train_loss: false 113 | 114 | # Hyperparameters for Training 115 | epoch: null 116 | optimizer: "SGD" 117 | lr: null 118 | weight_decay: null 119 | lr_scheduler: null 120 | warmup_epochs: 0 121 | finetuning_epochs: 0 122 | finetuning_lr: null 123 | sgd_momentum: 0.9 124 | lr_milestones: null 125 | multisteplr_gamma: 0.1 126 | padding_before_crop: False 127 | 128 | learning_framework: "SupervisedLearning" 129 | train_mode: 'normal' 130 | 131 | batch_size: 128 132 | batch_size_eval: 512 133 | max_train_dataset_size: null 134 | bn_track_running_stats: True 135 | bn_affine: True 136 | bn_momentum: 0.1 137 | 138 | # Hyperparameters for edge-popup 139 | conv_sparsity: 0.0 140 | linear_sparsity: null 141 | init_mode: 'kaiming_uniform' 142 | init_mode_linear: null 143 | init_mode_mask: 'kaiming_uniform' 144 | init_scale: 1.0 145 | init_scale_score: 1.0 146 | 147 | # Hyperparameters for IteRand 148 | rerand_mode: null 149 | rerand_freq: 1 150 | rerand_freq_unit: "iteration" 151 | rerand_lambda: null 152 | rerand_mu: null 153 | rerand_rate: 1.0 154 | 155 | # Hyperparameter Search Setting 156 | parallel_grid: null 157 | 158 | train_augmentation: True 159 | 160 | 161 | ######################################################## 162 | # Default settings for training on CIFAR-10 and ImageNet 163 | ######################################################## 164 | 165 | cifar10_sgd: &cifar10_sgd 166 | <<: *__default__ 167 | dataset.config_name: 'cifar10' 168 | padding_before_crop: True 169 | 170 | epoch: 100 171 | batch_size: 128 172 | optimizer: "SGD" 173 | lr_scheduler: "CustomCosineLR" 174 | warmup_epochs: 0 175 | 176 | # Override these options 177 | lr: 0.1 178 | model.config_name: null 179 | weight_decay: null # 0.0001 for convs, 0.0005 for resnets 180 | 181 | imagenet_sgd: &imagenet_sgd 182 | <<: *__default__ 183 | dataset.config_name: 'imagenet' 184 | model.config_name: null 185 | num_workers: 8 186 | 187 | epoch: 105 188 | batch_size: 128 189 | sgd_momentum: 0.9 190 | weight_decay: 0.0001 191 | optimizer: "SGD" 192 | 193 | lr: 0.1 194 | lr_scheduler: "CustomCosineLR" 195 | warmup_epochs: 5 196 | finetuning_epochs: 5 197 | finetuning_lr: 0.00001 198 | 199 | 200 | ######################################################## 201 | # Training Examples 202 | ######################################################## 203 | 204 | # standard training of ResNet18 on CIFAR-10 205 | cifar10_resnet18_ku_sgd: 206 | <<: *cifar10_sgd 207 | init_mode: 'kaiming_uniform' 208 | model.config_name: "resnet18" 209 | weight_decay: 0.0005 210 | seed: 1 211 | 212 | # edge-popup (SC dist, p=0.6) for ResNet18 on CIFAR-10 213 | cifar10_resnet18_sc_edgepopup: 214 | <<: *cifar10_sgd 215 | init_mode: 'signed_constant' 216 | model.config_name: "resnet18" 217 | weight_decay: 0.0005 218 | seed: 1 219 | 220 | train_mode: "score_only" 221 | conv_sparsity: 0.6 222 | bn_affine: False 223 | 224 | rerand_mode: null 225 | 226 | # IteRand (SC dist, p=0.6, K_per=300, r=0.1) for ResNet18 on CIFAR-10 227 | cifar10_resnet18_sc_iterand: 228 | <<: *cifar10_sgd 229 | init_mode: 'signed_constant' 230 | model.config_name: "resnet18" 231 | weight_decay: 0.0005 232 | seed: 1 233 | 234 | train_mode: "score_only" 235 | conv_sparsity: 0.6 236 | bn_affine: False 237 | 238 | rerand_mode: "bernoulli" 239 | rerand_freq: 300 240 | rerand_lambda: 0.1 241 | 242 | 243 | ######################################################## 244 | # Paper settings for CIFAR-10 experiments 245 | ######################################################## 246 | 247 | # -------------------------- 248 | # Figure 1 experiments 249 | # -------------------------- 250 | 251 | figure1_conv6_ku_iterand: 252 | <<: *cifar10_sgd 253 | init_mode: 'kaiming_uniform' 254 | train_mode: "score_only" 255 | rerand_mode: "bernoulli" 256 | rerand_freq: null 257 | 258 | bn_affine: False 259 | parallel_grid: 260 | model.config_name: ["conv6x1.0"] 261 | conv_sparsity: [0.5] 262 | rerand_freq: [300] 263 | rerand_lambda: [0.0, 0.1, 0.01, 1.0] 264 | lr: [0.3] 265 | weight_decay: [0.0001] 266 | seed: [1,2,3] 267 | figure1_resnet18_ku_iterand: 268 | <<: *cifar10_sgd 269 | init_mode: 'kaiming_uniform' 270 | train_mode: "score_only" 271 | rerand_mode: "bernoulli" 272 | rerand_freq: null 273 | 274 | bn_affine: False 275 | parallel_grid: 276 | model.config_name: ["resnet18x1.0"] 277 | conv_sparsity: [0.6] 278 | rerand_freq: [300] 279 | rerand_lambda: [0.0, 0.1, 0.01, 1.0] 280 | weight_decay: [0.0005] 281 | seed: [1,2,3] 282 | figure1_resnet34_ku_iterand: 283 | <<: *cifar10_sgd 284 | init_mode: 'kaiming_uniform' 285 | train_mode: "score_only" 286 | rerand_mode: "bernoulli" 287 | rerand_freq: null 288 | 289 | bn_affine: False 290 | parallel_grid: 291 | model.config_name: ["resnet34x1.0"] 292 | conv_sparsity: [0.6] 293 | rerand_freq: [300] 294 | rerand_lambda: [0.0, 0.1, 0.01, 1.0] 295 | weight_decay: [0.0005] 296 | seed: [1,2,3] 297 | 298 | # -------------------------- 299 | # Figure 2 experiments 300 | # -------------------------- 301 | 302 | # -- Conv6 -- 303 | 304 | figure2_conv6_ku_sgd: 305 | <<: *cifar10_sgd 306 | init_mode: 'kaiming_uniform' 307 | parallel_grid: 308 | model.config_name: ["conv6x0.25", "conv6x0.5", "conv6x1.0", "conv6x2.0"] 309 | lr: [0.01] # This lr is only for sgd (not for edge-popup) 310 | weight_decay: [0.0001] 311 | seed: [1,2,3] 312 | figure2_conv6_ku_edgepopup: 313 | <<: *cifar10_sgd 314 | init_mode: 'kaiming_uniform' 315 | train_mode: "score_only" 316 | rerand_mode: "bernoulli" 317 | rerand_freq: null 318 | 319 | bn_affine: False 320 | parallel_grid: 321 | model.config_name: ["conv6x0.25", "conv6x0.5", "conv6x1.0", "conv6x2.0"] 322 | conv_sparsity: [0.5] 323 | rerand_freq: [300] 324 | rerand_lambda: [0.0] 325 | lr: [0.3] 326 | weight_decay: [0.0001] 327 | seed: [1,2,3] 328 | figure2_conv6_sc_edgepopup: 329 | <<: *cifar10_sgd 330 | init_mode: 'signed_constant' 331 | train_mode: "score_only" 332 | rerand_mode: "bernoulli" 333 | rerand_freq: null 334 | 335 | bn_affine: False 336 | parallel_grid: 337 | model.config_name: ["conv6x0.25", "conv6x0.5", "conv6x1.0", "conv6x2.0"] 338 | conv_sparsity: [0.5] 339 | rerand_freq: [300] 340 | rerand_lambda: [0.0] 341 | lr: [0.3] 342 | weight_decay: [0.0001] 343 | seed: [1,2,3] 344 | figure2_conv6_ku_iterand: 345 | <<: *cifar10_sgd 346 | init_mode: 'kaiming_uniform' 347 | train_mode: "score_only" 348 | rerand_mode: "bernoulli" 349 | rerand_freq: null 350 | 351 | bn_affine: False 352 | parallel_grid: 353 | model.config_name: ["conv6x0.25", "conv6x0.5", "conv6x1.0", "conv6x2.0"] 354 | conv_sparsity: [0.5] 355 | rerand_freq: [300] 356 | rerand_lambda: [0.1] 357 | lr: [0.3] 358 | weight_decay: [0.0001] 359 | seed: [1,2,3] 360 | figure2_conv6_sc_iterand: 361 | <<: *cifar10_sgd 362 | init_mode: 'signed_constant' 363 | train_mode: "score_only" 364 | rerand_mode: "bernoulli" 365 | rerand_freq: null 366 | 367 | bn_affine: False 368 | parallel_grid: 369 | model.config_name: ["conv6x0.25", "conv6x0.5", "conv6x1.0", "conv6x2.0"] 370 | conv_sparsity: [0.5] 371 | rerand_freq: [300] 372 | rerand_lambda: [0.1] 373 | lr: [0.3] 374 | weight_decay: [0.0001] 375 | seed: [1,2,3] 376 | 377 | # -- ResNet18 -- 378 | 379 | figure2_resnet18_ku_sgd: 380 | <<: *cifar10_sgd 381 | init_mode: 'kaiming_uniform' 382 | parallel_grid: 383 | model.config_name: ["resnet18x0.25", "resnet18x0.5", "resnet18x1.0", "resnet18x2.0"] 384 | weight_decay: [0.0005] 385 | seed: [1,2,3] 386 | figure2_resnet18_ku_edgepopup: 387 | <<: *cifar10_sgd 388 | init_mode: 'kaiming_uniform' 389 | train_mode: "score_only" 390 | rerand_mode: "bernoulli" 391 | rerand_freq: null 392 | 393 | bn_affine: False 394 | parallel_grid: 395 | model.config_name: ["resnet18x0.25", "resnet18x0.5", "resnet18x1.0", "resnet18x2.0"] 396 | conv_sparsity: [0.6] 397 | rerand_freq: [300] 398 | rerand_lambda: [0.0] 399 | weight_decay: [0.0005] 400 | seed: [1,2,3] 401 | figure2_resnet18_sc_edgepopup: 402 | <<: *cifar10_sgd 403 | init_mode: 'signed_constant' 404 | train_mode: "score_only" 405 | rerand_mode: "bernoulli" 406 | rerand_freq: null 407 | 408 | bn_affine: False 409 | parallel_grid: 410 | model.config_name: ["resnet18x0.25", "resnet18x0.5", "resnet18x1.0", "resnet18x2.0"] 411 | conv_sparsity: [0.6] 412 | rerand_freq: [300] 413 | rerand_lambda: [0.0] 414 | weight_decay: [0.0005] 415 | seed: [1,2,3] 416 | figure2_resnet18_ku_iterand: 417 | <<: *cifar10_sgd 418 | init_mode: 'kaiming_uniform' 419 | train_mode: "score_only" 420 | rerand_mode: "bernoulli" 421 | rerand_freq: null 422 | 423 | bn_affine: False 424 | parallel_grid: 425 | model.config_name: ["resnet18x0.25", "resnet18x0.5", "resnet18x1.0", "resnet18x2.0"] 426 | conv_sparsity: [0.6] 427 | rerand_freq: [300] 428 | rerand_lambda: [0.1] 429 | weight_decay: [0.0005] 430 | seed: [1,2,3] 431 | figure2_resnet18_sc_iterand: 432 | <<: *cifar10_sgd 433 | init_mode: 'signed_constant' 434 | train_mode: "score_only" 435 | rerand_mode: "bernoulli" 436 | rerand_freq: null 437 | 438 | bn_affine: False 439 | parallel_grid: 440 | model.config_name: ["resnet18x0.25", "resnet18x0.5", "resnet18x1.0", "resnet18x2.0"] 441 | conv_sparsity: [0.6] 442 | rerand_freq: [300] 443 | rerand_lambda: [0.1] 444 | weight_decay: [0.0005] 445 | seed: [1,2,3] 446 | 447 | # -- ResNet34 -- 448 | 449 | figure2_resnet34_ku_sgd: 450 | <<: *cifar10_sgd 451 | init_mode: 'kaiming_uniform' 452 | parallel_grid: 453 | model.config_name: ["resnet34x0.25", "resnet34x0.5", "resnet34x1.0", "resnet34x2.0"] 454 | weight_decay: [0.0005] 455 | seed: [1,2,3] 456 | figure2_resnet34_ku_edgepopup: 457 | <<: *cifar10_sgd 458 | init_mode: 'kaiming_uniform' 459 | train_mode: "score_only" 460 | rerand_mode: "bernoulli" 461 | rerand_freq: null 462 | 463 | bn_affine: False 464 | parallel_grid: 465 | model.config_name: ["resnet34x0.25", "resnet34x0.5", "resnet34x1.0", "resnet34x2.0"] 466 | conv_sparsity: [0.6] 467 | rerand_freq: [300] 468 | rerand_lambda: [0.0] 469 | weight_decay: [0.0005] 470 | seed: [1,2,3] 471 | figure2_resnet34_sc_edgepopup: 472 | <<: *cifar10_sgd 473 | init_mode: 'signed_constant' 474 | train_mode: "score_only" 475 | rerand_mode: "bernoulli" 476 | rerand_freq: null 477 | 478 | bn_affine: False 479 | parallel_grid: 480 | model.config_name: ["resnet34x0.25", "resnet34x0.5", "resnet34x1.0", "resnet34x2.0"] 481 | conv_sparsity: [0.6] 482 | rerand_freq: [300] 483 | rerand_lambda: [0.0] 484 | weight_decay: [0.0005] 485 | seed: [1,2,3] 486 | figure2_resnet34_ku_iterand: 487 | <<: *cifar10_sgd 488 | init_mode: 'kaiming_uniform' 489 | train_mode: "score_only" 490 | rerand_mode: "bernoulli" 491 | rerand_freq: null 492 | 493 | bn_affine: False 494 | parallel_grid: 495 | model.config_name: ["resnet34x0.25", "resnet34x0.5", "resnet34x1.0", "resnet34x2.0"] 496 | conv_sparsity: [0.6] 497 | rerand_freq: [300] 498 | rerand_lambda: [0.1] 499 | weight_decay: [0.0005] 500 | seed: [1,2,3] 501 | figure2_resnet34_sc_iterand: 502 | <<: *cifar10_sgd 503 | init_mode: 'signed_constant' 504 | train_mode: "score_only" 505 | rerand_mode: "bernoulli" 506 | rerand_freq: null 507 | 508 | bn_affine: False 509 | parallel_grid: 510 | model.config_name: ["resnet34x0.25", "resnet34x0.5", "resnet34x1.0", "resnet34x2.0"] 511 | conv_sparsity: [0.6] 512 | rerand_freq: [300] 513 | rerand_lambda: [0.1] 514 | weight_decay: [0.0005] 515 | seed: [1,2,3] 516 | 517 | 518 | ######################################################## 519 | # Paper settings for ImageNet experiments 520 | ######################################################## 521 | 522 | # -------------------------- 523 | # Figure 3 experiments 524 | # -------------------------- 525 | 526 | figure3_resnet18_ku_sgd: 527 | <<: *imagenet_sgd 528 | num_gpus: 1 529 | model.config_name: "resnet18" 530 | 531 | parallel_grid: 532 | seed: [1] 533 | figure3_resnet34_ku_sgd: 534 | <<: *imagenet_sgd 535 | num_gpus: 1 536 | model.config_name: "resnet34" 537 | 538 | parallel_grid: 539 | seed: [1] 540 | figure3_resnet34_sc_edgepopup: 541 | <<: *imagenet_sgd 542 | num_gpus: 1 543 | 544 | init_mode: 'signed_constant' 545 | train_mode: "score_only" 546 | bn_affine: False 547 | 548 | model.config_name: "resnet34" 549 | 550 | parallel_grid: 551 | conv_sparsity: [0.7] 552 | seed: [1] 553 | figure3_resnet50_sc_edgepopup: 554 | <<: *imagenet_sgd 555 | num_gpus: 2 556 | 557 | init_mode: 'signed_constant' 558 | train_mode: "score_only" 559 | bn_affine: False 560 | 561 | model.config_name: "resnet50" 562 | 563 | parallel_grid: 564 | conv_sparsity: [0.7] 565 | seed: [1] 566 | figure3_resnet101_sc_edgepopup: 567 | <<: *imagenet_sgd 568 | num_gpus: 2 569 | 570 | init_mode: 'signed_constant' 571 | train_mode: "score_only" 572 | bn_affine: False 573 | 574 | model.config_name: "resnet101" 575 | 576 | parallel_grid: 577 | conv_sparsity: [0.7] 578 | seed: [1] 579 | figure3_resnet34_sc_iterand: 580 | <<: *imagenet_sgd 581 | num_gpus: 1 582 | 583 | init_mode: 'signed_constant' 584 | train_mode: "score_only" 585 | bn_affine: False 586 | 587 | rerand_mode: "bernoulli" 588 | rerand_freq: null 589 | 590 | model.config_name: "resnet34" 591 | 592 | parallel_grid: 593 | conv_sparsity: [0.7] 594 | rerand_freq: [1000] 595 | rerand_lambda: [0.1] 596 | seed: [1] 597 | figure3_resnet50_sc_iterand: 598 | <<: *imagenet_sgd 599 | num_gpus: 2 600 | 601 | init_mode: 'signed_constant' 602 | train_mode: "score_only" 603 | bn_affine: False 604 | 605 | rerand_mode: "bernoulli" 606 | rerand_freq: null 607 | 608 | model.config_name: "resnet50" 609 | 610 | parallel_grid: 611 | conv_sparsity: [0.7] 612 | rerand_freq: [1000] 613 | rerand_lambda: [0.1] 614 | seed: [1] 615 | figure3_resnet101_sc_iterand: 616 | <<: *imagenet_sgd 617 | num_gpus: 2 618 | 619 | init_mode: 'signed_constant' 620 | train_mode: "score_only" 621 | bn_affine: False 622 | 623 | rerand_mode: "bernoulli" 624 | rerand_freq: null 625 | 626 | model.config_name: "resnet101" 627 | 628 | parallel_grid: 629 | conv_sparsity: [0.7] 630 | rerand_freq: [1000] 631 | rerand_lambda: [0.1] 632 | seed: [1] 633 | 634 | -------------------------------------------------------------------------------- /images/iterand.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dchiji-ntt/iterand/90e0862269ba12df2c824d3f3b694b2b1ae02838/images/iterand.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import argparse 4 | import commands 5 | import yaml 6 | from pprint import PrettyPrinter 7 | 8 | hyperparam_names = { 9 | 'lr': float, 10 | 'weight_decay': float, 11 | 'model.config_name': str, 12 | 'seed': int, 13 | 'conv_sparsity': float, 14 | 'rerand_freq': int, 15 | 'rerand_lambda': float, 16 | } 17 | 18 | def load_configs(config): 19 | with open(config, 'r') as f: 20 | yml = f.read() 21 | dic = yaml.load(yml, Loader=yaml.FullLoader) 22 | return dic 23 | 24 | def main(args): 25 | pp = PrettyPrinter(indent=1) 26 | print('Experiment: ', args.exp_name) 27 | 28 | # Load config from YAML file 29 | command = args.command 30 | cfgs = load_configs(args.config) 31 | cfg = cfgs[args.exp_name] 32 | cfg['data_parallel'] = (cfg['num_gpus'] > 1) 33 | cfg['accum_grad'] = args.accum_grad 34 | cfg['force_restart'] = args.force_restart 35 | 36 | if args.output_dir is not None: 37 | cfg['output_dir'] = args.output_dir 38 | if args.save_best_model is not None: 39 | if args.save_best_model == 'True' or args.save_best_model == 'true': 40 | cfg['save_best_model'] = True 41 | elif args.save_best_model == 'False' or args.save_best_model == 'false': 42 | cfg['save_best_model'] = False 43 | else: 44 | raise NotImplementedError 45 | 46 | if (command == 'train') and cfg['parallel_grid'] is not None: 47 | print(cfg['parallel_grid']) 48 | for key in hyperparam_names: 49 | val = getattr(args, key) 50 | print(key, val) 51 | if val is None: 52 | if key in cfg['parallel_grid']: 53 | print(f"[Error] Please specify an option for `{args.exp_name}` experiment: --{key}") 54 | return 55 | else: 56 | continue 57 | else: 58 | if key not in cfg['parallel_grid']: 59 | print(f"[Error] Please specify only options in `parallel_grid` of `{args.exp_name}`; Not supported: --{key}") 60 | return 61 | cfg[key] = val 62 | 63 | #pp.pprint(cfg) 64 | cfg['__other_configs__'] = cfgs 65 | 66 | # Call command function 67 | command = getattr(getattr(commands, command), command) 68 | command(args.exp_name, cfg) 69 | 70 | if __name__ == '__main__': 71 | parser = argparse.ArgumentParser() 72 | parser.add_argument('command', type=str) 73 | parser.add_argument('config', type=str, help='file path for YAML configure file') 74 | parser.add_argument('exp_name', type=str, help='specify the name of experiment') 75 | parser.add_argument('--accum_grad', type=int, default=1) 76 | parser.add_argument('--force_restart', action='store_true') 77 | parser.add_argument('--save_best_model', type=str, default=None) 78 | parser.add_argument('--output_dir', type=str, default=None) 79 | 80 | for key in hyperparam_names: 81 | parser.add_argument('--' + key, type=hyperparam_names[key], default=None) 82 | 83 | args = parser.parse_args() 84 | 85 | main(args) 86 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dchiji-ntt/iterand/90e0862269ba12df2c824d3f3b694b2b1ae02838/models/__init__.py -------------------------------------------------------------------------------- /models/networks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dchiji-ntt/iterand/90e0862269ba12df2c824d3f3b694b2b1ae02838/models/networks/__init__.py -------------------------------------------------------------------------------- /models/networks/convs.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from models.networks.sparse_modules import SparseConv2d, SparseLinear 6 | 7 | 8 | class Conv4(nn.Module): 9 | def __init__(self, dataset_cfg, model_cfg, cfg=None): 10 | super(Conv4, self).__init__() 11 | 12 | assert dataset_cfg['image_size'] == 32 13 | assert dataset_cfg['num_channels'] == 3 14 | 15 | self.num_classes = dataset_cfg['num_classes'] 16 | self.factor = model_cfg['factor'] 17 | 18 | self.convs = nn.Sequential( 19 | SparseConv2d(3, int(64 * self.factor), kernel_size=3, 20 | stride=1, padding=1, cfg=cfg), 21 | nn.ReLU(), 22 | SparseConv2d(int(64 * self.factor), int(64 * self.factor), kernel_size=3, 23 | stride=1, padding=1, cfg=cfg), 24 | nn.ReLU(), 25 | nn.MaxPool2d((2, 2)), 26 | SparseConv2d(int(64 * self.factor), int(128 * self.factor), kernel_size=3, 27 | stride=1, padding=1, cfg=cfg), 28 | nn.ReLU(), 29 | SparseConv2d(int(128 * self.factor), int(128 * self.factor), kernel_size=3, 30 | stride=1, padding=1, cfg=cfg), 31 | nn.ReLU(), 32 | nn.MaxPool2d((2, 2)), 33 | ) 34 | 35 | if cfg['linear_sparsity'] is not None: 36 | linear_cfg = cfg.copy() 37 | linear_cfg['conv_sparsity'] = cfg['linear_sparsity'] 38 | else: 39 | linear_cfg = cfg 40 | 41 | self.linear = nn.Sequential( 42 | SparseConv2d(int(128 * self.factor) * 8 * 8, int(256 * self.factor), kernel_size=1, cfg=cfg), 43 | nn.ReLU(), 44 | SparseConv2d(int(256 * self.factor), int(256 * self.factor), kernel_size=1, cfg=cfg), 45 | nn.ReLU(), 46 | SparseConv2d(int(256 * self.factor), self.num_classes, kernel_size=1, cfg=linear_cfg) 47 | ) 48 | 49 | def forward(self, x): 50 | out = self.convs(x) 51 | out = out.view(out.size(0), int(128 * self.factor) * 8 * 8, 1, 1) 52 | out = self.linear(out) 53 | return out.squeeze() 54 | 55 | def rerandomize(self, mode, la, mu): 56 | for m in self.modules(): 57 | if type(m) is SparseConv2d or type(m) is SparseLinear: 58 | m.rerandomize(mode, la, mu) 59 | 60 | 61 | class Conv6(nn.Module): 62 | def __init__(self, dataset_cfg, model_cfg, cfg=None): 63 | super(Conv6, self).__init__() 64 | 65 | assert dataset_cfg['image_size'] == 32 66 | assert dataset_cfg['num_channels'] == 3 67 | 68 | self.num_classes = dataset_cfg['num_classes'] 69 | self.factor = model_cfg['factor'] 70 | 71 | self.convs = nn.Sequential( 72 | SparseConv2d(3, int(64 * self.factor), kernel_size=3, 73 | stride=1, padding=1, cfg=cfg), 74 | nn.ReLU(), 75 | SparseConv2d(int(64 * self.factor), int(64 * self.factor), kernel_size=3, 76 | stride=1, padding=1, cfg=cfg), 77 | nn.ReLU(), 78 | nn.MaxPool2d((2, 2)), 79 | SparseConv2d(int(64 * self.factor), int(128 * self.factor), kernel_size=3, 80 | stride=1, padding=1, cfg=cfg), 81 | nn.ReLU(), 82 | SparseConv2d(int(128 * self.factor), int(128 * self.factor), kernel_size=3, 83 | stride=1, padding=1, cfg=cfg), 84 | nn.ReLU(), 85 | nn.MaxPool2d((2, 2)), 86 | SparseConv2d(int(128 * self.factor), int(256 * self.factor), kernel_size=3, 87 | stride=1, padding=1, cfg=cfg), 88 | nn.ReLU(), 89 | SparseConv2d(int(256 * self.factor), int(256 * self.factor), kernel_size=3, 90 | stride=1, padding=1, cfg=cfg), 91 | nn.ReLU(), 92 | nn.MaxPool2d((2, 2)) 93 | ) 94 | 95 | if cfg['linear_sparsity'] is not None: 96 | linear_cfg = cfg.copy() 97 | linear_cfg['conv_sparsity'] = cfg['linear_sparsity'] 98 | else: 99 | linear_cfg = cfg 100 | 101 | self.linear = nn.Sequential( 102 | SparseConv2d(int(256 * self.factor) * 4 * 4, int(256 * self.factor), kernel_size=1, cfg=cfg), 103 | nn.ReLU(), 104 | SparseConv2d(int(256 * self.factor), int(256 * self.factor), kernel_size=1, cfg=cfg), 105 | nn.ReLU(), 106 | SparseConv2d(int(256 * self.factor), self.num_classes, kernel_size=1, cfg=linear_cfg) 107 | ) 108 | 109 | def forward(self, x): 110 | out = self.convs(x) 111 | out = out.view(out.size(0), int(256 * self.factor) * 4 * 4, 1, 1) 112 | out = self.linear(out) 113 | return out.squeeze() 114 | 115 | def rerandomize(self, mode, la, mu): 116 | for m in self.modules(): 117 | if type(m) is SparseConv2d or type(m) is SparseLinear: 118 | m.rerandomize(mode, la, mu) 119 | 120 | 121 | -------------------------------------------------------------------------------- /models/networks/resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation based on the code of https://github.com/kuangliu/pytorch-cifar/ 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from models.networks.sparse_modules import SparseConv2d, SparseLinear 9 | 10 | 11 | class ConvNormBlock(nn.Module): 12 | def __init__(self, in_channels, out_channels, kernel_size, 13 | stride=1, padding=0, bias=True, cfg=None): 14 | super().__init__() 15 | 16 | self.conv = SparseConv2d(in_channels, out_channels, kernel_size=kernel_size, 17 | stride=stride, padding=padding, bias=bias, cfg=cfg) 18 | self.norm = nn.BatchNorm2d(out_channels, 19 | momentum=cfg['bn_momentum'], 20 | track_running_stats=cfg['bn_track_running_stats'], 21 | affine=cfg['bn_affine']) 22 | 23 | def forward(self, x): 24 | out = self.conv(x) 25 | out = self.norm(out) 26 | return out 27 | 28 | 29 | class BasicBlock(nn.Module): 30 | expansion = 1 31 | 32 | def __init__(self, in_planes, planes, stride, cfg=None): 33 | super(BasicBlock, self).__init__() 34 | self.convnb1 = ConvNormBlock(in_planes, planes, kernel_size=3, 35 | stride=stride, padding=1, bias=False, cfg=cfg) 36 | self.convnb2 = ConvNormBlock(planes, planes, kernel_size=3, 37 | stride=1, padding=1, bias=False, cfg=cfg) 38 | 39 | self.shortcut = nn.Sequential() 40 | if stride != 1 or in_planes != self.expansion*planes: 41 | self.shortcut = ConvNormBlock(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False, cfg=cfg) 42 | 43 | def forward(self, x): 44 | out = F.relu(self.convnb1(x)) 45 | out = self.convnb2(out) 46 | out += self.shortcut(x) 47 | out = F.relu(out) 48 | return out 49 | 50 | 51 | class Bottleneck(nn.Module): 52 | expansion = 4 53 | 54 | def __init__(self, in_planes, planes, stride, cfg=None): 55 | super(Bottleneck, self).__init__() 56 | self.convnb1 = ConvNormBlock(in_planes, planes, kernel_size=1, bias=False, cfg=cfg) 57 | self.convnb2 = ConvNormBlock(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False, cfg=cfg) 58 | self.convnb3 = ConvNormBlock(planes, self.expansion * planes, kernel_size=1, bias=False, cfg=cfg) 59 | 60 | self.shortcut = nn.Sequential() 61 | if stride != 1 or in_planes != self.expansion*planes: 62 | self.shortcut = ConvNormBlock(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False, cfg=cfg) 63 | 64 | def forward(self, x): 65 | out = F.relu(self.convnb1(x)) 66 | out = F.relu(self.convnb2(out)) 67 | out = self.convnb3(out) 68 | out += self.shortcut(x) 69 | out = F.relu(out) 70 | return out 71 | 72 | 73 | class ResNet(nn.Module): 74 | def __init__(self, dataset_cfg, model_cfg, cfg): 75 | super(ResNet, self).__init__() 76 | block = globals()[model_cfg['block_class']] 77 | num_blocks = model_cfg['num_blocks'] 78 | factor = model_cfg['factor'] 79 | 80 | num_classes = dataset_cfg['num_classes'] 81 | self.in_channel = dataset_cfg['num_channels'] 82 | self.image_size = dataset_cfg['image_size'] 83 | self.in_planes = 64 84 | if self.image_size not in [28, 32, 64, 84, 224]: 85 | raise NotImplementedError 86 | 87 | if self.image_size in [224]: 88 | k1, s1, p1 = 7, 2, 3 89 | else: 90 | k1, s1, p1 = 3, 1, 1 91 | 92 | self.convnb1 = ConvNormBlock(self.in_channel, 64, kernel_size=k1, 93 | stride=s1, padding=p1, bias=False, cfg=cfg) 94 | 95 | self.layer1 = self._make_layer(block, int(64*factor), num_blocks[0], stride=1, cfg=cfg) 96 | self.layer2 = self._make_layer(block, int(128*factor), num_blocks[1], stride=2, cfg=cfg) 97 | self.layer3 = self._make_layer(block, int(256*factor), num_blocks[2], stride=2, cfg=cfg) 98 | self.layer4 = self._make_layer(block, int(512*factor), num_blocks[3], stride=2, cfg=cfg) 99 | 100 | self.linear = SparseLinear(int(512*factor*block.expansion), num_classes, cfg=cfg) 101 | 102 | def _make_layer(self, block, planes, num_blocks, stride, cfg): 103 | strides = [stride] + [1]*(num_blocks-1) 104 | layers = [] 105 | for stride in strides: 106 | layers.append(block(self.in_planes, planes, stride, cfg)) 107 | self.in_planes = planes * block.expansion 108 | return nn.Sequential(*layers) 109 | 110 | def forward(self, x): 111 | out = F.relu(self.convnb1(x)) 112 | if self.image_size in [84, 224]: 113 | out = F.max_pool2d(out, kernel_size=3, stride=2, padding=1) 114 | elif self.image_size in [64]: 115 | out = F.max_pool2d(out, kernel_size=3, stride=2, padding=1) 116 | out = self.layer1(out) 117 | out = self.layer2(out) 118 | out = self.layer3(out) 119 | out = self.layer4(out) 120 | if self.image_size in [224]: 121 | out = F.avg_pool2d(out, 7) 122 | elif self.image_size in [84]: 123 | out = F.avg_pool2d(out, 6) 124 | elif self.image_size in [64]: 125 | out = F.avg_pool2d(out, 4) 126 | else: 127 | out = F.avg_pool2d(out, 4) 128 | out = out.view(out.size(0), -1) 129 | out = self.linear(out) 130 | return out 131 | 132 | def rerandomize(self, mode, la, mu): 133 | for m in self.modules(): 134 | if type(m) is SparseConv2d or type(m) is SparseLinear: 135 | m.rerandomize(mode, la, mu) 136 | 137 | 138 | -------------------------------------------------------------------------------- /models/networks/sparse_modules.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import math 6 | 7 | 8 | class GetSubnet(torch.autograd.Function): 9 | @staticmethod 10 | def forward(ctx, scores, sparsity, zeros, ones): 11 | k_val = percentile(scores, sparsity*100) 12 | out = torch.where(scores < k_val, zeros.to(scores.device), ones.to(scores.device)) 13 | return out 14 | 15 | @staticmethod 16 | def backward(ctx, g): 17 | return g, None, None, None 18 | 19 | def percentile(t, q): 20 | k = 1 + round(.01 * float(q) * (t.numel() - 1)) 21 | return t.view(-1).kthvalue(k).values.item() 22 | 23 | 24 | class SparseModule(nn.Module): 25 | def init_param_(self, param, init_mode=None, scale=None): 26 | if init_mode == 'kaiming_normal': 27 | nn.init.kaiming_normal_(param, mode="fan_in", nonlinearity="relu") 28 | param.data *= scale 29 | elif init_mode == 'uniform(-1,1)': 30 | nn.init.uniform_(param, a=-1, b=1) 31 | param.data *= scale 32 | elif init_mode == 'kaiming_uniform': 33 | nn.init.kaiming_uniform_(param, mode='fan_in', nonlinearity='relu') 34 | param.data *= scale 35 | elif init_mode == 'signed_constant': 36 | # From github.com/allenai/hidden-networks 37 | fan = nn.init._calculate_correct_fan(param, 'fan_in') 38 | gain = nn.init.calculate_gain('relu') 39 | std = gain / math.sqrt(fan) 40 | nn.init.kaiming_normal_(param) # use only its sign 41 | param.data = param.data.sign() * std 42 | param.data *= scale 43 | else: 44 | raise NotImplementedError 45 | 46 | def rerandomize_(self, param, mask, mode=None, la=None, mu=None, 47 | init_mode=None, scale=None, param_twin=None): 48 | if param_twin is None: 49 | raise NotImplementedError 50 | else: 51 | param_twin = param_twin.to(param.device) 52 | 53 | with torch.no_grad(): 54 | if mode == 'bernoulli': 55 | assert (la is not None) and (mu is None) 56 | rnd = param_twin 57 | self.init_param_(rnd, init_mode=init_mode, scale=scale) 58 | ones = torch.ones(param.size()).to(param.device) 59 | b = torch.bernoulli(ones * la) 60 | 61 | t1 = param.data * mask 62 | t2 = param.data * (1 - mask) * (1 - b) 63 | t3 = rnd.data * (1 - mask) * b 64 | 65 | param.data = t1 + t2 + t3 66 | elif mode == 'manual': 67 | assert (la is not None) and (mu is not None) 68 | 69 | t1 = param.data * (1 - mask) 70 | t2 = param.data * mask 71 | 72 | rnd = param_twin 73 | self.init_param_(rnd, init_mode=init_mode, scale=scale) 74 | rnd *= (1 - mask) 75 | 76 | param.data = (t1*la + rnd.data*mu) + t2 77 | else: 78 | raise NotImplementedError 79 | 80 | 81 | class SparseConv2d(SparseModule): 82 | def __init__(self, in_ch, out_ch, **kwargs): 83 | super().__init__() 84 | 85 | self.in_ch = in_ch 86 | self.out_ch = out_ch 87 | 88 | self.kernel_size = kwargs['kernel_size'] 89 | self.stride = kwargs['stride'] if 'stride' in kwargs else 1 90 | self.padding = kwargs['padding'] if 'padding' in kwargs else 0 91 | self.bias_flag = kwargs['bias'] if 'bias' in kwargs else True 92 | self.padding_mode = kwargs['padding_mode'] if 'padding_mode' in kwargs else None 93 | 94 | cfg = kwargs['cfg'] 95 | self.sparsity = cfg['conv_sparsity'] 96 | self.init_mode = cfg['init_mode'] 97 | self.init_mode_mask = cfg['init_mode_mask'] 98 | self.init_scale = cfg['init_scale'] 99 | self.init_scale_score = cfg['init_scale_score'] 100 | self.rerand_rate = cfg['rerand_rate'] 101 | self.function = F.conv2d 102 | 103 | self.initialize_weights(2) 104 | 105 | def initialize_weights(self, convdim=None): 106 | if convdim == 1: 107 | self.weight = nn.Parameter(torch.ones(self.out_ch, self.in_ch, self.kernel_size)) 108 | elif convdim == 2: 109 | self.weight = nn.Parameter(torch.ones(self.out_ch, self.in_ch, self.kernel_size, self.kernel_size)) 110 | else: 111 | raise NotImplementedError 112 | 113 | self.weight_score = nn.Parameter(torch.ones(self.weight.size())) 114 | self.weight_score.is_score = True 115 | self.weight_score.sparsity = self.sparsity 116 | 117 | self.weight_twin = torch.zeros(self.weight.size()) 118 | self.weight_twin.requires_grad = False 119 | 120 | if self.bias_flag: 121 | self.bias = nn.Parameter(torch.zeros(self.out_ch)) 122 | else: 123 | self.bias = None 124 | 125 | self.init_param_(self.weight_score, init_mode=self.init_mode_mask, scale=self.init_scale_score) 126 | self.init_param_(self.weight, init_mode=self.init_mode, scale=self.init_scale) 127 | 128 | self.weight_zeros = torch.zeros(self.weight_score.size()) 129 | self.weight_ones = torch.ones(self.weight_score.size()) 130 | self.weight_zeros.requires_grad = False 131 | self.weight_ones.requires_grad = False 132 | 133 | def get_subnet(self, weight_score=None): 134 | if weight_score is None: 135 | weight_score = self.weight_score 136 | 137 | subnet = GetSubnet.apply(self.weight_score, self.sparsity, 138 | self.weight_zeros, self.weight_ones) 139 | return subnet 140 | 141 | def forward(self, input): 142 | subnet = self.get_subnet(self.weight_score) 143 | pruned_weight = self.weight * subnet 144 | ret = self.function( 145 | input, pruned_weight, self.bias, self.stride, self.padding, 146 | ) 147 | return ret 148 | 149 | def rerandomize(self, mode, la, mu): 150 | rate = self.rerand_rate 151 | mask = GetSubnet.apply(self.weight_score, self.sparsity * rate, 152 | self.weight_zeros, self.weight_ones) 153 | scale = self.init_scale 154 | self.rerandomize_(self.weight, mask, mode, la, mu, 155 | self.init_mode, scale, self.weight_twin) 156 | 157 | class SparseConv1d(SparseConv2d): 158 | def __init__(self, *args, **kwargs): 159 | super().__init__(*args, **kwargs) 160 | self.function = F.conv1d 161 | self.initialize_weights(1) 162 | 163 | 164 | class SparseLinear(SparseModule): 165 | def __init__(self, in_ch, out_ch, bias=True, cfg=None): 166 | super().__init__() 167 | 168 | if cfg['linear_sparsity'] is not None: 169 | self.sparsity = cfg['linear_sparsity'] 170 | else: 171 | self.sparsity = cfg['conv_sparsity'] 172 | 173 | if cfg['init_mode_linear'] is not None: 174 | self.init_mode = cfg['init_mode_linear'] 175 | else: 176 | self.init_mode = cfg['init_mode'] 177 | 178 | self.init_mode_mask = cfg['init_mode_mask'] 179 | self.init_scale = cfg['init_scale'] 180 | self.init_scale_score = cfg['init_scale_score'] 181 | self.rerand_rate = cfg['rerand_rate'] 182 | 183 | self.weight = nn.Parameter(torch.ones(out_ch, in_ch)) 184 | self.weight_score = nn.Parameter(torch.ones(self.weight.size())) 185 | self.weight_score.is_score = True 186 | self.weight_score.sparsity = self.sparsity 187 | if bias: 188 | self.bias = nn.Parameter(torch.zeros(out_ch)) 189 | else: 190 | self.bias = None 191 | 192 | self.weight_twin = torch.zeros(self.weight.size()) 193 | self.weight_twin.requires_grad = False 194 | 195 | self.init_param_(self.weight_score, init_mode=self.init_mode_mask, scale=self.init_scale_score) 196 | self.init_param_(self.weight, init_mode=self.init_mode, scale=self.init_scale) 197 | 198 | self.weight_zeros = torch.zeros(self.weight_score.size()) 199 | self.weight_ones = torch.ones(self.weight_score.size()) 200 | self.weight_zeros.requires_grad = False 201 | self.weight_ones.requires_grad = False 202 | 203 | def forward(self, x, manual_mask=None): 204 | if manual_mask is None: 205 | subnet = GetSubnet.apply(self.weight_score, self.sparsity, 206 | self.weight_zeros, self.weight_ones) 207 | pruned_weight = self.weight * subnet 208 | else: 209 | pruned_weight = self.weight * manual_mask 210 | 211 | ret = F.linear(x, pruned_weight, self.bias) 212 | return ret 213 | 214 | def rerandomize(self, mode, la, mu, manual_mask=None): 215 | if manual_mask is None: 216 | rate = self.rerand_rate 217 | mask = GetSubnet.apply(self.weight_score, self.sparsity * rate, 218 | self.weight_zeros, self.weight_ones) 219 | else: 220 | mask = manual_mask 221 | 222 | scale = self.init_scale 223 | self.rerandomize_(self.weight, mask, mode, la, mu, 224 | self.init_mode, scale, self.weight_twin) 225 | 226 | -------------------------------------------------------------------------------- /models/supervised_learning.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.datasets 4 | from torch.utils.data import DataLoader 5 | from torch.utils.data.sampler import SubsetRandomSampler 6 | from torchvision.transforms import ToTensor, Resize, Compose, ColorJitter, RandomResizedCrop, RandomHorizontalFlip, Normalize, CenterCrop, Pad 7 | from torch.nn import DataParallel 8 | from torch.optim.lr_scheduler import MultiStepLR 9 | 10 | from utils.subset_dataset import SubsetDataset, random_split 11 | import utils.datasets 12 | from utils.schedulers import CustomCosineLR 13 | from models.networks.resnet import ResNet 14 | from models.networks.convs import Conv6 15 | 16 | import random 17 | 18 | 19 | class SupervisedLearning(object): 20 | def __init__(self, outman, cfg, device, data_parallel): 21 | self.outman = outman 22 | self.cfg = cfg 23 | self.device = device 24 | self.data_parallel = data_parallel 25 | 26 | self.debug_max_iters = self.cfg['debug_max_iters'] 27 | self.train_augmentation = self.cfg['train_augmentation'] 28 | self.dataset_cfg = self.cfg['__other_configs__'][self.cfg['dataset.config_name']] 29 | 30 | self.model_cfg = self.cfg['__other_configs__'][self.cfg['model.config_name']] 31 | 32 | self.train_dataset, self.val_dataset, self.test_dataset = self._get_datasets() 33 | self.model = self._get_model().to(self.device) 34 | self.optimizer = self._get_optimizer() 35 | self.criterion = self._get_criterion() 36 | self.scheduler = self._get_scheduler() 37 | 38 | def train(self, epoch, total_iters, before_callback=None, after_callback=None): 39 | self.model.train() 40 | 41 | batch_size = self.cfg['batch_size'] 42 | num_workers = self.cfg['num_workers'] 43 | dataloader = DataLoader(self.train_dataset, batch_size=batch_size, 44 | shuffle=True, num_workers=num_workers) 45 | 46 | results = [] 47 | total_count = 0 48 | total_loss = 0. 49 | correct = 0 50 | if self.debug_max_iters is None: 51 | iters_per_epoch = len(dataloader) 52 | else: 53 | iters_per_epoch = min(len(dataloader), self.debug_max_iters) 54 | 55 | # for the case of self.scheduler == CustomCosineLR 56 | step_before_train = hasattr(self.scheduler, "step_before_train") and self.scheduler.step_before_train 57 | if step_before_train: 58 | try: 59 | self.scheduler.step(epoch=epoch) 60 | except: 61 | self.scheduler.step() 62 | 63 | for _it, (inputs, targets) in enumerate(dataloader): 64 | if self.debug_max_iters is not None and _it >= self.debug_max_iters: 65 | break 66 | 67 | if before_callback is not None: 68 | before_callback(self.model, epoch, total_iters, iters_per_epoch) 69 | 70 | inputs, targets = inputs.to(self.device), targets.to(self.device) 71 | self.optimizer.zero_grad() 72 | outputs = self.model(inputs) 73 | loss = self.criterion(outputs, targets) 74 | loss.backward() 75 | self.optimizer.step() 76 | 77 | _, predicted = outputs.max(1) 78 | total_count += targets.size(0) 79 | correct += predicted.eq(targets).sum().item() 80 | 81 | mean_loss = loss.item() / targets.size(0) 82 | results.append({ 83 | 'mean_loss': mean_loss, 84 | }) 85 | 86 | total_loss += mean_loss 87 | total_iters += 1 88 | 89 | if after_callback is not None: 90 | after_callback(self.model, epoch, total_iters, iters_per_epoch) 91 | 92 | if not step_before_train: 93 | try: 94 | self.scheduler.step(epoch=epoch) 95 | except: 96 | self.scheduler.step() 97 | 98 | self.model.eval() 99 | 100 | return { 101 | 'iterations': total_iters, 102 | 'per_iteration': results, 103 | 'loss': total_loss / total_count, 104 | 'moving_accuracy': correct / total_count 105 | } 106 | 107 | def evaluate(self, dataset_type='val'): 108 | self.model.eval() 109 | 110 | batch_size = self.cfg['batch_size_eval'] 111 | num_workers = self.cfg['num_workers'] 112 | if dataset_type == 'val': 113 | dataloader = DataLoader(self.val_dataset, batch_size=batch_size, 114 | shuffle=True, num_workers=num_workers) 115 | elif dataset_type == 'test': 116 | dataloader = DataLoader(self.test_dataset, batch_size=batch_size, 117 | shuffle=True, num_workers=num_workers) 118 | else: 119 | raise NotImplementedError 120 | 121 | results = [] 122 | total_count = 0 123 | total_loss = 0. 124 | correct = 0 125 | for _it, (inputs, targets) in enumerate(dataloader): 126 | if self.debug_max_iters is not None and _it >= self.debug_max_iters: 127 | break 128 | 129 | inputs, targets = inputs.to(self.device), targets.to(self.device) 130 | with torch.no_grad(): 131 | outputs = self.model(inputs) 132 | loss = self.criterion(outputs, targets) 133 | 134 | _, predicted = outputs.max(1) 135 | total_count += targets.size(0) 136 | correct += predicted.eq(targets).sum().item() 137 | 138 | total_loss += loss.item() / targets.size(0) 139 | return { 140 | 'loss': total_loss / total_count, 141 | 'accuracy': correct / total_count, 142 | } 143 | 144 | def _get_datasets(self): 145 | dataset_dir = self.cfg['dataset_dir'] 146 | max_size = self.cfg['max_train_dataset_size'] 147 | dataset_download = self.cfg['dataset_download'] 148 | dataset_classname = self.dataset_cfg['class'] 149 | data_type = self.dataset_cfg['data_type'] 150 | 151 | if dataset_classname in ['CIFAR10', 'CIFAR100', 'MNIST']: 152 | dataset_class = getattr(torchvision.datasets, dataset_classname) 153 | elif dataset_classname in ['ImageNet']: 154 | dataset_class = getattr(utils.datasets, dataset_classname) 155 | else: 156 | raise NotImplementedError 157 | 158 | if data_type == 'image': 159 | image_size = self.dataset_cfg['image_size'] 160 | train_val_split = self.dataset_cfg['train_val_split'] 161 | 162 | transform_train = self._create_transform(image_size, train=True) 163 | transform_val = self._create_transform(image_size, train=False) 164 | 165 | trainval_dataset = dataset_class(dataset_dir, 166 | train=True, 167 | transform=None, 168 | download=dataset_download) 169 | 170 | size = len(trainval_dataset) 171 | val_size = int(size * train_val_split) 172 | train_size = min(size - val_size, 173 | max_size if max_size is not None else size) 174 | gen = torch.Generator() 175 | gen.manual_seed(777) 176 | train_subset, val_subset, _ = random_split(trainval_dataset, 177 | [train_size, val_size, size-(train_size+val_size)], 178 | generator=gen) 179 | self.outman.print('Train/val dataset size:', size) 180 | self.outman.print('Train dataset size:', len(train_subset), 181 | ', Val dataset size:', len(val_subset)) 182 | 183 | train_dataset = SubsetDataset(train_subset, transform=transform_train) 184 | val_dataset = SubsetDataset(val_subset, transform=transform_val) 185 | test_dataset = dataset_class(dataset_dir, 186 | train=False, 187 | transform=transform_val, 188 | download=dataset_download) 189 | else: 190 | raise NotImplementedError 191 | 192 | return train_dataset, val_dataset, test_dataset 193 | 194 | def _create_transform(self, image_size, train=False): 195 | dataset_class = self.dataset_cfg['class'] 196 | norm_param = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 197 | jitter_param = dict(brightness=0.4, contrast=0.4, saturation=0.4) 198 | 199 | if train and self.train_augmentation: 200 | if dataset_class in ['MNIST']: 201 | train_transform = Compose([ 202 | Resize((image_size, image_size)), 203 | ToTensor(), 204 | Normalize((0.1307,), (0.3081,)), 205 | ]) 206 | else: 207 | train_transform = Compose([ 208 | RandomResizedCrop((image_size, image_size)), 209 | RandomHorizontalFlip(), 210 | ToTensor(), 211 | Normalize(**norm_param), 212 | ]) 213 | if self.cfg['padding_before_crop']: # this should be used in CIFAR-10 training 214 | train_transform.transforms.insert(0, Pad(4)) 215 | return train_transform 216 | else: 217 | if dataset_class in ['ImageNet']: 218 | return Compose([Resize(256), 219 | CenterCrop(image_size), 220 | ToTensor(), 221 | Normalize(**norm_param) 222 | ]) 223 | elif dataset_class in ['CIFAR10', 'CIFAR100']: 224 | return Compose([Resize((image_size, image_size)), 225 | ToTensor(), 226 | Normalize(**norm_param) 227 | ]) 228 | elif dataset_class in ['MNIST']: 229 | return Compose([Resize((image_size, image_size)), 230 | ToTensor(), 231 | Normalize((0.1307,), (0.3081,)), 232 | ]) 233 | else: 234 | raise NotImplementedError 235 | 236 | def _get_model(self, model_cfg=None): 237 | if model_cfg is None: 238 | model_cfg = self.model_cfg 239 | 240 | if model_cfg['class'] == 'ResNet': 241 | model = ResNet(self.dataset_cfg, model_cfg, self.cfg) 242 | elif model_cfg['class'] == 'Conv6': 243 | model = Conv6(self.dataset_cfg, model_cfg, self.cfg) 244 | else: 245 | raise NotImplementedError 246 | 247 | if self.data_parallel: 248 | gpu_ids = list(range(self.cfg['num_gpus'])) 249 | return DataParallel(model, gpu_ids) 250 | else: 251 | return model 252 | 253 | def _get_optimizer(self): 254 | optim_name = self.cfg['optimizer'] 255 | 256 | if self.cfg['train_mode'] == 'score_only': 257 | lr = self.cfg['lr'] 258 | weight_decay = self.cfg['weight_decay'] 259 | params = [param for param in self.model.parameters() 260 | if hasattr(param, 'is_score') and param.is_score] 261 | return self._new_optimizer(optim_name, params, lr, weight_decay) 262 | elif self.cfg['train_mode'] == 'normal': 263 | lr = self.cfg['lr'] 264 | weight_decay = self.cfg['weight_decay'] 265 | params = [param for param in self.model.parameters() 266 | if not (hasattr(param, 'is_score') and param.is_score)] 267 | return self._new_optimizer(optim_name, params, lr, weight_decay) 268 | else: 269 | raise NotImplementedError 270 | 271 | def _get_criterion(self): 272 | return nn.CrossEntropyLoss() 273 | 274 | def _new_optimizer(self, name, params, lr, weight_decay, momentum=0.9): 275 | if name == 'AdamW': 276 | return torch.optim.AdamW(params, lr=lr, weight_decay=weight_decay) 277 | elif name == 'SGD': 278 | return torch.optim.SGD(params, lr=lr, 279 | momentum=self.cfg['sgd_momentum'], weight_decay=weight_decay) 280 | else: 281 | raise NotImplementedError 282 | 283 | def _get_scheduler(self): 284 | class null_scheduler(object): 285 | def __init__(self, *args, **kwargs): 286 | return 287 | def step(self, *args, **kwargs): 288 | return 289 | def state_dict(self): 290 | return {} 291 | def load_state_dict(self, dic): 292 | return 293 | 294 | if self.cfg['lr_scheduler'] is None: 295 | return null_scheduler() 296 | elif self.cfg['lr_scheduler'] == 'CustomCosineLR': 297 | total_epoch = self.cfg['epoch'] 298 | init_lr = self.cfg['lr'] 299 | warmup_epochs = self.cfg['warmup_epochs'] 300 | ft_epochs = self.cfg['finetuning_epochs'] 301 | ft_lr = self.cfg['finetuning_lr'] 302 | return CustomCosineLR(self.optimizer, init_lr, total_epoch, warmup_epochs, ft_epochs, ft_lr) 303 | elif self.cfg['lr_scheduler'] == 'MultiStepLR': 304 | return MultiStepLR(self.optimizer, milestones=self.cfg['lr_milestones'], gamma=self.cfg['multisteplr_gamma']) 305 | else: 306 | raise NotImplementedError 307 | 308 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.5.1 2 | torchvision==0.6.1 3 | pyyaml==5.3.1 4 | pandas==1.2.0 5 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dchiji-ntt/iterand/90e0862269ba12df2c824d3f3b694b2b1ae02838/utils/__init__.py -------------------------------------------------------------------------------- /utils/datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Any, Dict, List, Iterator, Optional, Tuple 3 | import torch 4 | from torchvision.datasets.folder import ImageFolder 5 | import csv 6 | 7 | 8 | DIR_LABEL_CSV = "list/clist.csv" 9 | 10 | class ImageNet(ImageFolder): 11 | 12 | def __init__(self, root: str, train: bool = True, download: Optional[str] = None, **kwargs: Any) -> None: 13 | root = self.root = os.path.join(os.path.expanduser(root), 'imagenet') 14 | self.split = 'train' if train else 'val' 15 | self.split_folder = os.path.join(root, self.split) 16 | 17 | super(ImageNet, self).__init__(self.split_folder, **kwargs) 18 | 19 | self.root = root 20 | self.class_to_idx = self._get_corr() 21 | 22 | def _get_corr(self): 23 | csv_path = os.path.join(self.root, DIR_LABEL_CSV) 24 | with open(csv_path, newline='') as f: 25 | reader = csv.reader(f) 26 | dic = dict() 27 | for row in reader: 28 | if row[2] == "label": # ignore first line of csv 29 | continue 30 | dic[row[0]] = int(row[2]) 31 | return dic 32 | 33 | -------------------------------------------------------------------------------- /utils/filelock.py: -------------------------------------------------------------------------------- 1 | # The following code is from https://github.com/benediktschmitt/py-filelock/blob/master/filelock.py 2 | 3 | # This is free and unencumbered software released into the public domain. 4 | # 5 | # Anyone is free to copy, modify, publish, use, compile, sell, or 6 | # distribute this software, either in source code form or as a compiled 7 | # binary, for any purpose, commercial or non-commercial, and by any 8 | # means. 9 | # 10 | # In jurisdictions that recognize copyright laws, the author or authors 11 | # of this software dedicate any and all copyright interest in the 12 | # software to the public domain. We make this dedication for the benefit 13 | # of the public at large and to the detriment of our heirs and 14 | # successors. We intend this dedication to be an overt act of 15 | # relinquishment in perpetuity of all present and future rights to this 16 | # software under copyright law. 17 | # 18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 19 | # EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 20 | # MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 21 | # IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR 22 | # OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, 23 | # ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR 24 | # OTHER DEALINGS IN THE SOFTWARE. 25 | # 26 | # For more information, please refer to 27 | 28 | """ 29 | A platform independent file lock that supports the with-statement. 30 | """ 31 | 32 | 33 | # Modules 34 | # ------------------------------------------------ 35 | import logging 36 | import os 37 | import threading 38 | import time 39 | try: 40 | import warnings 41 | except ImportError: 42 | warnings = None 43 | 44 | try: 45 | import msvcrt 46 | except ImportError: 47 | msvcrt = None 48 | 49 | try: 50 | import fcntl 51 | except ImportError: 52 | fcntl = None 53 | 54 | 55 | # Backward compatibility 56 | # ------------------------------------------------ 57 | try: 58 | TimeoutError 59 | except NameError: 60 | TimeoutError = OSError 61 | 62 | 63 | # Data 64 | # ------------------------------------------------ 65 | __all__ = [ 66 | "Timeout", 67 | "BaseFileLock", 68 | "WindowsFileLock", 69 | "UnixFileLock", 70 | "SoftFileLock", 71 | "FileLock" 72 | ] 73 | 74 | __version__ = "3.0.12" 75 | 76 | 77 | _logger = None 78 | def logger(): 79 | """Returns the logger instance used in this module.""" 80 | global _logger 81 | _logger = _logger or logging.getLogger(__name__) 82 | return _logger 83 | 84 | 85 | # Exceptions 86 | # ------------------------------------------------ 87 | class Timeout(TimeoutError): 88 | """ 89 | Raised when the lock could not be acquired in *timeout* 90 | seconds. 91 | """ 92 | 93 | def __init__(self, lock_file): 94 | """ 95 | """ 96 | #: The path of the file lock. 97 | self.lock_file = lock_file 98 | return None 99 | 100 | def __str__(self): 101 | temp = "The file lock '{}' could not be acquired."\ 102 | .format(self.lock_file) 103 | return temp 104 | 105 | 106 | # Classes 107 | # ------------------------------------------------ 108 | 109 | # This is a helper class which is returned by :meth:`BaseFileLock.acquire` 110 | # and wraps the lock to make sure __enter__ is not called twice when entering 111 | # the with statement. 112 | # If we would simply return *self*, the lock would be acquired again 113 | # in the *__enter__* method of the BaseFileLock, but not released again 114 | # automatically. 115 | # 116 | # :seealso: issue #37 (memory leak) 117 | class _Acquire_ReturnProxy(object): 118 | 119 | def __init__(self, lock): 120 | self.lock = lock 121 | return None 122 | 123 | def __enter__(self): 124 | return self.lock 125 | 126 | def __exit__(self, exc_type, exc_value, traceback): 127 | self.lock.release() 128 | return None 129 | 130 | 131 | class BaseFileLock(object): 132 | """ 133 | Implements the base class of a file lock. 134 | """ 135 | 136 | def __init__(self, lock_file, timeout = -1): 137 | """ 138 | """ 139 | # The path to the lock file. 140 | self._lock_file = lock_file 141 | 142 | # The file descriptor for the *_lock_file* as it is returned by the 143 | # os.open() function. 144 | # This file lock is only NOT None, if the object currently holds the 145 | # lock. 146 | self._lock_file_fd = None 147 | 148 | # The default timeout value. 149 | self.timeout = timeout 150 | 151 | # We use this lock primarily for the lock counter. 152 | self._thread_lock = threading.Lock() 153 | 154 | # The lock counter is used for implementing the nested locking 155 | # mechanism. Whenever the lock is acquired, the counter is increased and 156 | # the lock is only released, when this value is 0 again. 157 | self._lock_counter = 0 158 | return None 159 | 160 | @property 161 | def lock_file(self): 162 | """ 163 | The path to the lock file. 164 | """ 165 | return self._lock_file 166 | 167 | @property 168 | def timeout(self): 169 | """ 170 | You can set a default timeout for the filelock. It will be used as 171 | fallback value in the acquire method, if no timeout value (*None*) is 172 | given. 173 | 174 | If you want to disable the timeout, set it to a negative value. 175 | 176 | A timeout of 0 means, that there is exactly one attempt to acquire the 177 | file lock. 178 | 179 | .. versionadded:: 2.0.0 180 | """ 181 | return self._timeout 182 | 183 | @timeout.setter 184 | def timeout(self, value): 185 | """ 186 | """ 187 | self._timeout = float(value) 188 | return None 189 | 190 | # Platform dependent locking 191 | # -------------------------------------------- 192 | 193 | def _acquire(self): 194 | """ 195 | Platform dependent. If the file lock could be 196 | acquired, self._lock_file_fd holds the file descriptor 197 | of the lock file. 198 | """ 199 | raise NotImplementedError() 200 | 201 | def _release(self): 202 | """ 203 | Releases the lock and sets self._lock_file_fd to None. 204 | """ 205 | raise NotImplementedError() 206 | 207 | # Platform independent methods 208 | # -------------------------------------------- 209 | 210 | @property 211 | def is_locked(self): 212 | """ 213 | True, if the object holds the file lock. 214 | 215 | .. versionchanged:: 2.0.0 216 | 217 | This was previously a method and is now a property. 218 | """ 219 | return self._lock_file_fd is not None 220 | 221 | def acquire(self, timeout=None, poll_intervall=0.05): 222 | """ 223 | Acquires the file lock or fails with a :exc:`Timeout` error. 224 | 225 | .. code-block:: python 226 | 227 | # You can use this method in the context manager (recommended) 228 | with lock.acquire(): 229 | pass 230 | 231 | # Or use an equivalent try-finally construct: 232 | lock.acquire() 233 | try: 234 | pass 235 | finally: 236 | lock.release() 237 | 238 | :arg float timeout: 239 | The maximum time waited for the file lock. 240 | If ``timeout < 0``, there is no timeout and this method will 241 | block until the lock could be acquired. 242 | If ``timeout`` is None, the default :attr:`~timeout` is used. 243 | 244 | :arg float poll_intervall: 245 | We check once in *poll_intervall* seconds if we can acquire the 246 | file lock. 247 | 248 | :raises Timeout: 249 | if the lock could not be acquired in *timeout* seconds. 250 | 251 | .. versionchanged:: 2.0.0 252 | 253 | This method returns now a *proxy* object instead of *self*, 254 | so that it can be used in a with statement without side effects. 255 | """ 256 | # Use the default timeout, if no timeout is provided. 257 | if timeout is None: 258 | timeout = self.timeout 259 | 260 | # Increment the number right at the beginning. 261 | # We can still undo it, if something fails. 262 | with self._thread_lock: 263 | self._lock_counter += 1 264 | 265 | lock_id = id(self) 266 | lock_filename = self._lock_file 267 | start_time = time.time() 268 | try: 269 | while True: 270 | with self._thread_lock: 271 | if not self.is_locked: 272 | logger().debug('Attempting to acquire lock %s on %s', lock_id, lock_filename) 273 | self._acquire() 274 | 275 | if self.is_locked: 276 | logger().info('Lock %s acquired on %s', lock_id, lock_filename) 277 | break 278 | elif timeout >= 0 and time.time() - start_time > timeout: 279 | logger().debug('Timeout on acquiring lock %s on %s', lock_id, lock_filename) 280 | raise Timeout(self._lock_file) 281 | else: 282 | logger().debug( 283 | 'Lock %s not acquired on %s, waiting %s seconds ...', 284 | lock_id, lock_filename, poll_intervall 285 | ) 286 | time.sleep(poll_intervall) 287 | except: 288 | # Something did go wrong, so decrement the counter. 289 | with self._thread_lock: 290 | self._lock_counter = max(0, self._lock_counter - 1) 291 | 292 | raise 293 | return _Acquire_ReturnProxy(lock = self) 294 | 295 | def release(self, force = False): 296 | """ 297 | Releases the file lock. 298 | 299 | Please note, that the lock is only completly released, if the lock 300 | counter is 0. 301 | 302 | Also note, that the lock file itself is not automatically deleted. 303 | 304 | :arg bool force: 305 | If true, the lock counter is ignored and the lock is released in 306 | every case. 307 | """ 308 | with self._thread_lock: 309 | 310 | if self.is_locked: 311 | self._lock_counter -= 1 312 | 313 | if self._lock_counter == 0 or force: 314 | lock_id = id(self) 315 | lock_filename = self._lock_file 316 | 317 | logger().debug('Attempting to release lock %s on %s', lock_id, lock_filename) 318 | self._release() 319 | self._lock_counter = 0 320 | logger().info('Lock %s released on %s', lock_id, lock_filename) 321 | 322 | return None 323 | 324 | def __enter__(self): 325 | self.acquire() 326 | return self 327 | 328 | def __exit__(self, exc_type, exc_value, traceback): 329 | self.release() 330 | return None 331 | 332 | def __del__(self): 333 | self.release(force = True) 334 | return None 335 | 336 | 337 | # Windows locking mechanism 338 | # ~~~~~~~~~~~~~~~~~~~~~~~~~ 339 | 340 | class WindowsFileLock(BaseFileLock): 341 | """ 342 | Uses the :func:`msvcrt.locking` function to hard lock the lock file on 343 | windows systems. 344 | """ 345 | 346 | def _acquire(self): 347 | open_mode = os.O_RDWR | os.O_CREAT | os.O_TRUNC 348 | 349 | try: 350 | fd = os.open(self._lock_file, open_mode) 351 | except OSError: 352 | pass 353 | else: 354 | try: 355 | msvcrt.locking(fd, msvcrt.LK_NBLCK, 1) 356 | except (IOError, OSError): 357 | os.close(fd) 358 | else: 359 | self._lock_file_fd = fd 360 | return None 361 | 362 | def _release(self): 363 | fd = self._lock_file_fd 364 | self._lock_file_fd = None 365 | msvcrt.locking(fd, msvcrt.LK_UNLCK, 1) 366 | os.close(fd) 367 | 368 | try: 369 | os.remove(self._lock_file) 370 | # Probably another instance of the application 371 | # that acquired the file lock. 372 | except OSError: 373 | pass 374 | return None 375 | 376 | # Unix locking mechanism 377 | # ~~~~~~~~~~~~~~~~~~~~~~ 378 | 379 | class UnixFileLock(BaseFileLock): 380 | """ 381 | Uses the :func:`fcntl.flock` to hard lock the lock file on unix systems. 382 | """ 383 | 384 | def _acquire(self): 385 | open_mode = os.O_RDWR | os.O_CREAT | os.O_TRUNC 386 | fd = os.open(self._lock_file, open_mode) 387 | 388 | try: 389 | fcntl.flock(fd, fcntl.LOCK_EX | fcntl.LOCK_NB) 390 | except (IOError, OSError): 391 | os.close(fd) 392 | else: 393 | self._lock_file_fd = fd 394 | return None 395 | 396 | def _release(self): 397 | # Do not remove the lockfile: 398 | # 399 | # https://github.com/benediktschmitt/py-filelock/issues/31 400 | # https://stackoverflow.com/questions/17708885/flock-removing-locked-file-without-race-condition 401 | fd = self._lock_file_fd 402 | self._lock_file_fd = None 403 | fcntl.flock(fd, fcntl.LOCK_UN) 404 | os.close(fd) 405 | return None 406 | 407 | # Soft lock 408 | # ~~~~~~~~~ 409 | 410 | class SoftFileLock(BaseFileLock): 411 | """ 412 | Simply watches the existence of the lock file. 413 | """ 414 | 415 | def _acquire(self): 416 | open_mode = os.O_WRONLY | os.O_CREAT | os.O_EXCL | os.O_TRUNC 417 | try: 418 | fd = os.open(self._lock_file, open_mode) 419 | except (IOError, OSError): 420 | pass 421 | else: 422 | self._lock_file_fd = fd 423 | return None 424 | 425 | def _release(self): 426 | os.close(self._lock_file_fd) 427 | self._lock_file_fd = None 428 | 429 | try: 430 | os.remove(self._lock_file) 431 | # The file is already deleted and that's what we want. 432 | except OSError: 433 | pass 434 | return None 435 | 436 | 437 | # Platform filelock 438 | # ~~~~~~~~~~~~~~~~~ 439 | 440 | #: Alias for the lock, which should be used for the current platform. On 441 | #: Windows, this is an alias for :class:`WindowsFileLock`, on Unix for 442 | #: :class:`UnixFileLock` and otherwise for :class:`SoftFileLock`. 443 | FileLock = None 444 | 445 | if msvcrt: 446 | FileLock = WindowsFileLock 447 | elif fcntl: 448 | FileLock = UnixFileLock 449 | else: 450 | FileLock = SoftFileLock 451 | 452 | if warnings is not None: 453 | warnings.warn("only soft file lock is available") 454 | 455 | -------------------------------------------------------------------------------- /utils/info.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | import os 3 | import json 4 | import argparse 5 | 6 | filter_keys = [ 7 | 'model.config_name', 8 | 'conv_sparsity', 9 | 'rerand_lambda', 10 | 'lr', 11 | 'seed', 12 | ] 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('dir', type=str) 16 | parser.add_argument('-l', '--loss_train', action='store_true') 17 | for key in filter_keys: 18 | parser.add_argument('--' + key, type=str, default=None) 19 | args = parser.parse_args() 20 | args_dic = vars(args) 21 | 22 | def filter(info): 23 | ret = False 24 | for key in filter_keys: 25 | if args_dic[key] is None: 26 | pass 27 | elif key + "_" + args_dic[key] + "--" in info['prefix']: 28 | pass 29 | else: 30 | ret = True 31 | return ret 32 | 33 | dir = args.dir 34 | infos = [] 35 | for fn in os.listdir(dir): 36 | full_fn = os.path.join(dir, fn) 37 | if not os.path.isfile(full_fn): 38 | continue 39 | if fn[:4] == 'info': 40 | with open(full_fn, 'r') as f: 41 | infostr = f.read() 42 | infos.append(json.loads(infostr)) 43 | 44 | 45 | if args.loss_train: 46 | infos = [x for x in infos if 'acc_train' in x] 47 | last_infos = sorted(infos, key=lambda x: -x['acc_train']) 48 | print("") 49 | print("Sorted by last train accs:") 50 | for info in last_infos: 51 | if filter(info): 52 | continue 53 | print("%.4f (%d)\t%s" % (info['acc_train'], info['epoch'], info['prefix'])) 54 | 55 | last_infos = sorted(infos, key=lambda x: -x['loss_train']) 56 | print("") 57 | print("Sorted by train loss:") 58 | for info in last_infos: 59 | if filter(info): 60 | continue 61 | print("%.10f (%d)\t%s" % (info['loss_train'], info['epoch'], info['prefix'])) 62 | else: 63 | accs = [] 64 | best_infos = sorted(infos, key=lambda x: -x['best_val']) 65 | print("") 66 | print("Sorted by best val accs:") 67 | for info in best_infos: 68 | if filter(info): 69 | continue 70 | print("%.4f (%d/%d)\t%s" % (info['best_val'], info['best_epoch'], info['epoch'], info['prefix'])) 71 | accs.append(info['best_val']) 72 | mean = sum(accs) / len(accs) 73 | print("Mean:", mean) 74 | 75 | 76 | accs = [] 77 | last_infos = sorted(infos, key=lambda x: -x['last_val']) 78 | print("") 79 | print("Sorted by last val accs:") 80 | for info in last_infos: 81 | if filter(info): 82 | continue 83 | print("%.4f (%d)\t%s" % (info['last_val'], info['epoch'], info['prefix'])) 84 | accs.append(info['last_val']) 85 | mean = sum(accs) / len(accs) 86 | print("Mean:", mean) 87 | -------------------------------------------------------------------------------- /utils/output_manager.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pprint 4 | 5 | class OutputManager(object): 6 | def __init__(self, output_dir, name): 7 | self.output_dir = output_dir 8 | self.name = name 9 | self.save_dir = os.path.join(self.output_dir, name) 10 | 11 | if not os.path.exists(self.output_dir): 12 | try: 13 | os.makedirs(self.output_dir) 14 | except Exception as e: 15 | print('[OutputManager] Caught Exception:', e.args) 16 | 17 | if not os.path.exists(self.save_dir): 18 | try: 19 | os.makedirs(self.save_dir) 20 | except Exception as e: 21 | print('[OutputManager] Caught Exception:', e.args) 22 | 23 | def save_dict(self, dic, prefix="dump", ext="pth", name=None): 24 | filepath = self.get_abspath(prefix, ext, name) 25 | with open(filepath, 'wb') as f: 26 | torch.save(dic, f) 27 | 28 | def load_dict(self, prefix="dump", ext="pth", name=None): 29 | filepath = self.get_abspath(prefix, ext, name) 30 | return torch.load(filepath) 31 | 32 | def get_abspath(self, prefix, ext, name=None): 33 | if name is None: 34 | name = self.name 35 | return os.path.abspath(os.path.join(self.save_dir, f'{prefix}.{name}.{ext}')) 36 | 37 | def add_log(self): 38 | pass 39 | 40 | def print(self, *args, prefix=""): 41 | print(*args) 42 | print(*args, file=open(os.path.join(self.save_dir, f'{prefix}.{self.name}.out'), "a+")) 43 | 44 | def pprint(self, *args, prefix=""): 45 | s = pprint.pformat(*args, indent=1) 46 | self.print(s, prefix=prefix) 47 | 48 | if __name__ == '__main__': 49 | outman = OutputManager('test', 'outman') 50 | outman.print("a", "b", prefix="thisisprefix") 51 | outman.print("c", "d", prefix="thisisprefix") 52 | -------------------------------------------------------------------------------- /utils/pd_logger.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import pickle 4 | import os 5 | 6 | class PDLogger(object): 7 | def __init__(self, filename=None): 8 | if filename is not None: 9 | with open(filename, "rb") as f: 10 | self.dfs = pickle.load(f) 11 | else: 12 | self.dfs = dict() 13 | 14 | self.filename = filename 15 | 16 | def set_filename(self, filename): 17 | self.filename = filename 18 | 19 | def save(self): 20 | if self.filename is None: 21 | raise Exception 22 | else: 23 | with open(self.filename, "wb") as f: 24 | pickle.dump(self.dfs, f) 25 | 26 | def load(self): 27 | if self.filename is None: 28 | raise Exception 29 | else: 30 | if os.path.getsize(self.filename) <= 0: 31 | raise Exception 32 | with open(self.filename, "rb") as f: 33 | self.dfs = pickle.load(f) 34 | 35 | def add(self, attr, value, index=None, columns=None): 36 | if attr in self.dfs: 37 | df = self.dfs[attr] 38 | if index[0] in df.index: 39 | print(f'[PDLogger] Warning: The results are already set at index={index[0]}.') 40 | return 41 | 42 | if columns is None: 43 | columns = df.columns 44 | df_new = pd.DataFrame(value, index=index, columns=columns) 45 | self.dfs[attr] = df.append(df_new) 46 | else: 47 | self.dfs[attr] = pd.DataFrame(value, index=index, columns=columns) 48 | 49 | def get_df(self, attr): 50 | return self.dfs[attr] 51 | -------------------------------------------------------------------------------- /utils/schedulers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | # This scheduler is based on https://github.com/allenai/hidden-networks/blob/master/utils/schedulers.py 4 | 5 | class CustomCosineLR(object): 6 | def __init__(self, optimizer, init_lr, total_epoch, warmup_length, ft_length, ft_lr): 7 | self.optimizer = optimizer 8 | self.init_lr = init_lr 9 | self.total_epoch = total_epoch 10 | self.warmup_length = warmup_length 11 | self.ft_length = ft_length 12 | self.ft_lr = ft_lr 13 | self.step_before_train = True 14 | 15 | def step(self, epoch=None): 16 | assert epoch is not None 17 | 18 | if epoch < self.warmup_length: 19 | lr = _warmup_lr(self.init_lr, self.warmup_length, epoch) 20 | elif self.warmup_length <= epoch < self.total_epoch - self.ft_length: 21 | e = epoch - self.warmup_length 22 | es = self.total_epoch - self.warmup_length - self.ft_length 23 | lr = 0.5 * (1 + np.cos(np.pi * e / es)) * self.init_lr 24 | elif self.total_epoch - self.ft_length <= epoch < self.total_epoch: 25 | lr = self.ft_lr 26 | else: 27 | lr = None 28 | 29 | _assign_learning_rate(self.optimizer, lr) 30 | 31 | def state_dict(self): 32 | return {} 33 | 34 | def load_state_dict(self, dic): 35 | return 36 | 37 | def _assign_learning_rate(optimizer, new_lr=None): 38 | if new_lr is not None: 39 | for param_group in optimizer.param_groups: 40 | param_group["lr"] = new_lr 41 | else: 42 | pass 43 | 44 | def _warmup_lr(base_lr, warmup_length, epoch): 45 | return base_lr * (epoch + 1) / warmup_length 46 | 47 | -------------------------------------------------------------------------------- /utils/seed.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | 5 | def set_random_seed(seed): 6 | random.seed(seed) 7 | np.random.seed(seed) 8 | torch.manual_seed(seed) 9 | torch.cuda.manual_seed(seed) 10 | torch.cuda.manual_seed_all(seed) 11 | torch.backends.cudnn.deterministic = True 12 | torch.backends.cudnn.benchmark = False 13 | -------------------------------------------------------------------------------- /utils/subset_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import randperm 3 | from torch._utils import _accumulate 4 | from torch.utils.data import Dataset, Subset 5 | 6 | # From https://github.com/pytorch/pytorch/blob/master/torch/utils/data/dataset.py 7 | def random_split(dataset, lengths, generator): 8 | if sum(lengths) != len(dataset): 9 | raise ValueError("Sum of input lengths does not equal the length of the input dataset!") 10 | 11 | indices = randperm(sum(lengths), generator=generator).tolist() 12 | return [Subset(dataset, indices[offset - length : offset]) for offset, length in zip(_accumulate(lengths), lengths)] 13 | 14 | class SubsetDataset(Dataset): 15 | def __init__(self, subset, transform=None): 16 | self.subset = subset 17 | self.transform = transform 18 | 19 | def __getitem__(self, idx): 20 | x, y = self.subset[idx] 21 | if self.transform is not None: 22 | x = self.transform(x) 23 | return x, y 24 | 25 | def __len__(self): 26 | return len(self.subset) 27 | 28 | -------------------------------------------------------------------------------- /utils/sync_jobs.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | from os.path import join 4 | import argparse 5 | import datetime 6 | import time 7 | import pathlib 8 | from subprocess import Popen 9 | 10 | class JobManager(object): 11 | interval = 60 * 5 12 | eps = 20 13 | 14 | def __init__(self, dir_path, job_id, sync_script_path=None): 15 | self.job_id = job_id 16 | self.dir_path = os.path.abspath(dir_path) 17 | if sync_script_path is not None: 18 | self.sync_script_path = os.path.abspath(sync_script_path) 19 | 20 | self.ts_dir_path = join(self.dir_path, 'timestamps') 21 | if not os.path.exists(self.ts_dir_path): 22 | try: 23 | os.makedirs(self.ts_dir_path) 24 | except Exception as e: 25 | print('[JobManager] Caught Exception:', e.args) 26 | self.my_ts_path = join(self.ts_dir_path, self.job_id) 27 | self.my_process = None 28 | 29 | def clear(self): 30 | current_ts = self._get_current_ts() 31 | 32 | ts_files = os.listdir(self.ts_dir_path) 33 | for f_name in ts_files: 34 | f_path = join(self.ts_dir_path, f_name) 35 | f = pathlib.Path(f_path) 36 | if not self._check_alive(f_path, current_ts): 37 | os.remove(f_path) 38 | 39 | def start(self): 40 | assert self.my_process is None 41 | self.my_process = Popen(['python3', self.sync_script_path, self.dir_path, self.job_id]) 42 | return True 43 | 44 | def stop(self): 45 | assert self.my_process is not None 46 | self.my_process.terminate() 47 | self.my_process = None 48 | 49 | def check_alive(self, job_id): 50 | current_ts = self._get_current_ts() 51 | return self._check_alive(join(self.ts_dir_path, job_id), current_ts) 52 | 53 | def _check_alive(self, path, current_ts): 54 | f = pathlib.Path(path) 55 | try: 56 | print("[DEBUG]") 57 | print(current_ts - f.stat().st_mtime) 58 | return current_ts - f.stat().st_mtime <= JobManager.interval + JobManager.eps 59 | except FileNotFoundError as e: 60 | return False 61 | 62 | def _get_current_ts(self): 63 | test_path = join(self.ts_dir_path, 'test') 64 | with open(test_path, 'w') as f: 65 | f.write('') 66 | return pathlib.Path(test_path).stat().st_mtime 67 | 68 | 69 | def main(dir_path, job_id): 70 | jman = JobManager(dir_path, job_id, None) 71 | print("[Sync] start: job_id=", job_id) 72 | while True: 73 | with open(jman.my_ts_path, 'w') as f: 74 | f.write("") 75 | time.sleep(JobManager.interval) 76 | 77 | 78 | if __name__ == "__main__": 79 | parser = argparse.ArgumentParser() 80 | parser.add_argument('dir_path', type=str) 81 | parser.add_argument('job_id', type=str) 82 | 83 | args = parser.parse_args() 84 | main(args.dir_path, args.job_id) 85 | -------------------------------------------------------------------------------- /utils/test_info.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | import os 3 | import json 4 | import argparse 5 | 6 | filter_keys = [ 7 | 'model.config_name', 8 | 'conv_sparsity', 9 | 'rerand_lambda', 10 | 'lr', 11 | 'seed', 12 | ] 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('dir', type=str) 16 | parser.add_argument('--epoch', type=str, required=True) 17 | for key in filter_keys: 18 | parser.add_argument('--' + key, type=str, default=None) 19 | args = parser.parse_args() 20 | args_dic = vars(args) 21 | 22 | def filter(info): 23 | ret = False 24 | for key in filter_keys: 25 | if args_dic[key] is None: 26 | pass 27 | elif key + "_" + args_dic[key] + "--" in info['prefix']: 28 | pass 29 | else: 30 | ret = True 31 | return ret 32 | 33 | dir = args.dir 34 | infos = [] 35 | test_prefix = f'test_epoch{args.epoch}' 36 | for fn in os.listdir(dir): 37 | full_fn = os.path.join(dir, fn) 38 | if not os.path.isfile(full_fn): 39 | continue 40 | if fn[:len(test_prefix)] == test_prefix: 41 | with open(full_fn, 'r') as f: 42 | infostr = f.read() 43 | infos.append(json.loads(infostr)) 44 | 45 | accs = [] 46 | last_infos = sorted(infos, key=lambda x: -x['accuracy']) 47 | print("") 48 | print("Sorted by test accs:") 49 | for info in last_infos: 50 | if filter(info): 51 | continue 52 | print("%.4f (%d)\t%s" % (info['accuracy'], info['epoch'], info['prefix'])) 53 | accs.append(info['accuracy']) 54 | mean = sum(accs) / len(accs) 55 | print("Mean:", mean) 56 | 57 | --------------------------------------------------------------------------------