├── LICENSE ├── README.md ├── config_task.py ├── decathlon_mean_std.pickle ├── download_data.sh ├── imdbfolder_coco.py ├── labels_test.zip ├── matconvnet ├── README.md ├── cnn_cifar.m ├── cnn_resnet_preact_new_conv1.m └── cnn_resnet_preact_reduce_stride.m ├── models.py ├── sgd.py ├── train_new_task_adapters.py ├── train_new_task_finetuning.py ├── train_new_task_from_scratch.py └── utils_pytorch.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Parametric families of deep neural networks with residual adapters [PyTorch + MatConvNet] 2 | 3 | Backbone codes for the papers: 4 | - NIPS 2017: "Learning multiple visual domains with residual adapters", https://papers.nips.cc/paper/6654-learning-multiple-visual-domains-with-residual-adapters.pdf 5 | - CVPR 2018: "Efficient parametrization of multi-domain deep neural networks", https://arxiv.org/pdf/1803.10082.pdf 6 | 7 | Page of our associated **Visual Domain Decathlon challenge** for multi-domain classification: http://www.robots.ox.ac.uk/~vgg/decathlon/ 8 | 9 | ## Abstract 10 | 11 | A practical limitation of deep neural networks is their high degree of specialization to a single task and visual domain. 12 | To overcome this limitation, in these papers we propose to consider instead universal parametric families of neural 13 | networks, which still contain specialized problem-specific models, but differing only by a small number of parameters. 14 | We study different designs for such parametrizations, including 15 | series and parallel residual adapters. We show that, in order to maximize performance, it is necessary 16 | to adapt both shallow and deep layers of a deep network, 17 | but the required changes are very small. We also show that 18 | these universal parametrization are very effective for transfer 19 | learning, where they outperform traditional fine-tuning 20 | techniques. 21 | 22 | ## Code 23 | 24 | ### Requirements 25 | - PyTorch 26 | - or MatConvNet with MATLAB 27 | 28 | ### Launching the code 29 | First download the data with ``download_data.sh /path/to/save/data/``. Please copy ``decathlon_mean_std.pickle`` to the data folder. 30 | 31 | To train a dataset from scratch: 32 | 33 | ``CUDA_VISIBLE_DEVICES=2 python train_new_task_from_scratch.py --dataset cifar100 --wd3x3 1. --wd 5. --mode bn `` 34 | 35 | To train a dataset with parallel adapters put on a pretrained 'off the shelf' deep network: 36 | 37 | ``CUDA_VISIBLE_DEVICES=2 python train_new_task_adapters.py --dataset cifar100 --wd1x1 1. --wd 5. --mode parallel_adapters --source /path/to/net`` 38 | 39 | To train a dataset with series adapters put on a pretrained deep network (with adapters in it during pretraining): 40 | 41 | ``CUDA_VISIBLE_DEVICES=2 python train_new_task_adapters.py --dataset cifar100 --wd1x1 1. --wd 5. --mode series_adapters --source /path/to/net`` 42 | 43 | To train a dataset with series adapters put on a pretrained 'off the shelf' deep network: 44 | 45 | ``CUDA_VISIBLE_DEVICES=2 python train_new_task_adapters.py --dataset cifar100 --wd1x1 1. --wd 5. --mode series_adapters --source /path/to/net`` 46 | 47 | To train a dataset with normal finetuning from a pretrained deep network: 48 | 49 | ``CUDA_VISIBLE_DEVICES=2 python train_new_task_finetuning.py --dataset cifar100 --wd 5. --mode bn --source /path/to/net`` 50 | 51 | ### Pretrained networks 52 | We pretrained networks on ImageNet (with reduced resolution): 53 | - a ResNet 26 inspired from the original ResNet from [He,16]: https://drive.google.com/open?id=1y7gz_9KfjY8O4Ue3yHE7SpwA90Ua1mbR 54 | - the same network with series adapters already in it:https://drive.google.com/open?id=1f1eBQY6eHm616SAt0UXxY9RldNM9XAHb 55 | 56 | ### Results of the commands above with the pretrained networks 57 | So we train on CIFAR 100 and evaluate on the eval split: 58 | 59 | | | Val. Acc. | 60 | | :------------ | :-------------: | 61 | | Scratch | 75.23 | 62 | | Parallel adapters | 80.61 | 63 | | Series adapters | 80.17 | 64 | | Series adapters (off the shelf) | 70.72 | 65 | | Normal finetuning | 78.40 | 66 | 67 | ## If you consider citing us 68 | 69 | For the Visual Domain Decathlon challenge and the series adapters: 70 | 71 | 72 | @inproceedings{Rebuffi17, 73 | author = "Rebuffi, S-A and Bilen, H. and Vedaldi, A.", 74 | title = "Learning multiple visual domains with residual adapters", 75 | booktitle = "Advances in Neural Information Processing Systems", 76 | year = "2017", 77 | } 78 | 79 | 80 | For the parallel adapters: 81 | 82 | 83 | @inproceedings{ rebuffi-cvpr2018, 84 | author = { Sylvestre-Alvise Rebuffi and Hakan Bilen and Andrea Vedaldi }, 85 | title = {Efficient parametrization of multi-domain deep neural networks}, 86 | booktitle = CVPR, 87 | year = 2018, 88 | } 89 | 90 | -------------------------------------------------------------------------------- /config_task.py: -------------------------------------------------------------------------------- 1 | # File allowing to change which task is currently used for training/testing 2 | task = 0 3 | mode = 'normal' 4 | proj = '11' 5 | factor = 1. 6 | -------------------------------------------------------------------------------- /decathlon_mean_std.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srebuffi/residual_adapters/5c51aa3ff842e8d77f12dc7215750fe6adc0d0f7/decathlon_mean_std.pickle -------------------------------------------------------------------------------- /download_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd $1 4 | 5 | wget http://www.robots.ox.ac.uk/~vgg/share/decathlon-1.0-devkit.tar.gz 6 | tar -xzvf decathlon-1.0-devkit.tar.gz 7 | rm decathlon-1.0-devkit.tar.gz 8 | 9 | cd decathlon-1.0 10 | cd data 11 | wget http://www.robots.ox.ac.uk/~vgg/share/decathlon-1.0-data.tar.gz 12 | tar -xzvf decathlon-1.0-data.tar.gz 13 | rm decathlon-1.0-data.tar.gz 14 | for filename in *.tar 15 | do 16 | tar -xvf $filename 17 | rm $filename 18 | done 19 | -------------------------------------------------------------------------------- /imdbfolder_coco.py: -------------------------------------------------------------------------------- 1 | # imdbfolder_coco.py 2 | # created by Sylvestre-Alvise Rebuffi [srebuffi@robots.ox.ac.uk] 3 | # Copyright © The University of Oxford, 2017-2020 4 | # This code is made available under the Apache v2.0 licence, see LICENSE.txt for details 5 | 6 | import torch.utils.data as data 7 | import torchvision 8 | import torchvision.transforms as transforms 9 | import torch 10 | import numpy as np 11 | import pickle 12 | import config_task 13 | from PIL import Image 14 | from pycocotools.coco import COCO 15 | import os 16 | import os.path 17 | 18 | IMG_EXTENSIONS = [ 19 | '.jpg', '.JPG', '.jpeg', '.JPEG', 20 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 21 | ] 22 | 23 | 24 | def pil_loader(path): 25 | return Image.open(path).convert('RGB') 26 | 27 | 28 | class ImageFolder(data.Dataset): 29 | 30 | def __init__(self, root, transform=None, target_transform=None, index=None, 31 | labels=None ,imgs=None,loader=pil_loader,skip_label_indexing=0): 32 | 33 | if len(imgs) == 0: 34 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" 35 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 36 | 37 | self.root = root 38 | if index is not None: 39 | imgs = [imgs[i] for i in index] 40 | self.imgs = imgs 41 | if index is not None: 42 | if skip_label_indexing == 0: 43 | labels = [labels[i] for i in index] 44 | self.labels = labels 45 | self.transform = transform 46 | self.target_transform = target_transform 47 | self.loader = loader 48 | 49 | def __getitem__(self, index): 50 | path = self.imgs[index][0] 51 | target = self.labels[index] 52 | img = self.loader(path) 53 | if self.transform is not None: 54 | img = self.transform(img) 55 | if self.target_transform is not None: 56 | target = self.target_transform(target) 57 | 58 | return img, target 59 | 60 | def __len__(self): 61 | return len(self.imgs) 62 | 63 | 64 | def prepare_data_loaders(dataset_names, data_dir, imdb_dir, shuffle_train=True, index=None): 65 | train_loaders = [] 66 | val_loaders = [] 67 | num_classes = [] 68 | train = [0] 69 | val = [1] 70 | config_task.offset = [] 71 | 72 | imdb_names_train = [imdb_dir + '/' + dataset_names[i] + '_train.json' for i in range(len(dataset_names))] 73 | imdb_names_val = [imdb_dir + '/' + dataset_names[i] + '_val.json' for i in range(len(dataset_names))] 74 | imdb_names = [imdb_names_train, imdb_names_val] 75 | 76 | with open(data_dir + 'decathlon_mean_std.pickle', 'rb') as handle: 77 | dict_mean_std = pickle.load(handle) 78 | 79 | for i in range(len(dataset_names)): 80 | imgnames_train = [] 81 | imgnames_val = [] 82 | labels_train = [] 83 | labels_val = [] 84 | for itera1 in train+val: 85 | annFile = imdb_names[itera1][i] 86 | coco = COCO(annFile) 87 | imgIds = coco.getImgIds() 88 | annIds = coco.getAnnIds(imgIds=imgIds) 89 | anno = coco.loadAnns(annIds) 90 | images = coco.loadImgs(imgIds) 91 | timgnames = [img['file_name'] for img in images] 92 | timgnames_id = [img['id'] for img in images] 93 | labels = [int(ann['category_id'])-1 for ann in anno] 94 | min_lab = min(labels) 95 | labels = [lab - min_lab for lab in labels] 96 | max_lab = max(labels) 97 | 98 | imgnames = [] 99 | for j in range(len(timgnames)): 100 | imgnames.append((data_dir + '/' + timgnames[j],timgnames_id[j])) 101 | 102 | if itera1 in train: 103 | imgnames_train += imgnames 104 | labels_train += labels 105 | if itera1 in val: 106 | imgnames_val += imgnames 107 | labels_val += labels 108 | 109 | num_classes.append(int(max_lab+1)) 110 | config_task.offset.append(min_lab) 111 | means = dict_mean_std[dataset_names[i] + 'mean'] 112 | stds = dict_mean_std[dataset_names[i] + 'std'] 113 | 114 | 115 | if dataset_names[i] in ['gtsrb', 'omniglot','svhn']: # no horz flip 116 | transform_train = transforms.Compose([ 117 | transforms.Resize(72), 118 | transforms.CenterCrop(72), 119 | transforms.ToTensor(), 120 | transforms.Normalize(means, stds), 121 | ]) 122 | else: 123 | transform_train = transforms.Compose([ 124 | transforms.Resize(72), 125 | transforms.RandomCrop(64), 126 | transforms.RandomHorizontalFlip(), 127 | transforms.ToTensor(), 128 | transforms.Normalize(means, stds), 129 | ]) 130 | if dataset_names[i] in ['gtsrb', 'omniglot','svhn']: # no horz flip 131 | transform_test = transforms.Compose([ 132 | transforms.Resize(72), 133 | transforms.CenterCrop(72), 134 | transforms.ToTensor(), 135 | transforms.Normalize(means, stds), 136 | ]) 137 | else: 138 | transform_test = transforms.Compose([ 139 | transforms.Resize(72), 140 | transforms.CenterCrop(64), 141 | transforms.ToTensor(), 142 | transforms.Normalize(means, stds), 143 | ]) 144 | 145 | img_path = data_dir 146 | trainloader = torch.utils.data.DataLoader(ImageFolder(data_dir, transform_train, None, index, labels_train, imgnames_train), batch_size=128, shuffle=shuffle_train, num_workers=4, pin_memory=True) 147 | valloader = torch.utils.data.DataLoader(ImageFolder(data_dir, transform_test, None, None, labels_val, imgnames_val), batch_size=100, shuffle=False, num_workers=4, pin_memory=True) 148 | train_loaders.append(trainloader) 149 | val_loaders.append(valloader) 150 | 151 | return train_loaders, val_loaders, num_classes 152 | 153 | -------------------------------------------------------------------------------- /labels_test.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srebuffi/residual_adapters/5c51aa3ff842e8d77f12dc7215750fe6adc0d0f7/labels_test.zip -------------------------------------------------------------------------------- /matconvnet/README.md: -------------------------------------------------------------------------------- 1 | ## Backbone codes for the parallel residual adapters with MatConvNet 2 | 3 | The ``cnn_cifar.m`` adds parallel residual adapters to a ResNet50 network pretrained on ImageNet and trains on CIFAR10/100. To adapt the architecture to the CIFAR input size, there are two possible options for ``opts.modelType``: 'new_conv1' for replacing the original conv1 by a 3x3 conv1 layer (which will be trained with the adapters) or 'reduce_stride' where the original parameters of the conv1 are preserved and frozen but the stride is reduced to 1. In both cases, the first maxpool layer is deleted. 4 | -------------------------------------------------------------------------------- /matconvnet/cnn_cifar.m: -------------------------------------------------------------------------------- 1 | function [net, info] = cnn_cifar(varargin) 2 | 3 | %Demonstrates ResNet (with preactivation) on: 4 | %CIFAR-10 and CIFAR-100 (tested for depth 164) 5 | % cnn_cifar2('modelType', 'new_conv1','GPU', 4, 'batchSize', 128,'momentum', 0.95, 'weightDecay', 0.0001, 'Nclass', 10, 'learningRate', [0.1*ones(1,80) 0.01*ones(1,10) 0.001*ones(1,30)]) 6 | 7 | run(fullfile(fileparts(mfilename('fullpath')),'matconvnet-1.0-beta25','matlab', 'vl_setupnn.m')) ; 8 | 9 | opts.modelType = 'new_conv1' ; 10 | opts.GPU=[]; 11 | opts.batchSize=128; 12 | opts.weightDecay=0.0001; 13 | opts.momentum=0.9; 14 | opts.Nclass=10; 15 | opts.filterDepths = []; 16 | opts.learningRate = [0.01*ones(1,3) 0.1*ones(1,80) 0.01*ones(1,10) 0.001*ones(1,20)] ; 17 | [opts, varargin] = vl_argparse(opts, varargin) ; 18 | 19 | datas='cifar'; 20 | opts.expDir = sprintf('/scratch/shared/nfs1/srebuffi/MCN/%s_%d-%s-D2%d-R%d',datas, opts.Nclass, opts.modelType); 21 | 22 | [opts, varargin] = vl_argparse(opts, varargin) ; 23 | 24 | opts.dataDir = fullfile(vl_rootnn, 'data', datas) ; 25 | opts.imdbPath = fullfile(opts.expDir, 'imdb.mat'); 26 | opts.whitenData = false ; 27 | opts.contrastNormalization = false ; 28 | opts.networkType = 'dagnn' ; 29 | opts.train = struct() ; 30 | opts = vl_argparse(opts, varargin) ; 31 | if ~isfield(opts.train, 'gpus'), opts.train.gpus = [opts.GPU]; end; 32 | 33 | % ------------------------------------------------------------------------- 34 | % Prepare model and data 35 | % ------------------------------------------------------------------------- 36 | 37 | switch opts.modelType 38 | case 'reduce_stride' 39 | net = cnn_resnet_preact_reduce_stride('Nclass', opts.Nclass); 40 | case 'new_conv1' 41 | net = cnn_resnet_preact_new_conv1('Nclass', opts.Nclass); 42 | otherwise 43 | error('Unknown model type ''%s''.', opts.modelType) ; 44 | end 45 | 46 | net.meta.trainOpts.learningRate=opts.learningRate; %update lr 47 | net.meta.trainOpts.batchSize = opts.batchSize; %batch size 48 | net.meta.trainOpts.weightDecay = opts.weightDecay; %weight decay 49 | net.meta.trainOpts.momentum = opts.momentum ; 50 | net.meta.trainOpts.numEpochs = numel(net.meta.trainOpts.learningRate); %update num. ep. 51 | 52 | if exist(opts.imdbPath, 'file') 53 | imdb = load(opts.imdbPath) ; 54 | else 55 | if opts.Nclass==10 && strcmp(datas,'cifar') 56 | imdb = getCifar10Imdb(opts) ; 57 | mkdir(opts.expDir) ; 58 | save(opts.imdbPath, '-struct', 'imdb') ; 59 | else 60 | imdb = getCifar100Imdb(opts) ; 61 | mkdir(opts.expDir) ; 62 | save(opts.imdbPath, '-struct', 'imdb') ; 63 | end 64 | end 65 | 66 | net.meta.classes.name = imdb.meta.classes(:)' ; 67 | 68 | % ------------------------------------------------------------------------- 69 | % Train 70 | % ------------------------------------------------------------------------- 71 | 72 | [net, info] = cnn_train_dag(net, imdb, getBatch(opts), ... 73 | 'expDir', opts.expDir, ... 74 | net.meta.trainOpts, ... 75 | opts.train, ... 76 | 'val', find(imdb.images.set == 3)) ; 77 | 78 | % ------------------------------------------------------------------------- 79 | function fn = getBatch(opts) 80 | % ------------------------------------------------------------------------- 81 | switch lower(opts.networkType) 82 | case 'simplenn' 83 | error('The simplenn structure is not supported for the ResNet architecture'); 84 | case 'dagnn' 85 | bopts = struct('numGpus', numel(opts.train.gpus)) ; 86 | fn = @(x,y) getDagNNBatch(bopts,x,y) ; 87 | end 88 | 89 | % ------------------------------------------------------------------------- 90 | function inputs = getDagNNBatch(opts, imdb, batch) 91 | % ------------------------------------------------------------------------- 92 | images = imdb.images.data(:,:,:,batch) ; 93 | labels = imdb.images.labels(1,batch) ; 94 | if rand > 0.5, images=fliplr(images) ; end 95 | images=cropRand(images) ; %random crop for all samples 96 | if opts.numGpus > 0 97 | images = gpuArray(images) ; 98 | end 99 | inputs = {'data', images, 'label', labels} ; 100 | 101 | % ------------------------------------------------------------------------- 102 | function imdb = getCifar10Imdb(opts) 103 | % ------------------------------------------------------------------------- 104 | unpackPath = fullfile(opts.dataDir, 'cifar-10-batches-mat'); 105 | files = [arrayfun(@(n) sprintf('data_batch_%d.mat', n), 1:5, 'UniformOutput', false) ... 106 | {'test_batch.mat'}]; 107 | files = cellfun(@(fn) fullfile(unpackPath, fn), files, 'UniformOutput', false); 108 | file_set = uint8([ones(1, 5), 3]); 109 | 110 | if any(cellfun(@(fn) ~exist(fn, 'file'), files)) 111 | url = 'http://www.cs.toronto.edu/~kriz/cifar-10-matlab.tar.gz' ; 112 | fprintf('downloading %s\n', url) ; 113 | untar(url, opts.dataDir) ; 114 | end 115 | 116 | data = cell(1, numel(files)); 117 | labels = cell(1, numel(files)); 118 | sets = cell(1, numel(files)); 119 | for fi = 1:numel(files) 120 | fd = load(files{fi}) ; 121 | data{fi} = permute(reshape(fd.data',32,32,3,[]),[2 1 3 4]) ; 122 | labels{fi} = fd.labels' + 1; % Index from 1 123 | sets{fi} = repmat(file_set(fi), size(labels{fi})); 124 | end 125 | 126 | set = cat(2, sets{:}); 127 | data = single(cat(4, data{:})); 128 | 129 | %pad the images to crop later 130 | data = padarray(data,[4,4],128,'both'); 131 | 132 | %remove mean 133 | r = data(:,:,1,set == 1); 134 | g = data(:,:,3,set == 1); 135 | b = data(:,:,3,set == 1); 136 | meanCifar = [mean(r(:)), mean(g(:)), mean(b(:))]; 137 | data = bsxfun(@minus, data, reshape(meanCifar,1,1,3)); 138 | 139 | %divide by std 140 | stdCifar = [std(r(:)), std(g(:)), std(b(:))]; 141 | data = bsxfun(@times, data,reshape(1./stdCifar,1,1,3)) ; 142 | 143 | clNames = load(fullfile(unpackPath, 'batches.meta.mat')); 144 | 145 | imdb.images.data = data ; 146 | imdb.images.labels = single(cat(2, labels{:})) ; 147 | imdb.images.set = set; 148 | imdb.meta.sets = {'train', 'val', 'test'} ; 149 | imdb.meta.classes = clNames.label_names; 150 | 151 | % ------------------------------------------------------------------------- 152 | function imdb = getCifar100Imdb(opts) 153 | % ------------------------------------------------------------------------- 154 | unpackPath = fullfile(opts.dataDir, 'cifar-100-matlab'); 155 | files{1} = fullfile(unpackPath, 'train.mat'); 156 | files{2} = fullfile(unpackPath, 'test.mat'); 157 | %files{3} = fullfile(unpackPath, 'meta.mat'); 158 | file_set = uint8([1, 3]); 159 | 160 | if any(cellfun(@(fn) ~exist(fn, 'file'), files)) 161 | url = 'http://www.cs.toronto.edu/~kriz/cifar-100-matlab.tar.gz' ; 162 | fprintf('downloading %s\n', url) ; 163 | untar(url, opts.dataDir) ; 164 | end 165 | 166 | data = cell(1, numel(files)); 167 | labels = cell(1, numel(files)); 168 | sets = cell(1, numel(files)); 169 | for fi = 1:numel(files) 170 | fd = load(files{fi}) ; 171 | data{fi} = permute(reshape(fd.data',32,32,3,[]),[2 1 3 4]) ; 172 | labels{fi} = fd.fine_labels' + 1; % Index from 1 173 | sets{fi} = repmat(file_set(fi), size(labels{fi})); 174 | end 175 | 176 | set = cat(2, sets{:}); 177 | data = single(cat(4, data{:})); 178 | 179 | %pad the images to crop later 180 | data = padarray(data,[4,4],128,'both'); 181 | 182 | % remove mean 183 | r = data(:,:,1,set == 1); 184 | g = data(:,:,3,set == 1); 185 | b = data(:,:,3,set == 1); 186 | meanCifar = [mean(r(:)), mean(g(:)), mean(b(:))]; 187 | data = bsxfun(@minus, data, reshape(meanCifar,1,1,3)); 188 | 189 | %divide by std 190 | stdCifar = [std(r(:)), std(g(:)), std(b(:))]; 191 | data = bsxfun(@times, data,reshape(1./stdCifar,1,1,3)) ; 192 | 193 | clNames = load(fullfile(unpackPath, 'meta.mat')); 194 | 195 | imdb.images.data = data ; 196 | imdb.images.labels = single(cat(2, labels{:})) ; 197 | imdb.images.set = set; 198 | imdb.meta.sets = {'train', 'val', 'test'} ; 199 | imdb.meta.classes = clNames.fine_label_names; -------------------------------------------------------------------------------- /matconvnet/cnn_resnet_preact_new_conv1.m: -------------------------------------------------------------------------------- 1 | function net = cnn_resnet_preact_new_conv1(varargin) 2 | 3 | opts.Nclass=10; %number of classses (CIFAR-10 / CIFAR-100) 4 | opts = vl_argparse(opts, varargin) ; 5 | 6 | net = load('/scratch/shared/nfs1/srebuffi/MCN/imagenet-resnet-50-dag.mat') ; 7 | net = dagnn.DagNN.loadobj(net) ; 8 | 9 | % Freeze part of the network 10 | for l = 1:numel(net.layers) 11 | if ~isa(net.layers(l).block, 'dagnn.BatchNorm') & ~isempty(net.layers(l).params) 12 | for l1 = 1:numel(net.layers(l).params) 13 | k = net.getParamIndex(net.layers(l).params{l1}) ; 14 | net.params(k).learningRate = 0.0 ; 15 | end 16 | end 17 | end 18 | 19 | % Create a new conv1 20 | playerName = 'conv1'; 21 | playerIndex = net.getLayerIndex(playerName) ; 22 | player = net.layers(playerIndex) ; 23 | net.renameLayer(playerName, 'tmp') ; 24 | net.addLayer(playerName, ... 25 | dagnn.Conv('size', [3 3 3 64], ... 26 | 'stride', 1, .... 27 | 'pad', 1, ... 28 | 'hasBias', false), ... 29 | player.inputs, ... 30 | player.outputs, ... 31 | ['conv1' '_f']) ; 32 | net.removeLayer('tmp'); 33 | 34 | % Delete the first maxpool layer 35 | net.removeLayer('fc1000'); 36 | net.removeLayer('prob'); 37 | net.removeLayer('pool5'); 38 | layer = net.layers(net.getLayerIndex('pool1')) ; 39 | net.removeLayer('pool1') ; 40 | net.renameVar(layer.outputs{1}, layer.inputs{1}, 'quiet', true) ; 41 | 42 | 43 | params = net.layers(end).block.initParams() ; 44 | params = cellfun(@gather, params, 'UniformOutput', false) ; 45 | p = net.getParamIndex(net.layers(end).params) ; 46 | [net.params(p).value] = deal(params{:}) ; 47 | 48 | 49 | % Add the parallel residual adapters 50 | for l = 1:numel(net.layers) 51 | if isa(net.layers(l).block, 'dagnn.Conv') 52 | if net.layers(l).block.size(1) == 3 53 | net.layers(l).block 54 | net.addLayer([net.layers(l).name '_bis'], ... 55 | dagnn.Conv('size', [1 1 net.layers(l).block.size(3) net.layers(l).block.size(4)], ... 56 | 'stride', net.layers(l).block.stride(1), .... 57 | 'pad', 0, ... 58 | 'hasBias', false), ... 59 | net.layers(l).inputs, ... 60 | {[net.layers(l).name '_bis']}, ... 61 | [net.layers(l).name '_bis' '_f']) ; 62 | 63 | % Initialize params 64 | params = net.layers(end).block.initParams() ; 65 | params = cellfun(@gather, params, 'UniformOutput', false) ; 66 | p = net.getParamIndex(net.layers(end).params) ; 67 | [net.params(p).value] = deal(params{:}) ; 68 | 69 | net.addLayer([net.layers(l).name '_sum'] , ... 70 | dagnn.Sum(), ... 71 | {[net.layers(l).name '_bis'],net.layers(l).name}, ... 72 | [net.layers(l).name '_sum']) ; 73 | layers = {} ; 74 | for l2 = 1:numel(net.layers) 75 | if strcmp(net.layers(l2).inputs{1}, net.layers(l).name) 76 | % net.renameVar(net.layers(l2).inputs{1}, [net.layers(l).name '_sum'], 'quiet', true) ; 77 | net.setLayerInputs(net.layers(l2).name,{[net.layers(l).name '_sum']}); 78 | end 79 | end 80 | end 81 | end 82 | end 83 | 84 | net.addLayer('prediction_avg' , ... 85 | dagnn.Pooling('poolSize', [4 4], 'method', 'avg'), ... 86 | 'res5cx', ... 87 | 'prediction_avg') ; 88 | 89 | net.addLayer('prediction' , ... 90 | dagnn.Conv('size', [1 1 2048 opts.Nclass]), ... 91 | 'prediction_avg', ... 92 | 'prediction', ... 93 | {'prediction_f', 'prediction_b'}) ; 94 | params = net.layers(end).block.initParams() ; 95 | params = cellfun(@gather, params, 'UniformOutput', false) ; 96 | p = net.getParamIndex(net.layers(end).params) ; 97 | [net.params(p).value] = deal(params{:}) ; 98 | %Modification from Andrea (similar to imagenet) 99 | f = net.getParamIndex(net.layers(end).params(1)) ; 100 | net.params(f).value = net.params(f).value /10; 101 | 102 | net.addLayer('loss', ... 103 | dagnn.Loss('loss', 'softmaxlog') ,... 104 | {'prediction', 'label'}, ... 105 | 'objective') ; 106 | 107 | net.addLayer('top1error', ... 108 | dagnn.Loss('loss', 'classerror'), ... 109 | {'prediction', 'label'}, ... 110 | 'error') ; 111 | 112 | 113 | %Meta parameters 114 | net.meta.inputSize = [32 32 3] ; 115 | net.meta.trainOpts.learningRate = [0.01*ones(1,2) 0.1*ones(1,80) 0.01*ones(1,40) 0.001*ones(1,40)] ; 116 | net.meta.trainOpts.weightDecay = 0.0001 ; 117 | net.meta.trainOpts.batchSize = 128 ; 118 | net.meta.trainOpts.momentum = 0.9 ; 119 | net.meta.trainOpts.numEpochs = numel(net.meta.trainOpts.learningRate) ; 120 | 121 | end 122 | -------------------------------------------------------------------------------- /matconvnet/cnn_resnet_preact_reduce_stride.m: -------------------------------------------------------------------------------- 1 | function net = cnn_resnet_preact_reduce_stride(varargin) 2 | 3 | opts.Nclass=10; %number of classses (CIFAR-10 / CIFAR-100) 4 | opts = vl_argparse(opts, varargin) ; 5 | 6 | net = load('/scratch/shared/nfs1/srebuffi/MCN/imagenet-resnet-50-dag.mat') ; 7 | net = dagnn.DagNN.loadobj(net) ; 8 | 9 | % Reduce the stride of conv1 10 | net.layers(1).block.stride = [1 1]; 11 | 12 | % Freeze layers of the network which are not BN layers 13 | for l = 1:numel(net.layers) 14 | if ~isa(net.layers(l).block, 'dagnn.BatchNorm') & ~isempty(net.layers(l).params) 15 | for l1 = 1:numel(net.layers(l).params) 16 | k = net.getParamIndex(net.layers(l).params{l1}) ; 17 | net.params(k).learningRate = 0.0 ; 18 | end 19 | end 20 | end 21 | 22 | % Delete the first maxpool layer 23 | net.removeLayer('fc1000'); 24 | net.removeLayer('prob'); 25 | net.removeLayer('pool5'); 26 | layer = net.layers(net.getLayerIndex('pool1')) ; 27 | net.removeLayer('pool1') ; 28 | net.renameVar(layer.outputs{1}, layer.inputs{1}, 'quiet', true) ; 29 | 30 | 31 | params = net.layers(end).block.initParams() ; 32 | params = cellfun(@gather, params, 'UniformOutput', false) ; 33 | p = net.getParamIndex(net.layers(end).params) ; 34 | [net.params(p).value] = deal(params{:}) ; 35 | 36 | % Add the parallel residual adapters 37 | for l = 1:numel(net.layers) 38 | if isa(net.layers(l).block, 'dagnn.Conv') 39 | if net.layers(l).block.size(1) == 3 40 | net.layers(l).block 41 | net.addLayer([net.layers(l).name '_bis'], ... 42 | dagnn.Conv('size', [1 1 net.layers(l).block.size(3) net.layers(l).block.size(4)], ... 43 | 'stride', net.layers(l).block.stride(1), .... 44 | 'pad', 0, ... 45 | 'hasBias', false), ... 46 | net.layers(l).inputs, ... 47 | {[net.layers(l).name '_bis']}, ... 48 | [net.layers(l).name '_bis' '_f']) ; 49 | 50 | % Initialize params 51 | params = net.layers(end).block.initParams() ; 52 | params = cellfun(@gather, params, 'UniformOutput', false) ; 53 | p = net.getParamIndex(net.layers(end).params) ; 54 | [net.params(p).value] = deal(params{:}) ; 55 | 56 | net.addLayer([net.layers(l).name '_sum'] , ... 57 | dagnn.Sum(), ... 58 | {[net.layers(l).name '_bis'],net.layers(l).name}, ... 59 | [net.layers(l).name '_sum']) ; 60 | layers = {} ; 61 | for l2 = 1:numel(net.layers) 62 | if strcmp(net.layers(l2).inputs{1}, net.layers(l).name) 63 | % net.renameVar(net.layers(l2).inputs{1}, [net.layers(l).name '_sum'], 'quiet', true) ; 64 | net.setLayerInputs(net.layers(l2).name,{[net.layers(l).name '_sum']}); 65 | end 66 | end 67 | end 68 | end 69 | end 70 | 71 | net.addLayer('prediction_avg' , ... 72 | dagnn.Pooling('poolSize', [4 4], 'method', 'avg'), ... 73 | 'res5cx', ... 74 | 'prediction_avg') ; 75 | 76 | net.addLayer('prediction' , ... 77 | dagnn.Conv('size', [1 1 2048 opts.Nclass]), ... 78 | 'prediction_avg', ... 79 | 'prediction', ... 80 | {'prediction_f', 'prediction_b'}) ; 81 | params = net.layers(end).block.initParams() ; 82 | params = cellfun(@gather, params, 'UniformOutput', false) ; 83 | p = net.getParamIndex(net.layers(end).params) ; 84 | [net.params(p).value] = deal(params{:}) ; 85 | %Modification from Andrea (similar to imagenet) 86 | f = net.getParamIndex(net.layers(end).params(1)) ; 87 | net.params(f).value = net.params(f).value /10; 88 | 89 | net.addLayer('loss', ... 90 | dagnn.Loss('loss', 'softmaxlog') ,... 91 | {'prediction', 'label'}, ... 92 | 'objective') ; 93 | 94 | net.addLayer('top1error', ... 95 | dagnn.Loss('loss', 'classerror'), ... 96 | {'prediction', 'label'}, ... 97 | 'error') ; 98 | 99 | 100 | %Meta parameters 101 | net.meta.inputSize = [32 32 3] ; 102 | net.meta.trainOpts.learningRate = [0.01*ones(1,2) 0.1*ones(1,80) 0.01*ones(1,40) 0.001*ones(1,40)] ; 103 | net.meta.trainOpts.weightDecay = 0.0001 ; 104 | net.meta.trainOpts.batchSize = 128 ; 105 | net.meta.trainOpts.momentum = 0.9 ; 106 | net.meta.trainOpts.numEpochs = numel(net.meta.trainOpts.learningRate) ; 107 | 108 | end 109 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | # models.py 2 | # created by Sylvestre-Alvise Rebuffi [srebuffi@robots.ox.ac.uk] 3 | # Copyright © The University of Oxford, 2017-2020 4 | # This code is made available under the Apache v2.0 licence, see LICENSE.txt for details 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from torch.autograd import Variable 11 | from torch.nn.parameter import Parameter 12 | import config_task 13 | import math 14 | 15 | 16 | def conv3x3(in_planes, out_planes, stride=1): 17 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 18 | 19 | def conv1x1_fonc(in_planes, out_planes=None, stride=1, bias=False): 20 | if out_planes is None: 21 | return nn.Conv2d(in_planes, in_planes, kernel_size=1, stride=stride, padding=0, bias=bias) 22 | else: 23 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=bias) 24 | 25 | class conv1x1(nn.Module): 26 | 27 | def __init__(self, planes, out_planes=None, stride=1): 28 | super(conv1x1, self).__init__() 29 | if config_task.mode == 'series_adapters': 30 | self.conv = nn.Sequential(nn.BatchNorm2d(planes), conv1x1_fonc(planes)) 31 | elif config_task.mode == 'parallel_adapters': 32 | self.conv = conv1x1_fonc(planes, out_planes, stride) 33 | else: 34 | self.conv = conv1x1_fonc(planes) 35 | def forward(self, x): 36 | y = self.conv(x) 37 | if config_task.mode == 'series_adapters': 38 | y += x 39 | return y 40 | 41 | class conv_task(nn.Module): 42 | 43 | def __init__(self, in_planes, planes, stride=1, nb_tasks=1, is_proj=1, second=0): 44 | super(conv_task, self).__init__() 45 | self.is_proj = is_proj 46 | self.second = second 47 | self.conv = conv3x3(in_planes, planes, stride) 48 | if config_task.mode == 'series_adapters' and is_proj: 49 | self.bns = nn.ModuleList([nn.Sequential(conv1x1(planes), nn.BatchNorm2d(planes)) for i in range(nb_tasks)]) 50 | elif config_task.mode == 'parallel_adapters' and is_proj: 51 | self.parallel_conv = nn.ModuleList([conv1x1(in_planes, planes, stride) for i in range(nb_tasks)]) 52 | self.bns = nn.ModuleList([nn.BatchNorm2d(planes) for i in range(nb_tasks)]) 53 | else: 54 | self.bns = nn.ModuleList([nn.BatchNorm2d(planes) for i in range(nb_tasks)]) 55 | 56 | def forward(self, x): 57 | task = config_task.task 58 | y = self.conv(x) 59 | if self.second == 0: 60 | if config_task.isdropout1: 61 | x = F.dropout2d(x, p=0.5, training = self.training) 62 | else: 63 | if config_task.isdropout2: 64 | x = F.dropout2d(x, p=0.5, training = self.training) 65 | if config_task.mode == 'parallel_adapters' and self.is_proj: 66 | y = y + self.parallel_conv[task](x) 67 | y = self.bns[task](y) 68 | 69 | return y 70 | 71 | # No projection: identity shortcut 72 | class BasicBlock(nn.Module): 73 | expansion = 1 74 | 75 | def __init__(self, in_planes, planes, stride=1, shortcut=0, nb_tasks=1): 76 | super(BasicBlock, self).__init__() 77 | self.conv1 = conv_task(in_planes, planes, stride, nb_tasks, is_proj=int(config_task.proj[0])) 78 | self.conv2 = nn.Sequential(nn.ReLU(True), conv_task(planes, planes, 1, nb_tasks, is_proj=int(config_task.proj[1]), second=1)) 79 | self.shortcut = shortcut 80 | if self.shortcut == 1: 81 | self.avgpool = nn.AvgPool2d(2) 82 | 83 | def forward(self, x): 84 | residual = x 85 | y = self.conv1(x) 86 | y = self.conv2(y) 87 | if self.shortcut == 1: 88 | residual = self.avgpool(x) 89 | residual = torch.cat((residual, residual*0),1) 90 | y += residual 91 | y = F.relu(y) 92 | return y 93 | 94 | 95 | class ResNet(nn.Module): 96 | def __init__(self, block, nblocks, num_classes=[10]): 97 | super(ResNet, self).__init__() 98 | nb_tasks = len(num_classes) 99 | blocks = [block, block, block] 100 | factor = config_task.factor 101 | self.in_planes = int(32*factor) 102 | self.pre_layers_conv = conv_task(3,int(32*factor), 1, nb_tasks) 103 | self.layer1 = self._make_layer(blocks[0], int(64*factor), nblocks[0], stride=2, nb_tasks=nb_tasks) 104 | self.layer2 = self._make_layer(blocks[1], int(128*factor), nblocks[1], stride=2, nb_tasks=nb_tasks) 105 | self.layer3 = self._make_layer(blocks[2], int(256*factor), nblocks[2], stride=2, nb_tasks=nb_tasks) 106 | self.end_bns = nn.ModuleList([nn.Sequential(nn.BatchNorm2d(int(256*factor)),nn.ReLU(True)) for i in range(nb_tasks)]) 107 | self.avgpool = nn.AdaptiveAvgPool2d(1) 108 | self.linears = nn.ModuleList([nn.Linear(int(256*factor), num_classes[i]) for i in range(nb_tasks)]) 109 | 110 | for m in self.modules(): 111 | if isinstance(m, nn.Conv2d): 112 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 113 | m.weight.data.normal_(0, math.sqrt(2. / n)) 114 | elif isinstance(m, nn.BatchNorm2d): 115 | m.weight.data.fill_(1) 116 | m.bias.data.zero_() 117 | 118 | def _make_layer(self, block, planes, nblocks, stride=1, nb_tasks=1): 119 | shortcut = 0 120 | if stride != 1 or self.in_planes != planes * block.expansion: 121 | shortcut = 1 122 | layers = [] 123 | layers.append(block(self.in_planes, planes, stride, shortcut, nb_tasks=nb_tasks)) 124 | self.in_planes = planes * block.expansion 125 | for i in range(1, nblocks): 126 | layers.append(block(self.in_planes, planes, nb_tasks=nb_tasks)) 127 | return nn.Sequential(*layers) 128 | 129 | def forward(self, x): 130 | x = self.pre_layers_conv(x) 131 | task = config_task.task 132 | x = self.layer1(x) 133 | x = self.layer2(x) 134 | x = self.layer3(x) 135 | x = self.end_bns[task](x) 136 | x = self.avgpool(x) 137 | x = x.view(x.size(0), -1) 138 | x = self.linears[task](x) 139 | return x 140 | 141 | 142 | def resnet26(num_classes=10, blocks=BasicBlock): 143 | return ResNet(blocks, [4,4,4],num_classes) 144 | 145 | 146 | -------------------------------------------------------------------------------- /sgd.py: -------------------------------------------------------------------------------- 1 | # sgd.py 2 | # created by Sylvestre-Alvise Rebuffi [srebuffi@robots.ox.ac.uk] 3 | # Copyright © The University of Oxford, 2017-2020 4 | # This code is made available under the Apache v2.0 licence, see LICENSE.txt for details 5 | 6 | import torch 7 | import math 8 | import torch.nn.functional as F 9 | import config_task 10 | 11 | class SGD(torch.optim.Optimizer): 12 | def __init__(self, params, lr=0.1, momentum=0, dampening=0, 13 | weight_decay=0, nesterov=False): 14 | defaults = dict(lr=lr, momentum=momentum, dampening=dampening, 15 | weight_decay=weight_decay, nesterov=nesterov) 16 | if nesterov and (momentum <= 0 or dampening != 0): 17 | raise ValueError("Nesterov momentum requires a momentum and zero dampening") 18 | super(SGD, self).__init__(params, defaults) 19 | 20 | def __setstate__(self, state): 21 | super(SGD, self).__setstate__(state) 22 | for group in self.param_groups: 23 | group.setdefault('nesterov', False) 24 | 25 | def step(self, closure=None): 26 | loss = None 27 | if closure is not None: 28 | loss = closure() 29 | 30 | for group in self.param_groups: 31 | weight_decay = group['weight_decay'] 32 | momentum = group['momentum'] 33 | dampening = group['dampening'] 34 | nesterov = group['nesterov'] 35 | 36 | for p in group['params']: 37 | if p.grad is None: 38 | continue 39 | d_p = p.grad.data 40 | param_state = self.state[p] 41 | siz = p.grad.size() 42 | if len(siz) > 3: 43 | if siz[2] == 3: 44 | weight_decay = config_task.decay3x3[config_task.task] 45 | elif siz[2] == 1: 46 | weight_decay = config_task.decay1x1[config_task.task] 47 | if weight_decay != 0: 48 | d_p.add_(weight_decay, p.data) 49 | if momentum != 0: 50 | if 'momentum_buffer' not in param_state: 51 | buf = param_state['momentum_buffer'] = d_p.clone() 52 | else: 53 | buf = param_state['momentum_buffer'] 54 | buf.mul_(momentum).add_(1 - dampening, d_p) 55 | if nesterov: 56 | d_p = d_p.add(momentum, buf) 57 | else: 58 | d_p = buf 59 | 60 | p.data.add_(-group['lr'], d_p) 61 | 62 | return loss 63 | 64 | -------------------------------------------------------------------------------- /train_new_task_adapters.py: -------------------------------------------------------------------------------- 1 | # train_new_task_adapters.py 2 | # created by Sylvestre-Alvise Rebuffi [srebuffi@robots.ox.ac.uk] 3 | # Copyright © The University of Oxford, 2017-2020 4 | # This code is made available under the Apache v2.0 licence, see LICENSE.txt for details 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | import torch.nn.functional as F 10 | import torch.backends.cudnn as cudnn 11 | import torchvision 12 | import torchvision.transforms as transforms 13 | import models 14 | import os 15 | import time 16 | import argparse 17 | import numpy as np 18 | 19 | from torch.autograd import Variable 20 | 21 | import imdbfolder_coco as imdbfolder 22 | import config_task 23 | import utils_pytorch 24 | import sgd 25 | 26 | parser = argparse.ArgumentParser(description='PyTorch Residual Adapters training') 27 | parser.add_argument('--dataset', default='cifar100', nargs='+', help='Task(s) to be trained') 28 | parser.add_argument('--lr', default=0.1, type=float, help='initial learning rate') 29 | parser.add_argument('--wd', default=1., type=float, help='weight decay for the classification layer') 30 | parser.add_argument('--wd3x3', default=1., type=float, nargs='+', help='weight decay for the 3x3') 31 | parser.add_argument('--wd1x1', default=1., type=float, nargs='+', help='weight decay for the 1x1') 32 | parser.add_argument('--nb_epochs', default=120, type=int, help='nb epochs') 33 | parser.add_argument('--step1', default=80, type=int, help='nb epochs before first lr decrease') 34 | parser.add_argument('--step2', default=100, type=int, help='nb epochs before second lr decrease') 35 | parser.add_argument('--mode', default='parallel_adapters', type=str, help='Task adaptation mode') 36 | parser.add_argument('--proj', default='11', type=str, help='Position of the adaptation module') 37 | parser.add_argument('--dropout', default='00', type=str, help='Position of dropouts') 38 | parser.add_argument('--expdir', default='/scratch/shared/nfs1/srebuffi/exp/dem_learning/tmp/', help='Save folder') 39 | parser.add_argument('--datadir', default='/scratch/local/ramdisk/srebuffi/decathlon/', help='folder containing data folder') 40 | parser.add_argument('--imdbdir', default='/scratch/local/ramdisk/srebuffi/decathlon/annotations/', help='annotation folder') 41 | parser.add_argument('--source', default='/scratch/shared/nfs1/srebuffi/exp/dem_learning/C100_alone/checkpoint/ckptpost11bnresidual11cifar1000.000180607060.t7', type=str, help='Network source') 42 | parser.add_argument('--seed', default=0, type=int, help='seed') 43 | parser.add_argument('--factor', default='1.', type=float, help='Width factor of the network') 44 | args = parser.parse_args() 45 | args.archi ='default' 46 | config_task.mode = args.mode 47 | config_task.proj = args.proj 48 | config_task.factor = args.factor 49 | args.use_cuda = torch.cuda.is_available() 50 | if type(args.dataset) is str: 51 | args.dataset = [args.dataset] 52 | 53 | if type(args.wd3x3) is float: 54 | args.wd3x3 = [args.wd3x3] 55 | 56 | if type(args.wd1x1) is float: 57 | args.wd1x1 = [args.wd1x1] 58 | 59 | if not os.path.isdir(args.expdir): 60 | os.mkdir(args.expdir) 61 | 62 | config_task.decay3x3 = np.array(args.wd3x3) * 0.0001 63 | config_task.decay1x1 = np.array(args.wd1x1) * 0.0001 64 | args.wd = args.wd * 0.0001 65 | 66 | args.ckpdir = args.expdir + '/checkpoint/' 67 | args.svdir = args.expdir + '/results/' 68 | 69 | if not os.path.isdir(args.ckpdir): 70 | os.mkdir(args.ckpdir) 71 | 72 | if not os.path.isdir(args.svdir): 73 | os.mkdir(args.svdir) 74 | 75 | config_task.isdropout1 = (args.dropout[0] == '1') 76 | config_task.isdropout2 = (args.dropout[1] == '1') 77 | 78 | ##################################### 79 | 80 | # Prepare data loaders 81 | train_loaders, val_loaders, num_classes = imdbfolder.prepare_data_loaders(args.dataset,args.datadir,args.imdbdir,True) 82 | args.num_classes = num_classes 83 | 84 | # Load checkpoint and initialize the networks with the weights of a pretrained network 85 | print('==> Resuming from checkpoint..') 86 | checkpoint = torch.load(args.source) 87 | net_old = checkpoint['net'] 88 | net = models.resnet26(num_classes) 89 | store_data = [] 90 | for name, m in net_old.named_modules(): 91 | if isinstance(m, nn.Conv2d) and (m.kernel_size[0]==3): 92 | store_data.append(m.weight.data) 93 | 94 | element = 0 95 | for name, m in net.named_modules(): 96 | if isinstance(m, nn.Conv2d) and (m.kernel_size[0]==3): 97 | m.weight.data = store_data[element] 98 | element += 1 99 | 100 | store_data = [] 101 | store_data_bias = [] 102 | store_data_rm = [] 103 | store_data_rv = [] 104 | names = [] 105 | 106 | for name, m in net_old.named_modules(): 107 | if isinstance(m, nn.BatchNorm2d) and 'bns.' in name: 108 | names.append(name) 109 | store_data.append(m.weight.data) 110 | store_data_bias.append(m.bias.data) 111 | store_data_rm.append(m.running_mean) 112 | store_data_rv.append(m.running_var) 113 | 114 | # Special case to copy the weight for the BN layers when the target and source networks have not the same number of BNs 115 | import re 116 | condition_bn = 'noproblem' 117 | if len(names) != 51 and args.mode == 'series_adapters': 118 | condition_bn ='bns.....conv' 119 | 120 | for id_task in range(len(num_classes)): 121 | element = 0 122 | for name, m in net.named_modules(): 123 | if isinstance(m, nn.BatchNorm2d) and 'bns.'+str(id_task) in name and not re.search(condition_bn,name): 124 | m.weight.data = store_data[element].clone() 125 | m.bias.data = store_data_bias[element].clone() 126 | m.running_var = store_data_rv[element].clone() 127 | m.running_mean = store_data_rm[element].clone() 128 | element += 1 129 | 130 | #net.linears[0].weight.data = net_old.linears[0].weight.data 131 | #net.linears[0].bias.data = net_old.linears[0].bias.data 132 | 133 | del net_old 134 | 135 | start_epoch = 0 136 | best_acc = 0 # best test accuracy 137 | results = np.zeros((4,start_epoch+args.nb_epochs,len(args.num_classes))) 138 | all_tasks = range(len(args.dataset)) 139 | np.random.seed(1993) 140 | 141 | if args.use_cuda: 142 | net.cuda() 143 | cudnn.benchmark = True 144 | 145 | 146 | # Freeze 3*3 convolution layers 147 | for name, m in net.named_modules(): 148 | if isinstance(m, nn.Conv2d) and (m.kernel_size[0]==3): 149 | m.weight.requires_grad = False 150 | 151 | 152 | args.criterion = nn.CrossEntropyLoss() 153 | optimizer = sgd.SGD(filter(lambda p: p.requires_grad, net.parameters()), lr=args.lr, momentum=0.9, weight_decay=args.wd) 154 | 155 | 156 | print("Start training") 157 | for epoch in range(start_epoch, start_epoch+args.nb_epochs): 158 | training_tasks = utils_pytorch.adjust_learning_rate_and_learning_taks(optimizer, epoch, args) 159 | st_time = time.time() 160 | 161 | # Training and validation 162 | train_acc, train_loss = utils_pytorch.train(epoch, train_loaders, training_tasks, net, args, optimizer) 163 | test_acc, test_loss, best_acc = utils_pytorch.test(epoch,val_loaders, all_tasks, net, best_acc, args, optimizer) 164 | 165 | # Record statistics 166 | for i in range(len(training_tasks)): 167 | current_task = training_tasks[i] 168 | results[0:2,epoch,current_task] = [train_loss[i],train_acc[i]] 169 | for i in all_tasks: 170 | results[2:4,epoch,i] = [test_loss[i],test_acc[i]] 171 | np.save(args.svdir+'/results_'+'adapt'+str(args.seed)+args.dropout+args.mode+args.proj+''.join(args.dataset)+'wd3x3_'+str(args.wd3x3)+'_wd1x1_'+str(args.wd1x1)+str(args.wd)+str(args.nb_epochs)+str(args.step1)+str(args.step2),results) 172 | print('Epoch lasted {0}'.format(time.time()-st_time)) 173 | 174 | -------------------------------------------------------------------------------- /train_new_task_finetuning.py: -------------------------------------------------------------------------------- 1 | # train_new_task_finetuning.py 2 | # created by Sylvestre-Alvise Rebuffi [srebuffi@robots.ox.ac.uk] 3 | # Copyright © The University of Oxford, 2017-2020 4 | # This code is made available under the Apache v2.0 licence, see LICENSE.txt for details 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | import torch.nn.functional as F 10 | import torch.backends.cudnn as cudnn 11 | import torchvision 12 | import torchvision.transforms as transforms 13 | import models 14 | import os 15 | import time 16 | import argparse 17 | import numpy as np 18 | 19 | from torch.autograd import Variable 20 | 21 | import imdbfolder_coco as imdbfolder 22 | import config_task 23 | import utils_pytorch 24 | import sgd 25 | 26 | parser = argparse.ArgumentParser(description='PyTorch Residual Adapters training') 27 | parser.add_argument('--dataset', default='cifar100', nargs='+', help='Task(s) to be trained') 28 | parser.add_argument('--lr', default=0.1, type=float, help='initial learning rate') 29 | parser.add_argument('--wd', default=1., type=float, help='weight decay for the classification layer') 30 | parser.add_argument('--wd3x3', default=1., type=float, nargs='+', help='weight decay for the 3x3') 31 | parser.add_argument('--wd1x1', default=1., type=float, nargs='+', help='weight decay for the 1x1') 32 | parser.add_argument('--nb_epochs', default=120, type=int, help='nb epochs') 33 | parser.add_argument('--step1', default=80, type=int, help='nb epochs before first lr decrease') 34 | parser.add_argument('--step2', default=100, type=int, help='nb epochs before second lr decrease') 35 | parser.add_argument('--mode', default='parallel_adapters', type=str, help='Task adaptation mode') 36 | parser.add_argument('--proj', default='11', type=str, help='Position of the adaptation module') 37 | parser.add_argument('--dropout', default='00', type=str, help='Position of dropouts') 38 | parser.add_argument('--expdir', default='/scratch/shared/nfs1/srebuffi/exp/dem_learning/tmp/', help='Save folder') 39 | parser.add_argument('--datadir', default='/scratch/local/ramdisk/srebuffi/decathlon/', help='folder containing data folder') 40 | parser.add_argument('--imdbdir', default='/scratch/local/ramdisk/srebuffi/decathlon/annotations/', help='annotation folder') 41 | parser.add_argument('--source', default='/scratch/shared/nfs1/srebuffi/exp/dem_learning/C100_alone/checkpoint/ckptpost11bnresidual11cifar1000.000180607060.t7', type=str, help='Network source') 42 | parser.add_argument('--seed', default=0, type=int, help='seed') 43 | parser.add_argument('--factor', default='1.', type=float, help='Width factor of the network') 44 | args = parser.parse_args() 45 | args.archi ='default' 46 | config_task.mode = args.mode 47 | config_task.proj = args.proj 48 | config_task.factor = args.factor 49 | args.use_cuda = torch.cuda.is_available() 50 | if type(args.dataset) is str: 51 | args.dataset = [args.dataset] 52 | 53 | if type(args.wd3x3) is float: 54 | args.wd3x3 = [args.wd3x3] 55 | 56 | if type(args.wd1x1) is float: 57 | args.wd1x1 = [args.wd1x1] 58 | 59 | if not os.path.isdir(args.expdir): 60 | os.mkdir(args.expdir) 61 | 62 | config_task.decay3x3 = np.array(args.wd3x3) * 0.0001 63 | config_task.decay1x1 = np.array(args.wd1x1) * 0.0001 64 | args.wd = args.wd * 0.0001 65 | 66 | args.ckpdir = args.expdir + '/checkpoint/' 67 | args.svdir = args.expdir + '/results/' 68 | 69 | if not os.path.isdir(args.ckpdir): 70 | os.mkdir(args.ckpdir) 71 | 72 | if not os.path.isdir(args.svdir): 73 | os.mkdir(args.svdir) 74 | 75 | config_task.isdropout1 = (args.dropout[0] == '1') 76 | config_task.isdropout2 = (args.dropout[1] == '1') 77 | 78 | ##################################### 79 | 80 | # Prepare data loaders 81 | train_loaders, val_loaders, num_classes = imdbfolder.prepare_data_loaders(args.dataset,args.datadir,args.imdbdir,True) 82 | args.num_classes = num_classes 83 | 84 | # Load checkpoint and initialize the networks with the weights of a pretrained network 85 | print('==> Resuming from checkpoint..') 86 | checkpoint = torch.load(args.source) 87 | net_old = checkpoint['net'] 88 | net = models.resnet26(num_classes) 89 | store_data = [] 90 | for name, m in net_old.named_modules(): 91 | if isinstance(m, nn.Conv2d) and (m.kernel_size[0]==3): 92 | store_data.append(m.weight.data) 93 | 94 | element = 0 95 | for name, m in net.named_modules(): 96 | if isinstance(m, nn.Conv2d) and (m.kernel_size[0]==3): 97 | m.weight.data = store_data[element] 98 | element += 1 99 | 100 | store_data = [] 101 | store_data_bias = [] 102 | store_data_rm = [] 103 | store_data_rv = [] 104 | for name, m in net_old.named_modules(): 105 | if isinstance(m, nn.BatchNorm2d) and 'bns.0' in name: 106 | store_data.append(m.weight.data) 107 | store_data_bias.append(m.bias.data) 108 | store_data_rm.append(m.running_mean) 109 | store_data_rv.append(m.running_var) 110 | 111 | for id_task in range(len(num_classes)): 112 | element = 0 113 | for name, m in net.named_modules(): 114 | if isinstance(m, nn.BatchNorm2d) and 'bns.'+str(id_task) in name: 115 | m.weight.data = store_data[element].clone() 116 | m.bias.data = store_data_bias[element].clone() 117 | m.running_var = store_data_rv[element].clone() 118 | m.running_mean = store_data_rm[element].clone() 119 | element += 1 120 | 121 | #net.linears[0].weight.data = net_old.linears[0].weight.data 122 | #net.linears[0].bias.data = net_old.linears[0].bias.data 123 | 124 | del net_old 125 | 126 | start_epoch = 0 127 | best_acc = 0 # best test accuracy 128 | results = np.zeros((4,start_epoch+args.nb_epochs,len(args.num_classes))) 129 | all_tasks = range(len(args.dataset)) 130 | np.random.seed(1993) 131 | 132 | if args.use_cuda: 133 | net.cuda() 134 | cudnn.benchmark = True 135 | 136 | 137 | 138 | args.criterion = nn.CrossEntropyLoss() 139 | optimizer = sgd.SGD(filter(lambda p: p.requires_grad, net.parameters()), lr=args.lr, momentum=0.9, weight_decay=args.wd) 140 | 141 | 142 | print("Start training") 143 | for epoch in range(start_epoch, start_epoch+args.nb_epochs): 144 | training_tasks = utils_pytorch.adjust_learning_rate_and_learning_taks(optimizer, epoch, args) 145 | st_time = time.time() 146 | 147 | # Training and validation 148 | train_acc, train_loss = utils_pytorch.train(epoch, train_loaders, training_tasks, net, args, optimizer) 149 | test_acc, test_loss, best_acc = utils_pytorch.test(epoch,val_loaders, all_tasks, net, best_acc, args, optimizer) 150 | 151 | # Record statistics 152 | for i in range(len(training_tasks)): 153 | current_task = training_tasks[i] 154 | results[0:2,epoch,current_task] = [train_loss[i],train_acc[i]] 155 | for i in all_tasks: 156 | results[2:4,epoch,i] = [test_loss[i],test_acc[i]] 157 | np.save(args.svdir+'/results_'+'adapt'+str(args.seed)+args.dropout+args.mode+args.proj+''.join(args.dataset)+'wd3x3_'+str(args.wd3x3)+'_wd1x1_'+str(args.wd1x1)+str(args.wd)+str(args.nb_epochs)+str(args.step1)+str(args.step2),results) 158 | print('Epoch lasted {0}'.format(time.time()-st_time)) 159 | 160 | -------------------------------------------------------------------------------- /train_new_task_from_scratch.py: -------------------------------------------------------------------------------- 1 | # train_new_task_from_scratch.py 2 | # created by Sylvestre-Alvise Rebuffi [srebuffi@robots.ox.ac.uk] 3 | # Copyright © The University of Oxford, 2017-2020 4 | # This code is made available under the Apache v2.0 licence, see LICENSE.txt for details 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | import torch.nn.functional as F 10 | import torch.backends.cudnn as cudnn 11 | import torchvision 12 | import torchvision.transforms as transforms 13 | import models 14 | import os 15 | import time 16 | import argparse 17 | import numpy as np 18 | 19 | from torch.autograd import Variable 20 | 21 | import imdbfolder_coco as imdbfolder 22 | import config_task 23 | import utils_pytorch 24 | import sgd 25 | 26 | parser = argparse.ArgumentParser(description='PyTorch Residual Adapters training') 27 | parser.add_argument('--dataset', default='cifar100', nargs='+', help='Task(s) to be trained') 28 | parser.add_argument('--lr', default=0.1, type=float, help='initial learning rate') 29 | parser.add_argument('--wd', default=1., type=float, help='weight decay for the classification layer') 30 | parser.add_argument('--wd3x3', default=1., type=float, nargs='+', help='weight decay for the 3x3') 31 | parser.add_argument('--wd1x1', default=1., type=float, nargs='+', help='weight decay for the 1x1') 32 | parser.add_argument('--nb_epochs', default=120, type=int, help='nb epochs') 33 | parser.add_argument('--step1', default=80, type=int, help='nb epochs before first lr decrease') 34 | parser.add_argument('--step2', default=100, type=int, help='nb epochs before second lr decrease') 35 | parser.add_argument('--mode', default='parallel_adapters', type=str, help='Task adaptation mode') 36 | parser.add_argument('--proj', default='11', type=str, help='Position of the adaptation module') 37 | parser.add_argument('--dropout', default='00', type=str, help='Position of dropouts') 38 | parser.add_argument('--expdir', default='/scratch/shared/nfs1/srebuffi/exp/dem_learning/tmp/', help='Save folder') 39 | parser.add_argument('--datadir', default='/scratch/local/ramdisk/srebuffi/decathlon/', help='folder containing data folder') 40 | parser.add_argument('--imdbdir', default='/scratch/local/ramdisk/srebuffi/decathlon/annotations/', help='annotation folder') 41 | parser.add_argument('--source', default='/scratch/shared/nfs1/srebuffi/exp/dem_learning/C100_alone/checkpoint/ckptpost11bnresidual11cifar1000.000180607060.t7', type=str, help='Network source') 42 | parser.add_argument('--seed', default=0, type=int, help='seed') 43 | parser.add_argument('--factor', default='1.', type=float, help='Width factor of the network') 44 | args = parser.parse_args() 45 | args.archi ='default' 46 | config_task.mode = args.mode 47 | config_task.proj = args.proj 48 | config_task.factor = args.factor 49 | args.use_cuda = torch.cuda.is_available() 50 | if type(args.dataset) is str: 51 | args.dataset = [args.dataset] 52 | 53 | if type(args.wd3x3) is float: 54 | args.wd3x3 = [args.wd3x3] 55 | 56 | if type(args.wd1x1) is float: 57 | args.wd1x1 = [args.wd1x1] 58 | 59 | if not os.path.isdir(args.expdir): 60 | os.mkdir(args.expdir) 61 | 62 | config_task.decay3x3 = np.array(args.wd3x3) * 0.0001 63 | config_task.decay1x1 = np.array(args.wd1x1) * 0.0001 64 | args.wd = args.wd * 0.0001 65 | 66 | args.ckpdir = args.expdir + '/checkpoint/' 67 | args.svdir = args.expdir + '/results/' 68 | 69 | if not os.path.isdir(args.ckpdir): 70 | os.mkdir(args.ckpdir) 71 | 72 | if not os.path.isdir(args.svdir): 73 | os.mkdir(args.svdir) 74 | 75 | config_task.isdropout1 = (args.dropout[0] == '1') 76 | config_task.isdropout2 = (args.dropout[1] == '1') 77 | 78 | ##################################### 79 | 80 | # Prepare data loaders 81 | train_loaders, val_loaders, num_classes = imdbfolder.prepare_data_loaders(args.dataset,args.datadir,args.imdbdir,True) 82 | args.num_classes = num_classes 83 | 84 | # Create the network 85 | net = models.resnet26(num_classes) 86 | 87 | 88 | start_epoch = 0 89 | best_acc = 0 # best test accuracy 90 | results = np.zeros((4,start_epoch+args.nb_epochs,len(args.num_classes))) 91 | all_tasks = range(len(args.dataset)) 92 | np.random.seed(1993) 93 | 94 | if args.use_cuda: 95 | net.cuda() 96 | cudnn.benchmark = True 97 | 98 | 99 | args.criterion = nn.CrossEntropyLoss() 100 | optimizer = sgd.SGD(filter(lambda p: p.requires_grad, net.parameters()), lr=args.lr, momentum=0.9, weight_decay=args.wd) 101 | 102 | 103 | print("Start training") 104 | for epoch in range(start_epoch, start_epoch+args.nb_epochs): 105 | training_tasks = utils_pytorch.adjust_learning_rate_and_learning_taks(optimizer, epoch, args) 106 | st_time = time.time() 107 | 108 | # Training and validation 109 | train_acc, train_loss = utils_pytorch.train(epoch, train_loaders, training_tasks, net, args, optimizer) 110 | test_acc, test_loss, best_acc = utils_pytorch.test(epoch,val_loaders, all_tasks, net, best_acc, args, optimizer) 111 | 112 | # Record statistics 113 | for i in range(len(training_tasks)): 114 | current_task = training_tasks[i] 115 | results[0:2,epoch,current_task] = [train_loss[i],train_acc[i]] 116 | for i in all_tasks: 117 | results[2:4,epoch,i] = [test_loss[i],test_acc[i]] 118 | np.save(args.svdir+'/results_'+'adapt'+str(args.seed)+args.dropout+args.mode+args.proj+''.join(args.dataset)+'wd3x3_'+str(args.wd3x3)+'_wd1x1_'+str(args.wd1x1)+str(args.wd)+str(args.nb_epochs)+str(args.step1)+str(args.step2),results) 119 | print('Epoch lasted {0}'.format(time.time()-st_time)) 120 | 121 | -------------------------------------------------------------------------------- /utils_pytorch.py: -------------------------------------------------------------------------------- 1 | # imdbfolder_coco.py 2 | # created by Sylvestre-Alvise Rebuffi [srebuffi@robots.ox.ac.uk] 3 | # Copyright © The University of Oxford, 2017-2020 4 | # This code is made available under the Apache v2.0 licence, see LICENSE.txt for details 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | import torch.nn.functional as F 10 | import torch.backends.cudnn as cudnn 11 | import os 12 | import time 13 | import numpy as np 14 | from torch.autograd import Variable 15 | import config_task 16 | 17 | class AverageMeter(object): 18 | """Computes and stores the average and current value""" 19 | def __init__(self): 20 | self.reset() 21 | 22 | def reset(self): 23 | self.val = 0 24 | self.avg = 0 25 | self.sum = 0 26 | self.count = 0 27 | 28 | def update(self, val, n=1): 29 | self.val = val 30 | self.sum += val * n 31 | self.count += n 32 | self.avg = self.sum / self.count 33 | 34 | 35 | def adjust_learning_rate_and_learning_taks(optimizer, epoch, args): 36 | """Sets the learning rate to the initial LR decayed by 10 every X epochs""" 37 | if epoch >= args.step2: 38 | lr = args.lr * 0.01 39 | elif epoch >= args.step1: 40 | lr = args.lr * 0.1 41 | else: 42 | lr = args.lr 43 | 44 | for param_group in optimizer.param_groups: 45 | param_group['lr'] = lr 46 | 47 | # Return training classes 48 | return range(len(args.dataset)) 49 | 50 | 51 | # Training 52 | def train(epoch, tloaders, tasks, net, args, optimizer,list_criterion=None): 53 | print('\nEpoch: %d' % epoch) 54 | net.train() 55 | batch_time = AverageMeter() 56 | data_time = AverageMeter() 57 | losses = [AverageMeter() for i in tasks] 58 | top1 = [AverageMeter() for i in tasks] 59 | end = time.time() 60 | 61 | loaders = [tloaders[i] for i in tasks] 62 | min_len_loader = np.min([len(i) for i in loaders]) 63 | train_iter = [iter(i) for i in loaders] 64 | 65 | for batch_idx in range(min_len_loader*len(tasks)): 66 | config_task.first_batch = (batch_idx == 0) 67 | # Round robin process of the tasks 68 | current_task_index = batch_idx % len(tasks) 69 | inputs, targets = (train_iter[current_task_index]).next() 70 | config_task.task = tasks[current_task_index] 71 | # measure data loading time 72 | data_time.update(time.time() - end) 73 | if args.use_cuda: 74 | inputs, targets = inputs.cuda(async=True), targets.cuda(async=True) 75 | optimizer.zero_grad() 76 | inputs, targets = Variable(inputs), Variable(targets) 77 | outputs = net(inputs) 78 | loss = args.criterion(outputs, targets) 79 | # measure accuracy and record loss 80 | (losses[current_task_index]).update(loss.data[0], targets.size(0)) 81 | _, predicted = torch.max(outputs.data, 1) 82 | correct = predicted.eq(targets.data).cpu().sum() 83 | (top1[current_task_index]).update(correct*100./targets.size(0), targets.size(0)) 84 | # apply gradients 85 | loss.backward() 86 | optimizer.step() 87 | 88 | # measure elapsed time 89 | batch_time.update(time.time() - end) 90 | end = time.time() 91 | 92 | if batch_idx % 50 == 0: 93 | print('Epoch: [{0}][{1}/{2}]\t' 94 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 95 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'.format( 96 | epoch, batch_idx, min_len_loader*len(tasks), batch_time=batch_time, 97 | data_time=data_time)) 98 | for i in range(len(tasks)): 99 | print('Task {0} : Loss {loss.val:.4f} ({loss.avg:.4f})\t' 100 | 'Acc {top1.val:.3f} ({top1.avg:.3f})'.format(tasks[i], loss=losses[i], top1=top1[i])) 101 | 102 | return [top1[i].avg for i in range(len(tasks))], [losses[i].avg for i in range(len(tasks))] 103 | 104 | 105 | 106 | def test(epoch, loaders, all_tasks, net, best_acc, args, optimizer): 107 | net.eval() 108 | losses = [AverageMeter() for i in all_tasks] 109 | top1 = [AverageMeter() for i in all_tasks] 110 | print('Epoch: [{0}]'.format(epoch)) 111 | for itera in range(len(all_tasks)): 112 | i = all_tasks[itera] 113 | config_task.task = i 114 | for batch_idx, (inputs, targets) in enumerate(loaders[i]): 115 | if args.use_cuda: 116 | inputs, targets = inputs.cuda(), targets.cuda() 117 | inputs, targets = Variable(inputs, volatile=True), Variable(targets) 118 | outputs = net(inputs) 119 | if isinstance(outputs, tuple): 120 | outputs = outputs[0] 121 | loss = args.criterion(outputs, targets) 122 | 123 | losses[itera].update(loss.data[0], targets.size(0)) 124 | _, predicted = torch.max(outputs.data, 1) 125 | correct = predicted.eq(targets.data).cpu().sum() 126 | top1[itera].update(correct*100./targets.size(0), targets.size(0)) 127 | 128 | print('Task {0} : Test Loss {loss.val:.4f} ({loss.avg:.4f})\t' 129 | 'Test Acc {top1.val:.3f} ({top1.avg:.3f})'.format(i, loss=losses[itera], top1=top1[itera])) 130 | 131 | # Save checkpoint. 132 | acc = np.sum([top1[i].avg for i in range(len(all_tasks))]) 133 | if acc > best_acc: 134 | print('Saving..') 135 | state = { 136 | 'net': net, 137 | 'acc': acc, 138 | 'epoch': epoch, 139 | } 140 | torch.save(state, args.ckpdir+'/ckpt'+config_task.mode+args.archi+args.proj+''.join(args.dataset)+'.t7') 141 | best_acc = acc 142 | 143 | return [top1[i].avg for i in range(len(all_tasks))], [losses[i].avg for i in range(len(all_tasks))], best_acc 144 | 145 | --------------------------------------------------------------------------------