├── experiments ├── __init__.py ├── nyuv2 │ ├── __init__.py │ ├── .DS_Store │ ├── README.md │ ├── utils.py │ ├── data.py │ ├── trainer.py │ └── models.py ├── quantum_chemistry │ ├── __init__.py │ ├── README.md │ ├── models.py │ ├── utils.py │ └── trainer.py ├── .DS_Store └── utils.py ├── .DS_Store ├── .idea ├── .gitignore ├── vcs.xml ├── inspectionProfiles │ ├── profiles_settings.xml │ └── Project_Default.xml ├── modules.xml └── IGB4MTL.iml ├── requirements.txt ├── methods ├── __init__.py ├── weight_method.py ├── min_norm_solvers.py ├── SAC_Agent.py ├── loss_weight_methods.py └── gradient_weight_methods.py ├── LICENSE └── README.md /experiments/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /experiments/nyuv2/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /experiments/quantum_chemistry/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YanqiDai/IGB4MTL/HEAD/.DS_Store -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | -------------------------------------------------------------------------------- /experiments/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YanqiDai/IGB4MTL/HEAD/experiments/.DS_Store -------------------------------------------------------------------------------- /experiments/nyuv2/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YanqiDai/IGB4MTL/HEAD/experiments/nyuv2/.DS_Store -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib>=3.2.1 2 | numpy>=1.18.2 3 | torch>=1.4.0 4 | torchvision>=0.8.0 5 | cvxpy 6 | tqdm>=4.45.0 7 | pandas 8 | scikit-learn 9 | seaborn 10 | plotly 11 | scipy==1.10.1 -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /experiments/quantum_chemistry/README.md: -------------------------------------------------------------------------------- 1 | # QM9 Experiment 2 | 3 | Modification of the code in [Nash-MTL](https://github.com/AvivNavon/nash-mtl). 4 | 5 | ## Dataset 6 | 7 | The dataset will be downloaded automatically from torch_geometric and saved in `./dataset`. 8 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /experiments/nyuv2/README.md: -------------------------------------------------------------------------------- 1 | # NYUv2 Experiment 2 | 3 | Modification of the code in [Nash-MTL](https://github.com/AvivNavon/nash-mtl). 4 | 5 | ## Dataset 6 | 7 | The dataset is available at [this link](https://www.dropbox.com/sh/86nssgwm6hm3vkb/AACrnUQ4GxpdrBbLjb6n-mWNa?dl=0). Put the downloaded files in `./dataset` so that the folder structure is `./dataset/train` and `./dataset/val`. -------------------------------------------------------------------------------- /.idea/IGB4MTL.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /methods/__init__.py: -------------------------------------------------------------------------------- 1 | from methods.loss_weight_methods import ( 2 | LOSS_METHODS, 3 | STL, 4 | LinearScalarization, 5 | Uncertainty, 6 | UncertaintyLog, 7 | ScaleInvariantLinearScalarization, 8 | RLW, 9 | RLWLog, 10 | DynamicWeightAverage, 11 | DynamicWeightAverageLog, 12 | ImprovableGapBalancing_v1, 13 | ImprovableGapBalancing_v2, 14 | ) 15 | from methods.gradient_weight_methods import ( 16 | GRADIENT_METHODS, 17 | PCGrad, 18 | MGDA, 19 | CAGrad, 20 | NashMTL, 21 | IMTLG, 22 | ) 23 | from methods.SAC_Agent import SAC_Agent, RandomBuffer 24 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Yanqi Dai 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /experiments/quantum_chemistry/models.py: -------------------------------------------------------------------------------- 1 | from itertools import chain 2 | from typing import Iterator 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch.nn import GRU, Linear, ReLU, Sequential 7 | from torch_geometric.nn import DimeNet, NNConv, Set2Set, radius_graph 8 | 9 | 10 | class Net(torch.nn.Module): 11 | def __init__(self, n_tasks, num_features=11, dim=64): 12 | super().__init__() 13 | self.n_tasks = n_tasks 14 | self.dim = dim 15 | self.lin0 = torch.nn.Linear(num_features, dim) 16 | 17 | nn = Sequential(Linear(5, 128), ReLU(), Linear(128, dim * dim)) 18 | self.conv = NNConv(dim, dim, nn, aggr="mean") 19 | self.gru = GRU(dim, dim) 20 | 21 | self.set2set = Set2Set(dim, processing_steps=3) 22 | self.lin1 = torch.nn.Linear(2 * dim, dim) 23 | 24 | self._init_task_heads() 25 | 26 | def _init_task_heads(self): 27 | for i in range(self.n_tasks): 28 | setattr(self, f"head_{i}", torch.nn.Linear(self.dim, 1)) 29 | self.task_specific = torch.nn.ModuleList( 30 | [getattr(self, f"head_{i}") for i in range(self.n_tasks)] 31 | ) 32 | 33 | def forward(self, data, return_representation=False): 34 | out = F.relu(self.lin0(data.x)) 35 | h = out.unsqueeze(0) 36 | 37 | for i in range(3): 38 | m = F.relu(self.conv(out, data.edge_index, data.edge_attr)) 39 | out, h = self.gru(m.unsqueeze(0), h) 40 | out = out.squeeze(0) 41 | 42 | out = self.set2set(out, data.batch) 43 | features = F.relu(self.lin1(out)) 44 | logits = torch.cat( 45 | [getattr(self, f"head_{i}")(features) for i in range(self.n_tasks)], dim=1 46 | ) 47 | if return_representation: 48 | return logits, features 49 | return logits 50 | 51 | def shared_parameters(self) -> Iterator[torch.nn.parameter.Parameter]: 52 | return chain( 53 | self.lin0.parameters(), 54 | self.conv.parameters(), 55 | self.gru.parameters(), 56 | self.set2set.parameters(), 57 | self.lin1.parameters(), 58 | ) 59 | 60 | def task_specific_parameters(self) -> Iterator[torch.nn.parameter.Parameter]: 61 | return self.task_specific.parameters() 62 | 63 | def last_shared_parameters(self) -> Iterator[torch.nn.parameter.Parameter]: 64 | return self.lin1.parameters() 65 | -------------------------------------------------------------------------------- /experiments/quantum_chemistry/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch_geometric.utils import remove_self_loops 4 | 5 | 6 | class MyTransform(object): 7 | def __init__(self, target: list = None): 8 | if target is None: 9 | target = torch.tensor([0, 1, 2, 3, 5, 6, 12, 13, 14, 15, 11]) # removing 4 10 | else: 11 | target = torch.tensor(target) 12 | self.target = target 13 | 14 | def __call__(self, data): 15 | # Specify target. 16 | data.y = data.y[:, self.target] 17 | return data 18 | 19 | 20 | class Complete(object): 21 | def __call__(self, data): 22 | device = data.edge_index.device 23 | 24 | row = torch.arange(data.num_nodes, dtype=torch.long, device=device) 25 | col = torch.arange(data.num_nodes, dtype=torch.long, device=device) 26 | 27 | row = row.view(-1, 1).repeat(1, data.num_nodes).view(-1) 28 | col = col.repeat(data.num_nodes) 29 | edge_index = torch.stack([row, col], dim=0) 30 | 31 | edge_attr = None 32 | if data.edge_attr is not None: 33 | idx = data.edge_index[0] * data.num_nodes + data.edge_index[1] 34 | size = list(data.edge_attr.size()) 35 | size[0] = data.num_nodes * data.num_nodes 36 | edge_attr = data.edge_attr.new_zeros(size) 37 | edge_attr[idx] = data.edge_attr 38 | 39 | edge_index, edge_attr = remove_self_loops(edge_index, edge_attr) 40 | data.edge_attr = edge_attr 41 | data.edge_index = edge_index 42 | 43 | return data 44 | 45 | 46 | qm9_target_dict = { 47 | 0: "mu", 48 | 1: "alpha", 49 | 2: "homo", 50 | 3: "lumo", 51 | 5: "r2", 52 | 6: "zpve", 53 | 7: "U0", 54 | 8: "U", 55 | 9: "H", 56 | 10: "G", 57 | 11: "Cv", 58 | } 59 | 60 | # for \Delta_m calculations 61 | # ------------------------- 62 | # DimeNet uses the atomization energy for targets U0, U, H, and G. 63 | target_idx = [0, 1, 2, 3, 5, 6, 12, 13, 14, 15, 11] 64 | 65 | # Report meV instead of eV. 66 | multiply_indx = [2, 3, 5, 6, 7, 8, 9] 67 | 68 | n_tasks = len(target_idx) 69 | 70 | # stl results 71 | BASE = np.array( 72 | [ 73 | 0.0671, 74 | 0.1814, 75 | 60.576, 76 | 53.915, 77 | 0.5027, 78 | 4.539, 79 | 58.838, 80 | 64.244, 81 | 63.852, 82 | 66.223, 83 | 0.07212, 84 | ] 85 | ) 86 | 87 | SIGN = np.array([0] * n_tasks) 88 | KK = np.ones(n_tasks) * -1 89 | 90 | 91 | def delta_fn(a): 92 | return (KK ** SIGN * (a - BASE) / BASE).mean() * 100.0 # *100 for percentage 93 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 51 | -------------------------------------------------------------------------------- /experiments/nyuv2/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class ConfMatrix(object): 6 | def __init__(self, num_classes): 7 | self.num_classes = num_classes 8 | self.mat = None 9 | 10 | def update(self, pred, target): 11 | n = self.num_classes 12 | if self.mat is None: 13 | self.mat = torch.zeros((n, n), dtype=torch.int64, device=pred.device) 14 | with torch.no_grad(): 15 | k = (target >= 0) & (target < n) 16 | inds = n * target[k].to(torch.int64) + pred[k] 17 | self.mat += torch.bincount(inds, minlength=n ** 2).reshape(n, n) 18 | 19 | def get_metrics(self): 20 | h = self.mat.float() 21 | acc = torch.diag(h).sum() / h.sum() 22 | iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h)) 23 | return torch.mean(iu).cpu().numpy(), acc.cpu().numpy() 24 | 25 | 26 | def depth_error(x_pred, x_output): 27 | device = x_pred.device 28 | binary_mask = (torch.sum(x_output, dim=1) != 0).unsqueeze(1).to(device) 29 | x_pred_true = x_pred.masked_select(binary_mask) 30 | x_output_true = x_output.masked_select(binary_mask) 31 | abs_err = torch.abs(x_pred_true - x_output_true) 32 | rel_err = torch.abs(x_pred_true - x_output_true) / x_output_true 33 | return ( 34 | torch.sum(abs_err) / torch.nonzero(binary_mask, as_tuple=False).size(0) 35 | ).item(), ( 36 | torch.sum(rel_err) / torch.nonzero(binary_mask, as_tuple=False).size(0) 37 | ).item() 38 | 39 | 40 | def normal_error(x_pred, x_output): 41 | binary_mask = torch.sum(x_output, dim=1) != 0 42 | error = ( 43 | torch.acos( 44 | torch.clamp( 45 | torch.sum(x_pred * x_output, 1).masked_select(binary_mask), -1, 1 46 | ) 47 | ) 48 | .detach() 49 | .cpu() 50 | .numpy() 51 | ) 52 | error = np.degrees(error) 53 | return ( 54 | np.mean(error), 55 | np.median(error), 56 | np.mean(error < 11.25), 57 | np.mean(error < 22.5), 58 | np.mean(error < 30), 59 | ) 60 | 61 | 62 | # for calculating \Delta_m 63 | delta_stats = [ 64 | "mean iou", 65 | "pix acc", 66 | "abs err", 67 | "rel err", 68 | "mean", 69 | "median", 70 | "<11.25", 71 | "<22.5", 72 | "<30", 73 | ] 74 | BASE = np.array( 75 | [0.4116, 0.657, 0.6074, 0.24, 24.49, 18.24, 0.3192, 0.5916, 0.7056] 76 | ) # base results of STL 77 | SIGN = np.array([1, 1, 0, 0, 0, 0, 1, 1, 1]) 78 | KK = np.ones(9) * -1 79 | 80 | 81 | def delta_fn(a): 82 | return (KK ** SIGN * (a - BASE) / BASE).mean() * 100.0 # * 100 for percentage 83 | 84 | 85 | M_NUM = np.array([[0, 2], [2, 4], [4, 9]]) 86 | 87 | 88 | def stl_eval_mean(a, main_task): 89 | a = KK ** SIGN * a 90 | eval_values = np.array([a[bound[0]: bound[1]].mean() for bound in M_NUM]) 91 | return eval_values[main_task] 92 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # IGB4MTL 2 | 3 | Official implementation of _"Improvable Gap Balancing for Multi-Task Learning"_, which has been accepted to UAI 2023. 4 | 5 | ## Setup environment 6 | 7 | ```bash 8 | conda create -n igb4mtl python=3.8.13 9 | conda activate igb4mtl 10 | conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=10.2 -c pytorch 11 | ``` 12 | 13 | Install the repo: 14 | 15 | ```bash 16 | git clone https://github.com/YanqiDai/IGB4MTL.git 17 | cd IGB4MTL 18 | pip install -r requirement.txt 19 | ``` 20 | 21 | ## Run experiment 22 | 23 | Follow instruction on the experiment README file for more information regarding, e.g., datasets. 24 | 25 | We support our IGB methods and other existing MTL methods with a unified API. To run experiments: 26 | 27 | ```bash 28 | cd experiments/ 29 | python trainer.py --loss_method= --gradient_method= 30 | ``` 31 | 32 | Here, 33 | - `` is one of `[quantum_chemistry, nyuv2]`. 34 | - `` is one of `igbv1`, `igbv2` and the following loss balancing MTL methods. 35 | - `` is one of the following gradient balancing MTL methods. 36 | - Both `` and `` are optional: 37 | - only using `` is to run a loss balancing method; 38 | - only using `` is to run a gradient balancing method; 39 | - using neither is to run Equal Weighting (EW) method. 40 | - using both is to run a combined MTL method by both loss balancing and gradient balancing. 41 | 42 | ## MTL methods 43 | 44 | We support the following loss balancing and gradient balancing methods. 45 | 46 | | Loss Balancing Method (code name) | Paper (notes) | 47 | |:-------------------------------------:|:--------------------------------------------------------------------------------------------------------------------------------:| 48 | | Equal Weighting (`ls`) | - (linear scalarization) | 49 | | Random Loss Weighting (`rlw`) | [A Closer Look at Loss Weighting in Multi-Task Learning](https://arxiv.org/pdf/2111.10603.pdf) | 50 | | Dynamic Weight Average (`dwa`) | [End-to-End Multi-Task Learning with Attention](https://arxiv.org/abs/1803.10704) | 51 | | Uncertainty Weighting (`uw`) | [Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics](https://arxiv.org/pdf/1705.07115v3.pdf) | 52 | | Improvable Gap Balancing v1 (`igbv1`) | - (our first IGB method) | 53 | | Improvable Gap Balancing v1 (`igbv1`) | - (our second IGB method) | 54 | 55 | 56 | | Gradient Balancing Method (code name) | Paper (notes) | 57 | |:-------------------------------------:|:------------------------------------------------------------------------------------------------:| 58 | | MGDA (`mgda`) | [Multi-Task Learning as Multi-Objective Optimization](https://arxiv.org/abs/1810.04650) | 59 | | PCGrad (`pcgrad`) | [Gradient Surgery for Multi-Task Learning](https://arxiv.org/abs/2001.06782) | 60 | | CAGrad (`cagrad`) | [Conflict-Averse Gradient Descent for Multi-task Learning](https://arxiv.org/pdf/2110.14048.pdf) | 61 | | IMTL-G (`imtl`) | [Towards Impartial Multi-task Learning](https://openreview.net/forum?id=IMPnRXEWpvr) | 62 | | Nash-MTL (`nashmtl`) | [Multi-Task Learning as a Bargaining Game](https://arxiv.org/pdf/2202.01017v1.pdf) | 63 | -------------------------------------------------------------------------------- /experiments/nyuv2/data.py: -------------------------------------------------------------------------------- 1 | import fnmatch 2 | import os 3 | import random 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn.functional as F 8 | from torch.utils.data.dataset import Dataset 9 | 10 | """ 11 | Source: https://github.com/Cranial-XIX/CAGrad/blob/main/nyuv2/create_dataset.py 12 | """ 13 | 14 | 15 | class RandomScaleCrop(object): 16 | """ 17 | Credit to Jialong Wu from https://github.com/lorenmt/mtan/issues/34. 18 | """ 19 | def __init__(self, scale=None): 20 | if scale is None: 21 | scale = [1.0, 1.1, 1.2, 1.3, 1.4, 1.5] 22 | self.scale = scale 23 | 24 | def __call__(self, img, label, depth, normal): 25 | height, width = img.shape[-2:] 26 | sc = self.scale[random.randint(0, len(self.scale) - 1)] 27 | h, w = int(height / sc), int(width / sc) 28 | i = random.randint(0, height - h) 29 | j = random.randint(0, width - w) 30 | img_ = F.interpolate( 31 | img[None, :, i: i + h, j: j + w], 32 | size=(height, width), 33 | mode="bilinear", 34 | align_corners=True, 35 | ).squeeze(0) 36 | label_ = ( 37 | F.interpolate( 38 | label[None, None, i: i + h, j: j + w], 39 | size=(height, width), 40 | mode="nearest", 41 | ) 42 | .squeeze(0) 43 | .squeeze(0) 44 | ) 45 | depth_ = F.interpolate( 46 | depth[None, :, i: i + h, j: j + w], size=(height, width), mode="nearest" 47 | ).squeeze(0) 48 | normal_ = F.interpolate( 49 | normal[None, :, i: i + h, j: j + w], 50 | size=(height, width), 51 | mode="bilinear", 52 | align_corners=True, 53 | ).squeeze(0) 54 | return img_, label_, depth_ / sc, normal_ 55 | 56 | 57 | class NYUv2(Dataset): 58 | def __init__(self, root, mode="train", augmentation=False): 59 | self.mode = mode 60 | self.root = os.path.expanduser(root) 61 | self.augmentation = augmentation 62 | 63 | # read the data file 64 | if mode == "train": 65 | self.data_path = root + "/train" 66 | elif mode == "val": 67 | self.data_path = root + "/val" 68 | else: 69 | self.data_path = root + "/test" 70 | 71 | # get data_files and calculate data length 72 | self.data_files = fnmatch.filter(os.listdir(self.data_path + "/image"), "*.npy") 73 | self.data_len = len(self.data_files) 74 | 75 | def __getitem__(self, index): 76 | # load data from the pre-processed npy files 77 | image = torch.from_numpy( 78 | np.moveaxis( 79 | np.load(self.data_path + "/image/{}".format(self.data_files[index])), -1, 0 80 | ) 81 | ) 82 | semantic = torch.from_numpy( 83 | np.load(self.data_path + "/label/{}".format(self.data_files[index])) 84 | ) 85 | depth = torch.from_numpy( 86 | np.moveaxis( 87 | np.load(self.data_path + "/depth/{}".format(self.data_files[index])), -1, 0 88 | ) 89 | ) 90 | normal = torch.from_numpy( 91 | np.moveaxis( 92 | np.load(self.data_path + "/normal/{}".format(self.data_files[index])), -1, 0 93 | ) 94 | ) 95 | 96 | # apply data augmentation if required 97 | if self.augmentation: 98 | image, semantic, depth, normal = RandomScaleCrop()( 99 | image, semantic, depth, normal 100 | ) 101 | if torch.rand(1) < 0.5: 102 | image = torch.flip(image, dims=[2]) 103 | semantic = torch.flip(semantic, dims=[1]) 104 | depth = torch.flip(depth, dims=[2]) 105 | normal = torch.flip(normal, dims=[2]) 106 | normal[0, :, :] = -normal[0, :, :] 107 | 108 | return image.float(), semantic.float(), depth.float(), normal.float() 109 | 110 | def __len__(self): 111 | return self.data_len 112 | -------------------------------------------------------------------------------- /experiments/utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import random 4 | from collections import defaultdict 5 | from pathlib import Path 6 | 7 | import numpy as np 8 | import torch 9 | 10 | from methods import LOSS_METHODS, GRADIENT_METHODS 11 | 12 | 13 | def str_to_list(string): 14 | return [float(s) for s in string.split(",")] 15 | 16 | 17 | def str_or_float(value): 18 | try: 19 | return float(value) 20 | except: 21 | return value 22 | 23 | 24 | def str2bool(v): 25 | if isinstance(v, bool): 26 | return v 27 | if v.lower() in ("yes", "true", "t", "y", "1"): 28 | return True 29 | elif v.lower() in ("no", "false", "f", "n", "0"): 30 | return False 31 | else: 32 | raise argparse.ArgumentTypeError("Boolean value expected.") 33 | 34 | 35 | common_parser = argparse.ArgumentParser(add_help=False) 36 | common_parser.add_argument("--data-path", type=Path, help="path to data") 37 | common_parser.add_argument("--n-epochs", type=int, default=500) 38 | common_parser.add_argument("--batch-size", type=int, default=2, help="batch size") 39 | common_parser.add_argument( 40 | "--loss_method", 41 | type=str, 42 | choices=list(LOSS_METHODS.keys()), 43 | default="ls", 44 | help="MTL loss weight method" 45 | ) 46 | common_parser.add_argument( 47 | "--gradient_method", 48 | type=str, 49 | choices=list(GRADIENT_METHODS.keys()), 50 | default="ls", 51 | help="MTL gradient weight method" 52 | ) 53 | common_parser.add_argument("--lr", type=float, default=1e-3, help="learning rate") 54 | common_parser.add_argument( 55 | "--method-params-lr", 56 | type=float, 57 | default=0.025, 58 | help="lr for weight method params. If None, set to args.lr. For uncertainty weighting", 59 | ) 60 | common_parser.add_argument("--gpu", type=int, default=0, help="gpu device ID") 61 | common_parser.add_argument("--seed", type=int, default=42, help="seed value") 62 | # NashMTL 63 | common_parser.add_argument( 64 | "--nashmtl-optim-niter", type=int, default=20, help="number of CCCP iterations" 65 | ) 66 | common_parser.add_argument( 67 | "--update-weights-every", 68 | type=int, 69 | default=1, 70 | help="update task weights every x iterations.", 71 | ) 72 | # stl 73 | common_parser.add_argument( 74 | "--main-task", 75 | type=int, 76 | default=0, 77 | help="main task for stl. Ignored if method != stl", 78 | ) 79 | # cagrad 80 | common_parser.add_argument("--c", type=float, default=0.4, help="c for CAGrad alg.") 81 | # dwa 82 | common_parser.add_argument( 83 | "--dwa-temp", 84 | type=float, 85 | default=2.0, 86 | help="Temperature hyper-parameter for DWA. Default to 2 like in the original paper.", 87 | ) 88 | 89 | # igbv1 and igbv2 90 | common_parser.add_argument( 91 | "--base_epoch", 92 | type=int, 93 | default=1, 94 | help="Set which epoch's average losses as base_losses for fw or fwlog", 95 | ) 96 | 97 | # igbv2 98 | common_parser.add_argument( 99 | "--sac_lr", 100 | type=float, 101 | default=3e-4, 102 | help="learning rate of sac in igbv2", 103 | ) 104 | common_parser.add_argument( 105 | "--buffer_size", 106 | type=float, 107 | default=1e4, 108 | help="max replay buffer size in igbv2", 109 | ) 110 | 111 | 112 | def count_parameters(model): 113 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 114 | 115 | 116 | def set_logger(): 117 | logging.basicConfig( 118 | format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", 119 | level=logging.INFO, 120 | ) 121 | 122 | 123 | def set_seed(seed): 124 | """for reproducibility 125 | :param seed: 126 | :return: 127 | """ 128 | np.random.seed(seed) 129 | random.seed(seed) 130 | 131 | torch.manual_seed(seed) 132 | torch.cuda.manual_seed(seed) 133 | torch.cuda.manual_seed_all(seed) 134 | 135 | torch.backends.cudnn.enabled = False 136 | torch.backends.cudnn.benchmark = False 137 | torch.backends.cudnn.deterministic = True 138 | 139 | 140 | def get_device(no_cuda=False, gpus="0"): 141 | return torch.device( 142 | f"cuda:{gpus}" if torch.cuda.is_available() and not no_cuda else "cpu" 143 | ) 144 | 145 | 146 | def extract_weight_method_parameters_from_args(args): 147 | weight_methods_parameters = defaultdict(dict) 148 | weight_methods_parameters.update( 149 | dict( 150 | nashmtl=dict( 151 | update_weights_every=args.update_weights_every, 152 | optim_niter=args.nashmtl_optim_niter, 153 | ), 154 | stl=dict(main_task=args.main_task), 155 | cagrad=dict(c=args.c), 156 | dwa=dict(temp=args.dwa_temp), 157 | igbv2=dict(sac_lr=args.sac_lr, buffer_size=int(args.buffer_size)), 158 | ) 159 | ) 160 | return weight_methods_parameters 161 | -------------------------------------------------------------------------------- /methods/weight_method.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import random 3 | from abc import abstractmethod 4 | from typing import Dict, List, Tuple, Union 5 | 6 | import cvxpy as cp 7 | import numpy as np 8 | import torch 9 | import torch.nn.functional as F 10 | from scipy.optimize import minimize 11 | 12 | from methods.min_norm_solvers import MinNormSolver, gradient_normalizers 13 | 14 | 15 | class WeightMethod: 16 | def __init__(self, n_tasks: int, device: torch.device): 17 | super().__init__() 18 | self.n_tasks = n_tasks 19 | self.device = device 20 | 21 | @abstractmethod 22 | def get_weighted_loss( 23 | self, 24 | losses: torch.Tensor, 25 | shared_parameters: Union[List[torch.nn.parameter.Parameter], torch.Tensor], 26 | task_specific_parameters: Union[ 27 | List[torch.nn.parameter.Parameter], torch.Tensor 28 | ], 29 | last_shared_parameters: Union[List[torch.nn.parameter.Parameter], torch.Tensor], 30 | representation: Union[torch.nn.parameter.Parameter, torch.Tensor], 31 | **kwargs, 32 | ): 33 | pass 34 | 35 | @abstractmethod 36 | def get_weighted_losses( 37 | self, 38 | losses: torch.Tensor, 39 | shared_parameters: Union[List[torch.nn.parameter.Parameter], torch.Tensor], 40 | task_specific_parameters: Union[ 41 | List[torch.nn.parameter.Parameter], torch.Tensor 42 | ], 43 | last_shared_parameters: Union[List[torch.nn.parameter.Parameter], torch.Tensor], 44 | representation: Union[torch.nn.parameter.Parameter, torch.Tensor], 45 | **kwargs, 46 | ): 47 | pass 48 | 49 | def backward( 50 | self, 51 | losses: torch.Tensor, 52 | shared_parameters: Union[ 53 | List[torch.nn.parameter.Parameter], torch.Tensor 54 | ] = None, 55 | task_specific_parameters: Union[ 56 | List[torch.nn.parameter.Parameter], torch.Tensor 57 | ] = None, 58 | last_shared_parameters: Union[ 59 | List[torch.nn.parameter.Parameter], torch.Tensor 60 | ] = None, 61 | representation: Union[List[torch.nn.parameter.Parameter], torch.Tensor] = None, 62 | **kwargs, 63 | ) -> Tuple[Union[torch.Tensor, None], Union[dict, None]]: 64 | """ 65 | 66 | Parameters 67 | ---------- 68 | losses : 69 | shared_parameters : 70 | task_specific_parameters : 71 | last_shared_parameters : parameters of last shared layer/block 72 | representation : shared representation 73 | kwargs : 74 | 75 | Returns 76 | ------- 77 | Loss, extra outputs 78 | """ 79 | loss, extra_outputs = self.get_weighted_loss( 80 | losses=losses, 81 | shared_parameters=shared_parameters, 82 | task_specific_parameters=task_specific_parameters, 83 | last_shared_parameters=last_shared_parameters, 84 | representation=representation, 85 | **kwargs, 86 | ) 87 | loss.backward() 88 | return loss, extra_outputs 89 | 90 | def __call__( 91 | self, 92 | losses: torch.Tensor, 93 | shared_parameters: Union[ 94 | List[torch.nn.parameter.Parameter], torch.Tensor 95 | ] = None, 96 | task_specific_parameters: Union[ 97 | List[torch.nn.parameter.Parameter], torch.Tensor 98 | ] = None, 99 | **kwargs, 100 | ): 101 | return self.backward( 102 | losses=losses, 103 | shared_parameters=shared_parameters, 104 | task_specific_parameters=task_specific_parameters, 105 | **kwargs, 106 | ) 107 | 108 | def parameters(self) -> List[torch.Tensor]: 109 | """return learnable parameters""" 110 | return [] 111 | 112 | 113 | class LinearScalarization(WeightMethod): 114 | """Linear scalarization baseline L = sum_j w_j * l_j where l_j is the loss for task j and w_h""" 115 | 116 | def __init__( 117 | self, 118 | n_tasks: int, 119 | device: torch.device, 120 | task_weights: Union[List[float], torch.Tensor] = None, 121 | ): 122 | super().__init__(n_tasks, device=device) 123 | if task_weights is None: 124 | task_weights = torch.ones((n_tasks,)) 125 | if not isinstance(task_weights, torch.Tensor): 126 | task_weights = torch.tensor(task_weights) 127 | assert len(task_weights) == n_tasks 128 | self.task_weights = task_weights.to(device) 129 | 130 | def get_weighted_loss(self, losses, **kwargs): 131 | loss = torch.sum(losses * self.task_weights) 132 | return loss, dict(weights=self.task_weights) 133 | 134 | def get_weighted_losses(self, losses, **kwargs): 135 | losses = losses * self.task_weights 136 | return losses, dict(weights=self.task_weights) 137 | -------------------------------------------------------------------------------- /methods/min_norm_solvers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | # This code is from 6 | # Multi-Task Learning as Multi-Objective Optimization 7 | # Ozan Sener, Vladlen Koltun 8 | # Neural Information Processing Systems (NeurIPS) 2018 9 | # https://github.com/intel-isl/MultiObjectiveOptimization 10 | class MinNormSolver: 11 | MAX_ITER = 250 12 | STOP_CRIT = 1e-5 13 | 14 | @staticmethod 15 | def _min_norm_element_from2(v1v1, v1v2, v2v2): 16 | """ 17 | Analytical solution for min_{c} |cx_1 + (1-c)x_2|_2^2 18 | d is the distance (objective) optimzed 19 | v1v1 = 20 | v1v2 = 21 | v2v2 = 22 | """ 23 | if v1v2 >= v1v1: 24 | # Case: Fig 1, third column 25 | gamma = 0.999 26 | cost = v1v1 27 | return gamma, cost 28 | if v1v2 >= v2v2: 29 | # Case: Fig 1, first column 30 | gamma = 0.001 31 | cost = v2v2 32 | return gamma, cost 33 | # Case: Fig 1, second column 34 | gamma = -1.0 * ((v1v2 - v2v2) / (v1v1 + v2v2 - 2 * v1v2)) 35 | cost = v2v2 + gamma * (v1v2 - v2v2) 36 | return gamma, cost 37 | 38 | @staticmethod 39 | def _min_norm_2d(vecs, dps): 40 | """ 41 | Find the minimum norm solution as combination of two points 42 | This is correct only in 2D 43 | ie. min_c |\sum c_i x_i|_2^2 st. \sum c_i = 1 , 1 >= c_1 >= 0 for all i, c_i + c_j = 1.0 for some i, j 44 | """ 45 | dmin = 1e8 46 | for i in range(len(vecs)): 47 | for j in range(i + 1, len(vecs)): 48 | if (i, j) not in dps: 49 | dps[(i, j)] = 0.0 50 | for k in range(len(vecs[i])): 51 | dps[(i, j)] += torch.dot( 52 | vecs[i][k], vecs[j][k] 53 | ).item() # torch.dot(vecs[i][k], vecs[j][k]).dataset[0] 54 | dps[(j, i)] = dps[(i, j)] 55 | if (i, i) not in dps: 56 | dps[(i, i)] = 0.0 57 | for k in range(len(vecs[i])): 58 | dps[(i, i)] += torch.dot( 59 | vecs[i][k], vecs[i][k] 60 | ).item() # torch.dot(vecs[i][k], vecs[i][k]).dataset[0] 61 | if (j, j) not in dps: 62 | dps[(j, j)] = 0.0 63 | for k in range(len(vecs[i])): 64 | dps[(j, j)] += torch.dot( 65 | vecs[j][k], vecs[j][k] 66 | ).item() # torch.dot(vecs[j][k], vecs[j][k]).dataset[0] 67 | c, d = MinNormSolver._min_norm_element_from2( 68 | dps[(i, i)], dps[(i, j)], dps[(j, j)] 69 | ) 70 | if d < dmin: 71 | dmin = d 72 | sol = [(i, j), c, d] 73 | return sol, dps 74 | 75 | @staticmethod 76 | def _projection2simplex(y): 77 | """ 78 | Given y, it solves argmin_z |y-z|_2 st \sum z = 1 , 1 >= z_i >= 0 for all i 79 | """ 80 | m = len(y) 81 | sorted_y = np.flip(np.sort(y), axis=0) 82 | tmpsum = 0.0 83 | tmax_f = (np.sum(y) - 1.0) / m 84 | for i in range(m - 1): 85 | tmpsum += sorted_y[i] 86 | tmax = (tmpsum - 1) / (i + 1.0) 87 | if tmax > sorted_y[i + 1]: 88 | tmax_f = tmax 89 | break 90 | return np.maximum(y - tmax_f, np.zeros(y.shape)) 91 | 92 | @staticmethod 93 | def _next_point(cur_val, grad, n): 94 | proj_grad = grad - (np.sum(grad) / n) 95 | tm1 = -1.0 * cur_val[proj_grad < 0] / proj_grad[proj_grad < 0] 96 | tm2 = (1.0 - cur_val[proj_grad > 0]) / (proj_grad[proj_grad > 0]) 97 | 98 | skippers = np.sum(tm1 < 1e-7) + np.sum(tm2 < 1e-7) 99 | t = 1 100 | if len(tm1[tm1 > 1e-7]) > 0: 101 | t = np.min(tm1[tm1 > 1e-7]) 102 | if len(tm2[tm2 > 1e-7]) > 0: 103 | t = min(t, np.min(tm2[tm2 > 1e-7])) 104 | 105 | next_point = proj_grad * t + cur_val 106 | next_point = MinNormSolver._projection2simplex(next_point) 107 | return next_point 108 | 109 | @staticmethod 110 | def find_min_norm_element(vecs): 111 | """ 112 | Given a list of vectors (vecs), this method finds the minimum norm element in the convex hull 113 | as min |u|_2 st. u = \sum c_i vecs[i] and \sum c_i = 1. 114 | It is quite geometric, and the main idea is the fact that if d_{ij} = min |u|_2 st u = c x_i + (1-c) x_j; the solution lies in (0, d_{i,j}) 115 | Hence, we find the best 2-task solution, and then run the projected gradient descent until convergence 116 | """ 117 | # Solution lying at the combination of two points 118 | dps = {} 119 | init_sol, dps = MinNormSolver._min_norm_2d(vecs, dps) 120 | 121 | n = len(vecs) 122 | sol_vec = np.zeros(n) 123 | sol_vec[init_sol[0][0]] = init_sol[1] 124 | sol_vec[init_sol[0][1]] = 1 - init_sol[1] 125 | 126 | if n < 3: 127 | # This is optimal for n=2, so return the solution 128 | return sol_vec, init_sol[2] 129 | 130 | iter_count = 0 131 | 132 | grad_mat = np.zeros((n, n)) 133 | for i in range(n): 134 | for j in range(n): 135 | grad_mat[i, j] = dps[(i, j)] 136 | 137 | while iter_count < MinNormSolver.MAX_ITER: 138 | grad_dir = -1.0 * np.dot(grad_mat, sol_vec) 139 | new_point = MinNormSolver._next_point(sol_vec, grad_dir, n) 140 | # Re-compute the inner products for line search 141 | v1v1 = 0.0 142 | v1v2 = 0.0 143 | v2v2 = 0.0 144 | for i in range(n): 145 | for j in range(n): 146 | v1v1 += sol_vec[i] * sol_vec[j] * dps[(i, j)] 147 | v1v2 += sol_vec[i] * new_point[j] * dps[(i, j)] 148 | v2v2 += new_point[i] * new_point[j] * dps[(i, j)] 149 | nc, nd = MinNormSolver._min_norm_element_from2(v1v1, v1v2, v2v2) 150 | new_sol_vec = nc * sol_vec + (1 - nc) * new_point 151 | change = new_sol_vec - sol_vec 152 | if np.sum(np.abs(change)) < MinNormSolver.STOP_CRIT: 153 | return sol_vec, nd 154 | sol_vec = new_sol_vec 155 | 156 | @staticmethod 157 | def find_min_norm_element_FW(vecs): 158 | """ 159 | Given a list of vectors (vecs), this method finds the minimum norm element in the convex hull 160 | as min |u|_2 st. u = \sum c_i vecs[i] and \sum c_i = 1. 161 | It is quite geometric, and the main idea is the fact that if d_{ij} = min |u|_2 st u = c x_i + (1-c) x_j; the solution lies in (0, d_{i,j}) 162 | Hence, we find the best 2-task solution, and then run the Frank Wolfe until convergence 163 | """ 164 | # Solution lying at the combination of two points 165 | dps = {} 166 | init_sol, dps = MinNormSolver._min_norm_2d(vecs, dps) 167 | 168 | n = len(vecs) 169 | sol_vec = np.zeros(n) 170 | sol_vec[init_sol[0][0]] = init_sol[1] 171 | sol_vec[init_sol[0][1]] = 1 - init_sol[1] 172 | 173 | if n < 3: 174 | # This is optimal for n=2, so return the solution 175 | return sol_vec, init_sol[2] 176 | 177 | iter_count = 0 178 | 179 | grad_mat = np.zeros((n, n)) 180 | for i in range(n): 181 | for j in range(n): 182 | grad_mat[i, j] = dps[(i, j)] 183 | 184 | while iter_count < MinNormSolver.MAX_ITER: 185 | t_iter = np.argmin(np.dot(grad_mat, sol_vec)) 186 | 187 | v1v1 = np.dot(sol_vec, np.dot(grad_mat, sol_vec)) 188 | v1v2 = np.dot(sol_vec, grad_mat[:, t_iter]) 189 | v2v2 = grad_mat[t_iter, t_iter] 190 | 191 | nc, nd = MinNormSolver._min_norm_element_from2(v1v1, v1v2, v2v2) 192 | new_sol_vec = nc * sol_vec 193 | new_sol_vec[t_iter] += 1 - nc 194 | 195 | change = new_sol_vec - sol_vec 196 | if np.sum(np.abs(change)) < MinNormSolver.STOP_CRIT: 197 | return sol_vec, nd 198 | sol_vec = new_sol_vec 199 | 200 | 201 | def gradient_normalizers(grads, losses, normalization_type): 202 | gn = {} 203 | if normalization_type == "norm": 204 | for t in grads: 205 | gn[t] = np.sqrt(np.sum([gr.pow(2).sum().data[0] for gr in grads[t]])) 206 | elif normalization_type == "loss": 207 | for t in grads: 208 | gn[t] = losses[t] 209 | elif normalization_type == "loss+": 210 | for t in grads: 211 | gn[t] = losses[t] * np.sqrt( 212 | np.sum([gr.pow(2).sum().data[0] for gr in grads[t]]) 213 | ) 214 | elif normalization_type == "none": 215 | for t in grads: 216 | gn[t] = 1.0 217 | else: 218 | print("ERROR: Invalid Normalization Type") 219 | return gn 220 | -------------------------------------------------------------------------------- /experiments/quantum_chemistry/trainer.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import os 3 | import time 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn.functional as F 8 | import torch_geometric.transforms as T 9 | from torch_geometric.datasets import QM9 10 | from torch_geometric.loader import DataLoader 11 | from tqdm import trange 12 | 13 | import sys 14 | sys.path.append("../..") 15 | from experiments.quantum_chemistry.models import Net 16 | from experiments.quantum_chemistry.utils import ( 17 | Complete, 18 | MyTransform, 19 | delta_fn, 20 | multiply_indx, 21 | ) 22 | from experiments.quantum_chemistry.utils import target_idx as targets 23 | from experiments.utils import ( 24 | common_parser, 25 | extract_weight_method_parameters_from_args, 26 | get_device, 27 | set_logger, 28 | set_seed, 29 | str2bool, 30 | ) 31 | 32 | from methods.loss_weight_methods import LossWeightMethods 33 | from methods.gradient_weight_methods import GradientWeightMethods 34 | 35 | set_logger() 36 | 37 | 38 | @torch.no_grad() 39 | def evaluate(model, loader, std, scale_target): 40 | model.eval() 41 | data_size = 0.0 42 | task_losses = 0.0 43 | for i, data in enumerate(loader): 44 | data = data.to(device) 45 | out = model(data) 46 | if scale_target: 47 | task_losses += F.l1_loss( 48 | out * std.to(device), data.y * std.to(device), reduction="none" 49 | ).sum( 50 | 0 51 | ) # MAE 52 | else: 53 | task_losses += F.l1_loss(out, data.y, reduction="none").sum(0) # MAE 54 | data_size += len(data.y) 55 | 56 | model.train() 57 | 58 | avg_task_losses = task_losses / data_size 59 | 60 | # Report meV instead of eV. 61 | avg_task_losses = avg_task_losses.detach().cpu().numpy() 62 | avg_task_losses[multiply_indx] *= 1000 63 | 64 | delta_m = delta_fn(avg_task_losses) 65 | return dict( 66 | avg_loss=avg_task_losses.mean(), 67 | avg_task_losses=avg_task_losses, 68 | delta_m=delta_m, 69 | ) 70 | 71 | 72 | def main( 73 | data_path: str, 74 | batch_size: int, 75 | device: torch.device, 76 | lr: float, 77 | n_epochs: int, 78 | targets: list = None, 79 | scale_target: bool = True, 80 | main_task: int = None, 81 | ): 82 | timestr = time.strftime("%Y%m%d-%H%M%S") 83 | os.makedirs("./logs", exist_ok=True) 84 | log_file = f"./logs/{timestr}_{args.loss_method}_{args.gradient_method}_seed{args.seed}_log.txt" 85 | 86 | dim = 64 87 | model = Net(n_tasks=len(targets), num_features=11, dim=dim).to(device) 88 | 89 | transform = T.Compose([MyTransform(targets), Complete(), T.Distance(norm=False)]) 90 | dataset = QM9(data_path, transform=transform).shuffle() 91 | 92 | # Split datasets. 93 | test_dataset = dataset[:10000] 94 | val_dataset = dataset[10000:20000] 95 | train_dataset = dataset[20000:] 96 | 97 | std = None 98 | if scale_target: 99 | mean = train_dataset.data.y[:, targets].mean(dim=0, keepdim=True) 100 | std = train_dataset.data.y[:, targets].std(dim=0, keepdim=True) 101 | 102 | dataset.data.y[:, targets] = (dataset.data.y[:, targets] - mean) / std 103 | 104 | test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0) 105 | val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0) 106 | train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0) 107 | 108 | loss_weight_methods_parameters = extract_weight_method_parameters_from_args(args) 109 | loss_weight_method = LossWeightMethods( 110 | args.loss_method, n_tasks=len(targets), device=device, **loss_weight_methods_parameters[args.loss_method] 111 | ) 112 | # gradient_weight method 113 | gradient_weight_methods_parameters = extract_weight_method_parameters_from_args(args) 114 | gradient_weight_method = GradientWeightMethods( 115 | args.gradient_method, n_tasks=len(targets), device=device, **gradient_weight_methods_parameters[args.gradient_method] 116 | ) 117 | 118 | optimizer = torch.optim.Adam( 119 | [ 120 | dict(params=model.parameters(), lr=lr), 121 | dict(params=loss_weight_method.parameters(), lr=args.method_params_lr), 122 | dict(params=gradient_weight_method.parameters(), lr=args.method_params_lr), 123 | 124 | ], 125 | ) 126 | 127 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 128 | optimizer, mode="min", factor=0.7, patience=5, min_lr=0.00001 129 | ) 130 | 131 | epoch_iterator = trange(n_epochs) 132 | train_batch = len(train_loader) 133 | 134 | best_val = np.inf 135 | best_val_delta = np.inf 136 | 137 | train_time_sum = 0.0 138 | 139 | # reward scale for IGBv2 140 | if args.loss_method == 'igbv2': 141 | loss_weight_method.method.train_batch = train_batch 142 | 143 | for epoch in epoch_iterator: 144 | lr = optimizer.param_groups[0]["lr"] 145 | avg_train_losses = torch.zeros(len(targets)).to(device) 146 | avg_loss_weights = torch.zeros(len(targets)).to(device) 147 | 148 | start_train_time = time.time() 149 | 150 | # reward scale for IGBv2 151 | if args.loss_method == 'igbv2': 152 | loss_weight_method.method.reward_scale = lr / optimizer.param_groups[0]['lr'] 153 | 154 | for j, data in enumerate(train_loader): 155 | model.train() 156 | 157 | data = data.to(device) 158 | optimizer.zero_grad() 159 | 160 | out, features = model(data, return_representation=True) 161 | 162 | losses = F.mse_loss(out, data.y, reduction="none").mean(0) 163 | # print(losses) 164 | avg_train_losses += losses.detach() / train_batch 165 | 166 | weighted_losses, loss_weights = loss_weight_method.get_weighted_losses( 167 | losses=losses, 168 | shared_parameters=list(model.shared_parameters()), 169 | task_specific_parameters=list(model.task_specific_parameters()), 170 | last_shared_parameters=list(model.last_shared_parameters()), 171 | representation=features, 172 | ) 173 | avg_loss_weights += loss_weights['weights'] / train_batch 174 | 175 | loss, gradient_weights = gradient_weight_method.backward( 176 | losses=weighted_losses, 177 | shared_parameters=list(model.shared_parameters()), 178 | task_specific_parameters=list(model.task_specific_parameters()), 179 | last_shared_parameters=list(model.last_shared_parameters()), 180 | representation=features, 181 | ) 182 | 183 | optimizer.step() 184 | 185 | epoch_iterator.set_description( 186 | f"[{epoch} {j + 1}/{train_batch}]" 187 | ) 188 | 189 | # base_losses for IGBv1 and IGBv2 190 | if 'igb' in args.loss_method and epoch == args.base_epoch: 191 | loss_weight_method.method.base_losses = avg_train_losses 192 | 193 | end_train_time = time.time() 194 | train_time_sum += end_train_time - start_train_time 195 | 196 | val_loss_dict = evaluate(model, val_loader, std=std, scale_target=scale_target) 197 | val_loss = val_loss_dict["avg_loss"] 198 | val_delta = val_loss_dict["delta_m"] 199 | 200 | results = f"Epoch: {epoch:04d}\n" \ 201 | f"AVERAGE LOSS WEIGHTS: " \ 202 | f"{avg_loss_weights[0]:.4f} {avg_loss_weights[1]:.4f} {avg_loss_weights[2]:.4f} " \ 203 | f"{avg_loss_weights[3]:.4f} {avg_loss_weights[4]:.4f} {avg_loss_weights[5]:.4f} " \ 204 | f"{avg_loss_weights[6]:.4f} {avg_loss_weights[7]:.4f} {avg_loss_weights[8]:.4f} " \ 205 | f"{avg_loss_weights[9]:.4f} {avg_loss_weights[10]:.4f}\n" \ 206 | f"TRAIN: {losses.mean().item():.3f}\n" \ 207 | f"VAL: {val_loss:.3f} {val_delta:.3f}\n" 208 | 209 | if args.loss_method == "stl": 210 | best_val_criteria = val_loss_dict["avg_task_losses"][main_task] <= best_val 211 | else: 212 | best_val_criteria = val_delta <= best_val_delta 213 | 214 | if best_val_criteria: 215 | best_val = val_loss 216 | best_val_delta = val_delta 217 | 218 | test_loss_dict = evaluate(model, test_loader, std=std, scale_target=scale_target) 219 | test_loss = test_loss_dict["avg_loss"] 220 | test_task_losses = test_loss_dict["avg_task_losses"] 221 | test_delta = test_loss_dict["delta_m"] 222 | test_result = f"TEST: {test_loss:.3f} {test_delta:.3f}\n" 223 | test_result += f"TEST LOSSES: " 224 | for i in range(len(targets)): 225 | test_result += f"{test_task_losses[i]:.3f} " 226 | test_result = test_result[:-1] + "\n" 227 | print(test_result, end='') 228 | results += test_result 229 | 230 | with open(log_file, mode="a") as log_f: 231 | log_f.write(results) 232 | 233 | scheduler.step( 234 | val_loss_dict["avg_task_losses"][main_task] 235 | if args.loss_method == "stl" 236 | else val_delta 237 | ) 238 | 239 | train_time_log = f"Training time: {int(train_time_sum)}s\n" 240 | print(train_time_log, end='') 241 | with open(log_file, mode="a") as log_f: 242 | log_f.write(train_time_log) 243 | 244 | 245 | if __name__ == "__main__": 246 | parser = ArgumentParser("QM9", parents=[common_parser]) 247 | parser.set_defaults( 248 | data_path="./dataset", 249 | lr=1e-3, 250 | n_epochs=300, 251 | batch_size=120, 252 | ) 253 | parser.add_argument("--scale-y", default=True, type=str2bool) 254 | args = parser.parse_args() 255 | 256 | # set seed 257 | set_seed(args.seed) 258 | 259 | device = get_device(gpus=args.gpu) 260 | main( 261 | data_path=args.data_path, 262 | batch_size=args.batch_size, 263 | device=device, 264 | lr=args.lr, 265 | n_epochs=args.n_epochs, 266 | targets=targets, 267 | scale_target=args.scale_y, 268 | main_task=args.main_task, 269 | ) 270 | -------------------------------------------------------------------------------- /methods/SAC_Agent.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.distributions import Normal 8 | 9 | 10 | """ 11 | Soft Actor Critic 12 | Modification of: https://github.com/XinJingHao/SAC-Continuous-Pytorch 13 | """ 14 | 15 | 16 | class RandomBuffer(object): 17 | def __init__(self, state_dim, action_dim, max_size=int(1e4), device='cuda'): 18 | self.max_size = max_size 19 | self.ptr = 0 20 | self.size = 0 21 | 22 | self.state = torch.zeros((max_size, state_dim)) 23 | self.action = torch.zeros((max_size, action_dim)) 24 | self.reward = torch.zeros((max_size, 1)) 25 | self.next_state = torch.zeros((max_size, state_dim)) 26 | 27 | self.device = device 28 | 29 | def add(self, state, action, reward, next_state): 30 | self.state[self.ptr] = state 31 | self.action[self.ptr] = action 32 | self.reward[self.ptr] = reward 33 | self.next_state[self.ptr] = next_state 34 | 35 | self.ptr = (self.ptr + 1) % self.max_size 36 | self.size = min(self.size + 1, self.max_size) 37 | 38 | def clean(self): 39 | self.size = 0 40 | 41 | def sample(self, batch_size): 42 | # ind = np.random.randint(0, self.size, size=batch_size) 43 | ind = torch.randint(0, self.size, size=(1, batch_size)).squeeze() 44 | with torch.no_grad(): 45 | return ( 46 | self.state[ind].to(self.device), 47 | self.action[ind].to(self.device), 48 | self.reward[ind].to(self.device), 49 | self.next_state[ind].to(self.device), 50 | ) 51 | 52 | # error, need to rewrite for tensor 53 | def save(self): 54 | """save the replay buffer if you want""" 55 | scaller = np.array([self.max_size, self.ptr, self.size], dtype=np.uint32) 56 | np.save("buffer/scaller.npy", scaller) 57 | np.save("buffer/state.npy", self.state) 58 | np.save("buffer/action.npy", self.action) 59 | np.save("buffer/reward.npy", self.reward) 60 | np.save("buffer/next_state.npy", self.next_state) 61 | 62 | # error, need to rewrite for tensor 63 | def load(self): 64 | scaller = np.load("buffer/scaller.npy") 65 | 66 | self.max_size = scaller[0] 67 | self.ptr = scaller[1] 68 | self.size = scaller[2] 69 | 70 | self.state = np.load("buffer/state.npy") 71 | self.action = np.load("buffer/action.npy") 72 | self.reward = np.load("buffer/reward.npy") 73 | self.next_state = np.load("buffer/next_state.npy") 74 | 75 | 76 | # Build net with for loop: multi-layer MLP with activation function 77 | def build_net(layer_shape, activation, output_activation): 78 | layers = [] 79 | for j in range(len(layer_shape) - 1): 80 | act = activation if j < len(layer_shape) - 2 else output_activation 81 | layers += [nn.Linear(layer_shape[j], layer_shape[j + 1]), act()] 82 | return nn.Sequential(*layers) 83 | 84 | 85 | class Actor(nn.Module): 86 | def __init__(self, state_dim, action_dim, hid_shape, h_acti=nn.ReLU, o_acti=nn.ReLU): 87 | super(Actor, self).__init__() 88 | 89 | layers = [state_dim] + list(hid_shape) 90 | self.a_net = build_net(layers, h_acti, o_acti) 91 | self.mu_layer = nn.Linear(layers[-1], action_dim) 92 | self.log_std_layer = nn.Linear(layers[-1], action_dim) 93 | 94 | self.LOG_STD_MAX = 2 95 | self.LOG_STD_MIN = -20 96 | 97 | self.action_dim = action_dim 98 | 99 | def forward(self, state, deterministic=False, with_logprob=True): 100 | """Network with Enforcing Action Bounds""" 101 | net_out = self.a_net(state) 102 | mu = self.mu_layer(net_out) 103 | log_std = self.log_std_layer(net_out) 104 | log_std = torch.clamp(log_std, self.LOG_STD_MIN, self.LOG_STD_MAX) 105 | std = torch.exp(log_std) 106 | dist = Normal(mu, std) 107 | 108 | if deterministic: 109 | u = mu 110 | else: 111 | u = dist.rsample() # reparameterization trick of Gaussian 112 | # a = torch.tanh(u) # dai norm_action 113 | a = self.action_dim * F.softmax(u, dim=-1) 114 | 115 | if with_logprob: 116 | # get probability density of logp_pi_a from probability density of u, which is given by the original paper. 117 | # logp_pi_a = (dist.log_prob(u) - torch.log(1 - a.pow(2) + 1e-6)).sum(dim=1, keepdim=True) 118 | 119 | # Derive from the above equation. No a, thus no tanh(h), thus less gradient vanish and more stable. 120 | logp_pi_a = dist.log_prob(u).sum(axis=1, keepdim=True) - (2 * (np.log(2) - u - F.softplus(-2 * u))).sum( 121 | axis=1, keepdim=True) 122 | else: 123 | logp_pi_a = None 124 | 125 | return a, logp_pi_a 126 | 127 | 128 | class Q_Critic(nn.Module): 129 | def __init__(self, state_dim, action_dim, hid_shape): 130 | super(Q_Critic, self).__init__() 131 | layers = [state_dim + action_dim] + list(hid_shape) + [1] 132 | 133 | self.Q_1 = build_net(layers, nn.ReLU, nn.Identity) 134 | self.Q_2 = build_net(layers, nn.ReLU, nn.Identity) 135 | 136 | def forward(self, state, action): 137 | sa = torch.cat([state, action], 1) 138 | # print(sa.size()) 139 | q1 = self.Q_1(sa) 140 | q2 = self.Q_2(sa) 141 | return q1, q2 142 | 143 | 144 | class SAC_Agent(nn.Module): 145 | def __init__( 146 | self, 147 | state_dim, 148 | action_dim, 149 | gamma=0.99, 150 | hid_shape=(128, 128), 151 | a_lr=3e-4, 152 | c_lr=3e-4, 153 | batch_size=256, 154 | alpha=0.2, 155 | adaptive_alpha=True, 156 | device='cuda' 157 | ): 158 | 159 | super(SAC_Agent, self).__init__() 160 | self.actor = Actor(state_dim, action_dim, hid_shape).to(device) 161 | self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=a_lr) 162 | 163 | self.q_critic = Q_Critic(state_dim, action_dim, hid_shape).to(device) 164 | self.q_critic_optimizer = torch.optim.Adam(self.q_critic.parameters(), lr=c_lr) 165 | self.q_critic_target = copy.deepcopy(self.q_critic) 166 | # Freeze target networks with respect to optimizers (only update via polyak averaging) 167 | for p in self.q_critic_target.parameters(): 168 | p.requires_grad = False 169 | 170 | self.action_dim = action_dim 171 | self.gamma = gamma 172 | self.tau = 0.005 173 | self.batch_size = batch_size 174 | 175 | self.alpha = alpha 176 | self.adaptive_alpha = adaptive_alpha 177 | if adaptive_alpha: 178 | # Target Entropy = −dim(A) (e.g. , -6 for HalfCheetah-v2) as given in the paper 179 | self.target_entropy = torch.tensor(-action_dim, dtype=float, requires_grad=True, device=device) 180 | # We learn log_alpha instead of alpha to ensure exp(log_alpha)=alpha>0 181 | self.log_alpha = torch.tensor(np.log(alpha), dtype=float, requires_grad=True, device=device) 182 | self.alpha_optim = torch.optim.Adam([self.log_alpha], lr=c_lr) 183 | 184 | self.device = device 185 | 186 | def select_action(self, state, deterministic, with_logprob=False): 187 | # only used when interact with the env 188 | with torch.no_grad(): 189 | # state = torch.FloatTensor(state.reshape(1, -1)).to(self.device) 190 | a, _ = self.actor(state, deterministic, with_logprob) 191 | # return a.cpu().numpy().flatten() 192 | return a 193 | 194 | # def train(self, replay_buffer): 195 | def train(self, replay_buffer, k): 196 | s, a, r, s_prime = replay_buffer.sample(int(self.batch_size * (k / 2))) 197 | # s, a, r, s_prime = replay_buffer.sample(self.batch_size) 198 | 199 | # ----------------------------- ↓↓↓↓↓ Update Q Net ↓↓↓↓↓ ------------------------------ # 200 | with torch.no_grad(): 201 | a_prime, log_pi_a_prime = self.actor(s_prime) 202 | target_Q1, target_Q2 = self.q_critic_target(s_prime, a_prime) 203 | target_Q = torch.min(target_Q1, target_Q2) 204 | target_Q = r + self.gamma * (target_Q - self.alpha * log_pi_a_prime) 205 | 206 | # Get current Q estimates 207 | current_Q1, current_Q2 = self.q_critic(s, a) 208 | 209 | q_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q) 210 | self.q_critic_optimizer.zero_grad() 211 | q_loss.backward() 212 | self.q_critic_optimizer.step() 213 | 214 | # ----------------------------- ↓↓↓↓↓ Update Actor Net ↓↓↓↓↓ ------------------------------# 215 | # Freeze Q-networks so you don't waste computational effort 216 | # computing gradients for them during the policy learning step. 217 | for params in self.q_critic.parameters(): 218 | params.requires_grad = False 219 | 220 | a, log_pi_a = self.actor(s) 221 | current_Q1, current_Q2 = self.q_critic(s, a) 222 | Q = torch.min(current_Q1, current_Q2) 223 | 224 | a_loss = (self.alpha * log_pi_a - Q).mean() 225 | self.actor_optimizer.zero_grad() 226 | a_loss.backward() 227 | self.actor_optimizer.step() 228 | 229 | for params in self.q_critic.parameters(): 230 | params.requires_grad = True 231 | # ----------------------------- ↓↓↓↓↓ Update alpha ↓↓↓↓↓ ------------------------------# 232 | if self.adaptive_alpha: 233 | # we optimize log_alpha instead of aplha, which is aimed to force alpha = exp(log_alpha)> 0 234 | # if we optimize aplpha directly, alpha might be < 0, which will lead to minimun entropy. 235 | alpha_loss = -(self.log_alpha * (log_pi_a + self.target_entropy).detach()).mean() 236 | self.alpha_optim.zero_grad() 237 | alpha_loss.backward() 238 | self.alpha_optim.step() 239 | self.alpha = self.log_alpha.exp() 240 | 241 | # ----------------------------- ↓↓↓↓↓ Update Target Net ↓↓↓↓↓ ------------------------------# 242 | for param, target_param in zip(self.q_critic.parameters(), self.q_critic_target.parameters()): 243 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) 244 | 245 | def save(self, episode): 246 | torch.save(self.actor.state_dict(), "./rl_model/sac_actor{}.pth".format(episode)) 247 | torch.save(self.q_critic.state_dict(), "./rl_model/sac_q_critic{}.pth".format(episode)) 248 | 249 | def load(self, episode): 250 | self.actor.load_state_dict(torch.load("./rl_model/sac_actor{}.pth".format(episode))) 251 | self.q_critic.load_state_dict(torch.load("./rl_model/sac_q_critic{}.pth".format(episode))) -------------------------------------------------------------------------------- /experiments/nyuv2/trainer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import time 4 | from argparse import ArgumentParser 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | from torch.utils.data import DataLoader 10 | from tqdm import trange 11 | 12 | import sys 13 | sys.path.append("../..") 14 | from experiments.nyuv2.data import NYUv2 15 | from experiments.nyuv2.models import SegNet, SegNetMtan 16 | from experiments.nyuv2.utils import ConfMatrix, delta_fn, depth_error, normal_error, stl_eval_mean 17 | from experiments.utils import ( 18 | common_parser, 19 | extract_weight_method_parameters_from_args, 20 | get_device, 21 | set_logger, 22 | set_seed, 23 | str2bool, 24 | ) 25 | 26 | from methods.loss_weight_methods import LossWeightMethods 27 | from methods.gradient_weight_methods import GradientWeightMethods 28 | 29 | set_logger() 30 | 31 | 32 | def calc_loss(x_pred, x_output, task_type): 33 | device = x_pred.device 34 | 35 | # binary mark to mask out undefined pixel space 36 | binary_mask = (torch.sum(x_output, dim=1) != 0).float().unsqueeze(1).to(device) 37 | 38 | if task_type == "semantic": 39 | # semantic loss: depth-wise cross entropy 40 | loss = F.nll_loss(x_pred, x_output, ignore_index=-1) 41 | 42 | if task_type == "depth": 43 | # depth loss: l1 norm 44 | loss = torch.sum(torch.abs(x_pred - x_output) * binary_mask) / torch.nonzero( 45 | binary_mask, as_tuple=False 46 | ).size(0) 47 | 48 | if task_type == "normal": 49 | # normal loss: dot product 50 | loss = 1 - torch.sum((x_pred * x_output) * binary_mask) / torch.nonzero( 51 | binary_mask, as_tuple=False 52 | ).size(0) 53 | 54 | return loss 55 | 56 | 57 | def main(path, lr, bs, device): 58 | timestr = time.strftime("%Y%m%d-%H%M%S") 59 | os.makedirs("./logs", exist_ok=True) 60 | log_file = f"./logs/{timestr}_{args.loss_method}_{args.gradient_method}_seed{args.seed}_log.txt" 61 | 62 | # Nets 63 | model = dict(segnet=SegNet(), mtan=SegNetMtan())[args.model] 64 | model = model.to(device) 65 | 66 | # dataset and dataloaders 67 | log_str = ( 68 | "Applying data augmentation on NYUv2." 69 | if args.apply_augmentation 70 | else "Standard training strategy without data augmentation." 71 | ) 72 | logging.info(log_str) 73 | 74 | nyuv2_train_set = NYUv2(root=path.as_posix(), mode="train", augmentation=args.apply_augmentation) 75 | nyuv2_val_set = NYUv2(root=path.as_posix(), mode="val") 76 | nyuv2_test_set = NYUv2(root=path.as_posix(), mode="test") 77 | 78 | train_loader = DataLoader(dataset=nyuv2_train_set, batch_size=bs, shuffle=True) 79 | val_loader = DataLoader(dataset=nyuv2_val_set, batch_size=bs, shuffle=False) 80 | test_loader = DataLoader(dataset=nyuv2_test_set, batch_size=bs, shuffle=False) 81 | 82 | # loss_weight method 83 | loss_weight_methods_parameters = extract_weight_method_parameters_from_args(args) 84 | loss_weight_method = LossWeightMethods( 85 | args.loss_method, n_tasks=3, device=device, **loss_weight_methods_parameters[args.loss_method] 86 | ) 87 | 88 | # gradient_weight method 89 | gradient_weight_methods_parameters = extract_weight_method_parameters_from_args(args) 90 | gradient_weight_method = GradientWeightMethods( 91 | args.gradient_method, n_tasks=3, device=device, **gradient_weight_methods_parameters[args.gradient_method] 92 | ) 93 | 94 | # optimizer 95 | optimizer = torch.optim.Adam( 96 | [ 97 | dict(params=model.parameters(), lr=lr), 98 | dict(params=loss_weight_method.parameters(), lr=args.method_params_lr), 99 | dict(params=gradient_weight_method.parameters(), lr=args.method_params_lr), 100 | ], 101 | ) 102 | 103 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.5) 104 | 105 | epochs = args.n_epochs 106 | epoch_iter = trange(epochs) 107 | train_batch = len(train_loader) 108 | val_batch = len(val_loader) 109 | test_batch = len(test_loader) 110 | avg_cost = np.zeros([epochs, 24], dtype=np.float32) 111 | 112 | # best model to test 113 | best_epoch = None 114 | best_eval = 0 115 | 116 | # print result head 117 | print( 118 | f"LOSS FORMAT: SEMANTIC_LOSS MEAN_IOU PIX_ACC | DEPTH_LOSS ABS_ERR REL_ERR " 119 | f"| NORMAL_LOSS MEAN MED <11.25 <22.5 <30 | ∆m" 120 | ) 121 | 122 | train_time_sum = 0.0 123 | 124 | # train batch for IGBv2 125 | if args.loss_method == 'igbv2': 126 | loss_weight_method.method.train_batch = train_batch 127 | 128 | for epoch in epoch_iter: 129 | cost = np.zeros(24, dtype=np.float32) 130 | conf_mat = ConfMatrix(model.segnet.class_nb) 131 | avg_loss_weights = torch.zeros(3).to(device) 132 | 133 | start_train_time = time.time() 134 | 135 | # reward scale for IGBv2 136 | if args.loss_method == 'igbv2': 137 | loss_weight_method.method.reward_scale = lr / optimizer.param_groups[0]['lr'] 138 | 139 | for j, batch in enumerate(train_loader): 140 | model.train() 141 | optimizer.zero_grad() 142 | 143 | train_data, train_label, train_depth, train_normal = batch 144 | train_data, train_label = train_data.to(device), train_label.long().to(device) 145 | train_depth, train_normal = train_depth.to(device), train_normal.to(device) 146 | 147 | train_pred, features = model(train_data, return_representation=True) 148 | 149 | losses = torch.stack((calc_loss(train_pred[0], train_label, "semantic"), 150 | calc_loss(train_pred[1], train_depth, "depth"), 151 | calc_loss(train_pred[2], train_normal, "normal"))) 152 | 153 | weighted_losses, loss_weights = loss_weight_method.get_weighted_losses( 154 | losses=losses, 155 | shared_parameters=list(model.shared_parameters()), 156 | task_specific_parameters=list(model.task_specific_parameters()), 157 | last_shared_parameters=list(model.last_shared_parameters()), 158 | representation=features, 159 | ) 160 | avg_loss_weights += loss_weights['weights'] / train_batch 161 | 162 | loss, gradient_weights = gradient_weight_method.backward( 163 | losses=weighted_losses, 164 | shared_parameters=list(model.shared_parameters()), 165 | task_specific_parameters=list(model.task_specific_parameters()), 166 | last_shared_parameters=list(model.last_shared_parameters()), 167 | representation=features, 168 | ) 169 | 170 | optimizer.step() 171 | 172 | # accumulate label prediction for every pixel in training images 173 | conf_mat.update(train_pred[0].argmax(1).flatten(), train_label.flatten()) 174 | 175 | cost[0] = losses[0].item() 176 | cost[3] = losses[1].item() 177 | cost[4], cost[5] = depth_error(train_pred[1], train_depth) 178 | cost[6] = losses[2].item() 179 | cost[7], cost[8], cost[9], cost[10], cost[11] = normal_error(train_pred[2], train_normal) 180 | avg_cost[epoch, :12] += cost[:12] / train_batch 181 | 182 | epoch_iter.set_description( 183 | f"[{epoch} {j + 1}/{train_batch}] losses: {losses[0].item():.3f} " 184 | f"{losses[1].item():.3f} {losses[2].item():.3f} " 185 | f"weights: {loss_weights['weights'][0].item():.3f} " 186 | f"{loss_weights['weights'][1].item():.3f} {loss_weights['weights'][2].item():.3f}" 187 | ) 188 | 189 | # scheduler 190 | scheduler.step() 191 | # compute mIoU and acc 192 | avg_cost[epoch, 1:3] = conf_mat.get_metrics() 193 | 194 | # base_losses for IGBv1 and IGBv2 195 | if 'igb' in args.loss_method and epoch == args.base_epoch: 196 | base_losses = torch.Tensor(avg_cost[epoch, [0, 3, 6]]).to(device) 197 | loss_weight_method.method.base_losses = base_losses 198 | 199 | end_train_time = time.time() 200 | train_time_sum += end_train_time - start_train_time 201 | 202 | # todo: move evaluate to function? 203 | # evaluating test data 204 | model.eval() 205 | conf_mat = ConfMatrix(model.segnet.class_nb) 206 | with torch.no_grad(): # operations inside don't track history 207 | for j, batch in enumerate(val_loader): 208 | val_data, val_label, val_depth, val_normal = batch 209 | val_data, val_label = val_data.to(device), val_label.long().to(device) 210 | val_depth, val_normal = val_depth.to(device), val_normal.to(device) 211 | 212 | val_pred = model(val_data) 213 | val_loss = torch.stack( 214 | ( 215 | calc_loss(val_pred[0], val_label, "semantic"), 216 | calc_loss(val_pred[1], val_depth, "depth"), 217 | calc_loss(val_pred[2], val_normal, "normal"), 218 | ) 219 | ) 220 | 221 | conf_mat.update(val_pred[0].argmax(1).flatten(), val_label.flatten()) 222 | 223 | cost[12] = val_loss[0].item() 224 | cost[15] = val_loss[1].item() 225 | cost[16], cost[17] = depth_error(val_pred[1], val_depth) 226 | cost[18] = val_loss[2].item() 227 | cost[19], cost[20], cost[21], cost[22], cost[23] = normal_error(val_pred[2], val_normal) 228 | avg_cost[epoch, 12:] += cost[12:] / val_batch 229 | 230 | # compute mIoU and acc 231 | avg_cost[epoch, 13:15] = conf_mat.get_metrics() 232 | 233 | # Val Delta_m 234 | val_delta_m = delta_fn( 235 | avg_cost[epoch, [13, 14, 16, 17, 19, 20, 21, 22, 23]] 236 | ) 237 | 238 | if args.loss_method != "stl": 239 | eval_value = val_delta_m 240 | else: 241 | eval_value = stl_eval_mean(avg_cost[epoch, [13, 14, 16, 17, 19, 20, 21, 22, 23]], args.main_task) 242 | 243 | results = f"Epoch: {epoch:04d}\n" \ 244 | f"AVERAGE LOSS WEIGHTS: " \ 245 | f"{avg_loss_weights[0]:.4f} {avg_loss_weights[1]:.4f} {avg_loss_weights[2]:.4f}\n" \ 246 | f"TRAIN: " \ 247 | f"{avg_cost[epoch, 0]:.4f} {avg_cost[epoch, 1]:.4f} {avg_cost[epoch, 2]:.4f} | " \ 248 | f"{avg_cost[epoch, 3]:.4f} {avg_cost[epoch, 4]:.4f} {avg_cost[epoch, 5]:.4f} | " \ 249 | f"{avg_cost[epoch, 6]:.4f} {avg_cost[epoch, 7]:.2f} {avg_cost[epoch, 8]:.2f} " \ 250 | f"{avg_cost[epoch, 9]:.4f} {avg_cost[epoch, 10]:.4f} {avg_cost[epoch, 11]:.4f}\n" \ 251 | f"VAL: " \ 252 | f"{avg_cost[epoch, 12]:.4f} {avg_cost[epoch, 13]:.4f} {avg_cost[epoch, 14]:.4f} | " \ 253 | f"{avg_cost[epoch, 15]:.4f} {avg_cost[epoch, 16]:.4f} {avg_cost[epoch, 17]:.4f} | " \ 254 | f"{avg_cost[epoch, 18]:.4f} {avg_cost[epoch, 19]:.2f} {avg_cost[epoch, 20]:.2f} " \ 255 | f"{avg_cost[epoch, 21]:.4f} {avg_cost[epoch, 22]:.4f} {avg_cost[epoch, 23]:.4f} | " \ 256 | f"{val_delta_m:.3f}\n" 257 | 258 | if best_epoch is None or eval_value < best_eval: 259 | best_epoch = epoch 260 | best_eval = eval_value 261 | 262 | # test 263 | test_cost = np.zeros(12, dtype=np.float32) 264 | test_avg_cost = np.zeros(12, dtype=np.float32) 265 | conf_mat = ConfMatrix(model.segnet.class_nb) 266 | with torch.no_grad(): 267 | for j, batch in enumerate(test_loader): 268 | test_data, test_label, test_depth, test_normal = batch 269 | test_data, test_label = test_data.to(device), test_label.long().to(device) 270 | test_depth, test_normal = test_depth.to(device), test_normal.to(device) 271 | 272 | test_pred = model(test_data) 273 | test_loss = torch.stack( 274 | ( 275 | calc_loss(test_pred[0], test_label, "semantic"), 276 | calc_loss(test_pred[1], test_depth, "depth"), 277 | calc_loss(test_pred[2], test_normal, "normal"), 278 | ) 279 | ) 280 | 281 | conf_mat.update(test_pred[0].argmax(1).flatten(), test_label.flatten()) 282 | 283 | test_cost[0] = test_loss[0].item() 284 | test_cost[3] = test_loss[1].item() 285 | test_cost[4], test_cost[5] = depth_error(test_pred[1], test_depth) 286 | test_cost[6] = test_loss[2].item() 287 | test_cost[7], test_cost[8], test_cost[9], test_cost[10], test_cost[11] = normal_error( 288 | test_pred[2], test_normal 289 | ) 290 | test_avg_cost += test_cost / test_batch 291 | 292 | # compute mIoU and acc 293 | test_avg_cost[1:3] = conf_mat.get_metrics() 294 | 295 | # Test Delta_m 296 | test_delta_m = delta_fn( 297 | test_avg_cost[[1, 2, 4, 5, 7, 8, 9, 10, 11]] 298 | ) 299 | test_result = f"TEST: {test_avg_cost[0]:.4f} {test_avg_cost[1]:.4f} {test_avg_cost[2]:.4f} | " \ 300 | f"{test_avg_cost[3]:.4f} {test_avg_cost[4]:.4f} {test_avg_cost[5]:.4f} | " \ 301 | f"{test_avg_cost[6]:.4f} {test_avg_cost[7]:.2f} {test_avg_cost[8]:.2f} " \ 302 | f"{test_avg_cost[9]:.4f} {test_avg_cost[10]:.4f} {test_avg_cost[11]:.4f} | " \ 303 | f"{test_delta_m:.3f}\n" 304 | results += test_result 305 | # print test result 306 | print(test_result, end='') 307 | with open(log_file, mode="a") as log_f: 308 | log_f.write(results) 309 | 310 | train_time_log = f"Training time: {int(train_time_sum)}s\n" 311 | print(train_time_log, end='') 312 | with open(log_file, mode="a") as log_f: 313 | log_f.write(train_time_log) 314 | 315 | 316 | if __name__ == "__main__": 317 | parser = ArgumentParser("NYUv2", parents=[common_parser]) 318 | parser.set_defaults( 319 | data_path="./dataset", 320 | lr=1e-4, 321 | n_epochs=500, 322 | batch_size=2, 323 | ) 324 | parser.add_argument( 325 | "--model", 326 | type=str, 327 | default="segnet", 328 | choices=["segnet", "mtan"], 329 | help="model type", 330 | ) 331 | parser.add_argument( 332 | "--apply-augmentation", 333 | type=str2bool, 334 | default=True, 335 | help="data augmentations" 336 | ) 337 | args = parser.parse_args() 338 | 339 | # set seed 340 | set_seed(args.seed) 341 | 342 | device = get_device(gpus=args.gpu) 343 | main(path=args.data_path, lr=args.lr, bs=args.batch_size, device=device) 344 | -------------------------------------------------------------------------------- /methods/loss_weight_methods.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Tuple, Union 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | from methods.weight_method import WeightMethod, LinearScalarization 8 | from methods.SAC_Agent import SAC_Agent, RandomBuffer 9 | 10 | 11 | class ScaleInvariantLinearScalarization(WeightMethod): 12 | """Scale-invariant loss balancing paradigm""" 13 | 14 | def __init__( 15 | self, 16 | n_tasks: int, 17 | device: torch.device, 18 | task_weights: Union[List[float], torch.Tensor] = None, 19 | ): 20 | super().__init__(n_tasks, device=device) 21 | if task_weights is None: 22 | task_weights = torch.ones((n_tasks,)) 23 | if not isinstance(task_weights, torch.Tensor): 24 | task_weights = torch.tensor(task_weights) 25 | assert len(task_weights) == n_tasks 26 | self.task_weights = task_weights.to(device) 27 | 28 | def get_weighted_loss(self, losses, **kwargs): 29 | loss = torch.sum(torch.log(losses) * self.task_weights) 30 | return loss, dict(weights=self.task_weights) 31 | 32 | def get_weighted_losses(self, losses: torch.Tensor, **kwargs): 33 | losses = torch.log(losses) * self.task_weights 34 | return losses, dict(weights=self.task_weights) 35 | 36 | 37 | class STL(WeightMethod): 38 | """Single task learning""" 39 | 40 | def __init__(self, n_tasks, device: torch.device, main_task): 41 | super().__init__(n_tasks, device=device) 42 | self.main_task = main_task 43 | self.weights = torch.zeros(n_tasks, device=device) 44 | self.weights[main_task] = 1.0 45 | 46 | def get_weighted_loss(self, losses: torch.Tensor, **kwargs): 47 | assert len(losses) == self.n_tasks 48 | loss = losses[self.main_task] 49 | 50 | return loss, dict(weights=self.weights) 51 | 52 | def get_weighted_losses(self, losses: torch.Tensor, **kwargs): 53 | losses = losses * self.weights 54 | return losses, dict(weights=self.weights) 55 | 56 | 57 | class Uncertainty(WeightMethod): 58 | """Implementation of `Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics` 59 | Source: https://github.com/yaringal/multi-task-learning-example/blob/master/multi-task-learning-example-pytorch.ipynb 60 | """ 61 | 62 | def __init__(self, n_tasks, device: torch.device): 63 | super().__init__(n_tasks, device=device) 64 | self.logsigma = torch.tensor([0.0] * n_tasks, device=device, requires_grad=True) 65 | 66 | def get_weighted_loss(self, losses: torch.Tensor, **kwargs): 67 | loss = sum(losses / (2 * self.logsigma.exp()) + self.logsigma / 2) 68 | return loss, dict(weights=torch.exp(-self.logsigma)) # NOTE: not exactly task weights 69 | 70 | def get_weighted_losses(self, losses: torch.Tensor, **kwargs): 71 | losses = losses / (2 * self.logsigma.exp()) + self.logsigma / 2 72 | return losses, dict(weights=torch.exp(-self.logsigma)) 73 | 74 | def parameters(self) -> List[torch.Tensor]: 75 | return [self.logsigma] 76 | 77 | 78 | class UncertaintyLog(WeightMethod): 79 | """UW + SI""" 80 | 81 | def __init__(self, n_tasks, device: torch.device): 82 | super().__init__(n_tasks, device=device) 83 | self.logsigma = torch.tensor([0.0] * n_tasks, device=device, requires_grad=True) 84 | 85 | def get_weighted_loss(self, losses: torch.Tensor, **kwargs): 86 | loss = sum(torch.log(losses) / (2 * self.logsigma.exp()) + self.logsigma / 2) 87 | return loss, dict(weights=torch.exp(-self.logsigma)) # NOTE: not exactly task weights 88 | 89 | def get_weighted_losses(self, losses: torch.Tensor, **kwargs): 90 | losses = torch.log(losses) / (2 * self.logsigma.exp()) + self.logsigma / 2 91 | return losses, dict(weights=torch.exp(-self.logsigma)) # NOTE: not exactly task weights 92 | 93 | def parameters(self) -> List[torch.Tensor]: 94 | return [self.logsigma] 95 | 96 | 97 | class RLW(WeightMethod): 98 | """Random loss weighting: https://arxiv.org/pdf/2111.10603.pdf""" 99 | 100 | def __init__(self, n_tasks, device: torch.device): 101 | super().__init__(n_tasks, device=device) 102 | 103 | def get_weighted_loss(self, losses: torch.Tensor, **kwargs): 104 | assert len(losses) == self.n_tasks 105 | weight = (self.n_tasks * F.softmax(torch.randn(self.n_tasks), dim=-1)).to(self.device) 106 | loss = torch.sum(losses * weight) 107 | 108 | return loss, dict(weights=weight) 109 | 110 | def get_weighted_losses(self, losses: torch.Tensor, **kwargs): 111 | weight = (self.n_tasks * F.softmax(torch.randn(self.n_tasks), dim=-1)).to(self.device) 112 | losses = losses * weight 113 | return losses, dict(weights=weight) 114 | 115 | 116 | class RLWLog(WeightMethod): 117 | """RLW + SI""" 118 | 119 | def __init__(self, n_tasks, device: torch.device): 120 | super().__init__(n_tasks, device=device) 121 | 122 | def get_weighted_loss(self, losses: torch.Tensor, **kwargs): 123 | assert len(losses) == self.n_tasks 124 | weight = (self.n_tasks * F.softmax(torch.randn(self.n_tasks), dim=-1)).to(self.device) 125 | loss = torch.sum(torch.log(losses) * weight) 126 | 127 | return loss, dict(weights=weight) 128 | 129 | def get_weighted_losses(self, losses: torch.Tensor, **kwargs): 130 | weight = (self.n_tasks * F.softmax(torch.randn(self.n_tasks), dim=-1)).to(self.device) 131 | losses = torch.log(losses) * weight 132 | return losses, dict(weights=weight) 133 | 134 | 135 | class DynamicWeightAverage(WeightMethod): 136 | """Dynamic Weight Average from `End-to-End Multi-Task Learning with Attention`. 137 | Modification of: https://github.com/lorenmt/mtan/blob/master/im2im_pred/model_segnet_split.py#L242 138 | """ 139 | 140 | def __init__( 141 | self, n_tasks, device: torch.device, iteration_window: int = 25, temp=2.0 142 | ): 143 | """ 144 | 145 | Parameters 146 | ---------- 147 | n_tasks : 148 | iteration_window : 'iteration' loss is averaged over the last 'iteration_window' losses 149 | temp : 150 | """ 151 | super().__init__(n_tasks, device=device) 152 | self.iteration_window = iteration_window 153 | self.temp = temp 154 | self.running_iterations = 0 155 | self.costs = np.ones((iteration_window * 2, n_tasks), dtype=np.float32) 156 | self.weights = np.ones(n_tasks, dtype=np.float32) 157 | 158 | def get_weighted_loss(self, losses, **kwargs): 159 | 160 | cost = losses.detach().cpu().numpy() 161 | 162 | # update costs - fifo 163 | self.costs[:-1, :] = self.costs[1:, :] 164 | self.costs[-1, :] = cost 165 | 166 | if self.running_iterations > self.iteration_window: 167 | ws = self.costs[self.iteration_window:, :].mean(0) / self.costs[: self.iteration_window, :].mean(0) 168 | self.weights = (self.n_tasks * np.exp(ws / self.temp)) / (np.exp(ws / self.temp)).sum() 169 | 170 | task_weights = torch.from_numpy(self.weights.astype(np.float32)).to(losses.device) 171 | loss = sum(task_weights * losses) 172 | 173 | self.running_iterations += 1 174 | 175 | return loss, dict(weights=task_weights) 176 | 177 | def get_weighted_losses(self, losses: torch.Tensor, **kwargs): 178 | cost = losses.detach().cpu().numpy() 179 | 180 | # update costs - fifo 181 | self.costs[:-1, :] = self.costs[1:, :] 182 | self.costs[-1, :] = cost 183 | 184 | if self.running_iterations > self.iteration_window: 185 | ws = self.costs[self.iteration_window:, :].mean(0) / self.costs[: self.iteration_window, :].mean(0) 186 | self.weights = (self.n_tasks * np.exp(ws / self.temp)) / (np.exp(ws / self.temp)).sum() 187 | 188 | task_weights = torch.from_numpy(self.weights.astype(np.float32)).to(losses.device) 189 | losses = task_weights * losses 190 | 191 | self.running_iterations += 1 192 | 193 | return losses, dict(weights=task_weights) 194 | 195 | 196 | class DynamicWeightAverageLog(WeightMethod): 197 | """DWA + SI""" 198 | def __init__( 199 | self, n_tasks, device: torch.device, iteration_window: int = 25, temp=2.0 200 | ): 201 | """ 202 | 203 | Parameters 204 | ---------- 205 | n_tasks : 206 | iteration_window : 'iteration' loss is averaged over the last 'iteration_window' losses 207 | temp : 208 | """ 209 | super().__init__(n_tasks, device=device) 210 | self.iteration_window = iteration_window 211 | self.temp = temp 212 | self.running_iterations = 0 213 | self.costs = np.ones((iteration_window * 2, n_tasks), dtype=np.float32) 214 | self.weights = np.ones(n_tasks, dtype=np.float32) 215 | 216 | def get_weighted_loss(self, losses, **kwargs): 217 | 218 | cost = losses.detach().cpu().numpy() 219 | 220 | # update costs - fifo 221 | self.costs[:-1, :] = self.costs[1:, :] 222 | self.costs[-1, :] = cost 223 | 224 | if self.running_iterations > self.iteration_window: 225 | ws = self.costs[self.iteration_window:, :].mean(0) / self.costs[: self.iteration_window, :].mean(0) 226 | self.weights = (self.n_tasks * np.exp(ws / self.temp)) / (np.exp(ws / self.temp)).sum() 227 | 228 | task_weights = torch.from_numpy(self.weights.astype(np.float32)).to(losses.device) 229 | loss = sum(task_weights * torch.log(losses)) 230 | 231 | self.running_iterations += 1 232 | 233 | return loss, dict(weights=task_weights) 234 | 235 | def get_weighted_losses(self, losses: torch.Tensor, **kwargs): 236 | cost = losses.detach().cpu().numpy() 237 | 238 | # update costs - fifo 239 | self.costs[:-1, :] = self.costs[1:, :] 240 | self.costs[-1, :] = cost 241 | 242 | if self.running_iterations > self.iteration_window: 243 | ws = self.costs[self.iteration_window:, :].mean(0) / self.costs[: self.iteration_window, :].mean(0) 244 | self.weights = (self.n_tasks * np.exp(ws / self.temp)) / (np.exp(ws / self.temp)).sum() 245 | 246 | task_weights = torch.from_numpy(self.weights.astype(np.float32)).to(losses.device) 247 | losses = task_weights * torch.log(losses) 248 | 249 | self.running_iterations += 1 250 | 251 | return losses, dict(weights=task_weights) 252 | 253 | 254 | class ImprovableGapBalancing_v1(WeightMethod): 255 | def __init__(self, n_tasks, device: torch.device): 256 | super().__init__(n_tasks, device=device) 257 | self.base_losses = None 258 | self.weights = torch.ones(n_tasks).to(device) 259 | 260 | def get_weighted_loss(self, losses: torch.Tensor, **kwargs): 261 | if self.base_losses is not None: 262 | self.weights = self.n_tasks * F.softmax(losses.detach() / self.base_losses, dim=-1).to(losses.device) 263 | loss = sum(self.weights * torch.log(losses)) 264 | return loss, dict(weights=self.weights) # NOTE: not exactly task weights 265 | 266 | def get_weighted_losses(self, losses: torch.Tensor, **kwargs): 267 | if self.base_losses is not None: 268 | self.weights = self.n_tasks * F.softmax(losses.detach() / self.base_losses, dim=-1).to(losses.device) 269 | losses = self.weights * torch.log(losses) 270 | return losses, dict(weights=self.weights) # NOTE: not exactly task weights 271 | 272 | 273 | class ImprovableGapBalancing_v2(WeightMethod): 274 | def __init__(self, n_tasks, device: torch.device, sac_lr=3e-4, buffer_size=1e4): 275 | super().__init__(n_tasks, device=device) 276 | self.base_losses = None 277 | self.weights = torch.ones(n_tasks).to(device) 278 | 279 | self.sac_model = SAC_Agent(state_dim=n_tasks, action_dim=n_tasks, a_lr=sac_lr, c_lr=sac_lr, batch_size=256, device=device) 280 | self.replay_buffer = RandomBuffer(state_dim=n_tasks, action_dim=n_tasks, max_size=buffer_size, device=device) 281 | self.custom_step = 0 282 | self.bool_custom_step = 0 283 | self.batch_loss = torch.zeros([2, n_tasks]).to(device) 284 | self.batch_rl_weight = torch.zeros([2, n_tasks]).to(device) 285 | self.train_batch = None 286 | self.start_epoch = 5 287 | self.update_after = 3 288 | self.update_every = 50 289 | self.reward_scale = 1.0 290 | 291 | def get_weighted_loss(self, losses: torch.Tensor, **kwargs): 292 | self.batch_loss[self.bool_custom_step] = losses.detach() 293 | 294 | # write random buffer 295 | if self.base_losses is not None: 296 | loss_de = (self.batch_loss[(self.bool_custom_step - 1) % 2] - self.batch_loss[self.bool_custom_step]) 297 | loss_de = loss_de / self.base_losses 298 | reward = min(loss_de) 299 | reward *= self.reward_scale 300 | self.replay_buffer.add(self.batch_loss[(self.bool_custom_step - 1) % 2], 301 | self.batch_rl_weight[(self.bool_custom_step - 1) % 2], 302 | reward, 303 | self.batch_loss[self.bool_custom_step]) 304 | # train sac_model 305 | if self.custom_step >= self.update_after * self.train_batch and self.custom_step % self.update_every == 0: 306 | k = 1 + self.replay_buffer.size / self.replay_buffer.max_size 307 | for i in range(int(self.update_every * (k / 2))): 308 | self.sac_model.train(self.replay_buffer, k) 309 | # change weights 310 | if self.custom_step < self.start_epoch * self.train_batch: 311 | self.weights = self.n_tasks * F.softmax(torch.randn(self.n_tasks), dim=-1).to(self.device) 312 | else: 313 | self.weights = self.sac_model.select_action(self.batch_loss[self.bool_custom_step], 314 | deterministic=False, 315 | with_logprob=False) 316 | self.batch_rl_weight[self.bool_custom_step] = self.weights.detach() 317 | 318 | loss = sum(self.weights * torch.log(losses)) 319 | 320 | self.custom_step += 1 321 | self.bool_custom_step = (self.bool_custom_step + 1) % 2 322 | 323 | return loss, dict(weights=self.weights) # NOTE: not exactly task weights 324 | 325 | def get_weighted_losses(self, losses: torch.Tensor, **kwargs): 326 | self.batch_loss[self.bool_custom_step] = losses.detach() 327 | 328 | # write random buffer 329 | if self.base_losses is not None: 330 | loss_de = (self.batch_loss[(self.bool_custom_step - 1) % 2] - self.batch_loss[self.bool_custom_step]) 331 | loss_de = loss_de / self.base_losses 332 | reward = min(loss_de) 333 | # reward = sum(loss_de) / self.n_tasks 334 | reward *= self.reward_scale 335 | self.replay_buffer.add(self.batch_loss[(self.bool_custom_step - 1) % 2], 336 | self.batch_rl_weight[(self.bool_custom_step - 1) % 2], 337 | reward, 338 | self.batch_loss[self.bool_custom_step]) 339 | # train sac_model 340 | if self.custom_step >= self.update_after * self.train_batch and self.custom_step % self.update_every == 0: 341 | k = 1 + self.replay_buffer.size / self.replay_buffer.max_size 342 | for i in range(int(self.update_every * (k / 2))): 343 | self.sac_model.train(self.replay_buffer, k) 344 | # give weights 345 | if self.custom_step < self.start_epoch * self.train_batch: 346 | self.weights = self.n_tasks * F.softmax(torch.randn(self.n_tasks), dim=-1).to(self.device) 347 | else: 348 | self.weights = self.sac_model.select_action(self.batch_loss[self.bool_custom_step], 349 | deterministic=False, 350 | with_logprob=False) 351 | self.batch_rl_weight[self.bool_custom_step] = self.weights.detach() 352 | 353 | losses = self.weights * torch.log(losses) 354 | 355 | self.custom_step += 1 356 | self.bool_custom_step = (self.bool_custom_step + 1) % 2 357 | 358 | return losses, dict(weights=self.weights) # NOTE: not exactly task weights 359 | 360 | 361 | class LossWeightMethods: 362 | def __init__(self, method: str, n_tasks: int, device: torch.device, **kwargs): 363 | """ 364 | :param method: 365 | """ 366 | assert method in list(LOSS_METHODS.keys()), f"unknown method {method}." 367 | 368 | self.method = LOSS_METHODS[method](n_tasks=n_tasks, device=device, **kwargs) 369 | 370 | def get_weighted_loss(self, losses, **kwargs): 371 | return self.method.get_weighted_loss(losses, **kwargs) 372 | 373 | def get_weighted_losses(self, losses: torch.Tensor, **kwargs): 374 | return self.method.get_weighted_losses(losses, **kwargs) 375 | 376 | def backward( 377 | self, losses, **kwargs 378 | ) -> Tuple[Union[torch.Tensor, None], Union[Dict, None]]: 379 | return self.method.backward(losses, **kwargs) 380 | 381 | def __ceil__(self, losses, **kwargs): 382 | return self.backward(losses, **kwargs) 383 | 384 | def parameters(self): 385 | return self.method.parameters() 386 | 387 | 388 | LOSS_METHODS = dict( 389 | ls=LinearScalarization, 390 | stl=STL, 391 | si=ScaleInvariantLinearScalarization, 392 | uw=Uncertainty, 393 | uwlog=UncertaintyLog, 394 | rlw=RLW, 395 | rlwlog=RLWLog, 396 | dwa=DynamicWeightAverage, 397 | dwalog=DynamicWeightAverageLog, 398 | igbv1=ImprovableGapBalancing_v1, 399 | igbv2=ImprovableGapBalancing_v2, 400 | ) 401 | -------------------------------------------------------------------------------- /methods/gradient_weight_methods.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import random 3 | from typing import Dict, List, Tuple, Union 4 | 5 | import cvxpy as cp 6 | import numpy as np 7 | import torch 8 | from scipy.optimize import minimize 9 | 10 | from methods.min_norm_solvers import MinNormSolver, gradient_normalizers 11 | from methods.weight_method import WeightMethod, LinearScalarization 12 | 13 | 14 | class NashMTL(WeightMethod): 15 | def __init__( 16 | self, 17 | n_tasks: int, 18 | device: torch.device, 19 | max_norm: float = 1.0, 20 | update_weights_every: int = 1, 21 | optim_niter=20, 22 | ): 23 | super(NashMTL, self).__init__(n_tasks=n_tasks, device=device) 24 | 25 | self.optim_niter = optim_niter 26 | self.update_weights_every = update_weights_every 27 | self.max_norm = max_norm 28 | 29 | self.prvs_alpha_param = None 30 | self.normalization_factor = np.ones((1,)) 31 | self.init_gtg = self.init_gtg = np.eye(self.n_tasks) 32 | self.step = 0.0 33 | self.prvs_alpha = np.ones(self.n_tasks, dtype=np.float32) 34 | 35 | def _stop_criteria(self, gtg, alpha_t): 36 | return ((self.alpha_param.value is None) 37 | or (np.linalg.norm(gtg @ alpha_t - 1 / (alpha_t + 1e-10)) < 1e-3) 38 | or (np.linalg.norm(self.alpha_param.value - self.prvs_alpha_param.value) < 1e-6)) 39 | 40 | def solve_optimization(self, gtg: np.array): 41 | self.G_param.value = gtg 42 | self.normalization_factor_param.value = self.normalization_factor 43 | 44 | alpha_t = self.prvs_alpha 45 | for _ in range(self.optim_niter): 46 | self.alpha_param.value = alpha_t 47 | self.prvs_alpha_param.value = alpha_t 48 | 49 | try: 50 | self.prob.solve(solver=cp.ECOS, warm_start=True, max_iters=100) 51 | except: 52 | self.alpha_param.value = self.prvs_alpha_param.value 53 | 54 | if self._stop_criteria(gtg, alpha_t): 55 | break 56 | 57 | alpha_t = self.alpha_param.value 58 | 59 | if alpha_t is not None: 60 | self.prvs_alpha = alpha_t 61 | 62 | return self.prvs_alpha 63 | 64 | def _calc_phi_alpha_linearization(self): 65 | G_prvs_alpha = self.G_param @ self.prvs_alpha_param 66 | prvs_phi_tag = 1 / self.prvs_alpha_param + (1 / G_prvs_alpha) @ self.G_param 67 | phi_alpha = prvs_phi_tag @ (self.alpha_param - self.prvs_alpha_param) 68 | return phi_alpha 69 | 70 | def _init_optim_problem(self): 71 | self.alpha_param = cp.Variable(shape=(self.n_tasks,), nonneg=True) 72 | self.prvs_alpha_param = cp.Parameter(shape=(self.n_tasks,), value=self.prvs_alpha) 73 | self.G_param = cp.Parameter(shape=(self.n_tasks, self.n_tasks), value=self.init_gtg) 74 | self.normalization_factor_param = cp.Parameter(shape=(1,), value=np.array([1.0])) 75 | 76 | self.phi_alpha = self._calc_phi_alpha_linearization() 77 | 78 | G_alpha = self.G_param @ self.alpha_param 79 | constraint = [] 80 | for i in range(self.n_tasks): 81 | constraint.append(-cp.log(self.alpha_param[i] * self.normalization_factor_param) - cp.log(G_alpha[i]) <= 0) 82 | obj = cp.Minimize(cp.sum(G_alpha) + self.phi_alpha / self.normalization_factor_param) 83 | self.prob = cp.Problem(obj, constraint) 84 | 85 | def get_weighted_loss( 86 | self, 87 | losses, 88 | shared_parameters, 89 | **kwargs, 90 | ): 91 | """ 92 | 93 | Parameters 94 | ---------- 95 | losses : 96 | shared_parameters : shared parameters 97 | kwargs : 98 | 99 | Returns 100 | ------- 101 | 102 | """ 103 | 104 | extra_outputs = dict() 105 | if self.step == 0: 106 | self._init_optim_problem() 107 | 108 | if self.step % self.update_weights_every == 0: 109 | self.step += 1 110 | 111 | grads = {} 112 | for i, loss in enumerate(losses): 113 | g = list(torch.autograd.grad(loss, shared_parameters, retain_graph=True)) 114 | grad = torch.cat([torch.flatten(grad) for grad in g]) 115 | grads[i] = grad 116 | 117 | G = torch.stack(tuple(v for v in grads.values())) 118 | GTG = torch.mm(G, G.t()) 119 | 120 | self.normalization_factor = (torch.norm(GTG).detach().cpu().numpy().reshape((1,))) 121 | GTG = GTG / self.normalization_factor.item() 122 | alpha = self.solve_optimization(GTG.cpu().detach().numpy()) 123 | alpha = torch.from_numpy(alpha) 124 | 125 | else: 126 | self.step += 1 127 | alpha = self.prvs_alpha 128 | 129 | weighted_loss = sum([losses[i] * alpha[i] for i in range(len(alpha))]) 130 | extra_outputs["weights"] = alpha 131 | return weighted_loss, extra_outputs 132 | 133 | def backward( 134 | self, 135 | losses: torch.Tensor, 136 | shared_parameters: Union[ 137 | List[torch.nn.parameter.Parameter], torch.Tensor 138 | ] = None, 139 | task_specific_parameters: Union[ 140 | List[torch.nn.parameter.Parameter], torch.Tensor 141 | ] = None, 142 | last_shared_parameters: Union[ 143 | List[torch.nn.parameter.Parameter], torch.Tensor 144 | ] = None, 145 | representation: Union[List[torch.nn.parameter.Parameter], torch.Tensor] = None, 146 | **kwargs, 147 | ) -> Tuple[Union[torch.Tensor, None], Union[Dict, None]]: 148 | loss, extra_outputs = self.get_weighted_loss( 149 | losses=losses, 150 | shared_parameters=shared_parameters, 151 | **kwargs, 152 | ) 153 | loss.backward() 154 | 155 | # make sure the solution for shared params has norm <= self.eps 156 | if self.max_norm > 0: 157 | torch.nn.utils.clip_grad_norm_(shared_parameters, self.max_norm) 158 | 159 | return loss, extra_outputs 160 | 161 | 162 | class MGDA(WeightMethod): 163 | """Based on the official implementation of: Multi-Task Learning as Multi-Objective Optimization 164 | Ozan Sener, Vladlen Koltun 165 | Neural Information Processing Systems (NeurIPS) 2018 166 | https://github.com/intel-isl/MultiObjectiveOptimization 167 | 168 | """ 169 | 170 | def __init__( 171 | self, n_tasks, device: torch.device, params="shared", normalization="none" 172 | ): 173 | super().__init__(n_tasks, device=device) 174 | self.solver = MinNormSolver() 175 | assert params in ["shared", "last", "rep"] 176 | self.params = params 177 | assert normalization in ["norm", "loss", "loss+", "none"] 178 | self.normalization = normalization 179 | 180 | @staticmethod 181 | def _flattening(grad): 182 | return torch.cat(tuple(g.reshape(-1,) for i, g in enumerate(grad)), dim=0) 183 | 184 | def get_weighted_loss( 185 | self, 186 | losses, 187 | shared_parameters=None, 188 | last_shared_parameters=None, 189 | representation=None, 190 | **kwargs, 191 | ): 192 | """ 193 | 194 | Parameters 195 | ---------- 196 | losses : 197 | shared_parameters : 198 | last_shared_parameters : 199 | representation : 200 | kwargs : 201 | 202 | Returns 203 | ------- 204 | 205 | """ 206 | # Our code 207 | grads = {} 208 | params = dict(rep=representation, shared=shared_parameters, last=last_shared_parameters)[self.params] 209 | for i, loss in enumerate(losses): 210 | g = list(torch.autograd.grad(loss, params,retain_graph=True)) 211 | # Normalize all gradients, this is optional and not included in the paper. 212 | 213 | grads[i] = [torch.flatten(grad) for grad in g] 214 | 215 | gn = gradient_normalizers(grads, losses, self.normalization) 216 | for t in range(self.n_tasks): 217 | for gr_i in range(len(grads[t])): 218 | grads[t][gr_i] = grads[t][gr_i] / gn[t] 219 | 220 | sol, min_norm = self.solver.find_min_norm_element([grads[t] for t in range(len(grads))]) 221 | sol = sol * self.n_tasks # make sure it sums to self.n_tasks 222 | weighted_loss = sum([losses[i] * sol[i] for i in range(len(sol))]) 223 | 224 | return weighted_loss, dict(weights=torch.from_numpy(sol.astype(np.float32))) 225 | 226 | 227 | class PCGrad(WeightMethod): 228 | """Modification of: https://github.com/WeiChengTseng/Pytorch-PCGrad/blob/master/pcgrad.py 229 | 230 | @misc{Pytorch-PCGrad, 231 | author = {Wei-Cheng Tseng}, 232 | title = {WeiChengTseng/Pytorch-PCGrad}, 233 | url = {https://github.com/WeiChengTseng/Pytorch-PCGrad.git}, 234 | year = {2020} 235 | } 236 | 237 | """ 238 | 239 | def __init__(self, n_tasks: int, device: torch.device, reduction="sum"): 240 | super().__init__(n_tasks, device=device) 241 | assert reduction in ["mean", "sum"] 242 | self.reduction = reduction 243 | 244 | def get_weighted_loss( 245 | self, 246 | losses: torch.Tensor, 247 | shared_parameters: Union[List[torch.nn.parameter.Parameter], torch.Tensor] = None, 248 | task_specific_parameters: Union[List[torch.nn.parameter.Parameter], torch.Tensor] = None, 249 | **kwargs, 250 | ): 251 | raise NotImplementedError 252 | 253 | def _set_pc_grads(self, losses, shared_parameters, task_specific_parameters=None): 254 | # shared part 255 | shared_grads = [] 256 | for l in losses: 257 | shared_grads.append(torch.autograd.grad(l, shared_parameters, retain_graph=True)) 258 | 259 | if isinstance(shared_parameters, torch.Tensor): 260 | shared_parameters = [shared_parameters] 261 | non_conflict_shared_grads = self._project_conflicting(shared_grads) 262 | for p, g in zip(shared_parameters, non_conflict_shared_grads): 263 | p.grad = g 264 | 265 | # task specific part 266 | if task_specific_parameters is not None: 267 | task_specific_grads = torch.autograd.grad(losses.sum(), task_specific_parameters) 268 | if isinstance(task_specific_parameters, torch.Tensor): 269 | task_specific_parameters = [task_specific_parameters] 270 | for p, g in zip(task_specific_parameters, task_specific_grads): 271 | p.grad = g 272 | 273 | def _project_conflicting(self, grads: List[Tuple[torch.Tensor]]): 274 | pc_grad = copy.deepcopy(grads) 275 | for g_i in pc_grad: 276 | random.shuffle(grads) 277 | for g_j in grads: 278 | g_i_g_j = sum([torch.dot(torch.flatten(grad_i), torch.flatten(grad_j)) 279 | for grad_i, grad_j in zip(g_i, g_j)]) 280 | if g_i_g_j < 0: 281 | g_j_norm_square = (torch.norm(torch.cat([torch.flatten(g) for g in g_j])) ** 2) 282 | for grad_i, grad_j in zip(g_i, g_j): 283 | grad_i -= g_i_g_j * grad_j / g_j_norm_square 284 | 285 | merged_grad = [sum(g) for g in zip(*pc_grad)] 286 | if self.reduction == "mean": 287 | merged_grad = [g / self.n_tasks for g in merged_grad] 288 | 289 | return merged_grad 290 | 291 | def backward( 292 | self, 293 | losses: torch.Tensor, 294 | parameters: Union[List[torch.nn.parameter.Parameter], torch.Tensor] = None, 295 | shared_parameters: Union[List[torch.nn.parameter.Parameter], torch.Tensor] = None, 296 | task_specific_parameters: Union[List[torch.nn.parameter.Parameter], torch.Tensor] = None, 297 | **kwargs, 298 | ): 299 | self._set_pc_grads(losses, shared_parameters, task_specific_parameters) 300 | return None, {} # NOTE: to align with all other weight methods 301 | 302 | 303 | class CAGrad(WeightMethod): 304 | def __init__(self, n_tasks, device: torch.device, c=0.4): 305 | super().__init__(n_tasks, device=device) 306 | self.c = c 307 | 308 | def get_weighted_loss( 309 | self, 310 | losses, 311 | shared_parameters, 312 | **kwargs, 313 | ): 314 | """ 315 | Parameters 316 | ---------- 317 | losses : 318 | shared_parameters : shared parameters 319 | kwargs : 320 | Returns 321 | ------- 322 | """ 323 | # NOTE: we allow only shared params for now. Need to see paper for other options. 324 | grad_dims = [] 325 | for param in shared_parameters: 326 | grad_dims.append(param.data.numel()) 327 | grads = torch.Tensor(sum(grad_dims), self.n_tasks).to(self.device) 328 | 329 | for i in range(self.n_tasks): 330 | if i < self.n_tasks: 331 | losses[i].backward(retain_graph=True) 332 | else: 333 | losses[i].backward() 334 | self.grad2vec(shared_parameters, grads, grad_dims, i) 335 | # multi_task_model.zero_grad_shared_modules() 336 | for p in shared_parameters: 337 | p.grad = None 338 | 339 | g = self.cagrad(grads, alpha=self.c, rescale=1) 340 | self.overwrite_grad(shared_parameters, g, grad_dims) 341 | 342 | def cagrad(self, grads, alpha=0.5, rescale=1): 343 | GG = grads.t().mm(grads).cpu() # [num_tasks, num_tasks] 344 | g0_norm = (GG.mean() + 1e-8).sqrt() # norm of the average gradient 345 | 346 | x_start = np.ones(self.n_tasks) / self.n_tasks 347 | bnds = tuple((0, 1) for x in x_start) 348 | cons = {"type": "eq", "fun": lambda x: 1 - sum(x)} 349 | A = GG.numpy() 350 | b = x_start.copy() 351 | c = (alpha * g0_norm + 1e-8).item() 352 | 353 | def objfn(x): 354 | return (x.reshape(1, self.n_tasks).dot(A).dot(b.reshape(self.n_tasks, 1)) 355 | + c * np.sqrt(x.reshape(1, self.n_tasks).dot(A).dot(x.reshape(self.n_tasks, 1)) + 1e-8)).sum() 356 | 357 | res = minimize(objfn, x_start, bounds=bnds, constraints=cons) 358 | w_cpu = res.x 359 | ww = torch.Tensor(w_cpu).to(grads.device) 360 | gw = (grads * ww.view(1, -1)).sum(1) 361 | gw_norm = gw.norm() 362 | lmbda = c / (gw_norm + 1e-8) 363 | g = grads.mean(1) + lmbda * gw 364 | if rescale == 0: 365 | return g 366 | elif rescale == 1: 367 | return g / (1 + alpha ** 2) 368 | else: 369 | return g / (1 + alpha) 370 | 371 | @staticmethod 372 | def grad2vec(shared_params, grads, grad_dims, task): 373 | # store the gradients 374 | grads[:, task].fill_(0.0) 375 | cnt = 0 376 | # for mm in m.shared_modules(): 377 | # for p in mm.parameters(): 378 | 379 | for param in shared_params: 380 | grad = param.grad 381 | if grad is not None: 382 | grad_cur = grad.data.detach().clone() 383 | beg = 0 if cnt == 0 else sum(grad_dims[:cnt]) 384 | en = sum(grad_dims[: cnt + 1]) 385 | grads[beg:en, task].copy_(grad_cur.data.view(-1)) 386 | cnt += 1 387 | 388 | def overwrite_grad(self, shared_parameters, newgrad, grad_dims): 389 | newgrad = newgrad * self.n_tasks # to match the sum loss 390 | cnt = 0 391 | 392 | # for mm in m.shared_modules(): 393 | # for param in mm.parameters(): 394 | for param in shared_parameters: 395 | beg = 0 if cnt == 0 else sum(grad_dims[:cnt]) 396 | en = sum(grad_dims[: cnt + 1]) 397 | this_grad = newgrad[beg:en].contiguous().view(param.data.size()) 398 | param.grad = this_grad.data.clone() 399 | cnt += 1 400 | 401 | def backward( 402 | self, 403 | losses: torch.Tensor, 404 | parameters: Union[List[torch.nn.parameter.Parameter], torch.Tensor] = None, 405 | shared_parameters: Union[List[torch.nn.parameter.Parameter], torch.Tensor] = None, 406 | task_specific_parameters: Union[List[torch.nn.parameter.Parameter], torch.Tensor] = None, 407 | **kwargs, 408 | ): 409 | self.get_weighted_loss(losses, shared_parameters) 410 | return None, {} # NOTE: to align with all other weight methods 411 | 412 | 413 | class IMTLG(WeightMethod): 414 | """TOWARDS IMPARTIAL MULTI-TASK LEARNING: https://openreview.net/pdf?id=IMPnRXEWpvr""" 415 | 416 | def __init__(self, n_tasks, device: torch.device): 417 | super().__init__(n_tasks, device=device) 418 | 419 | def get_weighted_loss( 420 | self, 421 | losses, 422 | shared_parameters, 423 | **kwargs, 424 | ): 425 | grads = {} 426 | norm_grads = {} 427 | 428 | for i, loss in enumerate(losses): 429 | g = list(torch.autograd.grad(loss, shared_parameters, retain_graph=True)) 430 | grad = torch.cat([torch.flatten(grad) for grad in g]) 431 | norm_term = torch.norm(grad) 432 | 433 | grads[i] = grad 434 | norm_grads[i] = grad / norm_term 435 | 436 | G = torch.stack(tuple(v for v in grads.values())) 437 | D = ( 438 | G[ 439 | 0, 440 | ] 441 | - G[ 442 | 1:, 443 | ] 444 | ) 445 | 446 | U = torch.stack(tuple(v for v in norm_grads.values())) 447 | U = ( 448 | U[ 449 | 0, 450 | ] 451 | - U[ 452 | 1:, 453 | ] 454 | ) 455 | first_element = torch.matmul( 456 | G[ 457 | 0, 458 | ], 459 | U.t(), 460 | ) 461 | try: 462 | second_element = torch.inverse(torch.matmul(D, U.t())) 463 | except: 464 | # workaround for cases where matrix is singular 465 | second_element = torch.inverse( 466 | torch.eye(self.n_tasks - 1, device=self.device) * 1e-8 467 | + torch.matmul(D, U.t()) 468 | ) 469 | 470 | alpha_ = torch.matmul(first_element, second_element) 471 | alpha = torch.cat( 472 | (torch.tensor(1 - alpha_.sum(), device=self.device).unsqueeze(-1), alpha_) 473 | ) 474 | 475 | loss = torch.sum(losses * alpha) 476 | 477 | return loss, dict(weights=alpha) 478 | 479 | 480 | class GradientWeightMethods: 481 | def __init__(self, method: str, n_tasks: int, device: torch.device, **kwargs): 482 | """ 483 | :param method: 484 | """ 485 | assert method in list(GRADIENT_METHODS.keys()), f"unknown method {method}." 486 | 487 | self.method = GRADIENT_METHODS[method](n_tasks=n_tasks, device=device, **kwargs) 488 | 489 | def get_weighted_loss(self, losses, **kwargs): 490 | return self.method.get_weighted_loss(losses, **kwargs) 491 | 492 | def backward( 493 | self, losses, **kwargs 494 | ) -> Tuple[Union[torch.Tensor, None], Union[Dict, None]]: 495 | return self.method.backward(losses, **kwargs) 496 | 497 | def __ceil__(self, losses, **kwargs): 498 | return self.backward(losses, **kwargs) 499 | 500 | def parameters(self): 501 | return self.method.parameters() 502 | 503 | 504 | GRADIENT_METHODS = dict( 505 | ls=LinearScalarization, 506 | pcgrad=PCGrad, 507 | mgda=MGDA, 508 | cagrad=CAGrad, 509 | nashmtl=NashMTL, 510 | imtl=IMTLG, 511 | ) 512 | -------------------------------------------------------------------------------- /experiments/nyuv2/models.py: -------------------------------------------------------------------------------- 1 | from typing import Iterator 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class _SegNet(nn.Module): 9 | """SegNet MTAN""" 10 | def __init__(self): 11 | super(_SegNet, self).__init__() 12 | # initialise network parameters 13 | filter = [64, 128, 256, 512, 512] 14 | self.class_nb = 13 15 | 16 | # define encoder decoder layers 17 | self.encoder_block = nn.ModuleList([self.conv_layer([3, filter[0]])]) 18 | self.decoder_block = nn.ModuleList([self.conv_layer([filter[0], filter[0]])]) 19 | for i in range(4): 20 | self.encoder_block.append(self.conv_layer([filter[i], filter[i + 1]])) 21 | self.decoder_block.append(self.conv_layer([filter[i + 1], filter[i]])) 22 | 23 | # define convolution layer 24 | self.conv_block_enc = nn.ModuleList([self.conv_layer([filter[0], filter[0]])]) 25 | self.conv_block_dec = nn.ModuleList([self.conv_layer([filter[0], filter[0]])]) 26 | for i in range(4): 27 | if i == 0: 28 | self.conv_block_enc.append( 29 | self.conv_layer([filter[i + 1], filter[i + 1]]) 30 | ) 31 | self.conv_block_dec.append(self.conv_layer([filter[i], filter[i]])) 32 | else: 33 | self.conv_block_enc.append( 34 | nn.Sequential( 35 | self.conv_layer([filter[i + 1], filter[i + 1]]), 36 | self.conv_layer([filter[i + 1], filter[i + 1]]), 37 | ) 38 | ) 39 | self.conv_block_dec.append( 40 | nn.Sequential( 41 | self.conv_layer([filter[i], filter[i]]), 42 | self.conv_layer([filter[i], filter[i]]), 43 | ) 44 | ) 45 | 46 | # define task attention layers 47 | self.encoder_att = nn.ModuleList( 48 | [nn.ModuleList([self.att_layer([filter[0], filter[0], filter[0]])])] 49 | ) 50 | self.decoder_att = nn.ModuleList( 51 | [nn.ModuleList([self.att_layer([2 * filter[0], filter[0], filter[0]])])] 52 | ) 53 | self.encoder_block_att = nn.ModuleList( 54 | [self.conv_layer([filter[0], filter[1]])] 55 | ) 56 | self.decoder_block_att = nn.ModuleList( 57 | [self.conv_layer([filter[0], filter[0]])] 58 | ) 59 | 60 | for j in range(3): 61 | if j < 2: 62 | self.encoder_att.append( 63 | nn.ModuleList([self.att_layer([filter[0], filter[0], filter[0]])]) 64 | ) 65 | self.decoder_att.append( 66 | nn.ModuleList( 67 | [self.att_layer([2 * filter[0], filter[0], filter[0]])] 68 | ) 69 | ) 70 | for i in range(4): 71 | self.encoder_att[j].append( 72 | self.att_layer([2 * filter[i + 1], filter[i + 1], filter[i + 1]]) 73 | ) 74 | self.decoder_att[j].append( 75 | self.att_layer([filter[i + 1] + filter[i], filter[i], filter[i]]) 76 | ) 77 | 78 | for i in range(4): 79 | if i < 3: 80 | self.encoder_block_att.append( 81 | self.conv_layer([filter[i + 1], filter[i + 2]]) 82 | ) 83 | self.decoder_block_att.append( 84 | self.conv_layer([filter[i + 1], filter[i]]) 85 | ) 86 | else: 87 | self.encoder_block_att.append( 88 | self.conv_layer([filter[i + 1], filter[i + 1]]) 89 | ) 90 | self.decoder_block_att.append( 91 | self.conv_layer([filter[i + 1], filter[i + 1]]) 92 | ) 93 | 94 | self.pred_task1 = self.conv_layer([filter[0], self.class_nb], pred=True) 95 | self.pred_task2 = self.conv_layer([filter[0], 1], pred=True) 96 | self.pred_task3 = self.conv_layer([filter[0], 3], pred=True) 97 | 98 | # define pooling and unpooling functions 99 | self.down_sampling = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True) 100 | self.up_sampling = nn.MaxUnpool2d(kernel_size=2, stride=2) 101 | 102 | for m in self.modules(): 103 | if isinstance(m, nn.Conv2d): 104 | nn.init.xavier_normal_(m.weight) 105 | nn.init.constant_(m.bias, 0) 106 | elif isinstance(m, nn.BatchNorm2d): 107 | nn.init.constant_(m.weight, 1) 108 | nn.init.constant_(m.bias, 0) 109 | elif isinstance(m, nn.Linear): 110 | nn.init.xavier_normal_(m.weight) 111 | nn.init.constant_(m.bias, 0) 112 | 113 | def shared_modules(self): 114 | return [ 115 | self.encoder_block, 116 | self.decoder_block, 117 | self.conv_block_enc, 118 | self.conv_block_dec, 119 | # self.encoder_att, self.decoder_att, 120 | self.encoder_block_att, 121 | self.decoder_block_att, 122 | self.down_sampling, 123 | self.up_sampling, 124 | ] 125 | 126 | def zero_grad_shared_modules(self): 127 | for mm in self.shared_modules(): 128 | mm.zero_grad() 129 | 130 | def conv_layer(self, channel, pred=False): 131 | if not pred: 132 | conv_block = nn.Sequential( 133 | nn.Conv2d( 134 | in_channels=channel[0], 135 | out_channels=channel[1], 136 | kernel_size=3, 137 | padding=1, 138 | ), 139 | nn.BatchNorm2d(num_features=channel[1]), 140 | nn.ReLU(inplace=True), 141 | ) 142 | else: 143 | conv_block = nn.Sequential( 144 | nn.Conv2d( 145 | in_channels=channel[0], 146 | out_channels=channel[0], 147 | kernel_size=3, 148 | padding=1, 149 | ), 150 | nn.Conv2d( 151 | in_channels=channel[0], 152 | out_channels=channel[1], 153 | kernel_size=1, 154 | padding=0, 155 | ), 156 | ) 157 | return conv_block 158 | 159 | def att_layer(self, channel): 160 | att_block = nn.Sequential( 161 | nn.Conv2d( 162 | in_channels=channel[0], 163 | out_channels=channel[1], 164 | kernel_size=1, 165 | padding=0, 166 | ), 167 | nn.BatchNorm2d(channel[1]), 168 | nn.ReLU(inplace=True), 169 | nn.Conv2d( 170 | in_channels=channel[1], 171 | out_channels=channel[2], 172 | kernel_size=1, 173 | padding=0, 174 | ), 175 | nn.BatchNorm2d(channel[2]), 176 | nn.Sigmoid(), 177 | ) 178 | return att_block 179 | 180 | def forward(self, x): 181 | g_encoder, g_decoder, g_maxpool, g_upsampl, indices = ( 182 | [0] * 5 for _ in range(5) 183 | ) 184 | for i in range(5): 185 | g_encoder[i], g_decoder[-i - 1] = ([0] * 2 for _ in range(2)) 186 | 187 | # define attention list for tasks 188 | atten_encoder, atten_decoder = ([0] * 3 for _ in range(2)) 189 | for i in range(3): 190 | atten_encoder[i], atten_decoder[i] = ([0] * 5 for _ in range(2)) 191 | for i in range(3): 192 | for j in range(5): 193 | atten_encoder[i][j], atten_decoder[i][j] = ([0] * 3 for _ in range(2)) 194 | 195 | # define global shared network 196 | for i in range(5): 197 | if i == 0: 198 | g_encoder[i][0] = self.encoder_block[i](x) 199 | g_encoder[i][1] = self.conv_block_enc[i](g_encoder[i][0]) 200 | g_maxpool[i], indices[i] = self.down_sampling(g_encoder[i][1]) 201 | else: 202 | g_encoder[i][0] = self.encoder_block[i](g_maxpool[i - 1]) 203 | g_encoder[i][1] = self.conv_block_enc[i](g_encoder[i][0]) 204 | g_maxpool[i], indices[i] = self.down_sampling(g_encoder[i][1]) 205 | 206 | for i in range(5): 207 | if i == 0: 208 | g_upsampl[i] = self.up_sampling(g_maxpool[-1], indices[-i - 1]) 209 | g_decoder[i][0] = self.decoder_block[-i - 1](g_upsampl[i]) 210 | g_decoder[i][1] = self.conv_block_dec[-i - 1](g_decoder[i][0]) 211 | else: 212 | g_upsampl[i] = self.up_sampling(g_decoder[i - 1][-1], indices[-i - 1]) 213 | g_decoder[i][0] = self.decoder_block[-i - 1](g_upsampl[i]) 214 | g_decoder[i][1] = self.conv_block_dec[-i - 1](g_decoder[i][0]) 215 | 216 | # define task dependent attention module 217 | for i in range(3): 218 | for j in range(5): 219 | if j == 0: 220 | atten_encoder[i][j][0] = self.encoder_att[i][j](g_encoder[j][0]) 221 | atten_encoder[i][j][1] = (atten_encoder[i][j][0]) * g_encoder[j][1] 222 | atten_encoder[i][j][2] = self.encoder_block_att[j]( 223 | atten_encoder[i][j][1] 224 | ) 225 | atten_encoder[i][j][2] = F.max_pool2d( 226 | atten_encoder[i][j][2], kernel_size=2, stride=2 227 | ) 228 | else: 229 | atten_encoder[i][j][0] = self.encoder_att[i][j]( 230 | torch.cat((g_encoder[j][0], atten_encoder[i][j - 1][2]), dim=1) 231 | ) 232 | atten_encoder[i][j][1] = (atten_encoder[i][j][0]) * g_encoder[j][1] 233 | atten_encoder[i][j][2] = self.encoder_block_att[j]( 234 | atten_encoder[i][j][1] 235 | ) 236 | atten_encoder[i][j][2] = F.max_pool2d( 237 | atten_encoder[i][j][2], kernel_size=2, stride=2 238 | ) 239 | 240 | for j in range(5): 241 | if j == 0: 242 | atten_decoder[i][j][0] = F.interpolate( 243 | atten_encoder[i][-1][-1], 244 | scale_factor=2, 245 | mode="bilinear", 246 | align_corners=True, 247 | ) 248 | atten_decoder[i][j][0] = self.decoder_block_att[-j - 1]( 249 | atten_decoder[i][j][0] 250 | ) 251 | atten_decoder[i][j][1] = self.decoder_att[i][-j - 1]( 252 | torch.cat((g_upsampl[j], atten_decoder[i][j][0]), dim=1) 253 | ) 254 | atten_decoder[i][j][2] = (atten_decoder[i][j][1]) * g_decoder[j][-1] 255 | else: 256 | atten_decoder[i][j][0] = F.interpolate( 257 | atten_decoder[i][j - 1][2], 258 | scale_factor=2, 259 | mode="bilinear", 260 | align_corners=True, 261 | ) 262 | atten_decoder[i][j][0] = self.decoder_block_att[-j - 1]( 263 | atten_decoder[i][j][0] 264 | ) 265 | atten_decoder[i][j][1] = self.decoder_att[i][-j - 1]( 266 | torch.cat((g_upsampl[j], atten_decoder[i][j][0]), dim=1) 267 | ) 268 | atten_decoder[i][j][2] = (atten_decoder[i][j][1]) * g_decoder[j][-1] 269 | 270 | # define task prediction layers 271 | t1_pred = F.log_softmax(self.pred_task1(atten_decoder[0][-1][-1]), dim=1) 272 | t2_pred = self.pred_task2(atten_decoder[1][-1][-1]) 273 | t3_pred = self.pred_task3(atten_decoder[2][-1][-1]) 274 | t3_pred = t3_pred / torch.norm(t3_pred, p=2, dim=1, keepdim=True) 275 | 276 | return ( 277 | [t1_pred, t2_pred, t3_pred], 278 | ( 279 | atten_decoder[0][-1][-1], 280 | atten_decoder[1][-1][-1], 281 | atten_decoder[2][-1][-1], 282 | ), 283 | ) 284 | 285 | 286 | class SegNetMtan(nn.Module): 287 | def __init__(self): 288 | super().__init__() 289 | self.segnet = _SegNet() 290 | 291 | def shared_parameters(self) -> Iterator[nn.parameter.Parameter]: 292 | return (p for n, p in self.segnet.named_parameters() if "pred" not in n) 293 | 294 | def task_specific_parameters(self) -> Iterator[nn.parameter.Parameter]: 295 | return (p for n, p in self.segnet.named_parameters() if "pred" in n) 296 | 297 | def last_shared_parameters(self) -> Iterator[nn.parameter.Parameter]: 298 | """Parameters of the last shared layer. 299 | Returns 300 | ------- 301 | """ 302 | return [] 303 | 304 | def forward(self, x, return_representation=False): 305 | if return_representation: 306 | return self.segnet(x) 307 | else: 308 | pred, rep = self.segnet(x) 309 | return pred 310 | 311 | 312 | class SegNetSplit(nn.Module): 313 | def __init__(self, model_type="standard"): 314 | super(SegNetSplit, self).__init__() 315 | # initialise network parameters 316 | assert model_type in ["standard", "wide", "deep"] 317 | self.model_type = model_type 318 | if self.model_type == "wide": 319 | filter = [64, 128, 256, 512, 1024] 320 | else: 321 | filter = [64, 128, 256, 512, 512] 322 | 323 | self.class_nb = 13 324 | 325 | # define encoder decoder layers 326 | self.encoder_block = nn.ModuleList([self.conv_layer([3, filter[0]])]) 327 | self.decoder_block = nn.ModuleList([self.conv_layer([filter[0], filter[0]])]) 328 | for i in range(4): 329 | self.encoder_block.append(self.conv_layer([filter[i], filter[i + 1]])) 330 | self.decoder_block.append(self.conv_layer([filter[i + 1], filter[i]])) 331 | 332 | # define convolution layer 333 | self.conv_block_enc = nn.ModuleList([self.conv_layer([filter[0], filter[0]])]) 334 | self.conv_block_dec = nn.ModuleList([self.conv_layer([filter[0], filter[0]])]) 335 | for i in range(4): 336 | if i == 0: 337 | self.conv_block_enc.append( 338 | self.conv_layer([filter[i + 1], filter[i + 1]]) 339 | ) 340 | self.conv_block_dec.append(self.conv_layer([filter[i], filter[i]])) 341 | else: 342 | self.conv_block_enc.append( 343 | nn.Sequential( 344 | self.conv_layer([filter[i + 1], filter[i + 1]]), 345 | self.conv_layer([filter[i + 1], filter[i + 1]]), 346 | ) 347 | ) 348 | self.conv_block_dec.append( 349 | nn.Sequential( 350 | self.conv_layer([filter[i], filter[i]]), 351 | self.conv_layer([filter[i], filter[i]]), 352 | ) 353 | ) 354 | 355 | # define task specific layers 356 | self.pred_task1 = nn.Sequential( 357 | nn.Conv2d( 358 | in_channels=filter[0], out_channels=filter[0], kernel_size=3, padding=1 359 | ), 360 | nn.Conv2d( 361 | in_channels=filter[0], 362 | out_channels=self.class_nb, 363 | kernel_size=1, 364 | padding=0, 365 | ), 366 | ) 367 | self.pred_task2 = nn.Sequential( 368 | nn.Conv2d( 369 | in_channels=filter[0], out_channels=filter[0], kernel_size=3, padding=1 370 | ), 371 | nn.Conv2d(in_channels=filter[0], out_channels=1, kernel_size=1, padding=0), 372 | ) 373 | self.pred_task3 = nn.Sequential( 374 | nn.Conv2d( 375 | in_channels=filter[0], out_channels=filter[0], kernel_size=3, padding=1 376 | ), 377 | nn.Conv2d(in_channels=filter[0], out_channels=3, kernel_size=1, padding=0), 378 | ) 379 | 380 | # define pooling and unpooling functions 381 | self.down_sampling = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True) 382 | self.up_sampling = nn.MaxUnpool2d(kernel_size=2, stride=2) 383 | 384 | for m in self.modules(): 385 | if isinstance(m, nn.Conv2d): 386 | nn.init.xavier_normal_(m.weight) 387 | nn.init.constant_(m.bias, 0) 388 | elif isinstance(m, nn.BatchNorm2d): 389 | nn.init.constant_(m.weight, 1) 390 | nn.init.constant_(m.bias, 0) 391 | elif isinstance(m, nn.Linear): 392 | nn.init.xavier_normal_(m.weight) 393 | nn.init.constant_(m.bias, 0) 394 | 395 | # define convolutional block 396 | def conv_layer(self, channel): 397 | if self.model_type == "deep": 398 | conv_block = nn.Sequential( 399 | nn.Conv2d( 400 | in_channels=channel[0], 401 | out_channels=channel[1], 402 | kernel_size=3, 403 | padding=1, 404 | ), 405 | nn.BatchNorm2d(num_features=channel[1]), 406 | nn.ReLU(inplace=True), 407 | nn.Conv2d( 408 | in_channels=channel[1], 409 | out_channels=channel[1], 410 | kernel_size=3, 411 | padding=1, 412 | ), 413 | nn.BatchNorm2d(num_features=channel[1]), 414 | nn.ReLU(inplace=True), 415 | ) 416 | else: 417 | conv_block = nn.Sequential( 418 | nn.Conv2d( 419 | in_channels=channel[0], 420 | out_channels=channel[1], 421 | kernel_size=3, 422 | padding=1, 423 | ), 424 | nn.BatchNorm2d(num_features=channel[1]), 425 | nn.ReLU(inplace=True), 426 | ) 427 | return conv_block 428 | 429 | def forward(self, x): 430 | g_encoder, g_decoder, g_maxpool, g_upsampl, indices = ( 431 | [0] * 5 for _ in range(5) 432 | ) 433 | for i in range(5): 434 | g_encoder[i], g_decoder[-i - 1] = ([0] * 2 for _ in range(2)) 435 | 436 | # global shared encoder-decoder network 437 | for i in range(5): 438 | if i == 0: 439 | g_encoder[i][0] = self.encoder_block[i](x) 440 | g_encoder[i][1] = self.conv_block_enc[i](g_encoder[i][0]) 441 | g_maxpool[i], indices[i] = self.down_sampling(g_encoder[i][1]) 442 | else: 443 | g_encoder[i][0] = self.encoder_block[i](g_maxpool[i - 1]) 444 | g_encoder[i][1] = self.conv_block_enc[i](g_encoder[i][0]) 445 | g_maxpool[i], indices[i] = self.down_sampling(g_encoder[i][1]) 446 | 447 | for i in range(5): 448 | if i == 0: 449 | g_upsampl[i] = self.up_sampling(g_maxpool[-1], indices[-i - 1]) 450 | g_decoder[i][0] = self.decoder_block[-i - 1](g_upsampl[i]) 451 | g_decoder[i][1] = self.conv_block_dec[-i - 1](g_decoder[i][0]) 452 | else: 453 | g_upsampl[i] = self.up_sampling(g_decoder[i - 1][-1], indices[-i - 1]) 454 | g_decoder[i][0] = self.decoder_block[-i - 1](g_upsampl[i]) 455 | g_decoder[i][1] = self.conv_block_dec[-i - 1](g_decoder[i][0]) 456 | 457 | # define task prediction layers 458 | t1_pred = F.log_softmax(self.pred_task1(g_decoder[i][1]), dim=1) 459 | t2_pred = self.pred_task2(g_decoder[i][1]) 460 | t3_pred = self.pred_task3(g_decoder[i][1]) 461 | t3_pred = t3_pred / torch.norm(t3_pred, p=2, dim=1, keepdim=True) 462 | 463 | return [t1_pred, t2_pred, t3_pred], g_decoder[i][ 464 | 1 465 | ] # NOTE: last element is representation 466 | 467 | 468 | class SegNet(nn.Module): 469 | def __init__(self): 470 | super().__init__() 471 | self.segnet = SegNetSplit() 472 | 473 | def shared_parameters(self) -> Iterator[nn.parameter.Parameter]: 474 | return (p for n, p in self.segnet.named_parameters() if "pred" not in n) 475 | 476 | def task_specific_parameters(self) -> Iterator[nn.parameter.Parameter]: 477 | return (p for n, p in self.segnet.named_parameters() if "pred" in n) 478 | 479 | def last_shared_parameters(self) -> Iterator[nn.parameter.Parameter]: 480 | """Parameters of the last shared layer. 481 | Returns 482 | ------- 483 | """ 484 | return self.segnet.conv_block_dec[-5].parameters() 485 | 486 | def forward(self, x, return_representation=False): 487 | if return_representation: 488 | return self.segnet(x) 489 | else: 490 | pred, rep = self.segnet(x) 491 | return pred 492 | --------------------------------------------------------------------------------