├── .gitignore ├── LICENSE ├── README.md ├── auto_lambda.py ├── create_dataset.py ├── create_network.py ├── dataset └── cityscapes_preprocess.py ├── trainer_cifar.py ├── trainer_cifar_single.py ├── trainer_dense.py ├── trainer_dense_single.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/* 2 | .DS_Store 3 | __pycache__/* 4 | panoptic_parts/ 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | AUTO-LAMBDA SOFTWARE 2 | 3 | LICENCE AGREEMENT 4 | 5 | WE (Imperial College of Science, Technology and Medicine, (“Imperial College 6 | London”)) ARE WILLING TO LICENSE THIS SOFTWARE TO YOU (a licensee “You”) ONLY 7 | ON THE CONDITION THAT YOU ACCEPT ALL OF THE TERMS CONTAINED IN THE FOLLOWING 8 | AGREEMENT. PLEASE READ THE AGREEMENT CAREFULLY BEFORE DOWNLOADING THE SOFTWARE. 9 | BY EXERCISING THE OPTION TO DOWNLOAD THE SOFTWARE YOU AGREE TO BE BOUND BY THE 10 | TERMS OF THE AGREEMENT. 11 | 12 | SOFTWARE LICENCE AGREEMENT (EXCLUDING BSD COMPONENTS) 13 | 14 | 1.This Agreement pertains to a worldwide, non-exclusive, temporary, fully 15 | paid-up, royalty free, non-transferable, non-sub- licensable licence (the 16 | “Licence”) to use the elastic fusion source code, including any modification, 17 | part or derivative (the “Software”). 18 | 19 | Ownership and Licence. Your rights to use and download the Software onto your 20 | computer, and all other copies that You are authorised to make, are specified 21 | in this Agreement. However, we (or our licensors) retain all rights, including 22 | but not limited to all copyright and other intellectual property rights 23 | anywhere in the world, in the Software not expressly granted to You in this 24 | Agreement. 25 | 26 | 2. Permitted use of the Licence: 27 | 28 | (a) You may download and install the Software onto one computer or server for 29 | use in accordance with Clause 2(b) of this Agreement provided that You ensure 30 | that the Software is not accessible by other users unless they have themselves 31 | accepted the terms of this licence agreement. 32 | 33 | (b) You may use the Software solely for non-commercial, internal or academic 34 | research purposes and only in accordance with the terms of this Agreement. You 35 | may not use the Software for commercial purposes, including but not limited to 36 | (1) integration of all or part of the source code or the Software into a 37 | product for sale or licence by or on behalf of You to third parties or (2) use 38 | of the Software or any derivative of it for research to develop software 39 | products for sale or licence to a third party or (3) use of the Software or any 40 | derivative of it for research to develop non-software products for sale or 41 | licence to a third party, or (4) use of the Software to provide any service to 42 | an external organisation for which payment is received. 43 | 44 | Should You wish to use the Software for commercial purposes, You shall 45 | email researchcontracts.engineering@imperial.ac.uk . 46 | 47 | (c) Right to Copy. You may copy the Software for back-up and archival purposes, 48 | provided that each copy is kept in your possession and provided You reproduce 49 | our copyright notice (set out in Schedule 1) on each copy. 50 | 51 | (d) Transfer and sub-licensing. You may not rent, lend, or lease the Software 52 | and You may not transmit, transfer or sub-license this licence to use the 53 | Software or any of your rights or obligations under this Agreement to another 54 | party. 55 | 56 | (e) Identity of Licensee. The licence granted herein is personal to You. You 57 | shall not permit any third party to access, modify or otherwise use the 58 | Software nor shall You access modify or otherwise use the Software on behalf of 59 | any third party. If You wish to obtain a licence for mutiple users or a site 60 | licence for the Software please contact us 61 | at researchcontracts.engineering@imperial.ac.uk . 62 | 63 | (f) Publications and presentations. You may make public, results or data 64 | obtained from, dependent on or arising from research carried out using the 65 | Software, provided that any such presentation or publication identifies the 66 | Software as the source of the results or the data, including the Copyright 67 | Notice given in each element of the Software, and stating that the Software has 68 | been made available for use by You under licence from Imperial College London 69 | and You provide a copy of any such publication to Imperial College London. 70 | 71 | 3. Prohibited Uses. You may not, without written permission from us 72 | at researchcontracts.engineering@imperial.ac.uk : 73 | 74 | (a) Use, copy, modify, merge, or transfer copies of the Software or any 75 | documentation provided by us which relates to the Software except as provided 76 | in this Agreement; 77 | 78 | (b) Use any back-up or archival copies of the Software (or allow anyone else to 79 | use such copies) for any purpose other than to replace the original copy in the 80 | event it is destroyed or becomes defective; or 81 | 82 | (c) Disassemble, decompile or "unlock", reverse translate, or in any manner 83 | decode the Software for any reason. 84 | 85 | 4. Warranty Disclaimer 86 | 87 | (a) Disclaimer. The Software has been developed for research purposes only. You 88 | acknowledge that we are providing the Software to You under this licence 89 | agreement free of charge and on condition that the disclaimer set out below 90 | shall apply. We do not represent or warrant that the Software as to: (i) the 91 | quality, accuracy or reliability of the Software; (ii) the suitability of the 92 | Software for any particular use or for use under any specific conditions; and 93 | (iii) whether use of the Software will infringe third-party rights. 94 | 95 | You acknowledge that You have reviewed and evaluated the Software to determine 96 | that it meets your needs and that You assume all responsibility and liability 97 | for determining the suitability of the Software as fit for your particular 98 | purposes and requirements. Subject to Clause 4(b), we exclude and expressly 99 | disclaim all express and implied representations, warranties, conditions and 100 | terms not stated herein (including the implied conditions or warranties of 101 | satisfactory quality, merchantable quality, merchantability and fitness for 102 | purpose). 103 | 104 | (b) Savings. Some jurisdictions may imply warranties, conditions or terms or 105 | impose obligations upon us which cannot, in whole or in part, be excluded, 106 | restricted or modified or otherwise do not allow the exclusion of implied 107 | warranties, conditions or terms, in which case the above warranty disclaimer 108 | and exclusion will only apply to You to the extent permitted in the relevant 109 | jurisdiction and does not in any event exclude any implied warranties, 110 | conditions or terms which may not under applicable law be excluded. 111 | 112 | (c) Imperial College London disclaims all responsibility for the use which is 113 | made of the Software and any liability for the outcomes arising from using the 114 | Software. 115 | 116 | 5. Limitation of Liability 117 | 118 | (a) You acknowledge that we are providing the Software to You under this 119 | licence agreement free of charge and on condition that the limitation of 120 | liability set out below shall apply. Accordingly, subject to Clause 5(b), we 121 | exclude all liability whether in contract, tort, negligence or otherwise, in 122 | respect of the Software and/or any related documentation provided to You by us 123 | including, but not limited to, liability for loss or corruption of data, loss 124 | of contracts, loss of income, loss of profits, loss of cover and any 125 | consequential or indirect loss or damage of any kind arising out of or in 126 | connection with this licence agreement, however caused. This exclusion shall 127 | apply even if we have been advised of the possibility of such loss or damage. 128 | 129 | (b) You agree to indemnify Imperial College London and hold it harmless from 130 | and against any and all claims, damages and liabilities asserted by third 131 | parties (including claims for negligence) which arise directly or indirectly 132 | from the use of the Software or any derivative of it or the sale of any 133 | products based on the Software. You undertake to make no liability claim 134 | against any employee, student, agent or appointee of Imperial College London, 135 | in connection with this Licence or the Software. 136 | 137 | (c) Nothing in this Agreement shall have the effect of excluding or limiting 138 | our statutory liability. 139 | 140 | (d) Some jurisdictions do not allow these limitations or exclusions either 141 | wholly or in part, and, to that extent, they may not apply to you. Nothing in 142 | this licence agreement will affect your statutory rights or other relevant 143 | statutory provisions which cannot be excluded, restricted or modified, and its 144 | terms and conditions must be read and construed subject to any such statutory 145 | rights and/or provisions. 146 | 147 | 6. Confidentiality. You agree not to disclose any confidential information 148 | provided to You by us pursuant to this Agreement to any third party without our 149 | prior written consent. The obligations in this Clause 6 shall survive the 150 | termination of this Agreement for any reason. 151 | 152 | 7. Termination. 153 | 154 | (a) We may terminate this licence agreement and your right to use the Software 155 | at any time with immediate effect upon written notice to You. 156 | 157 | (b) This licence agreement and your right to use the Software automatically 158 | terminate if You: 159 | 160 | (i) fail to comply with any provisions of this Agreement; or 161 | 162 | (ii) destroy the copies of the Software in your possession, or voluntarily 163 | return the Software to us. 164 | 165 | (c) Upon termination You will destroy all copies of the Software. 166 | 167 | (d) Otherwise, the restrictions on your rights to use the Software will expire 168 | 10 (ten) years after first use of the Software under this licence agreement. 169 | 170 | 8. Miscellaneous Provisions. 171 | 172 | (a) This Agreement will be governed by and construed in accordance with the 173 | substantive laws of England and Wales whose courts shall have exclusive 174 | jurisdiction over all disputes which may arise between us. 175 | 176 | (b) This is the entire agreement between us relating to the Software, and 177 | supersedes any prior purchase order, communications, advertising or 178 | representations concerning the Software. 179 | 180 | (c) No change or modification of this Agreement will be valid unless it is in 181 | writing, and is signed by us. 182 | 183 | (d) The unenforceability or invalidity of any part of this Agreement will not 184 | affect the enforceability or validity of the remaining parts. 185 | 186 | BSD Elements of the Software 187 | 188 | For BSD elements of the Software, the following terms shall apply: 189 | Copyright as indicated in the header of the individual element of the Software. 190 | All rights reserved. 191 | 192 | Redistribution and use in source and binary forms, with or without 193 | modification, are permitted provided that the following conditions are met: 194 | 195 | 1. Redistributions of source code must retain the above copyright notice, this 196 | list of conditions and the following disclaimer. 197 | 198 | 2. Redistributions in binary form must reproduce the above copyright notice, 199 | this list of conditions and the following disclaimer in the documentation 200 | and/or other materials provided with the distribution. 201 | 202 | 3. Neither the name of the copyright holder nor the names of its contributors 203 | may be used to endorse or promote products derived from this software without 204 | specific prior written permission. 205 | 206 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 207 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 208 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 209 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 210 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 211 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 212 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 213 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 214 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 215 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 216 | 217 | SCHEDULE 1 218 | 219 | The Software 220 | 221 | Auto-lambda is a multi-task and auxiliary learning optimisation framework, which automatically and dynamically 222 | update task weightings during training based on the choice of the primary tasks. 223 | 224 | • Shikun Liu, Stephen James, Andrew J. Davison, Edward Johns. 225 | Auto-Lambda: Disentangling Dynamic Task Relationships. Transactions on Machine Learning Research, 2022 226 | _________________________ 227 | 228 | Acknowledgments 229 | 230 | If you use the software, you should reference the following paper in any 231 | publication: 232 | 233 | • Shikun Liu, Stephen James, Andrew J. Davison, Edward Johns. 234 | Auto-Lambda: Disentangling Dynamic Task Relationships. Transactions on Machine Learning Research, 2022 -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Auto-Lambda 2 | This repository contains the source code of Auto-Lambda and baselines from the paper, [Auto-Lambda: Disentangling Dynamic Task Relationships](https://arxiv.org/abs/2202.03091). 3 | 4 | We encourage readers to check out our [project page](https://shikun.io/projects/auto-lambda), including more interesting discussions and insights which are not covered in our technical paper. 5 | 6 | ## Multi-task Methods 7 | We implemented all weighting and gradient-based baselines presented in the paper for computer vision tasks: Dense Prediction Tasks (for NYUv2 and CityScapes) and Multi-domain Classification Tasks (for CIFAR-100). 8 | 9 | Specifically, we have covered the implementation of these following multi-task optimisation methods: 10 | 11 | ### Weighting-based: 12 | - **Equal** - All task weightings are 1. `--weight equal` 13 | - **Uncertainty** - [https://arxiv.org/abs/1705.07115](https://arxiv.org/abs/1705.07115) `--weight uncert` 14 | - **Dynamic Weight Average** - [https://arxiv.org/abs/1803.10704](https://arxiv.org/abs/1803.10704) `--weight dwa` 15 | - **Auto-Lambda** - Our approach. `--weight autol` 16 | 17 | ### Gradient-based: 18 | - **GradDrop** - [https://arxiv.org/abs/2010.06808](https://arxiv.org/abs/2010.06808) `--grad_method graddrop` 19 | - **PCGrad** - [https://arxiv.org/abs/2001.06782](https://arxiv.org/abs/2001.06782) `--grad_method pcgrad` 20 | - **CAGrad** - [https://arxiv.org/abs/2110.14048](https://arxiv.org/abs/2110.14048) `--grad_method cagrad` 21 | 22 | *Note: Applying a combination of both weighting and gradient-based methods can further improve performance.* 23 | 24 | ## Datasets 25 | We applied the same data pre-processing following our previous project: [MTAN](https://github.com/lorenmt/mtan) which experimented on: 26 | 27 | - [**NYUv2 [3 Tasks]**](https://www.dropbox.com/sh/86nssgwm6hm3vkb/AACrnUQ4GxpdrBbLjb6n-mWNa?dl=0) - 13 Class Segmentation + Depth Estimation + Surface Normal. [288 x 384] Resolution. 28 | - [**CityScapes [3 Tasks]**](https://www.dropbox.com/sh/qk3cr18d55d08gj/AAA5OCTPNFDEDk5fZsmCfmrAa?dl=0) - 19 Class Segmentation + 10 Class Part Segmentation + Disparity (Inverse Depth) Estimation. [256 x 512] Resolution. 29 | 30 | *Note: We have included a new task: [Part Segmentation](https://github.com/pmeletis/panoptic_parts) for CityScapes dataset. Please install the `pip install panoptic_parts` for CityScapes experiments. The pre-processing file for CityScapes has also been included in the `dataset` folder.* 31 | 32 | 33 | ## Experiments 34 | All experiments were written in `PyTorch 1.7` and can be trained with different flags (hyper-parameters) when running each training script. We briefly introduce some important flags below. 35 | 36 | | Flag Name | Usage | Comments | 37 | |---------------|------------------------------------------------------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------| 38 | | `network` | choose multi-task network: `split, mtan` | both architectures are based on ResNet-50; only available in dense prediction tasks | 39 | | `dataset` | choose dataset: `nyuv2, cityscapes` | only available in dense prediction tasks | 40 | | `weight` | choose weighting-based method: `equal, uncert, dwa, autol` | only `autol` will behave differently when set to different primary tasks | 41 | | `grad_method` | choose gradient-based method: `graddrop, pcgrad, cagrad` | `weight` and `grad_method` can be applied together | 42 | | `task` | choose primary tasks: `seg, depth, normal` for NYUv2, `seg, part_seg, disp` for CityScapes, `all`: a combination of all standard 3 tasks | only available in dense prediction tasks | 43 | | `with_noise` | toggle on to add noise prediction task for training (to evaluate robustness in auxiliary learning setting) | only available in dense prediction tasks | 44 | | `subset_id` | choose domain ID for CIFAR-100, choose `-1` for the multi-task learning setting | only available in CIFAR-100 tasks | 45 | | `autol_init` | initialisation of Auto-Lambda, default `0.1` | only available when applying Auto-Lambda | 46 | | `autol_lr` | learning rate of Auto-Lambda, default `1e-4` for NYUv2 and `3e-5` for CityScapes | only available when applying Auto-Lambda | 47 | 48 | Training Auto-Lambda in Multi-task / Auxiliary Learning Mode: 49 | ``` 50 | python trainer_dense.py --dataset [nyuv2, cityscapes] --task [PRIMARY_TASK] --weight autol --gpu 0 # for NYUv2 or CityScapes dataset 51 | python trainer_cifar.py --subset_id [PRIMARY_DOMAIN_ID] --weight autol --gpu 0 # for CIFAR-100 dataset 52 | ``` 53 | 54 | Training in Single-task Learning Mode: 55 | ``` 56 | python trainer_dense_single.py --dataset [nyuv2, cityscapes] --task [PRIMARY_TASK] --gpu 0 # for NYUv2 or CityScapes dataset 57 | python trainer_cifar_single.py --subset_id [PRIMARY_DOMAIN_ID] --gpu 0 # for CIFAR-100 dataset 58 | ``` 59 | 60 | *Note: All experiments in the original paper were trained from scratch without pre-training.* 61 | 62 | ## Benchmark 63 | For standard 3 tasks in NYUv2 (without noise prediction task) in the multi-task learning setting with Split architecture, please follow the results below. 64 | 65 | | Method | Type | Sem. Seg. (mIOU) | Depth (aErr.) | Normal (mDist.) | Delta MTL | 66 | |----------------------|-------|-----------|---------------|-----------------|-----------| 67 | |Single | - | 43.37 | 52.24 |22.40| - | 68 | | Equal | W |44.64 | 43.32 | 24.48 | +3.57% | 69 | | DWA | W |45.14 | 43.06 | 24.17 | +4.58% | 70 | | GradDrop | G |45.39 | 43.23 | 24.18 | +4.65% | 71 | | PCGrad | G | 45.15 | 42.38 | 24.13 | +5.09% | 72 | | Uncertainty | W | 45.98 | 41.26 | 24.09 | +6.50% | 73 | | CAGrad | G |46.14 | 41.91 | 23.52 | +7.05% | 74 | | Auto-Lambda | W | 47.17 | 40.97 | 23.68 | +8.21% | 75 | | Auto-Lambda + CAGrad | W + G| 48.26 | 39.82 | 22.81 | +11.07% | 76 | 77 | *Note: The results were averaged across three random seeds. You should expect the error range less than +/-1%.* 78 | 79 | ## Citation 80 | If you found this code/work to be useful in your own research, please considering citing the following: 81 | 82 | ``` 83 | @article{liu2022auto_lambda, 84 | title={Auto-Lambda: Disentangling Dynamic Task Relationships}, 85 | author={Liu, Shikun and James, Stephen and Davison, Andrew J and Johns, Edward}, 86 | journal={Transactions on Machine Learning Research}, 87 | year={2022} 88 | } 89 | ``` 90 | 91 | ## Acknowledgement 92 | We would like to thank [@Cranial-XIX](https://github.com/Cranial-XIX) for his clean implementation for gradient-based optimisation methods. 93 | 94 | ## Contact 95 | If you have any questions, please contact `sk.lorenmt@gmail.com`. 96 | -------------------------------------------------------------------------------- /auto_lambda.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from utils import * 3 | 4 | 5 | class AutoLambda: 6 | def __init__(self, model, device, train_tasks, pri_tasks, weight_init=0.1): 7 | self.model = model 8 | self.model_ = copy.deepcopy(model) 9 | self.meta_weights = torch.tensor([weight_init] * len(train_tasks), requires_grad=True, device=device) 10 | self.train_tasks = train_tasks 11 | self.pri_tasks = pri_tasks 12 | 13 | def virtual_step(self, train_x, train_y, alpha, model_optim): 14 | """ 15 | Compute unrolled network theta' (virtual step) 16 | """ 17 | 18 | # forward & compute loss 19 | if type(train_x) == list: # multi-domain setting [many-to-many] 20 | train_pred = [self.model(x, t) for t, x in enumerate(train_x)] 21 | else: # single-domain setting [one-to-many] 22 | train_pred = self.model(train_x) 23 | 24 | train_loss = self.model_fit(train_pred, train_y) 25 | 26 | loss = sum([w * train_loss[i] for i, w in enumerate(self.meta_weights)]) 27 | 28 | # compute gradient 29 | gradients = torch.autograd.grad(loss, self.model.parameters()) 30 | 31 | # do virtual step (update gradient): theta' = theta - alpha * sum_i lambda_i * L_i(f_theta(x_i), y_i) 32 | with torch.no_grad(): 33 | for weight, weight_, grad in zip(self.model.parameters(), self.model_.parameters(), gradients): 34 | if 'momentum' in model_optim.param_groups[0].keys(): # used in SGD with momentum 35 | m = model_optim.state[weight].get('momentum_buffer', 0.) * model_optim.param_groups[0]['momentum'] 36 | else: 37 | m = 0 38 | weight_.copy_(weight - alpha * (m + grad + model_optim.param_groups[0]['weight_decay'] * weight)) 39 | 40 | def unrolled_backward(self, train_x, train_y, val_x, val_y, alpha, model_optim): 41 | """ 42 | Compute un-rolled loss and backward its gradients 43 | """ 44 | 45 | # do virtual step (calc theta`) 46 | self.virtual_step(train_x, train_y, alpha, model_optim) 47 | 48 | # define weighting for primary tasks (with binary weights) 49 | pri_weights = [] 50 | for t in self.train_tasks: 51 | if t in self.pri_tasks: 52 | pri_weights += [1.0] 53 | else: 54 | pri_weights += [0.0] 55 | 56 | # compute validation data loss on primary tasks 57 | if type(val_x) == list: 58 | val_pred = [self.model_(x, t) for t, x in enumerate(val_x)] 59 | else: 60 | val_pred = self.model_(val_x) 61 | val_loss = self.model_fit(val_pred, val_y) 62 | loss = sum([w * val_loss[i] for i, w in enumerate(pri_weights)]) 63 | 64 | # compute hessian via finite difference approximation 65 | model_weights_ = tuple(self.model_.parameters()) 66 | d_model = torch.autograd.grad(loss, model_weights_, allow_unused=True) 67 | hessian = self.compute_hessian(d_model, train_x, train_y) 68 | 69 | # update final gradient = - alpha * hessian 70 | with torch.no_grad(): 71 | for mw, h in zip([self.meta_weights], hessian): 72 | mw.grad = - alpha * h 73 | 74 | def compute_hessian(self, d_model, train_x, train_y): 75 | norm = torch.cat([w.view(-1) for w in d_model]).norm() 76 | eps = 0.01 / norm 77 | 78 | # \theta+ = \theta + eps * d_model 79 | with torch.no_grad(): 80 | for p, d in zip(self.model.parameters(), d_model): 81 | p += eps * d 82 | 83 | if type(train_x) == list: 84 | train_pred = [self.model(x, t) for t, x in enumerate(train_x)] 85 | else: 86 | train_pred = self.model(train_x) 87 | train_loss = self.model_fit(train_pred, train_y) 88 | loss = sum([w * train_loss[i] for i, w in enumerate(self.meta_weights)]) 89 | d_weight_p = torch.autograd.grad(loss, self.meta_weights) 90 | 91 | # \theta- = \theta - eps * d_model 92 | with torch.no_grad(): 93 | for p, d in zip(self.model.parameters(), d_model): 94 | p -= 2 * eps * d 95 | 96 | if type(train_x) == list: 97 | train_pred = [self.model(x, t) for t, x in enumerate(train_x)] 98 | else: 99 | train_pred = self.model(train_x) 100 | train_loss = self.model_fit(train_pred, train_y) 101 | loss = sum([w * train_loss[i] for i, w in enumerate(self.meta_weights)]) 102 | d_weight_n = torch.autograd.grad(loss, self.meta_weights) 103 | 104 | # recover theta 105 | with torch.no_grad(): 106 | for p, d in zip(self.model.parameters(), d_model): 107 | p += eps * d 108 | 109 | hessian = [(p - n) / (2. * eps) for p, n in zip(d_weight_p, d_weight_n)] 110 | return hessian 111 | 112 | def model_fit(self, pred, targets): 113 | """ 114 | define task specific losses 115 | """ 116 | loss = [compute_loss(pred[i], targets[task_id], task_id) for i, task_id in enumerate(self.train_tasks)] 117 | return loss 118 | -------------------------------------------------------------------------------- /create_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import random 4 | import torch 5 | import fnmatch 6 | 7 | import numpy as np 8 | import panoptic_parts as pp 9 | import torch.utils.data as data 10 | import matplotlib.pylab as plt 11 | import torchvision.transforms as transforms 12 | import torchvision.transforms.functional as transforms_f 13 | 14 | from PIL import Image 15 | from torchvision.datasets import CIFAR100 16 | 17 | 18 | class DataTransform(object): 19 | def __init__(self, scales, crop_size, is_disparity=False): 20 | self.scales = scales 21 | self.crop_size = crop_size 22 | self.is_disparity = is_disparity 23 | 24 | def __call__(self, data_dict): 25 | if type(self.scales) == tuple: 26 | # Continuous range of scales 27 | sc = np.random.uniform(*self.scales) 28 | 29 | elif type(self.scales) == list: 30 | # Fixed range of scales 31 | sc = random.sample(self.scales, 1)[0] 32 | 33 | raw_h, raw_w = data_dict['im'].shape[-2:] 34 | resized_size = [int(raw_h * sc), int(raw_w * sc)] 35 | i, j, h, w = 0, 0, 0, 0 # initialise cropping coordinates 36 | flip_prop = random.random() 37 | 38 | for task in data_dict: 39 | if len(data_dict[task].shape) == 2: # make sure single-channel labels are in the same size [H, W, 1] 40 | data_dict[task] = data_dict[task].unsqueeze(0) 41 | 42 | # Resize based on randomly sampled scale 43 | if task in ['im', 'noise']: 44 | data_dict[task] = transforms_f.resize(data_dict[task], resized_size, Image.BILINEAR) 45 | elif task in ['normal', 'depth', 'seg', 'part_seg', 'disp']: 46 | data_dict[task] = transforms_f.resize(data_dict[task], resized_size, Image.NEAREST) 47 | 48 | # Add padding if crop size is smaller than the resized size 49 | if self.crop_size[0] > resized_size[0] or self.crop_size[1] > resized_size[1]: 50 | right_pad, bottom_pad = max(self.crop_size[1] - resized_size[1], 0), max(self.crop_size[0] - resized_size[0], 0) 51 | if task in ['im']: 52 | data_dict[task] = transforms_f.pad(data_dict[task], padding=(0, 0, right_pad, bottom_pad), 53 | padding_mode='reflect') 54 | elif task in ['seg', 'part_seg', 'disp']: 55 | data_dict[task] = transforms_f.pad(data_dict[task], padding=(0, 0, right_pad, bottom_pad), 56 | fill=-1, padding_mode='constant') # -1 will be ignored in loss 57 | elif task in ['normal', 'depth', 'noise']: 58 | data_dict[task] = transforms_f.pad(data_dict[task], padding=(0, 0, right_pad, bottom_pad), 59 | fill=0, padding_mode='constant') # 0 will be ignored in loss 60 | 61 | # Random Cropping 62 | if i + j + h + w == 0: # only run once 63 | i, j, h, w = transforms.RandomCrop.get_params(data_dict[task], output_size=self.crop_size) 64 | data_dict[task] = transforms_f.crop(data_dict[task], i, j, h, w) 65 | 66 | # Random Flip 67 | if flip_prop > 0.5: 68 | data_dict[task] = torch.flip(data_dict[task], dims=[2]) 69 | if task == 'normal': 70 | data_dict[task][0, :, :] = - data_dict[task][0, :, :] 71 | 72 | # Final Check: 73 | if task == 'depth': 74 | data_dict[task] = data_dict[task] / sc 75 | 76 | if task == 'disp': # disparity is inverse depth 77 | data_dict[task] = data_dict[task] * sc 78 | 79 | if task in ['seg', 'part_seg']: 80 | data_dict[task] = data_dict[task].squeeze(0) 81 | return data_dict 82 | 83 | 84 | class NYUv2(data.Dataset): 85 | """ 86 | NYUv2 dataset, 3 tasks + 1 generated useless task 87 | Included tasks: 88 | 1. Semantic Segmentation, 89 | 2. Depth prediction, 90 | 3. Surface Normal prediction, 91 | 4. Noise prediction [to test auxiliary learning, purely conflict gradients] 92 | """ 93 | def __init__(self, root, train=True, augmentation=False): 94 | self.train = train 95 | self.root = os.path.expanduser(root) 96 | self.augmentation = augmentation 97 | 98 | # read the data file 99 | if train: 100 | self.data_path = root + '/train' 101 | else: 102 | self.data_path = root + '/val' 103 | 104 | # calculate data length 105 | self.data_len = len(fnmatch.filter(os.listdir(self.data_path + '/image'), '*.npy')) 106 | self.noise = torch.rand(self.data_len, 1, 288, 384) 107 | 108 | def __getitem__(self, index): 109 | # load data from the pre-processed npy files 110 | image = torch.from_numpy(np.moveaxis(np.load(self.data_path + '/image/{:d}.npy'.format(index)), -1, 0)).float() 111 | semantic = torch.from_numpy(np.load(self.data_path + '/label/{:d}.npy'.format(index))).long() 112 | depth = torch.from_numpy(np.moveaxis(np.load(self.data_path + '/depth/{:d}.npy'.format(index)), -1, 0)).float() 113 | normal = torch.from_numpy(np.moveaxis(np.load(self.data_path + '/normal/{:d}.npy'.format(index)), -1, 0)).float() 114 | noise = self.noise[index].float() 115 | 116 | data_dict = {'im': image, 'seg': semantic, 'depth': depth, 'normal': normal, 'noise': noise} 117 | 118 | # apply data augmentation if required 119 | if self.augmentation: 120 | data_dict = DataTransform(crop_size=[288, 384], scales=[1.0, 1.2, 1.5])(data_dict) 121 | 122 | im = 2. * data_dict.pop('im') - 1. # normalised to [-1, 1] 123 | return im, data_dict 124 | 125 | def __len__(self): 126 | return self.data_len 127 | 128 | 129 | class CityScapes(data.Dataset): 130 | """ 131 | CityScapes dataset, 3 tasks + 1 generated useless task 132 | Included tasks: 133 | 1. Semantic Segmentation, 134 | 2. Part Segmentation, 135 | 3. Disparity Estimation (Inverse Depth), 136 | 4. Noise prediction [to test auxiliary learning, purely conflict gradients] 137 | """ 138 | def __init__(self, root, train=True, augmentation=False): 139 | self.train = train 140 | self.root = os.path.expanduser(root) 141 | self.augmentation = augmentation 142 | 143 | # read the data file 144 | if train: 145 | self.data_path = root + '/train' 146 | else: 147 | self.data_path = root + '/val' 148 | 149 | # calculate data length 150 | self.data_len = len(fnmatch.filter(os.listdir(self.data_path + '/image'), '*.png')) 151 | self.noise = torch.rand(self.data_len, 1, 256, 256) if self.train else torch.rand(self.data_len, 1, 256, 512) 152 | 153 | def __getitem__(self, index): 154 | # load data from the pre-processed npy files 155 | image = torch.from_numpy(np.moveaxis(plt.imread(self.data_path + '/image/{:d}.png'.format(index)), -1, 0)).float() 156 | disparity = cv2.imread(self.data_path + '/depth/{:d}.png'.format(index), cv2.IMREAD_UNCHANGED).astype(np.float32) 157 | disparity = torch.from_numpy(self.map_disparity(disparity)).unsqueeze(0).float() 158 | seg = np.array(Image.open(self.data_path + '/seg/{:d}.png'.format(index)), dtype=float) 159 | seg = torch.from_numpy(self.map_seg_label(seg)).long() 160 | part_seg = np.array(Image.open(self.data_path + '/part_seg/{:d}.tif'.format(index))) 161 | part_seg = torch.from_numpy(self.map_part_seg_label(part_seg)).long() 162 | noise = self.noise[index].float() 163 | 164 | data_dict = {'im': image, 'seg': seg, 'part_seg': part_seg, 'disp': disparity, 'noise': noise} 165 | 166 | # apply data augmentation if required 167 | if self.augmentation: 168 | data_dict = DataTransform(crop_size=[256, 256], scales=[1.0])(data_dict) 169 | 170 | im = 2. * data_dict.pop('im') - 1. # normalised to [-1, 1] 171 | return im, data_dict 172 | 173 | def map_seg_label(self, mask): 174 | # source: https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/helpers/labels.py 175 | mask_map = np.zeros_like(mask) 176 | mask_map[np.isin(mask, [0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30])] = -1 177 | mask_map[np.isin(mask, [7])] = 0 178 | mask_map[np.isin(mask, [8])] = 1 179 | mask_map[np.isin(mask, [11])] = 2 180 | mask_map[np.isin(mask, [12])] = 3 181 | mask_map[np.isin(mask, [13])] = 4 182 | mask_map[np.isin(mask, [17])] = 5 183 | mask_map[np.isin(mask, [19])] = 6 184 | mask_map[np.isin(mask, [20])] = 7 185 | mask_map[np.isin(mask, [21])] = 8 186 | mask_map[np.isin(mask, [22])] = 9 187 | mask_map[np.isin(mask, [23])] = 10 188 | mask_map[np.isin(mask, [24])] = 11 189 | mask_map[np.isin(mask, [25])] = 12 190 | mask_map[np.isin(mask, [26])] = 13 191 | mask_map[np.isin(mask, [27])] = 14 192 | mask_map[np.isin(mask, [28])] = 15 193 | mask_map[np.isin(mask, [31])] = 16 194 | mask_map[np.isin(mask, [32])] = 17 195 | mask_map[np.isin(mask, [33])] = 18 196 | return mask_map 197 | 198 | def map_part_seg_label(self, mask): 199 | # https://panoptic-parts.readthedocs.io/en/stable/api_and_code.html 200 | # https://arxiv.org/abs/2004.07944 201 | mask = pp.decode_uids(mask, return_sids_pids=True)[-1] 202 | mask_map = np.zeros_like(mask) # background 203 | mask_map[np.isin(mask, [2401, 2501])] = 1 # human/rider torso 204 | mask_map[np.isin(mask, [2402, 2502])] = 2 # human/rider head 205 | mask_map[np.isin(mask, [2403, 2503])] = 3 # human/rider arms 206 | mask_map[np.isin(mask, [2404, 2504])] = 4 # human/rider legs 207 | mask_map[np.isin(mask, [2601, 2701, 2801])] = 5 # car/truck/bus windows 208 | mask_map[np.isin(mask, [2602, 2702, 2802])] = 6 # car/truck/bus wheels 209 | mask_map[np.isin(mask, [2603, 2703, 2803])] = 7 # car/truck/bus lights 210 | mask_map[np.isin(mask, [2604, 2704, 2804])] = 8 # car/truck/bus license_plate 211 | mask_map[np.isin(mask, [2605, 2705, 2805])] = 9 # car/truck/bus chassis 212 | return mask_map 213 | 214 | def map_disparity(self, disparity): 215 | # https://github.com/mcordts/cityscapesScripts/issues/55#issuecomment-411486510 216 | # remap invalid points to -1 (not to conflict with 0, infinite depth, such as sky) 217 | disparity[disparity == 0] = -1 218 | # reduce by a factor of 4 based on the rescaled resolution 219 | disparity[disparity > -1] = (disparity[disparity > -1] - 1) / (256 * 4) 220 | return disparity 221 | 222 | def __len__(self): 223 | return self.data_len 224 | 225 | 226 | class CIFAR100MTL(CIFAR100): 227 | """ 228 | CIFAR100 dataset, 20 tasks (grouped by coarse labels) 229 | Each task is a 5-label classification, with 2500 training and 500 testing number of data for each task. 230 | Modified from https://pytorch.org/docs/stable/torchvision/datasets.html 231 | """ 232 | def __init__(self, root, subset_id=0, train=True, transform=None, target_transform=None, download=False): 233 | super(CIFAR100MTL, self).__init__(root, train, transform, target_transform, download) 234 | # define coarse label maps 235 | coarse_labels = np.array([4, 1, 14, 8, 0, 6, 7, 7, 18, 3, 236 | 3, 14, 9, 18, 7, 11, 3, 9, 7, 11, 237 | 6, 11, 5, 10, 7, 6, 13, 15, 3, 15, 238 | 0, 11, 1, 10, 12, 14, 16, 9, 11, 5, 239 | 5, 19, 8, 8, 15, 13, 14, 17, 18, 10, 240 | 16, 4, 17, 4, 2, 0, 17, 4, 18, 17, 241 | 10, 3, 2, 12, 12, 16, 12, 1, 9, 19, 242 | 2, 10, 0, 1, 16, 12, 9, 13, 15, 13, 243 | 16, 19, 2, 4, 6, 19, 5, 5, 8, 19, 244 | 18, 1, 2, 15, 6, 0, 17, 8, 14, 13]) 245 | 246 | self.coarse_targets = coarse_labels[self.targets] 247 | 248 | # filter the data and targets for the desired subset 249 | self.data = self.data[self.coarse_targets == subset_id] 250 | self.targets = np.array(self.targets)[self.coarse_targets == subset_id] 251 | 252 | # remap fine labels into 5-class classification 253 | self.targets = np.unique(self.targets, return_inverse=True)[1] 254 | 255 | # update semantic classes 256 | self.class_dict = { 257 | "aquatic mammals": ["beaver", "dolphin", "otter", "seal", "whale"], 258 | "fish": ["aquarium_fish", "flatfish", "ray", "shark", "trout"], 259 | "flowers": ["orchid", "poppy", "rose", "sunflower", "tulip"], 260 | "food containers": ["bottle", "bowl", "can", "cup", "plate"], 261 | "fruit and vegetables": ["apple", "mushroom", "orange", "pear", "sweet_pepper"], 262 | "household electrical device": ["clock", "computer_keyboard", "lamp", "telephone", "television"], 263 | "household furniture": ["bed", "chair", "couch", "table", "wardrobe"], 264 | "insects": ["bee", "beetle", "butterfly", "caterpillar", "cockroach"], 265 | "large carnivores": ["bear", "leopard", "lion", "tiger", "wolf"], 266 | "large man-made outdoor things": ["bridge", "castle", "house", "road", "skyscraper"], 267 | "large natural outdoor scenes": ["cloud", "forest", "mountain", "plain", "sea"], 268 | "large omnivores and herbivores": ["camel", "cattle", "chimpanzee", "elephant", "kangaroo"], 269 | "medium-sized mammals": ["fox", "porcupine", "possum", "raccoon", "skunk"], 270 | "non-insect invertebrates": ["crab", "lobster", "snail", "spider", "worm"], 271 | "people": ["baby", "boy", "girl", "man", "woman"], 272 | "reptiles": ["crocodile", "dinosaur", "lizard", "snake", "turtle"], 273 | "small mammals": ["hamster", "mouse", "rabbit", "shrew", "squirrel"], 274 | "trees": ["maple_tree", "oak_tree", "palm_tree", "pine_tree", "willow_tree"], 275 | "vehicles 1": ["bicycle", "bus", "motorcycle", "pickup_truck", "train"], 276 | "vehicles 2": ["lawn_mower", "rocket", "streetcar", "tank", "tractor"], 277 | } 278 | 279 | self.subset_class = list(self.class_dict.keys())[subset_id] 280 | self.classes = self.class_dict[self.subset_class] 281 | 282 | -------------------------------------------------------------------------------- /create_network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchvision.models.resnet as resnet 5 | 6 | 7 | # -------------------------------------------------------------------------------- 8 | # Define DeepLab Modules 9 | # -------------------------------------------------------------------------------- 10 | class DeepLabHead(nn.Sequential): 11 | def __init__(self, in_channels, num_classes): 12 | super(DeepLabHead, self).__init__( 13 | ASPP(in_channels, [12, 24, 36]), 14 | nn.Conv2d(256, 256, 3, padding=1, bias=False), 15 | nn.BatchNorm2d(256), 16 | nn.ReLU(), 17 | nn.Conv2d(256, num_classes, 1) 18 | ) 19 | 20 | 21 | class ASPPConv(nn.Sequential): 22 | def __init__(self, in_channels, out_channels, dilation): 23 | modules = [ 24 | nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False), 25 | nn.BatchNorm2d(out_channels), 26 | nn.ReLU() 27 | ] 28 | super(ASPPConv, self).__init__(*modules) 29 | 30 | 31 | class ASPPPooling(nn.Sequential): 32 | def __init__(self, in_channels, out_channels): 33 | super(ASPPPooling, self).__init__( 34 | nn.AdaptiveAvgPool2d(1), 35 | nn.Conv2d(in_channels, out_channels, 1, bias=False), 36 | nn.BatchNorm2d(out_channels), 37 | nn.ReLU() 38 | ) 39 | 40 | def forward(self, x): 41 | size = x.shape[-2:] 42 | x = super(ASPPPooling, self).forward(x) 43 | return F.interpolate(x, size=size, mode='bilinear', align_corners=False) 44 | 45 | 46 | class ASPP(nn.Module): 47 | def __init__(self, in_channels, atrous_rates): 48 | super(ASPP, self).__init__() 49 | out_channels = 256 50 | modules = [nn.Sequential( 51 | nn.Conv2d(in_channels, out_channels, 1, bias=False), 52 | nn.BatchNorm2d(out_channels), 53 | nn.ReLU() 54 | )] 55 | 56 | rate1, rate2, rate3 = tuple(atrous_rates) 57 | modules.append(ASPPConv(in_channels, out_channels, rate1)) 58 | modules.append(ASPPConv(in_channels, out_channels, rate2)) 59 | modules.append(ASPPConv(in_channels, out_channels, rate3)) 60 | modules.append(ASPPPooling(in_channels, out_channels)) 61 | 62 | self.convs = nn.ModuleList(modules) 63 | self.project = nn.Sequential( 64 | nn.Conv2d(5 * out_channels, out_channels, 1, bias=False), 65 | nn.BatchNorm2d(out_channels), 66 | nn.ReLU(), 67 | ) 68 | 69 | def forward(self, x): 70 | res = [] 71 | for conv in self.convs: 72 | res.append(conv(x)) 73 | res = torch.cat(res, dim=1) 74 | return self.project(res) 75 | 76 | 77 | class ResnetDilated(nn.Module): 78 | def __init__(self, orig_resnet, dilate_scale=8): 79 | super(ResnetDilated, self).__init__() 80 | from functools import partial 81 | 82 | if dilate_scale == 8: 83 | orig_resnet.layer3.apply(partial(self._nostride_dilate, dilate=2)) 84 | orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=4)) 85 | elif dilate_scale == 16: 86 | orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=2)) 87 | 88 | # take pre-defined ResNet, except AvgPool and FC 89 | self.conv1 = orig_resnet.conv1 90 | self.bn1 = orig_resnet.bn1 91 | self.relu1 = orig_resnet.relu 92 | 93 | self.maxpool = orig_resnet.maxpool 94 | self.layer1 = orig_resnet.layer1 95 | self.layer2 = orig_resnet.layer2 96 | self.layer3 = orig_resnet.layer3 97 | self.layer4 = orig_resnet.layer4 98 | 99 | def _nostride_dilate(self, m, dilate): 100 | classname = m.__class__.__name__ 101 | if classname.find('Conv') != -1: 102 | if m.stride == (2, 2): 103 | m.stride = (1, 1) 104 | if m.kernel_size == (3, 3): 105 | m.dilation = (dilate // 2, dilate // 2) 106 | m.padding = (dilate // 2, dilate // 2) 107 | else: 108 | if m.kernel_size == (3, 3): 109 | m.dilation = (dilate, dilate) 110 | m.padding = (dilate, dilate) 111 | 112 | def forward(self, x): 113 | x = self.relu1(self.bn1(self.conv1(x))) 114 | x = self.maxpool(x) 115 | 116 | x = self.layer1(x) 117 | x = self.layer2(x) 118 | x = self.layer3(x) 119 | x = self.layer4(x) 120 | return x 121 | 122 | 123 | # -------------------------------------------------------------------------------- 124 | # Define MTAN DeepLab 125 | # -------------------------------------------------------------------------------- 126 | class MTANDeepLabv3(nn.Module): 127 | def __init__(self, tasks): 128 | super(MTANDeepLabv3, self).__init__() 129 | backbone = ResnetDilated(resnet.resnet50()) 130 | ch = [256, 512, 1024, 2048] 131 | 132 | self.tasks = tasks 133 | self.shared_conv = nn.Sequential(backbone.conv1, backbone.bn1, backbone.relu1, backbone.maxpool) 134 | 135 | # We will apply the attention over the last bottleneck layer in the ResNet. 136 | self.shared_layer1_b = backbone.layer1[:-1] 137 | self.shared_layer1_t = backbone.layer1[-1] 138 | 139 | self.shared_layer2_b = backbone.layer2[:-1] 140 | self.shared_layer2_t = backbone.layer2[-1] 141 | 142 | self.shared_layer3_b = backbone.layer3[:-1] 143 | self.shared_layer3_t = backbone.layer3[-1] 144 | 145 | self.shared_layer4_b = backbone.layer4[:-1] 146 | self.shared_layer4_t = backbone.layer4[-1] 147 | 148 | # Define task specific attention modules using a similar bottleneck design in residual block 149 | self.encoder_att_1 = nn.ModuleList([self.att_layer(ch[0], ch[0] // 4, ch[0]) for _ in self.tasks]) 150 | self.encoder_att_2 = nn.ModuleList([self.att_layer(2 * ch[1], ch[1] // 4, ch[1]) for _ in self.tasks]) 151 | self.encoder_att_3 = nn.ModuleList([self.att_layer(2 * ch[2], ch[2] // 4, ch[2]) for _ in self.tasks]) 152 | self.encoder_att_4 = nn.ModuleList([self.att_layer(2 * ch[3], ch[3] // 4, ch[3]) for _ in self.tasks]) 153 | 154 | # Define task shared attention encoders using residual bottleneck layers 155 | self.encoder_block_att_1 = self.conv_layer(ch[0], ch[1] // 4) 156 | self.encoder_block_att_2 = self.conv_layer(ch[1], ch[2] // 4) 157 | self.encoder_block_att_3 = self.conv_layer(ch[2], ch[3] // 4) 158 | 159 | self.down_sampling = nn.MaxPool2d(kernel_size=2, stride=2) 160 | 161 | # Define task-specific decoders using ASPP modules 162 | self.decoders = nn.ModuleList([DeepLabHead(ch[-1], self.tasks[t]) for t in self.tasks]) 163 | 164 | def att_layer(self, in_channel, intermediate_channel, out_channel): 165 | return nn.Sequential( 166 | nn.Conv2d(in_channels=in_channel, out_channels=intermediate_channel, kernel_size=1, padding=0), 167 | nn.BatchNorm2d(intermediate_channel), 168 | nn.ReLU(inplace=True), 169 | nn.Conv2d(in_channels=intermediate_channel, out_channels=out_channel, kernel_size=1, padding=0), 170 | nn.BatchNorm2d(out_channel), 171 | nn.Sigmoid() 172 | ) 173 | 174 | def conv_layer(self, in_channel, out_channel): 175 | downsample = nn.Sequential(nn.Conv2d(in_channel, 4 * out_channel, kernel_size=1, stride=1, bias=False), 176 | nn.BatchNorm2d(4 * out_channel)) 177 | return resnet.Bottleneck(in_channel, out_channel, downsample=downsample) 178 | 179 | def forward(self, x): 180 | _, _, im_h, im_w = x.shape 181 | 182 | # Shared convolution 183 | x = self.shared_conv(x) 184 | 185 | # Shared ResNet block 1 186 | u_1_b = self.shared_layer1_b(x) 187 | u_1_t = self.shared_layer1_t(u_1_b) 188 | 189 | # Shared ResNet block 2 190 | u_2_b = self.shared_layer2_b(u_1_t) 191 | u_2_t = self.shared_layer2_t(u_2_b) 192 | 193 | # Shared ResNet block 3 194 | u_3_b = self.shared_layer3_b(u_2_t) 195 | u_3_t = self.shared_layer3_t(u_3_b) 196 | 197 | # Shared ResNet block 4 198 | u_4_b = self.shared_layer4_b(u_3_t) 199 | u_4_t = self.shared_layer4_t(u_4_b) 200 | 201 | # Attention block 1 -> Apply attention over last residual block 202 | a_1_mask = [att_i(u_1_b) for att_i in self.encoder_att_1] # Generate task specific attention map 203 | a_1 = [a_1_mask_i * u_1_t for a_1_mask_i in a_1_mask] # Apply task specific attention map to shared features 204 | a_1 = [self.down_sampling(self.encoder_block_att_1(a_1_i)) for a_1_i in a_1] 205 | 206 | # Attention block 2 -> Apply attention over last residual block 207 | a_2_mask = [att_i(torch.cat((u_2_b, a_1_i), dim=1)) for a_1_i, att_i in zip(a_1, self.encoder_att_2)] 208 | a_2 = [a_2_mask_i * u_2_t for a_2_mask_i in a_2_mask] 209 | a_2 = [self.encoder_block_att_2(a_2_i) for a_2_i in a_2] 210 | 211 | # Attention block 3 -> Apply attention over last residual block 212 | a_3_mask = [att_i(torch.cat((u_3_b, a_2_i), dim=1)) for a_2_i, att_i in zip(a_2, self.encoder_att_3)] 213 | a_3 = [a_3_mask_i * u_3_t for a_3_mask_i in a_3_mask] 214 | a_3 = [self.encoder_block_att_3(a_3_i) for a_3_i in a_3] 215 | 216 | # Attention block 4 -> Apply attention over last residual block (without final encoder) 217 | a_4_mask = [att_i(torch.cat((u_4_b, a_3_i), dim=1)) for a_3_i, att_i in zip(a_3, self.encoder_att_4)] 218 | a_4 = [a_4_mask_i * u_4_t for a_4_mask_i in a_4_mask] 219 | 220 | # Task specific decoders 221 | out = [0 for _ in self.tasks] 222 | for i, t in enumerate(self.tasks): 223 | out[i] = F.interpolate(self.decoders[i](a_4[i]), size=[im_h, im_w], mode='bilinear', align_corners=True) 224 | if t == 'normal': 225 | out[i] = out[i] / torch.norm(out[i], p=2, dim=1, keepdim=True) 226 | return out 227 | 228 | def shared_modules(self): 229 | return [self.shared_conv, 230 | self.shared_layer1_b, 231 | self.shared_layer1_t, 232 | self.shared_layer2_b, 233 | self.shared_layer2_t, 234 | self.shared_layer3_b, 235 | self.shared_layer3_t, 236 | self.shared_layer4_b, 237 | self.shared_layer4_t, 238 | self.encoder_block_att_1, 239 | self.encoder_block_att_2, 240 | self.encoder_block_att_3] 241 | 242 | def zero_grad_shared_modules(self): 243 | for mm in self.shared_modules(): 244 | mm.zero_grad() 245 | 246 | 247 | # -------------------------------------------------------------------------------- 248 | # Define Split DeepLab 249 | # -------------------------------------------------------------------------------- 250 | class MTLDeepLabv3(nn.Module): 251 | def __init__(self, tasks): 252 | super(MTLDeepLabv3, self).__init__() 253 | backbone = ResnetDilated(resnet.resnet50()) 254 | ch = [256, 512, 1024, 2048] 255 | 256 | self.tasks = tasks 257 | 258 | self.shared_conv = nn.Sequential(backbone.conv1, backbone.bn1, backbone.relu1, backbone.maxpool) 259 | self.shared_layer1 = backbone.layer1 260 | self.shared_layer2 = backbone.layer2 261 | self.shared_layer3 = backbone.layer3 262 | self.shared_layer4 = backbone.layer4 263 | 264 | # Define task-specific decoders using ASPP modules 265 | self.decoders = nn.ModuleList([DeepLabHead(ch[-1], self.tasks[t]) for t in self.tasks]) 266 | 267 | def forward(self, x): 268 | _, _, im_h, im_w = x.shape 269 | 270 | # Shared convolution 271 | x = self.shared_conv(x) 272 | x = self.shared_layer1(x) 273 | x = self.shared_layer2(x) 274 | x = self.shared_layer3(x) 275 | x = self.shared_layer4(x) 276 | 277 | # Task specific decoders 278 | out = [0 for _ in self.tasks] 279 | for i, t in enumerate(self.tasks): 280 | out[i] = F.interpolate(self.decoders[i](x), size=[im_h, im_w], mode='bilinear', align_corners=True) 281 | if t == 'normal': 282 | out[i] = out[i] / torch.norm(out[i], p=2, dim=1, keepdim=True) 283 | return out 284 | 285 | def shared_modules(self): 286 | return [self.shared_conv, 287 | self.shared_layer1, 288 | self.shared_layer2, 289 | self.shared_layer3, 290 | self.shared_layer4] 291 | 292 | def zero_grad_shared_modules(self): 293 | for mm in self.shared_modules(): 294 | mm.zero_grad() 295 | 296 | 297 | # -------------------------------------------------------------------------------- 298 | # Define VGG-16 (for CIFAR-100 experiments) 299 | # -------------------------------------------------------------------------------- 300 | class ConditionalBatchNorm2d(nn.Module): 301 | def __init__(self, num_features, num_classes): 302 | super().__init__() 303 | self.num_features = num_features 304 | self.bn_list = nn.ModuleList() 305 | 306 | for i in range(num_classes): 307 | self.bn_list.append(nn.BatchNorm2d(num_features)) 308 | 309 | def forward(self, x, y): 310 | out = self.bn_list[y](x) 311 | return out 312 | 313 | 314 | class MTLVGG16(nn.Module): 315 | def __init__(self, num_tasks): 316 | super(MTLVGG16, self).__init__() 317 | filter = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512] 318 | self.num_tasks = num_tasks 319 | 320 | # define VGG-16 block 321 | network_layers = [] 322 | channel_in = 3 323 | for ch in filter: 324 | if ch == 'M': 325 | network_layers += [nn.MaxPool2d(2, 2)] 326 | else: 327 | network_layers += [nn.Conv2d(channel_in, ch, kernel_size=3, padding=1), 328 | ConditionalBatchNorm2d(ch, num_tasks), 329 | nn.ReLU(inplace=True)] 330 | channel_in = ch 331 | 332 | self.network_block = nn.Sequential(*network_layers) 333 | 334 | # define classifiers here 335 | self.classifier = nn.ModuleList() 336 | for i in range(num_tasks): 337 | self.classifier.append(nn.Sequential(nn.Linear(filter[-1], 5))) 338 | 339 | def forward(self, x, task_id): 340 | for layer in self.network_block: 341 | if isinstance(layer, ConditionalBatchNorm2d): 342 | x = layer(x, task_id) 343 | else: 344 | x = layer(x) 345 | 346 | x = F.adaptive_avg_pool2d(x, 1) 347 | pred = self.classifier[task_id](x.view(x.shape[0], -1)) 348 | return pred 349 | 350 | 351 | -------------------------------------------------------------------------------- /dataset/cityscapes_preprocess.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | from PIL import Image 4 | 5 | root = 'dataset' 6 | 7 | im_train_root = root + 'cityscapes/train/image' 8 | im_val_root = root + '/cityscapes/val/image' 9 | label_train_root = root + '/cityscapes/train/seg' 10 | label_val_root = root + '/cityscapes/val/seg' 11 | part_train_root = root + '/cityscapes/train/part_seg' 12 | part_val_root = root + '/cityscapes/val/part_seg' 13 | depth_train_root = root + '/cityscapes/train/depth' 14 | depth_val_root = root + '/cityscapes/val/depth' 15 | 16 | os.makedirs(im_train_root) 17 | os.makedirs(im_val_root) 18 | os.makedirs(label_train_root) 19 | os.makedirs(label_val_root) 20 | os.makedirs(part_train_root) 21 | os.makedirs(part_val_root) 22 | os.makedirs(depth_train_root) 23 | os.makedirs(depth_val_root) 24 | 25 | 26 | # Images 27 | train_im_list = glob.glob(root + '/leftImg8bit_trainvaltest/leftImg8bit/train/*') 28 | counter = 0 29 | for city in train_im_list: 30 | im_list = glob.glob(city + '/*.png') 31 | im_list.sort() 32 | for i in im_list: 33 | im = Image.open(i) 34 | im = im.resize((512, 256)) 35 | im.save(im_train_root + '/{}.png'.format(counter)) 36 | counter += 1 37 | print('Training RGB images processing has completed.') 38 | 39 | 40 | val_im_list = glob.glob(root + '/leftImg8bit_trainvaltest/leftImg8bit/val/*') 41 | counter = 0 42 | for city in val_im_list: 43 | im_list = glob.glob(city + '/*.png') 44 | im_list.sort() 45 | for i in im_list: 46 | im = Image.open(i) 47 | im = im.resize((512, 256)) 48 | im.save(im_val_root + '/{}.png'.format(counter)) 49 | counter += 1 50 | print('Validation RGB images processing has completed.') 51 | 52 | 53 | # Disparity 54 | train_im_list = glob.glob(root + '/disparity_trainvaltest/disparity/train/*') 55 | counter = 0 56 | for city in train_im_list: 57 | im_list = glob.glob(city + '/*.png') 58 | im_list.sort() 59 | for i in im_list: 60 | im = Image.open(i) 61 | im = im.resize((512, 256), resample=Image.NEAREST) 62 | im.save(depth_train_root + '/{}.png'.format(counter)) 63 | counter += 1 64 | print('Training depth images processing has completed.') 65 | 66 | 67 | val_im_list = glob.glob(root + '/disparity_trainvaltest/disparity/val/*') 68 | counter = 0 69 | for city in val_im_list: 70 | im_list = glob.glob(city + '/*.png') 71 | im_list.sort() 72 | for i in im_list: 73 | im = Image.open(i) 74 | im = im.resize((512, 256), resample=Image.NEAREST) 75 | im.save(depth_val_root + '/{}.png'.format(counter)) 76 | counter += 1 77 | print('Validation depth images processing has completed.') 78 | 79 | 80 | # Segmentation 81 | counter = 0 82 | train_label_list = glob.glob(root + '/gtFine_trainvaltest/gtFine/train/*') 83 | for city in train_label_list: 84 | label_list = glob.glob(city + '/*_labelIds.png') 85 | label_list.sort() 86 | for l in label_list: 87 | im = Image.open(l) 88 | im = im.resize((512, 256), resample=Image.NEAREST) 89 | im.save(label_train_root + '/{}.png'.format(counter)) 90 | counter += 1 91 | print('Training Label images processing has completed.') 92 | 93 | 94 | counter = 0 95 | val_label_list = glob.glob(root + '/gtFine_trainvaltest/gtFine/val/*') 96 | for city in val_label_list: 97 | label_list = glob.glob(city + '/*_labelIds.png') 98 | label_list.sort() 99 | for l in label_list: 100 | im = Image.open(l) 101 | im = im.resize((512, 256), resample=Image.NEAREST) 102 | im.save(label_val_root + '/{}.png'.format(counter)) 103 | counter += 1 104 | print('Validation Label images processing has completed.') 105 | 106 | 107 | # Part Segmentation 108 | counter = 0 109 | train_label_list = glob.glob(root + '/gtFinePanopticParts_trainval/gtFinePanopticParts/train/*') 110 | for city in train_label_list: 111 | label_list = glob.glob(city + '/*.tif') 112 | label_list.sort() 113 | for l in label_list: 114 | im = Image.open(l) 115 | im = im.resize((512, 256), resample=Image.NEAREST) 116 | im.save(part_train_root + '{}.tif'.format(counter)) 117 | counter += 1 118 | print('Training Label images processing has completed.') 119 | 120 | counter = 0 121 | val_label_list = glob.glob(root + '/gtFinePanopticParts_trainval/gtFinePanopticParts/val/*') 122 | for city in val_label_list: 123 | label_list = glob.glob(city + '/*.tif') 124 | label_list.sort() 125 | for l in label_list: 126 | im = Image.open(l) 127 | im = im.resize((512, 256), resample=Image.NEAREST) 128 | im.save(part_val_root + '/{}.tif'.format(counter)) 129 | counter += 1 130 | print('Validation Label images processing has completed.') 131 | 132 | -------------------------------------------------------------------------------- /trainer_cifar.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch.optim as optim 4 | import torch.utils.data.sampler as sampler 5 | 6 | from auto_lambda import AutoLambda 7 | from create_network import * 8 | from create_dataset import * 9 | from utils import * 10 | 11 | parser = argparse.ArgumentParser(description='Multi-task/Auxiliary Learning: CIFAR-100') 12 | parser.add_argument('--mode', default='none', type=str) 13 | parser.add_argument('--port', default='none', type=str) 14 | 15 | parser.add_argument('--weight', default='equal', type=str, help='multi-task weighting: equal, dwa, uncert, autol') 16 | parser.add_argument('--gpu', default=0, type=int, help='gpu ID') 17 | parser.add_argument('--autol_init', default=0.1, type=float, help='initialisation for auto-lambda') 18 | parser.add_argument('--autol_lr', default=3e-4, type=float, help='learning rate for auto-lambda') 19 | parser.add_argument('--subset_id', default=0, type=int, help='domain id for cifar-100, -1 for MTL mode') 20 | parser.add_argument('--seed', default=0, type=int, help='random seed ID') 21 | 22 | opt = parser.parse_args() 23 | 24 | torch.manual_seed(opt.seed) 25 | np.random.seed(opt.seed) 26 | random.seed(opt.seed) 27 | 28 | # create logging folder to store training weights and losses 29 | if not os.path.exists('logging'): 30 | os.makedirs('logging') 31 | 32 | # define model, optimiser and scheduler 33 | device = torch.device("cuda:{}".format(opt.gpu) if torch.cuda.is_available() else "cpu") 34 | model = MTLVGG16(num_tasks=20).to(device) 35 | train_tasks = {'class_{}'.format(i): 5 for i in range(20)} 36 | pri_tasks = {'class_{}'.format(opt.subset_id): 5} if opt.subset_id >= 0 else train_tasks 37 | 38 | total_epoch = 200 39 | 40 | if opt.weight == 'autol': 41 | params = model.parameters() 42 | autol = AutoLambda(model, device, train_tasks, pri_tasks, opt.autol_init) 43 | meta_weight_ls = np.zeros([total_epoch, len(train_tasks)], dtype=np.float32) 44 | meta_optimizer = optim.Adam([autol.meta_weights], lr=opt.autol_lr) 45 | 46 | elif opt.weight in ['dwa', 'equal']: 47 | T = 2.0 # temperature used in dwa 48 | lambda_weight = np.ones([total_epoch, len(train_tasks)], dtype=np.float32) 49 | params = model.parameters() 50 | 51 | elif opt.weight == 'uncert': 52 | logsigma = torch.tensor([-0.7] * len(train_tasks), requires_grad=True, device=device) 53 | params = list(model.parameters()) + [logsigma] 54 | logsigma_ls = np.zeros([total_epoch, len(train_tasks)], dtype=np.float32) 55 | 56 | optimizer = optim.SGD(params, lr=0.1, weight_decay=5e-4, momentum=0.9) 57 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, total_epoch) 58 | 59 | # define dataset 60 | trans_train = transforms.Compose([ 61 | transforms.RandomCrop(32, padding=4), 62 | transforms.RandomHorizontalFlip(), 63 | transforms.ToTensor(), 64 | transforms.Normalize([0.5071, 0.4867, 0.4408], [0.2675, 0.2565, 0.2761]), 65 | ]) 66 | 67 | trans_test = transforms.Compose([ 68 | transforms.ToTensor(), 69 | transforms.Normalize([0.5071, 0.4867, 0.4408], [0.2675, 0.2565, 0.2761]), 70 | ]) 71 | 72 | train_sets = [CIFAR100MTL(root='dataset', train=True, transform=trans_train, subset_id=i) for i in range(20)] 73 | if opt.subset_id >= 0: 74 | test_set = CIFAR100MTL(root='dataset', train=False, transform=trans_test, subset_id=opt.subset_id) 75 | else: 76 | test_sets = [CIFAR100MTL(root='dataset', train=False, transform=trans_test, subset_id=i) for i in range(20)] 77 | 78 | batch_size = 32 79 | 80 | train_loaders = [torch.utils.data.DataLoader(dataset=train_set, batch_size=batch_size, shuffle=True, num_workers=2) 81 | for train_set in train_sets] 82 | 83 | # a copy of train_loader with different data order, used for Auto-Lambda meta-update 84 | if opt.weight == 'autol': 85 | val_loaders = [torch.utils.data.DataLoader(dataset=train_set, batch_size=batch_size, shuffle=True, num_workers=2) 86 | for train_set in train_sets] 87 | 88 | if opt.subset_id >= 0: 89 | test_loader = torch.utils.data.DataLoader(dataset=test_set, batch_size=batch_size, shuffle=False, num_workers=2) 90 | else: 91 | test_loaders = [torch.utils.data.DataLoader(dataset=test_set, batch_size=batch_size, shuffle=True, num_workers=2) 92 | for test_set in test_sets] 93 | 94 | 95 | # Train and evaluate multi-task network 96 | if opt.subset_id >= 0: 97 | print('CIFAR-100 | Training Task: All Domains | Primary Task: {} in Multi-task / Auxiliary Learning Mode with VGG-16' 98 | .format(test_set.subset_class.title())) 99 | else: 100 | print('CIFAR-100 | Training Task: All Domains | Primary Task: All Domains in Multi-task / Auxiliary Learning Mode with VGG16') 101 | 102 | print('Applying Multi-task Methods: Weighting-based: {}' 103 | .format(opt.weight.title())) 104 | 105 | train_batch = len(train_loaders[0]) 106 | test_batch = len(test_loader) if opt.subset_id >= 0 else len(test_loaders[0]) 107 | train_metric = TaskMetric(train_tasks, pri_tasks, batch_size, total_epoch, 'cifar100') 108 | if opt.subset_id >= 0: 109 | test_metric = TaskMetric(train_tasks, pri_tasks, batch_size, total_epoch, 'cifar100') 110 | else: 111 | test_metric = TaskMetric(train_tasks, pri_tasks, batch_size, total_epoch, 'cifar100', include_mtl=True) 112 | 113 | for index in range(total_epoch): 114 | 115 | # apply Dynamic Weight Average 116 | if opt.weight == 'dwa': 117 | if index == 0 or index == 1: 118 | lambda_weight[index, :] = 1.0 119 | else: 120 | w = [] 121 | for i, t in enumerate(train_tasks): 122 | w += [train_metric.metric[t][index - 1, 0] / train_metric.metric[t][index - 2, 0]] 123 | w = torch.softmax(torch.tensor(w) / T, dim=0) 124 | lambda_weight[index] = len(train_tasks) * w.numpy() 125 | 126 | # evaluating train data 127 | model.train() 128 | train_datasets = [iter(train_loader) for train_loader in train_loaders] 129 | if opt.weight == 'autol': 130 | val_datasets = [iter(val_loader) for val_loader in val_loaders] 131 | for k in range(train_batch): 132 | train_datas = [] 133 | train_targets = {} 134 | for t in range(20): 135 | train_data, train_target = train_datasets[t].next() 136 | train_datas += [train_data.to(device)] 137 | train_targets['class_{}'.format(t)] = train_target.to(device) 138 | 139 | if opt.weight == 'autol': 140 | val_datas = [] 141 | val_targets = {} 142 | for t in range(20): 143 | val_data, val_target = val_datasets[t].next() 144 | val_datas += [val_data.to(device)] 145 | val_targets['class_{}'.format(t)] = val_target.to(device) 146 | 147 | meta_optimizer.zero_grad() 148 | autol.unrolled_backward(train_datas, train_targets, val_datas, val_targets, 149 | scheduler.get_last_lr()[0], optimizer) 150 | meta_optimizer.step() 151 | 152 | optimizer.zero_grad() 153 | 154 | train_pred = [model(train_data, t) for t, train_data in enumerate(train_datas)] 155 | train_loss = [compute_loss(train_pred[t], train_targets[task_id], task_id) for t, task_id in enumerate(train_targets)] 156 | 157 | if opt.weight in ['equal', 'dwa']: 158 | loss = sum(w * train_loss[i] for i, w in enumerate(lambda_weight[index])) 159 | 160 | if opt.weight == 'autol': 161 | loss = sum(w * train_loss[i] for i, w in enumerate(autol.meta_weights)) 162 | 163 | if opt.weight == 'uncert': 164 | loss = sum(1 / (2 * torch.exp(w)) * train_loss[i] + w / 2 for i, w in enumerate(logsigma)) 165 | 166 | loss.backward() 167 | optimizer.step() 168 | 169 | train_metric.update_metric(train_pred, train_targets, train_loss) 170 | 171 | train_str = train_metric.compute_metric(only_pri=True) 172 | train_metric.reset() 173 | 174 | # evaluating test data 175 | model.eval() 176 | with torch.no_grad(): 177 | if opt.subset_id >= 0: 178 | test_dataset = iter(test_loader) 179 | for k in range(test_batch): 180 | test_data, test_target = test_dataset.next() 181 | test_data = test_data.to(device) 182 | test_target = test_target.to(device) 183 | 184 | test_pred = model(test_data, opt.subset_id) 185 | test_loss = F.cross_entropy(test_pred, test_target) 186 | 187 | test_metric.update_metric([test_pred], {'class_{}'.format(opt.subset_id): test_target}, [test_loss]) 188 | else: 189 | test_datasets = [iter(test_loader) for test_loader in test_loaders] 190 | for k in range(test_batch): 191 | test_datas = [] 192 | test_targets = {} 193 | for t in range(20): 194 | test_data, test_target = test_datasets[t].next() 195 | test_datas += [test_data.to(device)] 196 | test_targets['class_{}'.format(t)] = test_target.to(device) 197 | test_pred = [model(test_data, t) for t, test_data in enumerate(test_datas)] 198 | test_loss = [compute_loss(test_pred[t], test_targets[task_id], task_id) for t, task_id in enumerate(test_targets)] 199 | test_metric.update_metric(test_pred, test_targets, test_loss) 200 | 201 | test_str = test_metric.compute_metric(only_pri=True) 202 | test_metric.reset() 203 | 204 | scheduler.step() 205 | 206 | if opt.subset_id >= 0: 207 | print('Epoch {:04d} | TRAIN:{} || TEST:{} | Best: {} {:.4f}' 208 | .format(index, train_str, test_str, test_set.subset_class.title(), 209 | test_metric.get_best_performance('class_{}'.format(opt.subset_id)))) 210 | else: 211 | print('Epoch {:04d} | TRAIN:{} || TEST:{} | Best: All {:.4f}' 212 | .format(index, train_str, test_str, test_metric.get_best_performance('all'))) 213 | 214 | if opt.weight == 'autol': 215 | meta_weight_ls[index] = autol.meta_weights.detach().cpu() 216 | dict = {'train_loss': train_metric.metric, 'test_loss': test_metric.metric, 217 | 'weight': meta_weight_ls} 218 | 219 | print(get_weight_str_ranked(meta_weight_ls[index], list(train_sets[0].class_dict.keys()), 4)) 220 | 221 | if opt.weight in ['dwa', 'equal']: 222 | dict = {'train_loss': train_metric.metric, 'test_loss': test_metric.metric, 223 | 'weight': lambda_weight} 224 | print(get_weight_str_ranked(lambda_weight[index], list(train_sets[0].class_dict.keys()), 4)) 225 | 226 | if opt.weight == 'uncert': 227 | logsigma_ls[index] = logsigma.detach().cpu() 228 | dict = {'train_loss': train_metric.metric, 'test_loss': test_metric.metric, 229 | 'weight': logsigma_ls} 230 | print(get_weight_str_ranked(1 / (2 * np.exp(logsigma_ls[index])), list(train_sets[0].class_dict.keys()), 4)) 231 | 232 | np.save('logging/mtl_cifar_{}_{}_{}.npy'.format(opt.subset_id, opt.weight, opt.seed), dict) 233 | 234 | 235 | 236 | -------------------------------------------------------------------------------- /trainer_cifar_single.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch.optim as optim 4 | import torch.utils.data.sampler as sampler 5 | 6 | from create_network import * 7 | from create_dataset import * 8 | from utils import * 9 | 10 | parser = argparse.ArgumentParser(description='Single-task: Split') 11 | parser.add_argument('--mode', default='none', type=str) 12 | parser.add_argument('--port', default='none', type=str) 13 | 14 | parser.add_argument('--gpu', default=0, type=int, help='gpu ID') 15 | parser.add_argument('--seed', default=0, type=int, help='gpu ID') 16 | parser.add_argument('--subset_id', default=0, type=int, help='mtan') 17 | 18 | opt = parser.parse_args() 19 | 20 | torch.manual_seed(opt.seed) 21 | np.random.seed(opt.seed) 22 | random.seed(opt.seed) 23 | 24 | # create logging folder to store training weights and losses 25 | if not os.path.exists('logging'): 26 | os.makedirs('logging') 27 | 28 | # define model, optimiser and scheduler 29 | device = torch.device("cuda:{}".format(opt.gpu) if torch.cuda.is_available() else "cpu") 30 | model = MTLVGG16(num_tasks=1).to(device) 31 | train_tasks = {'class_{}'.format(opt.subset_id): 5} 32 | 33 | total_epoch = 200 34 | optimizer = optim.SGD(model.parameters(), lr=0.1, weight_decay=5e-4, momentum=0.9) 35 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, total_epoch) 36 | 37 | # define dataset 38 | trans_train = transforms.Compose([ 39 | transforms.RandomCrop(32, padding=4), 40 | transforms.RandomHorizontalFlip(), 41 | transforms.ToTensor(), 42 | transforms.Normalize([0.5071, 0.4867, 0.4408], [0.2675, 0.2565, 0.2761]), 43 | ]) 44 | 45 | trans_test = transforms.Compose([ 46 | transforms.ToTensor(), 47 | transforms.Normalize([0.5071, 0.4867, 0.4408], [0.2675, 0.2565, 0.2761]), 48 | ]) 49 | 50 | train_set = CIFAR100MTL(root='dataset', train=True, transform=trans_train, subset_id=opt.subset_id) 51 | test_set = CIFAR100MTL(root='dataset', train=False, transform=trans_test, subset_id=opt.subset_id) 52 | 53 | batch_size = 32 54 | 55 | train_loader = torch.utils.data.DataLoader( 56 | dataset=train_set, 57 | batch_size=batch_size, 58 | shuffle=True, 59 | num_workers=4 60 | ) 61 | 62 | test_loader = torch.utils.data.DataLoader( 63 | dataset=test_set, 64 | batch_size=batch_size, 65 | shuffle=False, 66 | num_workers=4 67 | ) 68 | 69 | 70 | # Train and evaluate multi-task network 71 | print('Training Task: CIFAR-100 - {} in Single Task Learning Mode with VGG-16'.format(train_set.subset_class.title())) 72 | 73 | train_batch = len(train_loader) 74 | test_batch = len(test_loader) 75 | train_metric = TaskMetric(train_tasks, train_tasks, batch_size, total_epoch, 'cifar100') 76 | test_metric = TaskMetric(train_tasks, train_tasks, batch_size, total_epoch, 'cifar100') 77 | for index in range(total_epoch): 78 | 79 | # evaluating train data 80 | model.train() 81 | train_dataset = iter(train_loader) 82 | for k in range(train_batch): 83 | train_data, train_target = train_dataset.next() 84 | train_data = train_data.to(device) 85 | train_target = train_target.to(device) 86 | 87 | train_pred = model(train_data, 0) 88 | 89 | optimizer.zero_grad() 90 | train_loss = F.cross_entropy(train_pred, train_target) 91 | train_loss.backward() 92 | optimizer.step() 93 | 94 | train_metric.update_metric([train_pred], {'class_{}'.format(opt.subset_id): train_target}, [train_loss]) 95 | 96 | train_str = train_metric.compute_metric() 97 | train_metric.reset() 98 | 99 | # evaluating test data 100 | model.eval() 101 | with torch.no_grad(): 102 | test_dataset = iter(test_loader) 103 | for k in range(test_batch): 104 | test_data, test_target = test_dataset.next() 105 | test_data = test_data.to(device) 106 | test_target = test_target.to(device) 107 | 108 | test_pred = model(test_data, 0) 109 | test_loss = F.cross_entropy(test_pred, test_target) 110 | 111 | test_metric.update_metric([test_pred], {'class_{}'.format(opt.subset_id): test_target}, [test_loss]) 112 | 113 | test_str = test_metric.compute_metric() 114 | test_metric.reset() 115 | 116 | scheduler.step() 117 | 118 | print('Epoch {:04d} | TRAIN:{} || TEST:{} | Best: {} {:.4f}' 119 | .format(index, train_str, test_str, train_set.subset_class.title(), 120 | test_metric.get_best_performance('class_{}'.format(opt.subset_id)))) 121 | 122 | task_dict = {'train_loss': train_metric.metric, 'test_loss': test_metric.metric} 123 | np.save('logging/stl_cifar_{}_{}.npy'.format(opt.subset_id, opt.seed), task_dict) 124 | 125 | 126 | 127 | -------------------------------------------------------------------------------- /trainer_dense.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch.optim as optim 4 | import torch.utils.data.sampler as sampler 5 | 6 | from auto_lambda import AutoLambda 7 | from create_network import * 8 | from create_dataset import * 9 | from utils import * 10 | 11 | parser = argparse.ArgumentParser(description='Multi-task/Auxiliary Learning: Dense Prediction Tasks') 12 | parser.add_argument('--mode', default='none', type=str) 13 | parser.add_argument('--port', default='none', type=str) 14 | 15 | parser.add_argument('--network', default='split', type=str, help='split, mtan') 16 | parser.add_argument('--weight', default='equal', type=str, help='weighting methods: equal, dwa, uncert, autol') 17 | parser.add_argument('--grad_method', default='none', type=str, help='graddrop, pcgrad, cagrad') 18 | parser.add_argument('--gpu', default=0, type=int, help='gpu ID') 19 | parser.add_argument('--with_noise', action='store_true', help='with noise prediction task') 20 | parser.add_argument('--autol_init', default=0.1, type=float, help='initialisation for auto-lambda') 21 | parser.add_argument('--autol_lr', default=1e-4, type=float, help='learning rate for auto-lambda') 22 | parser.add_argument('--task', default='all', type=str, help='primary tasks, use all for MTL setting') 23 | parser.add_argument('--dataset', default='nyuv2', type=str, help='nyuv2, cityscapes') 24 | parser.add_argument('--seed', default=0, type=int, help='random seed ID') 25 | 26 | opt = parser.parse_args() 27 | 28 | torch.manual_seed(opt.seed) 29 | np.random.seed(opt.seed) 30 | random.seed(opt.seed) 31 | 32 | # create logging folder to store training weights and losses 33 | if not os.path.exists('logging'): 34 | os.makedirs('logging') 35 | 36 | # define model, optimiser and scheduler 37 | device = torch.device("cuda:{}".format(opt.gpu) if torch.cuda.is_available() else "cpu") 38 | if opt.with_noise: 39 | train_tasks = create_task_flags('all', opt.dataset, with_noise=True) 40 | else: 41 | train_tasks = create_task_flags('all', opt.dataset, with_noise=False) 42 | 43 | pri_tasks = create_task_flags(opt.task, opt.dataset, with_noise=False) 44 | 45 | train_tasks_str = ''.join(task.title() + ' + ' for task in train_tasks.keys())[:-3] 46 | pri_tasks_str = ''.join(task.title() + ' + ' for task in pri_tasks.keys())[:-3] 47 | print('Dataset: {} | Training Task: {} | Primary Task: {} in Multi-task / Auxiliary Learning Mode with {}' 48 | .format(opt.dataset.title(), train_tasks_str, pri_tasks_str, opt.network.upper())) 49 | print('Applying Multi-task Methods: Weighting-based: {} + Gradient-based: {}' 50 | .format(opt.weight.title(), opt.grad_method.upper())) 51 | 52 | if opt.network == 'split': 53 | model = MTLDeepLabv3(train_tasks).to(device) 54 | elif opt.network == 'mtan': 55 | model = MTANDeepLabv3(train_tasks).to(device) 56 | 57 | total_epoch = 200 58 | 59 | # choose task weighting here 60 | if opt.weight == 'uncert': 61 | logsigma = torch.tensor([-0.7] * len(train_tasks), requires_grad=True, device=device) 62 | params = list(model.parameters()) + [logsigma] 63 | logsigma_ls = np.zeros([total_epoch, len(train_tasks)], dtype=np.float32) 64 | 65 | if opt.weight in ['dwa', 'equal']: 66 | T = 2.0 # temperature used in dwa 67 | lambda_weight = np.ones([total_epoch, len(train_tasks)]) 68 | params = model.parameters() 69 | 70 | if opt.weight == 'autol': 71 | params = model.parameters() 72 | autol = AutoLambda(model, device, train_tasks, pri_tasks, opt.autol_init) 73 | meta_weight_ls = np.zeros([total_epoch, len(train_tasks)], dtype=np.float32) 74 | meta_optimizer = optim.Adam([autol.meta_weights], lr=opt.autol_lr) 75 | 76 | optimizer = optim.SGD(params, lr=0.1, weight_decay=1e-4, momentum=0.9) 77 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, total_epoch) 78 | 79 | # define dataset 80 | if opt.dataset == 'nyuv2': 81 | dataset_path = 'dataset/nyuv2' 82 | train_set = NYUv2(root=dataset_path, train=True, augmentation=True) 83 | test_set = NYUv2(root=dataset_path, train=False) 84 | batch_size = 4 85 | 86 | elif opt.dataset == 'cityscapes': 87 | dataset_path = 'dataset/cityscapes' 88 | train_set = CityScapes(root=dataset_path, train=True, augmentation=True) 89 | test_set = CityScapes(root=dataset_path, train=False) 90 | batch_size = 4 91 | 92 | train_loader = torch.utils.data.DataLoader( 93 | dataset=train_set, 94 | batch_size=batch_size, 95 | shuffle=True, 96 | num_workers=4 97 | ) 98 | 99 | # a copy of train_loader with different data order, used for Auto-Lambda meta-update 100 | if opt.weight == 'autol': 101 | val_loader = torch.utils.data.DataLoader( 102 | dataset=train_set, 103 | batch_size=batch_size, 104 | shuffle=True, 105 | num_workers=4 106 | ) 107 | 108 | test_loader = torch.utils.data.DataLoader( 109 | dataset=test_set, 110 | batch_size=batch_size, 111 | shuffle=False 112 | ) 113 | 114 | # apply gradient methods 115 | if opt.grad_method != 'none': 116 | rng = np.random.default_rng() 117 | grad_dims = [] 118 | for mm in model.shared_modules(): 119 | for param in mm.parameters(): 120 | grad_dims.append(param.data.numel()) 121 | grads = torch.Tensor(sum(grad_dims), len(train_tasks)).to(device) 122 | 123 | 124 | # Train and evaluate multi-task network 125 | train_batch = len(train_loader) 126 | test_batch = len(test_loader) 127 | train_metric = TaskMetric(train_tasks, pri_tasks, batch_size, total_epoch, opt.dataset) 128 | test_metric = TaskMetric(train_tasks, pri_tasks, batch_size, total_epoch, opt.dataset, include_mtl=True) 129 | for index in range(total_epoch): 130 | 131 | # apply Dynamic Weight Average 132 | if opt.weight == 'dwa': 133 | if index == 0 or index == 1: 134 | lambda_weight[index, :] = 1.0 135 | else: 136 | w = [] 137 | for i, t in enumerate(train_tasks): 138 | w += [train_metric.metric[t][index - 1, 0] / train_metric.metric[t][index - 2, 0]] 139 | w = torch.softmax(torch.tensor(w) / T, dim=0) 140 | lambda_weight[index] = len(train_tasks) * w.numpy() 141 | 142 | # iteration for all batches 143 | model.train() 144 | train_dataset = iter(train_loader) 145 | if opt.weight == 'autol': 146 | val_dataset = iter(val_loader) 147 | 148 | for k in range(train_batch): 149 | train_data, train_target = train_dataset.next() 150 | train_data = train_data.to(device) 151 | train_target = {task_id: train_target[task_id].to(device) for task_id in train_tasks.keys()} 152 | 153 | # update meta-weights with Auto-Lambda 154 | if opt.weight == 'autol': 155 | val_data, val_target = val_dataset.next() 156 | val_data = val_data.to(device) 157 | val_target = {task_id: val_target[task_id].to(device) for task_id in train_tasks.keys()} 158 | 159 | meta_optimizer.zero_grad() 160 | autol.unrolled_backward(train_data, train_target, val_data, val_target, 161 | scheduler.get_last_lr()[0], optimizer) 162 | meta_optimizer.step() 163 | 164 | # update multi-task network parameters with task weights 165 | optimizer.zero_grad() 166 | train_pred = model(train_data) 167 | train_loss = [compute_loss(train_pred[i], train_target[task_id], task_id) for i, task_id in enumerate(train_tasks)] 168 | 169 | train_loss_tmp = [0] * len(train_tasks) 170 | 171 | if opt.weight in ['equal', 'dwa']: 172 | train_loss_tmp = [w * train_loss[i] for i, w in enumerate(lambda_weight[index])] 173 | 174 | if opt.weight == 'uncert': 175 | train_loss_tmp = [1 / (2 * torch.exp(w)) * train_loss[i] + w / 2 for i, w in enumerate(logsigma)] 176 | 177 | if opt.weight == 'autol': 178 | train_loss_tmp = [w * train_loss[i] for i, w in enumerate(autol.meta_weights)] 179 | 180 | loss = sum(train_loss_tmp) 181 | 182 | if opt.grad_method == 'none': 183 | loss.backward() 184 | optimizer.step() 185 | 186 | # gradient-based methods applied here: 187 | elif opt.grad_method == "graddrop": 188 | for i in range(len(train_tasks)): 189 | train_loss_tmp[i].backward(retain_graph=True) 190 | grad2vec(model, grads, grad_dims, i) 191 | model.zero_grad_shared_modules() 192 | g = graddrop(grads) 193 | overwrite_grad(model, g, grad_dims, len(train_tasks)) 194 | optimizer.step() 195 | 196 | elif opt.grad_method == "pcgrad": 197 | for i in range(len(train_tasks)): 198 | train_loss_tmp[i].backward(retain_graph=True) 199 | grad2vec(model, grads, grad_dims, i) 200 | model.zero_grad_shared_modules() 201 | g = pcgrad(grads, rng, len(train_tasks)) 202 | overwrite_grad(model, g, grad_dims, len(train_tasks)) 203 | optimizer.step() 204 | 205 | elif opt.grad_method == "cagrad": 206 | for i in range(len(train_tasks)): 207 | train_loss_tmp[i].backward(retain_graph=True) 208 | grad2vec(model, grads, grad_dims, i) 209 | model.zero_grad_shared_modules() 210 | g = cagrad(grads, len(train_tasks), 0.4, rescale=1) 211 | overwrite_grad(model, g, grad_dims, len(train_tasks)) 212 | optimizer.step() 213 | 214 | train_metric.update_metric(train_pred, train_target, train_loss) 215 | 216 | train_str = train_metric.compute_metric() 217 | train_metric.reset() 218 | 219 | # evaluating test data 220 | model.eval() 221 | with torch.no_grad(): 222 | test_dataset = iter(test_loader) 223 | for k in range(test_batch): 224 | test_data, test_target = test_dataset.next() 225 | test_data = test_data.to(device) 226 | test_target = {task_id: test_target[task_id].to(device) for task_id in train_tasks.keys()} 227 | 228 | test_pred = model(test_data) 229 | test_loss = [compute_loss(test_pred[i], test_target[task_id], task_id) for i, task_id in 230 | enumerate(train_tasks)] 231 | 232 | test_metric.update_metric(test_pred, test_target, test_loss) 233 | 234 | test_str = test_metric.compute_metric() 235 | test_metric.reset() 236 | 237 | scheduler.step() 238 | 239 | print('Epoch {:04d} | TRAIN:{} || TEST:{} | Best: {} {:.4f}' 240 | .format(index, train_str, test_str, opt.task.title(), test_metric.get_best_performance(opt.task))) 241 | 242 | if opt.weight == 'autol': 243 | meta_weight_ls[index] = autol.meta_weights.detach().cpu() 244 | dict = {'train_loss': train_metric.metric, 'test_loss': test_metric.metric, 245 | 'weight': meta_weight_ls} 246 | 247 | print(get_weight_str(meta_weight_ls[index], train_tasks)) 248 | 249 | if opt.weight in ['dwa', 'equal']: 250 | dict = {'train_loss': train_metric.metric, 'test_loss': test_metric.metric, 251 | 'weight': lambda_weight} 252 | 253 | print(get_weight_str(lambda_weight[index], train_tasks)) 254 | 255 | if opt.weight == 'uncert': 256 | logsigma_ls[index] = logsigma.detach().cpu() 257 | dict = {'train_loss': train_metric.metric, 'test_loss': test_metric.metric, 258 | 'weight': logsigma_ls} 259 | 260 | print(get_weight_str(1 / (2 * np.exp(logsigma_ls[index])), train_tasks)) 261 | 262 | np.save('logging/mtl_dense_{}_{}_{}_{}_{}_{}_.npy' 263 | .format(opt.network, opt.dataset, opt.task, opt.weight, opt.grad_method, opt.seed), dict) 264 | -------------------------------------------------------------------------------- /trainer_dense_single.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch.optim as optim 4 | import torch.utils.data.sampler as sampler 5 | 6 | from create_network import * 7 | from create_dataset import * 8 | from utils import * 9 | 10 | parser = argparse.ArgumentParser(description='Single-task Learning: Dense Prediction Tasks') 11 | parser.add_argument('--mode', default='none', type=str) 12 | parser.add_argument('--port', default='none', type=str) 13 | 14 | parser.add_argument('--gpu', default=0, type=int, help='gpu ID') 15 | parser.add_argument('--network', default='split', type=str, help='split, mtan') 16 | parser.add_argument('--dataset', default='nyuv2', type=str, help='nyuv2, cityscapes') 17 | parser.add_argument('--task', default='seg', type=str, help='choose task for single task learning') 18 | parser.add_argument('--seed', default=0, type=int, help='gpu ID') 19 | 20 | opt = parser.parse_args() 21 | 22 | torch.manual_seed(opt.seed) 23 | np.random.seed(opt.seed) 24 | random.seed(opt.seed) 25 | 26 | # create logging folder to store training weights and losses 27 | if not os.path.exists('logging'): 28 | os.makedirs('logging') 29 | 30 | # define model, optimiser and scheduler 31 | device = torch.device("cuda:{}".format(opt.gpu) if torch.cuda.is_available() else "cpu") 32 | train_tasks = create_task_flags(opt.task, opt.dataset) 33 | 34 | print('Training Task: {} - {} in Single Task Learning Mode with {}' 35 | .format(opt.dataset.title(), opt.task.title(), opt.network.upper())) 36 | 37 | if opt.network == 'split': 38 | model = MTLDeepLabv3(train_tasks).to(device) 39 | elif opt.network == 'mtan': 40 | model = MTANDeepLabv3(train_tasks).to(device) 41 | 42 | 43 | total_epoch = 200 44 | optimizer = optim.SGD(model.parameters(), lr=0.1, weight_decay=1e-4, momentum=0.9) 45 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, total_epoch) 46 | 47 | # define dataset 48 | if opt.dataset == 'nyuv2': 49 | dataset_path = 'dataset/nyuv2' 50 | train_set = NYUv2(root=dataset_path, train=True, augmentation=True) 51 | test_set = NYUv2(root=dataset_path, train=False) 52 | batch_size = 4 53 | 54 | elif opt.dataset == 'cityscapes': 55 | dataset_path = 'dataset/cityscapes' 56 | train_set = CityScapes(root=dataset_path, train=True, augmentation=True) 57 | test_set = CityScapes(root=dataset_path, train=False) 58 | batch_size = 4 59 | 60 | train_loader = torch.utils.data.DataLoader( 61 | dataset=train_set, 62 | batch_size=batch_size, 63 | shuffle=True, 64 | num_workers=4 65 | ) 66 | 67 | test_loader = torch.utils.data.DataLoader( 68 | dataset=test_set, 69 | batch_size=batch_size, 70 | shuffle=False, 71 | num_workers=4 72 | ) 73 | 74 | 75 | # Train and evaluate multi-task network 76 | train_batch = len(train_loader) 77 | test_batch = len(test_loader) 78 | train_metric = TaskMetric(train_tasks, train_tasks, batch_size, total_epoch, opt.dataset) 79 | test_metric = TaskMetric(train_tasks, train_tasks, batch_size, total_epoch, opt.dataset) 80 | for index in range(total_epoch): 81 | 82 | # evaluating train data 83 | model.train() 84 | train_dataset = iter(train_loader) 85 | for k in range(train_batch): 86 | train_data, train_target = train_dataset.next() 87 | train_data = train_data.to(device) 88 | train_target = {task_id: train_target[task_id].to(device) for task_id in train_tasks.keys()} 89 | 90 | train_pred = model(train_data) 91 | optimizer.zero_grad() 92 | 93 | train_loss = [compute_loss(train_pred[i], train_target[task_id], task_id) for i, task_id in enumerate(train_tasks)] 94 | train_loss[0].backward() 95 | optimizer.step() 96 | 97 | train_metric.update_metric(train_pred, train_target, train_loss) 98 | 99 | train_str = train_metric.compute_metric() 100 | train_metric.reset() 101 | 102 | # evaluating test data 103 | model.eval() 104 | with torch.no_grad(): 105 | test_dataset = iter(test_loader) 106 | for k in range(test_batch): 107 | test_data, test_target = test_dataset.next() 108 | test_data = test_data.to(device) 109 | test_target = {task_id: test_target[task_id].to(device) for task_id in train_tasks.keys()} 110 | 111 | test_pred = model(test_data) 112 | test_loss = [compute_loss(test_pred[i], test_target[task_id], task_id) for i, task_id in enumerate(train_tasks)] 113 | 114 | test_metric.update_metric(test_pred, test_target, test_loss) 115 | 116 | test_str = test_metric.compute_metric() 117 | test_metric.reset() 118 | 119 | scheduler.step() 120 | 121 | print('Epoch {:04d} | TRAIN:{} || TEST:{} | Best: {} {:.4f}' 122 | .format(index, train_str, test_str, opt.task.title(), test_metric.get_best_performance(opt.task))) 123 | 124 | task_dict = {'train_loss': train_metric.metric, 'test_loss': test_metric.metric} 125 | np.save('logging/stl_{}_{}_{}_{}.npy'.format(opt.network, opt.dataset, opt.task, opt.seed), task_dict) 126 | 127 | 128 | 129 | 130 | 131 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from scipy.optimize import minimize 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | import numpy as np 7 | 8 | 9 | """ 10 | Define task metrics, loss functions and model trainer here. 11 | """ 12 | 13 | 14 | class ConfMatrix(object): 15 | """ 16 | For mIoU and other pixel-level classification tasks. 17 | """ 18 | def __init__(self, num_classes): 19 | self.num_classes = num_classes 20 | self.mat = None 21 | 22 | def reset(self): 23 | self.mat = None 24 | 25 | def update(self, pred, target): 26 | n = self.num_classes 27 | if self.mat is None: 28 | self.mat = torch.zeros((n, n), dtype=torch.int64, device=pred.device) 29 | with torch.no_grad(): 30 | k = (target >= 0) & (target < n) 31 | inds = n * target[k].to(torch.int64) + pred[k] 32 | self.mat += torch.bincount(inds, minlength=n ** 2).reshape(n, n) 33 | 34 | def get_metrics(self): 35 | h = self.mat.float() 36 | iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h)) 37 | return torch.mean(iu).item() 38 | 39 | 40 | def create_task_flags(task, dataset, with_noise=False): 41 | """ 42 | Record task and its prediction dimension. 43 | Noise prediction is only applied in auxiliary learning. 44 | """ 45 | nyu_tasks = {'seg': 13, 'depth': 1, 'normal': 3} 46 | cityscapes_tasks = {'seg': 19, 'part_seg': 10, 'disp': 1} 47 | 48 | tasks = {} 49 | if task != 'all': 50 | if dataset == 'nyuv2': 51 | tasks[task] = nyu_tasks[task] 52 | elif dataset == 'cityscapes': 53 | tasks[task] = cityscapes_tasks[task] 54 | else: 55 | if dataset == 'nyuv2': 56 | tasks = nyu_tasks 57 | elif dataset == 'cityscapes': 58 | tasks = cityscapes_tasks 59 | 60 | if with_noise: 61 | tasks['noise'] = 1 62 | return tasks 63 | 64 | 65 | def get_weight_str(weight, tasks): 66 | """ 67 | Record task weighting. 68 | """ 69 | weight_str = 'Task Weighting | ' 70 | for i, task_id in enumerate(tasks): 71 | weight_str += '{} {:.04f} '.format(task_id.title(), weight[i]) 72 | return weight_str 73 | 74 | 75 | def get_weight_str_ranked(weight, tasks, rank_num): 76 | """ 77 | Record top-k ranked task weighting. 78 | """ 79 | rank_idx = np.argsort(weight) 80 | 81 | if type(tasks) == dict: 82 | tasks = list(tasks.keys()) 83 | 84 | top_str = 'Top {}: '.format(rank_num) 85 | bot_str = 'Bottom {}: '.format(rank_num) 86 | for i in range(rank_num): 87 | top_str += '{} {:.02f} '.format(tasks[rank_idx[-i-1]].title(), weight[rank_idx[-i-1]]) 88 | bot_str += '{} {:.02f} '.format(tasks[rank_idx[i]].title(), weight[rank_idx[i]]) 89 | 90 | return 'Task Weighting | {}| {}'.format(top_str, bot_str) 91 | 92 | 93 | def compute_loss(pred, gt, task_id): 94 | """ 95 | Compute task-specific loss. 96 | """ 97 | if task_id in ['seg', 'part_seg'] or 'class' in task_id: 98 | # Cross Entropy Loss with Ignored Index (values are -1) 99 | loss = F.cross_entropy(pred, gt, ignore_index=-1) 100 | 101 | if task_id in ['normal', 'depth', 'disp', 'noise']: 102 | # L1 Loss with Ignored Region (values are 0 or -1) 103 | invalid_idx = -1 if task_id == 'disp' else 0 104 | valid_mask = (torch.sum(gt, dim=1, keepdim=True) != invalid_idx).to(pred.device) 105 | loss = torch.sum(F.l1_loss(pred, gt, reduction='none').masked_select(valid_mask)) \ 106 | / torch.nonzero(valid_mask, as_tuple=False).size(0) 107 | return loss 108 | 109 | 110 | class TaskMetric: 111 | def __init__(self, train_tasks, pri_tasks, batch_size, epochs, dataset, include_mtl=False): 112 | self.train_tasks = train_tasks 113 | self.pri_tasks = pri_tasks 114 | self.batch_size = batch_size 115 | self.dataset = dataset 116 | self.include_mtl = include_mtl 117 | self.metric = {key: np.zeros([epochs, 2]) for key in train_tasks.keys()} # record loss & task-specific metric 118 | self.data_counter = 0 119 | self.epoch_counter = 0 120 | self.conf_mtx = {} 121 | 122 | if include_mtl: # include multi-task performance (relative averaged task improvement) 123 | self.metric['all'] = np.zeros(epochs) 124 | for task in self.train_tasks: 125 | if task in ['seg', 'part_seg']: 126 | self.conf_mtx[task] = ConfMatrix(self.train_tasks[task]) 127 | 128 | def reset(self): 129 | """ 130 | Reset data counter and confusion matrices. 131 | """ 132 | self.epoch_counter += 1 133 | self.data_counter = 0 134 | 135 | if len(self.conf_mtx) > 0: 136 | for i in self.conf_mtx: 137 | self.conf_mtx[i].reset() 138 | 139 | def update_metric(self, task_pred, task_gt, task_loss): 140 | """ 141 | Update batch-wise metric for each task. 142 | :param task_pred: [TASK_PRED1, TASK_PRED2, ...] 143 | :param task_gt: {'TASK_ID1': TASK_GT1, 'TASK_ID2': TASK_GT2, ...} 144 | :param task_loss: [TASK_LOSS1, TASK_LOSS2, ...] 145 | """ 146 | curr_bs = task_pred[0].shape[0] 147 | r = self.data_counter / (self.data_counter + curr_bs / self.batch_size) 148 | e = self.epoch_counter 149 | self.data_counter += 1 150 | 151 | with torch.no_grad(): 152 | for loss, pred, (task_id, gt) in zip(task_loss, task_pred, task_gt.items()): 153 | self.metric[task_id][e, 0] = r * self.metric[task_id][e, 0] + (1 - r) * loss.item() 154 | 155 | if task_id in ['seg', 'part_seg']: 156 | # update confusion matrix (metric will be computed directly in the Confusion Matrix) 157 | self.conf_mtx[task_id].update(pred.argmax(1).flatten(), gt.flatten()) 158 | 159 | if 'class' in task_id: 160 | # Accuracy for image classification tasks 161 | pred_label = pred.data.max(1)[1] 162 | acc = pred_label.eq(gt).sum().item() / pred_label.shape[0] 163 | self.metric[task_id][e, 1] = r * self.metric[task_id][e, 1] + (1 - r) * acc 164 | 165 | if task_id in ['depth', 'disp', 'noise']: 166 | # Abs. Err. 167 | invalid_idx = -1 if task_id == 'disp' else 0 168 | valid_mask = (torch.sum(gt, dim=1, keepdim=True) != invalid_idx).to(pred.device) 169 | abs_err = torch.mean(torch.abs(pred - gt).masked_select(valid_mask)).item() 170 | self.metric[task_id][e, 1] = r * self.metric[task_id][e, 1] + (1 - r) * abs_err 171 | 172 | if task_id in ['normal']: 173 | # Mean Degree Err. 174 | valid_mask = (torch.sum(gt, dim=1) != 0).to(pred.device) 175 | degree_error = torch.acos(torch.clamp(torch.sum(pred * gt, dim=1).masked_select(valid_mask), -1, 1)) 176 | mean_error = torch.mean(torch.rad2deg(degree_error)).item() 177 | self.metric[task_id][e, 1] = r * self.metric[task_id][e, 1] + (1 - r) * mean_error 178 | 179 | def compute_metric(self, only_pri=False): 180 | metric_str = '' 181 | e = self.epoch_counter 182 | tasks = self.pri_tasks if only_pri else self.train_tasks # only print primary tasks performance in evaluation 183 | 184 | for task_id in tasks: 185 | if task_id in ['seg', 'part_seg']: # mIoU for segmentation 186 | self.metric[task_id][e, 1] = self.conf_mtx[task_id].get_metrics() 187 | 188 | metric_str += ' {} {:.4f} {:.4f}'\ 189 | .format(task_id.capitalize(), self.metric[task_id][e, 0], self.metric[task_id][e, 1]) 190 | 191 | if self.include_mtl: 192 | # Pre-computed single task learning performance using trainer_dense_single.py 193 | if self.dataset == 'nyuv2': 194 | stl = {'seg': 0.4337, 'depth': 0.5224, 'normal': 22.40} 195 | elif self.dataset == 'cityscapes': 196 | stl = {'seg': 0.5620, 'part_seg': 0.5274, 'disp': 0.84} 197 | elif self.dataset == 'cifar100': 198 | stl = {'class_0': 0.6865, 'class_1': 0.8100, 'class_2': 0.8234, 'class_3': 0.8371, 'class_4': 0.8910, 199 | 'class_5': 0.8872, 'class_6': 0.8475, 'class_7': 0.8588, 'class_8': 0.8707, 'class_9': 0.9015, 200 | 'class_10': 0.8976, 'class_11': 0.8488, 'class_12': 0.9033, 'class_13': 0.8441, 'class_14': 0.5537, 201 | 'class_15': 0.7584, 'class_16': 0.7279, 'class_17': 0.7537, 'class_18': 0.9148, 'class_19': 0.9469} 202 | 203 | delta_mtl = 0 204 | for task_id in self.train_tasks: 205 | if task_id in ['seg', 'part_seg'] or 'class' in task_id: # higher better 206 | delta_mtl += (self.metric[task_id][e, 1] - stl[task_id]) / stl[task_id] 207 | elif task_id in ['depth', 'normal', 'disp']: 208 | delta_mtl -= (self.metric[task_id][e, 1] - stl[task_id]) / stl[task_id] 209 | 210 | self.metric['all'][e] = delta_mtl / len(stl) 211 | metric_str += ' | All {:.4f}'.format(self.metric['all'][e]) 212 | return metric_str 213 | 214 | def get_best_performance(self, task): 215 | e = self.epoch_counter 216 | if task in ['seg', 'part_seg'] or 'class' in task: # higher better 217 | return max(self.metric[task][:e, 1]) 218 | if task in ['depth', 'normal', 'disp']: # lower better 219 | return min(self.metric[task][:e, 1]) 220 | if task in ['all']: # higher better 221 | return max(self.metric[task][:e]) 222 | 223 | 224 | """ 225 | Define Gradient-based frameworks here. 226 | Based on https://github.com/Cranial-XIX/CAGrad/blob/main/cityscapes/utils.py 227 | """ 228 | 229 | 230 | def graddrop(grads): 231 | P = 0.5 * (1. + grads.sum(1) / (grads.abs().sum(1) + 1e-8)) 232 | U = torch.rand_like(grads[:, 0]) 233 | M = P.gt(U).view(-1, 1) * grads.gt(0) + P.lt(U).view(-1, 1) * grads.lt(0) 234 | g = (grads * M.float()).mean(1) 235 | return g 236 | 237 | 238 | def pcgrad(grads, rng, num_tasks): 239 | grad_vec = grads.t() 240 | 241 | shuffled_task_indices = np.zeros((num_tasks, num_tasks - 1), dtype=int) 242 | for i in range(num_tasks): 243 | task_indices = np.arange(num_tasks) 244 | task_indices[i] = task_indices[-1] 245 | shuffled_task_indices[i] = task_indices[:-1] 246 | rng.shuffle(shuffled_task_indices[i]) 247 | shuffled_task_indices = shuffled_task_indices.T 248 | 249 | normalized_grad_vec = grad_vec / (grad_vec.norm(dim=1, keepdim=True) + 1e-8) # num_tasks x dim 250 | modified_grad_vec = deepcopy(grad_vec) 251 | for task_indices in shuffled_task_indices: 252 | normalized_shuffled_grad = normalized_grad_vec[task_indices] # num_tasks x dim 253 | dot = (modified_grad_vec * normalized_shuffled_grad).sum(dim=1, keepdim=True) # num_tasks x dim 254 | modified_grad_vec -= torch.clamp_max(dot, 0) * normalized_shuffled_grad 255 | g = modified_grad_vec.mean(dim=0) 256 | return g 257 | 258 | 259 | def cagrad(grads, num_tasks, alpha=0.5, rescale=1): 260 | GG = grads.t().mm(grads).cpu() # [num_tasks, num_tasks] 261 | g0_norm = (GG.mean() + 1e-8).sqrt() # norm of the average gradient 262 | 263 | x_start = np.ones(num_tasks) / num_tasks 264 | bnds = tuple((0, 1) for x in x_start) 265 | cons = ({'type': 'eq', 'fun': lambda x: 1 - sum(x)}) 266 | A = GG.numpy() 267 | b = x_start.copy() 268 | c = (alpha * g0_norm + 1e-8).item() 269 | 270 | def objfn(x): 271 | return (x.reshape(1, num_tasks).dot(A).dot(b.reshape(num_tasks, 1)) + c * np.sqrt( 272 | x.reshape(1, num_tasks).dot(A).dot(x.reshape(num_tasks, 1)) + 1e-8)).sum() 273 | 274 | res = minimize(objfn, x_start, bounds=bnds, constraints=cons) 275 | w_cpu = res.x 276 | ww = torch.Tensor(w_cpu).to(grads.device) 277 | gw = (grads * ww.view(1, -1)).sum(1) 278 | gw_norm = gw.norm() 279 | lmbda = c / (gw_norm + 1e-8) 280 | g = grads.mean(1) + lmbda * gw 281 | if rescale == 0: 282 | return g 283 | elif rescale == 1: 284 | return g / (1 + alpha ** 2) 285 | else: 286 | return g / (1 + alpha) 287 | 288 | 289 | def grad2vec(m, grads, grad_dims, task): 290 | # store the gradients 291 | grads[:, task].fill_(0.0) 292 | cnt = 0 293 | for mm in m.shared_modules(): 294 | for p in mm.parameters(): 295 | grad = p.grad 296 | if grad is not None: 297 | grad_cur = grad.data.detach().clone() 298 | beg = 0 if cnt == 0 else sum(grad_dims[:cnt]) 299 | en = sum(grad_dims[:cnt + 1]) 300 | grads[beg:en, task].copy_(grad_cur.data.view(-1)) 301 | cnt += 1 302 | 303 | 304 | def overwrite_grad(m, newgrad, grad_dims, num_tasks): 305 | newgrad = newgrad * num_tasks # to match the sum loss 306 | cnt = 0 307 | for mm in m.shared_modules(): 308 | for param in mm.parameters(): 309 | beg = 0 if cnt == 0 else sum(grad_dims[:cnt]) 310 | en = sum(grad_dims[:cnt + 1]) 311 | this_grad = newgrad[beg: en].contiguous().view(param.data.size()) 312 | param.grad = this_grad.data.clone() 313 | cnt += 1 314 | --------------------------------------------------------------------------------