├── dataset └── cityscapes │ └── link.txt ├── common.py ├── README.md ├── utils.py ├── weight_methods.py ├── data.py ├── adatask.py ├── main_cityscapes.py └── models.py /dataset/cityscapes/link.txt: -------------------------------------------------------------------------------- 1 | https://www.dropbox.com/sh/gaw6vh6qusoyms6/AADwWi0Tp3E3M4B2xzeGlsEna?dl=0 -------------------------------------------------------------------------------- /common.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import random 4 | from collections import defaultdict 5 | from pathlib import Path 6 | import numpy as np 7 | import torch 8 | 9 | def str_to_list(string): 10 | return [float(s) for s in string.split(",")] 11 | 12 | def str_or_float(value): 13 | try: 14 | return float(value) 15 | except: 16 | return value 17 | 18 | def str2bool(v): 19 | if isinstance(v, bool): 20 | return v 21 | if v.lower() in ("yes", "true", "t", "y", "1"): 22 | return True 23 | elif v.lower() in ("no", "false", "f", "n", "0"): 24 | return False 25 | else: 26 | raise argparse.ArgumentTypeError("Boolean value expected.") 27 | 28 | common_parser = argparse.ArgumentParser(add_help=False) 29 | common_parser.add_argument("--data-path", type=Path, help="path to data") 30 | common_parser.add_argument("--log_path", type=Path, help="path to log") 31 | common_parser.add_argument("--n-epochs", type=int, default=200) 32 | common_parser.add_argument("--n_task", type=int, default=2) 33 | common_parser.add_argument("--batch-size", type=int, default=120, help="batch size") 34 | common_parser.add_argument("--lr", type=float, default=1e-3, help="learning rate") 35 | common_parser.add_argument("--method-params-lr", type=float, default=0.025, help="lr for weight method params. If None, set to args.lr. For uncertainty weighting",) 36 | common_parser.add_argument("--gpu", type=int, default=0, help="gpu device ID") 37 | common_parser.add_argument("--seed", type=int, default=42, help="seed value") 38 | 39 | def count_parameters(model): 40 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 41 | 42 | def set_logger(): 43 | logging.basicConfig( 44 | format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", 45 | level=logging.INFO,) 46 | 47 | def set_seed(seed): 48 | """for reproducibility 49 | :param seed: 50 | :return: 51 | """ 52 | np.random.seed(seed) 53 | random.seed(seed) 54 | 55 | torch.manual_seed(seed) 56 | if torch.cuda.is_available(): 57 | torch.cuda.manual_seed(seed) 58 | torch.cuda.manual_seed_all(seed) 59 | 60 | torch.backends.cudnn.enabled = True 61 | torch.backends.cudnn.benchmark = False 62 | torch.backends.cudnn.deterministic = True 63 | 64 | def get_device(no_cuda=False, gpus="0"): 65 | return torch.device( 66 | f"cuda:{gpus}" if torch.cuda.is_available() and not no_cuda else "cpu" 67 | ) 68 | 69 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AdaTask 2 | [AdaTask: A Task-Aware Adaptive Learning Rate Approach to Multi-Task Learning. AAAI, 2023.](https://arxiv.org/abs/2211.15055) 3 | 4 | ## Abstract 5 | > Multi-task learning (MTL) models have demonstrated impressive results in computer vision, natural language processing, and recommender systems. Even though many approaches have been proposed, how well these approaches balance different tasks on each parameter still remains unclear. In this paper, we propose to measure the task dominance degree of a parameter by the total updates of each task on this parameter. Specifically, we compute the total updates by the exponentially decaying Average of the squared Updates (AU) on a parameter from the corresponding task. Based on this novel metric, we observe that many parameters in existing MTL methods, especially those in the higher shared layers, are still dominated by one or several tasks. The dominance of AU is mainly due to the dominance of accumulative gradients from one or several tasks. Motivated by this, we propose a Task-wise Adaptive learning rate approach, AdaTask in short, to separate the accumulative gradients and hence the learning rate of each task for each parameter in adaptive learning rate approaches (e.g., AdaGrad, RMSProp, and Adam). Comprehensive experiments on computer vision and recommender system MTL datasets demonstrate that AdaTask significantly improves the performance of dominated tasks, resulting SOTA average task-wise performance. Analysis on both synthetic and real-world datasets shows AdaTask balance parameters in every shared layer well. 6 | 7 | 8 | 9 | ## Citation 10 | If you find our paper or this resource helpful, please consider cite: 11 | 12 | ``` 13 | @article{AdaTask_AAAI2023, 14 | title={AdaTask: A Task-aware Adaptive Learning Rate Approach to Multi-task Learning}, 15 | author={{Yang, Enneng and Pan, Junwei and Wang, Ximei and Yu, Haibin and Shen, Li and Chen, Xihua and Xiao, Lei and Jiang, Jie and Guo, Guibing}, 16 | journal={Proceedings of the AAAI Conference on Artificial Intelligence}, 17 | volume={37}, 18 | number={9}, 19 | pages={10745-10753}, 20 | year={2023} 21 | } 22 | 23 | ``` 24 | 25 | 26 | 27 | ## DataSet 28 | - Download [CityScapes](https://www.dropbox.com/sh/gaw6vh6qusoyms6/AADwWi0Tp3E3M4B2xzeGlsEna?dl=0) dataset and put it in the dataset directory. 29 | 30 | 31 | ## Train and Evaluate Method 32 | 33 | ``` 34 | python3 main_cityscapes.py --method=adam 35 | ``` 36 | 37 | ``` 38 | python3 main_cityscapes.py --method=adam_with_adatask 39 | ``` 40 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import os 4 | import logging 5 | 6 | def create_log_dir(path='./log'): 7 | path += '/' 8 | if not os.path.exists(path): 9 | os.makedirs(path) 10 | 11 | logger = logging.getLogger(path) 12 | logger.setLevel(logging.DEBUG) 13 | fh = logging.FileHandler(path + 'log.txt') 14 | fh.setLevel(logging.DEBUG) 15 | ch = logging.StreamHandler() 16 | ch.setLevel(logging.DEBUG) 17 | logger.addHandler(fh) 18 | logger.addHandler(ch) 19 | return logger 20 | 21 | class ConfMatrix(object): 22 | def __init__(self, num_classes): 23 | self.num_classes = num_classes 24 | self.mat = None 25 | 26 | def update(self, pred, target): 27 | n = self.num_classes 28 | if self.mat is None: 29 | self.mat = torch.zeros((n, n), dtype=torch.int64, device=pred.device) 30 | with torch.no_grad(): 31 | k = (target >= 0) & (target < n) 32 | inds = n * target[k].to(torch.int64) + pred[k] 33 | self.mat += torch.bincount(inds, minlength=n ** 2).reshape(n, n) 34 | 35 | def get_metrics(self): 36 | h = self.mat.float() 37 | acc = torch.diag(h).sum() / h.sum() 38 | iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h)) 39 | return torch.mean(iu).cpu().numpy(), acc.cpu().numpy() 40 | 41 | def depth_error(x_pred, x_output): 42 | device = x_pred.device 43 | binary_mask = (torch.sum(x_output, dim=1) != 0).unsqueeze(1).to(device) 44 | x_pred_true = x_pred.masked_select(binary_mask) 45 | x_output_true = x_output.masked_select(binary_mask) 46 | abs_err = torch.abs(x_pred_true - x_output_true) 47 | rel_err = torch.abs(x_pred_true - x_output_true) / x_output_true 48 | return ( 49 | torch.sum(abs_err) / torch.nonzero(binary_mask, as_tuple=False).size(0) 50 | ).item(), ( 51 | torch.sum(rel_err) / torch.nonzero(binary_mask, as_tuple=False).size(0) 52 | ).item() 53 | 54 | def normal_error(x_pred, x_output): 55 | binary_mask = torch.sum(x_output, dim=1) != 0 56 | error = ( 57 | torch.acos( 58 | torch.clamp( 59 | torch.sum(x_pred * x_output, 1).masked_select(binary_mask), -1, 1 60 | ) 61 | ) 62 | .detach() 63 | .cpu() 64 | .numpy() 65 | ) 66 | error = np.degrees(error) 67 | return ( 68 | np.mean(error), 69 | np.median(error), 70 | np.mean(error < 11.25), 71 | np.mean(error < 22.5), 72 | np.mean(error < 30), 73 | ) 74 | 75 | # for calculating \Delta_m 76 | def delta_fn_cityscapes(a): 77 | delta_stats = [ 78 | "mean iou", 79 | "pix acc", 80 | "abs err", 81 | "rel err", 82 | ] 83 | BASE = np.array( 84 | [0.7401, 0.9316, 0.0125, 27.77] 85 | ) # base results from CAGrad (single task / independent) 86 | SIGN = np.array([1, 1, 0, 0]) 87 | KK = np.ones(4) * -1 88 | 89 | return (KK ** SIGN * (a - BASE) / BASE).mean() * 100.0 # * 100 for percentage 90 | -------------------------------------------------------------------------------- /weight_methods.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import Dict, List, Tuple, Union 3 | import torch 4 | 5 | class WeightMethod: 6 | def __init__(self, n_tasks: int, device: torch.device): 7 | super().__init__() 8 | self.n_tasks = n_tasks 9 | self.device = device 10 | 11 | @abstractmethod 12 | def get_weighted_loss( 13 | self, 14 | losses: torch.Tensor, 15 | shared_parameters: Union[List[torch.nn.parameter.Parameter], torch.Tensor], 16 | task_specific_parameters: Union[ 17 | List[torch.nn.parameter.Parameter], torch.Tensor 18 | ], 19 | last_shared_parameters: Union[List[torch.nn.parameter.Parameter], torch.Tensor], 20 | representation: Union[torch.nn.parameter.Parameter, torch.Tensor], 21 | **kwargs, 22 | ): 23 | pass 24 | 25 | def backward( 26 | self, 27 | losses: torch.Tensor, 28 | shared_parameters: Union[ 29 | List[torch.nn.parameter.Parameter], torch.Tensor 30 | ] = None, 31 | task_specific_parameters: Union[ 32 | List[torch.nn.parameter.Parameter], torch.Tensor 33 | ] = None, 34 | last_shared_parameters: Union[ 35 | List[torch.nn.parameter.Parameter], torch.Tensor 36 | ] = None, 37 | representation: Union[List[torch.nn.parameter.Parameter], torch.Tensor] = None, 38 | **kwargs, 39 | ) -> Tuple[Union[torch.Tensor, None], Union[dict, None]]: 40 | 41 | loss, extra_outputs = self.get_weighted_loss( 42 | losses=losses, 43 | shared_parameters=shared_parameters, 44 | task_specific_parameters=task_specific_parameters, 45 | last_shared_parameters=last_shared_parameters, 46 | representation=representation, 47 | **kwargs, 48 | ) 49 | loss.backward() 50 | return loss, extra_outputs 51 | 52 | def __call__( 53 | self, 54 | losses: torch.Tensor, 55 | shared_parameters: Union[ 56 | List[torch.nn.parameter.Parameter], torch.Tensor 57 | ] = None, 58 | task_specific_parameters: Union[ 59 | List[torch.nn.parameter.Parameter], torch.Tensor 60 | ] = None, 61 | **kwargs, 62 | ): 63 | return self.backward( 64 | losses=losses, 65 | shared_parameters=shared_parameters, 66 | task_specific_parameters=task_specific_parameters, 67 | **kwargs, 68 | ) 69 | 70 | def parameters(self) -> List[torch.Tensor]: 71 | """return learnable parameters""" 72 | return [] 73 | 74 | class EqualWeight(WeightMethod): 75 | def __init__( 76 | self, 77 | n_tasks: int, 78 | device: torch.device, 79 | task_weights: Union[List[float], torch.Tensor] = None, 80 | ): 81 | super().__init__(n_tasks, device=device) 82 | if task_weights is None: 83 | task_weights = torch.ones((n_tasks,)) 84 | if not isinstance(task_weights, torch.Tensor): 85 | task_weights = torch.tensor(task_weights) 86 | assert len(task_weights) == n_tasks 87 | self.task_weights = task_weights.to(device) 88 | 89 | def get_weighted_loss(self, losses, **kwargs): 90 | loss = torch.sum(losses * self.task_weights) 91 | return loss, dict(weights=self.task_weights) 92 | 93 | class WeightMethods: 94 | def __init__(self, method: str, n_tasks: int, device: torch.device, **kwargs): 95 | """ 96 | :param method: 97 | """ 98 | assert method in list(METHODS.keys()), f"unknown method {method}." 99 | 100 | self.method = METHODS[method](n_tasks=n_tasks, device=device, **kwargs) 101 | 102 | def get_weighted_loss(self, losses, **kwargs): 103 | return self.method.get_weighted_loss(losses, **kwargs) 104 | 105 | def backward( 106 | self, losses, **kwargs 107 | ) -> Tuple[Union[torch.Tensor, None], Union[Dict, None]]: 108 | return self.method.backward(losses, **kwargs) 109 | 110 | def __ceil__(self, losses, **kwargs): 111 | return self.backward(losses, **kwargs) 112 | 113 | def parameters(self): 114 | return self.method.parameters() 115 | 116 | 117 | METHODS = dict( 118 | equalweight=EqualWeight 119 | ) 120 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import fnmatch 2 | import os 3 | import random 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch.utils.data.dataset import Dataset 8 | 9 | class RandomScaleCrop(object): 10 | """ 11 | Credit to Jialong Wu from https://github.com/lorenmt/mtan/issues/34. 12 | """ 13 | 14 | def __init__(self, scale=[1.0, 1.2, 1.5]): 15 | self.scale = scale 16 | 17 | def __call__(self, img, label, depth, normal): 18 | height, width = img.shape[-2:] 19 | sc = self.scale[random.randint(0, len(self.scale) - 1)] 20 | h, w = int(height / sc), int(width / sc) 21 | i = random.randint(0, height - h) 22 | j = random.randint(0, width - w) 23 | img_ = F.interpolate( 24 | img[None, :, i : i + h, j : j + w], 25 | size=(height, width), 26 | mode="bilinear", 27 | align_corners=True, 28 | ).squeeze(0) 29 | label_ = ( 30 | F.interpolate( 31 | label[None, None, i : i + h, j : j + w], 32 | size=(height, width), 33 | mode="nearest", 34 | ) 35 | .squeeze(0) 36 | .squeeze(0) 37 | ) 38 | depth_ = F.interpolate( 39 | depth[None, :, i : i + h, j : j + w], size=(height, width), mode="nearest" 40 | ).squeeze(0) 41 | normal_ = F.interpolate( 42 | normal[None, :, i : i + h, j : j + w], 43 | size=(height, width), 44 | mode="bilinear", 45 | align_corners=True, 46 | ).squeeze(0) 47 | return img_, label_, depth_ / sc, normal_ 48 | 49 | class RandomScaleCropCityScapes(object): 50 | """ 51 | Credit to Jialong Wu from https://github.com/lorenmt/mtan/issues/34. 52 | """ 53 | def __init__(self, scale=[1.0, 1.2, 1.5]): 54 | self.scale = scale 55 | 56 | def __call__(self, img, label, depth): 57 | height, width = img.shape[-2:] 58 | sc = self.scale[random.randint(0, len(self.scale) - 1)] 59 | h, w = int(height / sc), int(width / sc) 60 | i = random.randint(0, height - h) 61 | j = random.randint(0, width - w) 62 | img_ = F.interpolate(img[None, :, i:i + h, j:j + w], size=(height, width), mode='bilinear', align_corners=True).squeeze(0) 63 | label_ = F.interpolate(label[None, None, i:i + h, j:j + w], size=(height, width), mode='nearest').squeeze(0).squeeze(0) 64 | depth_ = F.interpolate(depth[None, :, i:i + h, j:j + w], size=(height, width), mode='nearest').squeeze(0) 65 | return img_, label_, depth_ / sc 66 | 67 | class CityScapes(Dataset): 68 | """ 69 | We could further improve the performance with the data augmentation of NYUv2 defined in: 70 | [1] PAD-Net: Multi-Tasks Guided Prediction-and-Distillation Network for Simultaneous Depth Estimation and Scene Parsing 71 | [2] Pattern affinitive propagation across depth, surface normal and semantic segmentation 72 | [3] Mti-net: Multiscale task interaction networks for multi-task learning 73 | 74 | 1. Random scale in a selected raio 1.0, 1.2, and 1.5. 75 | 2. Random horizontal flip. 76 | 77 | Please note that: all baselines and MTAN did NOT apply data augmentation in the original paper. 78 | """ 79 | def __init__(self, root, train=True, augmentation=False): 80 | self.train = train 81 | self.root = os.path.expanduser(root) 82 | self.augmentation = augmentation 83 | 84 | # read the data file 85 | if train: 86 | self.data_path = root + '/train' 87 | else: 88 | self.data_path = root + '/val' 89 | 90 | # calculate data length 91 | self.data_len = len(fnmatch.filter(os.listdir(self.data_path + '/image'), '*.npy')) 92 | 93 | def __getitem__(self, index): 94 | # load data from the pre-processed npy files 95 | image = torch.from_numpy(np.moveaxis(np.load(self.data_path + '/image/{:d}.npy'.format(index)), -1, 0)) 96 | semantic = torch.from_numpy(np.load(self.data_path + '/label_7/{:d}.npy'.format(index))) 97 | depth = torch.from_numpy(np.moveaxis(np.load(self.data_path + '/depth/{:d}.npy'.format(index)), -1, 0)) 98 | 99 | # apply data augmentation if required 100 | if self.augmentation: 101 | image, semantic, depth = RandomScaleCropCityScapes()(image, semantic, depth) 102 | if torch.rand(1) < 0.5: 103 | image = torch.flip(image, dims=[2]) 104 | semantic = torch.flip(semantic, dims=[1]) 105 | depth = torch.flip(depth, dims=[2]) 106 | 107 | return image.float(), semantic.float(), depth.float() 108 | 109 | def __len__(self): 110 | return self.data_len -------------------------------------------------------------------------------- /adatask.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.optim.optimizer import Optimizer 4 | from typing import List, Union 5 | 6 | class Adam_with_AdaTask(Optimizer): 7 | r""" 8 | Implements Adam with AdaTask algorithm. 9 | """ 10 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, args=None, device='cpu', n_tasks=3, task_weight=[1, 1]): 11 | if not 0.0 <= lr: 12 | raise ValueError("Invalid learning rate: {}".format(lr)) 13 | if not 0.0 <= eps: 14 | raise ValueError("Invalid epsilon value: {}".format(eps)) 15 | if not 0.0 <= betas[0] < 1.0: 16 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 17 | if not 0.0 <= betas[1] < 1.0: 18 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 19 | if not 0.0 <= weight_decay: 20 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 21 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad) 22 | super(Adam_with_AdaTask, self).__init__(params, defaults) 23 | 24 | self.n_tasks = n_tasks 25 | self.device = device 26 | self.betas = betas 27 | self.eps = eps 28 | self.task_weight = torch.Tensor(task_weight).to(device) 29 | 30 | def zero_grad_modules(self, modules_parameters): 31 | for p in modules_parameters: 32 | if p.grad is not None: 33 | p.grad.detach_() 34 | p.grad.zero_() 35 | 36 | def backward_and_step(self, 37 | losses: torch.Tensor, 38 | shared_parameters: Union[List[torch.nn.parameter.Parameter], torch.Tensor] = None, 39 | task_specific_parameters: Union[List[torch.nn.parameter.Parameter], torch.Tensor] = None, 40 | last_shared_parameters: Union[List[torch.nn.parameter.Parameter], torch.Tensor] = None, ): 41 | 42 | shared_grads = [] 43 | if shared_parameters is not None: 44 | for i in range(len(losses)): 45 | self.zero_grad_modules(shared_parameters) 46 | (self.task_weight[i] * losses[i]).backward(retain_graph=True) 47 | grad = [p.grad.detach().clone() if (p.requires_grad is True and p.grad is not None) else None for p in shared_parameters] 48 | shared_grads.append(grad) 49 | 50 | if task_specific_parameters is not None: 51 | self.zero_grad_modules(task_specific_parameters) 52 | (self.task_weight*losses).sum().backward() 53 | task_specific_grads = [p.grad.detach().clone() if (p.requires_grad is True and p.grad is not None) else None for p in task_specific_parameters] 54 | 55 | return self.step(shared_parameters, task_specific_parameters, shared_grads, task_specific_grads) 56 | 57 | @torch.no_grad() 58 | def step(self, shared_parameters, task_specific_parameters, shared_grads, task_specific_grads): 59 | # lr 60 | for group in self.param_groups: 61 | step_lr = group['lr'] 62 | 63 | # shared param 64 | for pi in range(len(shared_parameters)): 65 | p = shared_parameters[pi] 66 | state = self.state[p] 67 | # State initialization 68 | if len(state) == 0: 69 | state['step'] = 0 70 | for t in range(self.n_tasks): 71 | # Exponential moving average of gradient values 72 | state['exp_avg_'+str(t)] = torch.zeros_like(p, memory_format=torch.preserve_format) 73 | # Exponential moving average of squared gradient values 74 | state['exp_avg_sq_'+str(t)] = torch.zeros_like(p, memory_format=torch.preserve_format) 75 | 76 | state['step'] += 1 77 | beta1, beta2 = self.betas 78 | bias_correction1 = 1 - beta1 ** state['step'] 79 | bias_correction2 = 1 - beta2 ** state['step'] 80 | 81 | for t in range(self.n_tasks): 82 | grad = shared_grads[t][pi] 83 | exp_avg = state['exp_avg_' + str(t)] 84 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 85 | exp_avg_sq = state['exp_avg_sq_' + str(t)] 86 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 87 | denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(self.eps) 88 | step_size = step_lr / bias_correction1 89 | p.addcdiv_(exp_avg, denom, value=-step_size) 90 | 91 | # task specific param 92 | for pi in range(len(task_specific_parameters)): 93 | p = task_specific_parameters[pi] 94 | state = self.state[p] 95 | # State initialization 96 | if len(state) == 0: 97 | state['step'] = 0 98 | state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) 99 | state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 100 | 101 | state['step'] += 1 102 | beta1, beta2 = self.betas 103 | bias_correction1 = 1 - beta1 ** state['step'] 104 | bias_correction2 = 1 - beta2 ** state['step'] 105 | 106 | grad = task_specific_grads[pi] 107 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 108 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 109 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 110 | denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(self.eps) 111 | step_size = step_lr / bias_correction1 112 | p.addcdiv_(exp_avg, denom, value=-step_size) 113 | 114 | return None -------------------------------------------------------------------------------- /main_cityscapes.py: -------------------------------------------------------------------------------- 1 | # import logging 2 | from argparse import ArgumentParser 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | from torch.utils.data import DataLoader 7 | import time 8 | from data import CityScapes 9 | from models import SegNetMtan 10 | 11 | from utils import ConfMatrix, delta_fn_cityscapes, depth_error 12 | from common import ( 13 | common_parser, 14 | get_device, 15 | set_logger, 16 | set_seed, 17 | str2bool, 18 | ) 19 | from weight_methods import WeightMethods 20 | 21 | from utils import create_log_dir 22 | from pathlib import Path 23 | import time 24 | from adatask import Adam_with_AdaTask 25 | from torch.optim import Adam 26 | set_logger() 27 | 28 | def calc_loss(x_pred, x_output, task_type): 29 | device = x_pred.device 30 | 31 | # binary mark to mask out undefined pixel space 32 | binary_mask = (torch.sum(x_output, dim=1) != 0).float().unsqueeze(1).to(device) 33 | 34 | if task_type == "semantic": 35 | # semantic loss: depth-wise cross entropy 36 | loss = F.nll_loss(x_pred, x_output, ignore_index=-1) 37 | 38 | if task_type == "depth": 39 | # depth loss: l1 norm 40 | loss = torch.sum(torch.abs(x_pred - x_output) * binary_mask) / torch.nonzero( 41 | binary_mask, as_tuple=False 42 | ).size(0) 43 | 44 | return loss 45 | 46 | 47 | def main(path, lr, bs, device): 48 | # ---- 49 | # Nets 50 | # --- 51 | 52 | model = SegNetMtan() 53 | model = model.to(device) 54 | 55 | # weight method 56 | weight_method = WeightMethods(args.method, n_tasks=2, device=device) 57 | 58 | # optimizer 59 | if args.optimizer == 'adam_with_adatask': 60 | optimizer = Adam_with_AdaTask([dict(params=model.parameters(), lr=lr)], n_tasks=2, args=args, device=device) 61 | elif args.optimizer == 'adam': 62 | optimizer = Adam([dict(params=model.parameters(), lr=lr)]) 63 | 64 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.5) 65 | 66 | train_set = CityScapes(root=path.as_posix(), train=True, augmentation=args.apply_augmentation) 67 | test_set = CityScapes(root=path.as_posix(), train=False) 68 | 69 | train_loader = torch.utils.data.DataLoader(dataset=train_set, batch_size=bs, shuffle=True) 70 | test_loader = torch.utils.data.DataLoader(dataset=test_set, batch_size=bs, shuffle=False) 71 | 72 | # dataset and dataloaders 73 | log_str = ("Applying data augmentation." if args.apply_augmentation else "Standard training strategy without data augmentation.") 74 | logger.info(log_str) 75 | 76 | epochs = args.n_epochs 77 | train_batch = len(train_loader) 78 | test_batch = len(test_loader) 79 | avg_cost = np.zeros([epochs, 12], dtype=np.float32) 80 | custom_step = -1 81 | conf_mat = ConfMatrix(model.segnet.class_nb) 82 | logger.info('---train begin---') 83 | for epoch in range(epochs): 84 | cost = np.zeros(12, dtype=np.float32) 85 | 86 | for j, batch in enumerate(train_loader): 87 | custom_step += 1 88 | 89 | model.train() 90 | optimizer.zero_grad() 91 | 92 | train_data, train_label, train_depth = batch 93 | train_data, train_label = train_data.to(device), train_label.long().to(device) 94 | train_depth = train_depth.to(device) 95 | 96 | train_pred, features = model(train_data, return_representation=True) 97 | 98 | losses = torch.stack( 99 | ( 100 | calc_loss(train_pred[0], train_label, "semantic"), 101 | calc_loss(train_pred[1], train_depth, "depth"), 102 | ) 103 | ) 104 | 105 | if args.optimizer == 'adam_with_adatask': 106 | optimizer.backward_and_step( 107 | losses=losses, 108 | shared_parameters=list(model.shared_parameters()), 109 | task_specific_parameters=list(model.task_specific_parameters()), 110 | last_shared_parameters=list(model.last_shared_parameters()), 111 | ) 112 | 113 | elif args.optimizer == 'adam': 114 | weight_method.backward( 115 | losses=losses, 116 | shared_parameters=list(model.shared_parameters()), 117 | task_specific_parameters=list(model.task_specific_parameters()), 118 | last_shared_parameters=list(model.last_shared_parameters()), 119 | representation=features, 120 | ) 121 | 122 | optimizer.step() 123 | 124 | # accumulate label prediction for every pixel in training images 125 | conf_mat.update(train_pred[0].argmax(1).flatten(), train_label.flatten()) 126 | 127 | cost[0] = losses[0].item() 128 | cost[3] = losses[1].item() 129 | cost[4], cost[5] = depth_error(train_pred[1], train_depth) 130 | avg_cost[epoch, :6] += cost[:6] / train_batch 131 | 132 | if j % 100 == 0: 133 | print( 134 | f"[{epoch+1} {j+1}/{train_batch}] semantic loss: {losses[0].item():.16f}, " 135 | f"depth loss: {losses[1].item():.16f}, " 136 | ) 137 | 138 | # scheduler 139 | scheduler.step() 140 | # compute mIoU and acc 141 | avg_cost[epoch, 1:3] = conf_mat.get_metrics() 142 | 143 | # todo: move evaluate to function? 144 | # evaluating test data 145 | model.eval() 146 | conf_mat = ConfMatrix(model.segnet.class_nb) 147 | with torch.no_grad(): # operations inside don't track history 148 | test_dataset = iter(test_loader) 149 | for k in range(test_batch): 150 | test_data, test_label, test_depth = test_dataset.next() 151 | test_data, test_label = test_data.to(device), test_label.long().to( 152 | device 153 | ) 154 | test_depth = test_depth.to(device) 155 | 156 | test_pred = model(test_data) 157 | test_loss = torch.stack( 158 | ( 159 | calc_loss(test_pred[0], test_label, "semantic"), 160 | calc_loss(test_pred[1], test_depth, "depth"), 161 | ) 162 | ) 163 | 164 | conf_mat.update(test_pred[0].argmax(1).flatten(), test_label.flatten()) 165 | 166 | cost[6] = test_loss[0].item() 167 | cost[9] = test_loss[1].item() 168 | cost[10], cost[11] = depth_error(test_pred[1], test_depth) 169 | avg_cost[epoch, 6:] += cost[6:] / test_batch 170 | 171 | # compute mIoU and acc 172 | avg_cost[epoch, 7:9] = conf_mat.get_metrics() 173 | 174 | # print results 175 | logger.info(f"LOSS FORMAT: SEMANTIC_LOSS MEAN_IOU PIX_ACC | DEPTH_LOSS ABS_ERR REL_ERR (test)") 176 | logger.info( 177 | f"Epoch: {epoch:04d} | TRAIN: {avg_cost[epoch, 0]:.4f} {avg_cost[epoch, 1]:.4f} {avg_cost[epoch, 2]:.4f} " 178 | f"| {avg_cost[epoch, 3]:.4f} {avg_cost[epoch, 4]:.4f} {avg_cost[epoch, 5]:.4f} " 179 | f"|| TEST: {avg_cost[epoch, 6]:.4f} {avg_cost[epoch, 7]:.4f} {avg_cost[epoch, 8]:.4f} " 180 | f"| {avg_cost[epoch, 9]:.4f} {avg_cost[epoch, 10]:.4f} {avg_cost[epoch, 11]:.4f} " 181 | ) 182 | 183 | 184 | if __name__ == "__main__": 185 | parser = ArgumentParser("cityscapes", parents=[common_parser]) 186 | parser.set_defaults( 187 | data_path='./dataset/cityscapes', 188 | log_path='./log/', 189 | batch_size=8, 190 | n_task=2, 191 | lr=1e-4, 192 | n_epochs=200, 193 | ) 194 | parser.add_argument( 195 | "--method", 196 | type=str, 197 | default="equalweight", 198 | choices=["equalweight"], 199 | help="method type", 200 | ) 201 | parser.add_argument( 202 | "--optimizer", 203 | type=str, 204 | default="adam_with_adatask", 205 | choices=["adam", "adam_with_adatask"], 206 | help="optimizer type", 207 | ) 208 | parser.add_argument( 209 | "--apply-augmentation", type=str2bool, default=True, help="data augmentations" 210 | ) 211 | args = parser.parse_args() 212 | 213 | # set seed 214 | set_seed(args.seed) 215 | logger = create_log_dir(str(Path(args.log_path)) + '/' + 'model_mtan' + '_optimizer_' + args.optimizer+'_' + time.strftime('%Y%m%d_%H%M%S', time.localtime(time.time()))) 216 | logger.info(str(args)) 217 | 218 | device = get_device(gpus=args.gpu) 219 | main(path=args.data_path, lr=args.lr, bs=args.batch_size, device=device) 220 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | from typing import Iterator 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | class _SegNet(nn.Module): 8 | """SegNet MTAN""" 9 | 10 | def __init__(self): 11 | super(_SegNet, self).__init__() 12 | # initialise network parameters 13 | filter = [64, 128, 256, 512, 512] 14 | self.class_nb = 7 15 | 16 | # define encoder decoder layers 17 | self.encoder_block = nn.ModuleList([self.conv_layer([3, filter[0]])]) 18 | self.decoder_block = nn.ModuleList([self.conv_layer([filter[0], filter[0]])]) 19 | for i in range(4): 20 | self.encoder_block.append(self.conv_layer([filter[i], filter[i + 1]])) 21 | self.decoder_block.append(self.conv_layer([filter[i + 1], filter[i]])) 22 | 23 | # define convolution layer 24 | self.conv_block_enc = nn.ModuleList([self.conv_layer([filter[0], filter[0]])]) 25 | self.conv_block_dec = nn.ModuleList([self.conv_layer([filter[0], filter[0]])]) 26 | for i in range(4): 27 | if i == 0: 28 | self.conv_block_enc.append(self.conv_layer([filter[i + 1], filter[i + 1]])) 29 | self.conv_block_dec.append(self.conv_layer([filter[i], filter[i]])) 30 | else: 31 | self.conv_block_enc.append(nn.Sequential(self.conv_layer([filter[i + 1], filter[i + 1]]), 32 | self.conv_layer([filter[i + 1], filter[i + 1]]))) 33 | self.conv_block_dec.append(nn.Sequential(self.conv_layer([filter[i], filter[i]]), 34 | self.conv_layer([filter[i], filter[i]]))) 35 | 36 | # define task attention layers 37 | self.encoder_att = nn.ModuleList([nn.ModuleList([self.att_layer([filter[0], filter[0], filter[0]])])]) 38 | self.decoder_att = nn.ModuleList([nn.ModuleList([self.att_layer([2 * filter[0], filter[0], filter[0]])])]) 39 | self.encoder_block_att = nn.ModuleList([self.conv_layer([filter[0], filter[1]])]) 40 | self.decoder_block_att = nn.ModuleList([self.conv_layer([filter[0], filter[0]])]) 41 | 42 | for j in range(2): 43 | if j < 1: 44 | self.encoder_att.append(nn.ModuleList([self.att_layer([filter[0], filter[0], filter[0]])])) 45 | self.decoder_att.append(nn.ModuleList([self.att_layer([2 * filter[0], filter[0], filter[0]])])) 46 | for i in range(4): 47 | self.encoder_att[j].append(self.att_layer([2 * filter[i + 1], filter[i + 1], filter[i + 1]])) 48 | self.decoder_att[j].append(self.att_layer([filter[i + 1] + filter[i], filter[i], filter[i]])) 49 | 50 | for i in range(4): 51 | if i < 3: 52 | self.encoder_block_att.append(self.conv_layer([filter[i + 1], filter[i + 2]])) 53 | self.decoder_block_att.append(self.conv_layer([filter[i + 1], filter[i]])) 54 | else: 55 | self.encoder_block_att.append(self.conv_layer([filter[i + 1], filter[i + 1]])) 56 | self.decoder_block_att.append(self.conv_layer([filter[i + 1], filter[i + 1]])) 57 | 58 | self.pred_task1 = self.conv_layer([filter[0], self.class_nb], pred=True) 59 | self.pred_task2 = self.conv_layer([filter[0], 1], pred=True) 60 | 61 | # define pooling and unpooling functions 62 | self.down_sampling = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True) 63 | self.up_sampling = nn.MaxUnpool2d(kernel_size=2, stride=2) 64 | 65 | for m in self.modules(): 66 | if isinstance(m, nn.Conv2d): 67 | nn.init.xavier_normal_(m.weight) 68 | nn.init.constant_(m.bias, 0) 69 | elif isinstance(m, nn.BatchNorm2d): 70 | nn.init.constant_(m.weight, 1) 71 | nn.init.constant_(m.bias, 0) 72 | elif isinstance(m, nn.Linear): 73 | nn.init.xavier_normal_(m.weight) 74 | nn.init.constant_(m.bias, 0) 75 | 76 | def shared_modules(self): 77 | return [ 78 | self.encoder_block, 79 | self.decoder_block, 80 | self.conv_block_enc, 81 | self.conv_block_dec, 82 | # self.encoder_att, self.decoder_att, 83 | self.encoder_block_att, 84 | self.decoder_block_att, 85 | self.down_sampling, 86 | self.up_sampling, 87 | ] 88 | 89 | def zero_grad_shared_modules(self): 90 | for mm in self.shared_modules(): 91 | mm.zero_grad() 92 | 93 | def conv_layer(self, channel, pred=False): 94 | if not pred: 95 | conv_block = nn.Sequential( 96 | nn.Conv2d( 97 | in_channels=channel[0], 98 | out_channels=channel[1], 99 | kernel_size=3, 100 | padding=1, 101 | ), 102 | nn.BatchNorm2d(num_features=channel[1]), 103 | nn.ReLU(inplace=True), 104 | ) 105 | else: 106 | conv_block = nn.Sequential( 107 | nn.Conv2d( 108 | in_channels=channel[0], 109 | out_channels=channel[0], 110 | kernel_size=3, 111 | padding=1, 112 | ), 113 | nn.Conv2d( 114 | in_channels=channel[0], 115 | out_channels=channel[1], 116 | kernel_size=1, 117 | padding=0, 118 | ), 119 | ) 120 | return conv_block 121 | 122 | def att_layer(self, channel): 123 | att_block = nn.Sequential( 124 | nn.Conv2d( 125 | in_channels=channel[0], 126 | out_channels=channel[1], 127 | kernel_size=1, 128 | padding=0, 129 | ), 130 | nn.BatchNorm2d(channel[1]), 131 | nn.ReLU(inplace=True), 132 | nn.Conv2d( 133 | in_channels=channel[1], 134 | out_channels=channel[2], 135 | kernel_size=1, 136 | padding=0, 137 | ), 138 | nn.BatchNorm2d(channel[2]), 139 | nn.Sigmoid(), 140 | ) 141 | return att_block 142 | 143 | def forward(self, x): 144 | g_encoder, g_decoder, g_maxpool, g_upsampl, indices = ([0] * 5 for _ in range(5)) 145 | for i in range(5): 146 | g_encoder[i], g_decoder[-i - 1] = ([0] * 2 for _ in range(2)) 147 | 148 | # define attention list for tasks 149 | atten_encoder, atten_decoder = ([0] * 2 for _ in range(2)) 150 | for i in range(2): 151 | atten_encoder[i], atten_decoder[i] = ([0] * 5 for _ in range(2)) 152 | for i in range(2): 153 | for j in range(5): 154 | atten_encoder[i][j], atten_decoder[i][j] = ([0] * 3 for _ in range(2)) 155 | 156 | # define global shared network 157 | for i in range(5): 158 | if i == 0: 159 | g_encoder[i][0] = self.encoder_block[i](x) 160 | g_encoder[i][1] = self.conv_block_enc[i](g_encoder[i][0]) 161 | g_maxpool[i], indices[i] = self.down_sampling(g_encoder[i][1]) 162 | else: 163 | g_encoder[i][0] = self.encoder_block[i](g_maxpool[i - 1]) 164 | g_encoder[i][1] = self.conv_block_enc[i](g_encoder[i][0]) 165 | g_maxpool[i], indices[i] = self.down_sampling(g_encoder[i][1]) 166 | 167 | for i in range(5): 168 | if i == 0: 169 | g_upsampl[i] = self.up_sampling(g_maxpool[-1], indices[-i - 1]) 170 | g_decoder[i][0] = self.decoder_block[-i - 1](g_upsampl[i]) 171 | g_decoder[i][1] = self.conv_block_dec[-i - 1](g_decoder[i][0]) 172 | else: 173 | g_upsampl[i] = self.up_sampling(g_decoder[i - 1][-1], indices[-i - 1]) 174 | g_decoder[i][0] = self.decoder_block[-i - 1](g_upsampl[i]) 175 | g_decoder[i][1] = self.conv_block_dec[-i - 1](g_decoder[i][0]) 176 | 177 | # define task dependent attention module 178 | for i in range(2): 179 | for j in range(5): 180 | if j == 0: 181 | atten_encoder[i][j][0] = self.encoder_att[i][j](g_encoder[j][0]) 182 | atten_encoder[i][j][1] = (atten_encoder[i][j][0]) * g_encoder[j][1] 183 | atten_encoder[i][j][2] = self.encoder_block_att[j](atten_encoder[i][j][1]) 184 | atten_encoder[i][j][2] = F.max_pool2d(atten_encoder[i][j][2], kernel_size=2, stride=2) 185 | else: 186 | atten_encoder[i][j][0] = self.encoder_att[i][j](torch.cat((g_encoder[j][0], atten_encoder[i][j - 1][2]), dim=1)) 187 | atten_encoder[i][j][1] = (atten_encoder[i][j][0]) * g_encoder[j][1] 188 | atten_encoder[i][j][2] = self.encoder_block_att[j](atten_encoder[i][j][1]) 189 | atten_encoder[i][j][2] = F.max_pool2d(atten_encoder[i][j][2], kernel_size=2, stride=2) 190 | 191 | for j in range(5): 192 | if j == 0: 193 | atten_decoder[i][j][0] = F.interpolate(atten_encoder[i][-1][-1], scale_factor=2, mode='bilinear', align_corners=True) 194 | atten_decoder[i][j][0] = self.decoder_block_att[-j - 1](atten_decoder[i][j][0]) 195 | atten_decoder[i][j][1] = self.decoder_att[i][-j - 1](torch.cat((g_upsampl[j], atten_decoder[i][j][0]), dim=1)) 196 | atten_decoder[i][j][2] = (atten_decoder[i][j][1]) * g_decoder[j][-1] 197 | else: 198 | atten_decoder[i][j][0] = F.interpolate(atten_decoder[i][j - 1][2], scale_factor=2, mode='bilinear', align_corners=True) 199 | atten_decoder[i][j][0] = self.decoder_block_att[-j - 1](atten_decoder[i][j][0]) 200 | atten_decoder[i][j][1] = self.decoder_att[i][-j - 1](torch.cat((g_upsampl[j], atten_decoder[i][j][0]), dim=1)) 201 | atten_decoder[i][j][2] = (atten_decoder[i][j][1]) * g_decoder[j][-1] 202 | # define task prediction layers 203 | t1_pred = F.log_softmax(self.pred_task1(atten_decoder[0][-1][-1]), dim=1) 204 | t2_pred = self.pred_task2(atten_decoder[1][-1][-1]) 205 | 206 | return ( 207 | [t1_pred, t2_pred], 208 | ( 209 | atten_decoder[0][-1][-1], 210 | atten_decoder[1][-1][-1], 211 | # atten_decoder[2][-1][-1], 212 | ), 213 | ) 214 | 215 | class SegNetMtan(nn.Module): 216 | def __init__(self): 217 | super().__init__() 218 | self.segnet = _SegNet() 219 | 220 | def shared_parameters(self) -> Iterator[nn.parameter.Parameter]: 221 | return (p for n, p in self.segnet.named_parameters() if "pred" not in n and 'coder_att' not in n) 222 | 223 | def task_specific_parameters(self) -> Iterator[nn.parameter.Parameter]: 224 | return (p for n, p in self.segnet.named_parameters() if "pred" in n or 'coder_att' in n) 225 | 226 | def last_shared_parameters(self) -> Iterator[nn.parameter.Parameter]: 227 | """Parameters of the last shared layer. 228 | Returns 229 | ------- 230 | """ 231 | return self.segnet.conv_block_dec[-5].parameters() 232 | 233 | def forward(self, x, return_representation=False): 234 | if return_representation: 235 | return self.segnet(x) 236 | else: 237 | pred, rep = self.segnet(x) 238 | return pred 239 | 240 | class SegNetSplit(nn.Module): 241 | def __init__(self, model_type="standard"): 242 | super(SegNetSplit, self).__init__() 243 | # initialise network parameters 244 | assert model_type in ["standard", "wide", "deep"] 245 | self.model_type = model_type 246 | if self.model_type == "wide": 247 | filter = [64, 128, 256, 512, 1024] 248 | else: 249 | filter = [64, 128, 256, 512, 512] 250 | 251 | self.class_nb = 7 252 | 253 | # define encoder decoder layers 254 | self.encoder_block = nn.ModuleList([self.conv_layer([3, filter[0]])]) 255 | self.decoder_block = nn.ModuleList([self.conv_layer([filter[0], filter[0]])]) 256 | for i in range(4): 257 | self.encoder_block.append(self.conv_layer([filter[i], filter[i + 1]])) 258 | self.decoder_block.append(self.conv_layer([filter[i + 1], filter[i]])) 259 | 260 | # define convolution layer 261 | self.conv_block_enc = nn.ModuleList([self.conv_layer([filter[0], filter[0]])]) 262 | self.conv_block_dec = nn.ModuleList([self.conv_layer([filter[0], filter[0]])]) 263 | for i in range(4): 264 | if i == 0: 265 | self.conv_block_enc.append(self.conv_layer([filter[i + 1], filter[i + 1]])) 266 | self.conv_block_dec.append(self.conv_layer([filter[i], filter[i]])) 267 | else: 268 | self.conv_block_enc.append( 269 | nn.Sequential( 270 | self.conv_layer([filter[i + 1], filter[i + 1]]), 271 | self.conv_layer([filter[i + 1], filter[i + 1]]), 272 | ) 273 | ) 274 | self.conv_block_dec.append( 275 | nn.Sequential( 276 | self.conv_layer([filter[i], filter[i]]), 277 | self.conv_layer([filter[i], filter[i]]), 278 | ) 279 | ) 280 | 281 | # define task specific layers 282 | self.pred_task1 = nn.Sequential( 283 | nn.Conv2d(in_channels=filter[0], out_channels=filter[0], kernel_size=3, padding=1), 284 | nn.Conv2d(in_channels=filter[0], out_channels=self.class_nb, kernel_size=1, padding=0,), 285 | ) 286 | self.pred_task2 = nn.Sequential( 287 | nn.Conv2d(in_channels=filter[0], out_channels=filter[0], kernel_size=3, padding=1), 288 | nn.Conv2d(in_channels=filter[0], out_channels=1, kernel_size=1, padding=0), 289 | ) 290 | 291 | # define pooling and unpooling functions 292 | self.down_sampling = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True) 293 | self.up_sampling = nn.MaxUnpool2d(kernel_size=2, stride=2) 294 | 295 | for m in self.modules(): 296 | if isinstance(m, nn.Conv2d): 297 | nn.init.xavier_normal_(m.weight) 298 | nn.init.constant_(m.bias, 0) 299 | elif isinstance(m, nn.BatchNorm2d): 300 | nn.init.constant_(m.weight, 1) 301 | nn.init.constant_(m.bias, 0) 302 | elif isinstance(m, nn.Linear): 303 | nn.init.xavier_normal_(m.weight) 304 | nn.init.constant_(m.bias, 0) 305 | 306 | # define convolutional block 307 | def conv_layer(self, channel): 308 | if self.model_type == "deep": 309 | conv_block = nn.Sequential( 310 | nn.Conv2d( 311 | in_channels=channel[0], 312 | out_channels=channel[1], 313 | kernel_size=3, 314 | padding=1, 315 | ), 316 | nn.BatchNorm2d(num_features=channel[1]), 317 | nn.ReLU(inplace=True), 318 | nn.Conv2d( 319 | in_channels=channel[1], 320 | out_channels=channel[1], 321 | kernel_size=3, 322 | padding=1, 323 | ), 324 | nn.BatchNorm2d(num_features=channel[1]), 325 | nn.ReLU(inplace=True), 326 | ) 327 | else: 328 | conv_block = nn.Sequential( 329 | nn.Conv2d( 330 | in_channels=channel[0], 331 | out_channels=channel[1], 332 | kernel_size=3, 333 | padding=1, 334 | ), 335 | nn.BatchNorm2d(num_features=channel[1]), 336 | nn.ReLU(inplace=True), 337 | ) 338 | return conv_block 339 | 340 | def forward(self, x): 341 | g_encoder, g_decoder, g_maxpool, g_upsampl, indices = ( 342 | [0] * 5 for _ in range(5) 343 | ) 344 | for i in range(5): 345 | g_encoder[i], g_decoder[-i - 1] = ([0] * 2 for _ in range(2)) 346 | 347 | # global shared encoder-decoder network 348 | for i in range(5): 349 | if i == 0: 350 | g_encoder[i][0] = self.encoder_block[i](x) 351 | g_encoder[i][1] = self.conv_block_enc[i](g_encoder[i][0]) 352 | g_maxpool[i], indices[i] = self.down_sampling(g_encoder[i][1]) 353 | else: 354 | g_encoder[i][0] = self.encoder_block[i](g_maxpool[i - 1]) 355 | g_encoder[i][1] = self.conv_block_enc[i](g_encoder[i][0]) 356 | g_maxpool[i], indices[i] = self.down_sampling(g_encoder[i][1]) 357 | 358 | for i in range(5): 359 | if i == 0: 360 | g_upsampl[i] = self.up_sampling(g_maxpool[-1], indices[-i - 1]) 361 | g_decoder[i][0] = self.decoder_block[-i - 1](g_upsampl[i]) 362 | g_decoder[i][1] = self.conv_block_dec[-i - 1](g_decoder[i][0]) 363 | else: 364 | g_upsampl[i] = self.up_sampling(g_decoder[i - 1][-1], indices[-i - 1]) 365 | g_decoder[i][0] = self.decoder_block[-i - 1](g_upsampl[i]) 366 | g_decoder[i][1] = self.conv_block_dec[-i - 1](g_decoder[i][0]) 367 | 368 | # define task prediction layers 369 | t1_pred = F.log_softmax(self.pred_task1(g_decoder[i][1]), dim=1) 370 | t2_pred = self.pred_task2(g_decoder[i][1]) 371 | 372 | return ([t1_pred, t2_pred], (g_decoder[i][1]),) 373 | # NOTE: last element is representation 374 | --------------------------------------------------------------------------------